io-actions #1
|
@ -2,6 +2,7 @@ import asyncio
|
|||
|
||||
from falyx import Action, ActionGroup, ChainedAction
|
||||
|
||||
|
||||
# Actions can be defined as synchronous functions
|
||||
# Falyx will automatically convert them to async functions
|
||||
def hello() -> None:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
import random
|
||||
|
||||
from falyx import Falyx, Action, ChainedAction
|
||||
from falyx import Action, ChainedAction, Falyx
|
||||
from falyx.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
|
|
|
@ -8,7 +8,7 @@ import logging
|
|||
|
||||
from .action import Action, ActionGroup, ChainedAction, ProcessAction
|
||||
from .command import Command
|
||||
from .context import ExecutionContext, ResultsContext
|
||||
from .context import ExecutionContext, SharedContext
|
||||
from .execution_registry import ExecutionRegistry
|
||||
from .falyx import Falyx
|
||||
|
||||
|
@ -24,6 +24,6 @@ __all__ = [
|
|||
"Falyx",
|
||||
"Command",
|
||||
"ExecutionContext",
|
||||
"ResultsContext",
|
||||
"SharedContext",
|
||||
"ExecutionRegistry",
|
||||
]
|
||||
|
|
|
@ -38,7 +38,7 @@ from typing import Any, Callable
|
|||
from rich.console import Console
|
||||
from rich.tree import Tree
|
||||
|
||||
from falyx.context import ExecutionContext, ResultsContext
|
||||
from falyx.context import ExecutionContext, SharedContext
|
||||
from falyx.debug import register_debug_hooks
|
||||
from falyx.exceptions import EmptyChainError
|
||||
from falyx.execution_registry import ExecutionRegistry as er
|
||||
|
@ -47,8 +47,6 @@ from falyx.retry import RetryHandler, RetryPolicy
|
|||
from falyx.themes.colors import OneColors
|
||||
from falyx.utils import ensure_async, logger
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class BaseAction(ABC):
|
||||
"""
|
||||
|
@ -72,11 +70,12 @@ class BaseAction(ABC):
|
|||
self.name = name
|
||||
self.hooks = hooks or HookManager()
|
||||
self.is_retryable: bool = False
|
||||
self.results_context: ResultsContext | None = None
|
||||
self.shared_context: SharedContext | None = None
|
||||
self.inject_last_result: bool = inject_last_result
|
||||
self.inject_last_result_as: str = inject_last_result_as
|
||||
self._requires_injection: bool = False
|
||||
self._skip_in_chain: bool = False
|
||||
self.console = Console(color_system="auto")
|
||||
|
||||
if logging_hooks:
|
||||
register_debug_hooks(self.hooks)
|
||||
|
@ -92,32 +91,32 @@ class BaseAction(ABC):
|
|||
async def preview(self, parent: Tree | None = None):
|
||||
raise NotImplementedError("preview must be implemented by subclasses")
|
||||
|
||||
def set_results_context(self, results_context: ResultsContext):
|
||||
self.results_context = results_context
|
||||
def set_shared_context(self, shared_context: SharedContext):
|
||||
self.shared_context = shared_context
|
||||
|
||||
def prepare_for_chain(self, results_context: ResultsContext) -> BaseAction:
|
||||
def prepare_for_chain(self, shared_context: SharedContext) -> BaseAction:
|
||||
"""
|
||||
Prepare the action specifically for sequential (ChainedAction) execution.
|
||||
Can be overridden for chain-specific logic.
|
||||
"""
|
||||
self.set_results_context(results_context)
|
||||
self.set_shared_context(shared_context)
|
||||
return self
|
||||
|
||||
def prepare_for_group(self, results_context: ResultsContext) -> BaseAction:
|
||||
def prepare_for_group(self, shared_context: SharedContext) -> BaseAction:
|
||||
"""
|
||||
Prepare the action specifically for parallel (ActionGroup) execution.
|
||||
Can be overridden for group-specific logic.
|
||||
"""
|
||||
self.set_results_context(results_context)
|
||||
self.set_shared_context(shared_context)
|
||||
return self
|
||||
|
||||
def _maybe_inject_last_result(self, kwargs: dict[str, Any]) -> dict[str, Any]:
|
||||
if self.inject_last_result and self.results_context:
|
||||
if self.inject_last_result and self.shared_context:
|
||||
key = self.inject_last_result_as
|
||||
if key in kwargs:
|
||||
logger.warning("[%s] ⚠️ Overriding '%s' with last_result", self.name, key)
|
||||
kwargs = dict(kwargs)
|
||||
kwargs[key] = self.results_context.last_result()
|
||||
kwargs[key] = self.shared_context.last_result()
|
||||
return kwargs
|
||||
|
||||
def register_hooks_recursively(self, hook_type: HookType, hook: Hook):
|
||||
|
@ -146,7 +145,7 @@ class BaseAction(ABC):
|
|||
return self._requires_injection
|
||||
|
||||
def __str__(self):
|
||||
return f"<{self.__class__.__name__} '{self.name}'>"
|
||||
return f"{self.__class__.__name__}('{self.name}')"
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
@ -261,7 +260,7 @@ class Action(BaseAction):
|
|||
if parent:
|
||||
parent.add("".join(label))
|
||||
else:
|
||||
console.print(Tree("".join(label)))
|
||||
self.console.print(Tree("".join(label)))
|
||||
|
||||
def __str__(self):
|
||||
return f"Action(name={self.name}, action={self.action.__name__})"
|
||||
|
@ -405,9 +404,9 @@ class ChainedAction(BaseAction, ActionListMixin):
|
|||
if not self.actions:
|
||||
raise EmptyChainError(f"[{self.name}] No actions to execute.")
|
||||
|
||||
results_context = ResultsContext(name=self.name)
|
||||
if self.results_context:
|
||||
results_context.add_result(self.results_context.last_result())
|
||||
shared_context = SharedContext(name=self.name)
|
||||
if self.shared_context:
|
||||
shared_context.add_result(self.shared_context.last_result())
|
||||
updated_kwargs = self._maybe_inject_last_result(kwargs)
|
||||
context = ExecutionContext(
|
||||
name=self.name,
|
||||
|
@ -415,6 +414,7 @@ class ChainedAction(BaseAction, ActionListMixin):
|
|||
kwargs=updated_kwargs,
|
||||
action=self,
|
||||
extra={"results": [], "rollback_stack": []},
|
||||
shared_context=shared_context,
|
||||
)
|
||||
context.start_timer()
|
||||
try:
|
||||
|
@ -424,9 +424,9 @@ class ChainedAction(BaseAction, ActionListMixin):
|
|||
if action._skip_in_chain:
|
||||
logger.debug("[%s] ⚠️ Skipping consumed action '%s'", self.name, action.name)
|
||||
continue
|
||||
results_context.current_index = index
|
||||
prepared = action.prepare_for_chain(results_context)
|
||||
last_result = results_context.last_result()
|
||||
shared_context.current_index = index
|
||||
prepared = action.prepare_for_chain(shared_context)
|
||||
last_result = shared_context.last_result()
|
||||
try:
|
||||
if self.requires_io_injection() and last_result is not None:
|
||||
result = await prepared(**{prepared.inject_last_result_as: last_result})
|
||||
|
@ -436,14 +436,14 @@ class ChainedAction(BaseAction, ActionListMixin):
|
|||
if index + 1 < len(self.actions) and isinstance(self.actions[index + 1], FallbackAction):
|
||||
logger.warning("[%s] ⚠️ Fallback triggered: %s, recovering with fallback '%s'.",
|
||||
self.name, error, self.actions[index + 1].name)
|
||||
results_context.add_result(None)
|
||||
shared_context.add_result(None)
|
||||
context.extra["results"].append(None)
|
||||
fallback = self.actions[index + 1].prepare_for_chain(results_context)
|
||||
fallback = self.actions[index + 1].prepare_for_chain(shared_context)
|
||||
result = await fallback()
|
||||
fallback._skip_in_chain = True
|
||||
else:
|
||||
raise
|
||||
results_context.add_result(result)
|
||||
shared_context.add_result(result)
|
||||
context.extra["results"].append(result)
|
||||
context.extra["rollback_stack"].append(prepared)
|
||||
|
||||
|
@ -455,7 +455,7 @@ class ChainedAction(BaseAction, ActionListMixin):
|
|||
|
||||
except Exception as error:
|
||||
context.exception = error
|
||||
results_context.errors.append((results_context.current_index, error))
|
||||
shared_context.add_error(shared_context.current_index, error)
|
||||
await self._rollback(context.extra["rollback_stack"], *args, **kwargs)
|
||||
await self.hooks.trigger(HookType.ON_ERROR, context)
|
||||
raise
|
||||
|
@ -495,7 +495,7 @@ class ChainedAction(BaseAction, ActionListMixin):
|
|||
for action in self.actions:
|
||||
await action.preview(parent=tree)
|
||||
if not parent:
|
||||
console.print(tree)
|
||||
self.console.print(tree)
|
||||
|
||||
def register_hooks_recursively(self, hook_type: HookType, hook: Hook):
|
||||
"""Register a hook for all actions and sub-actions."""
|
||||
|
@ -503,6 +503,9 @@ class ChainedAction(BaseAction, ActionListMixin):
|
|||
for action in self.actions:
|
||||
action.register_hooks_recursively(hook_type, hook)
|
||||
|
||||
def __str__(self):
|
||||
return f"ChainedAction(name={self.name}, actions={self.actions})"
|
||||
|
||||
|
||||
class ActionGroup(BaseAction, ActionListMixin):
|
||||
"""
|
||||
|
@ -550,9 +553,9 @@ class ActionGroup(BaseAction, ActionListMixin):
|
|||
self.set_actions(actions)
|
||||
|
||||
async def _run(self, *args, **kwargs) -> list[tuple[str, Any]]:
|
||||
results_context = ResultsContext(name=self.name, is_parallel=True)
|
||||
if self.results_context:
|
||||
results_context.set_shared_result(self.results_context.last_result())
|
||||
shared_context = SharedContext(name=self.name, is_parallel=True)
|
||||
if self.shared_context:
|
||||
shared_context.set_shared_result(self.shared_context.last_result())
|
||||
updated_kwargs = self._maybe_inject_last_result(kwargs)
|
||||
context = ExecutionContext(
|
||||
name=self.name,
|
||||
|
@ -560,15 +563,16 @@ class ActionGroup(BaseAction, ActionListMixin):
|
|||
kwargs=updated_kwargs,
|
||||
action=self,
|
||||
extra={"results": [], "errors": []},
|
||||
shared_context=shared_context,
|
||||
)
|
||||
async def run_one(action: BaseAction):
|
||||
try:
|
||||
prepared = action.prepare_for_group(results_context)
|
||||
prepared = action.prepare_for_group(shared_context)
|
||||
result = await prepared(*args, **updated_kwargs)
|
||||
results_context.add_result((action.name, result))
|
||||
shared_context.add_result((action.name, result))
|
||||
context.extra["results"].append((action.name, result))
|
||||
except Exception as error:
|
||||
results_context.errors.append((results_context.current_index, error))
|
||||
shared_context.add_error(shared_context.current_index, error)
|
||||
context.extra["errors"].append((action.name, error))
|
||||
|
||||
context.start_timer()
|
||||
|
@ -606,7 +610,7 @@ class ActionGroup(BaseAction, ActionListMixin):
|
|||
random.shuffle(actions)
|
||||
await asyncio.gather(*(action.preview(parent=tree) for action in actions))
|
||||
if not parent:
|
||||
console.print(tree)
|
||||
self.console.print(tree)
|
||||
|
||||
def register_hooks_recursively(self, hook_type: HookType, hook: Hook):
|
||||
"""Register a hook for all actions and sub-actions."""
|
||||
|
@ -614,6 +618,9 @@ class ActionGroup(BaseAction, ActionListMixin):
|
|||
for action in self.actions:
|
||||
action.register_hooks_recursively(hook_type, hook)
|
||||
|
||||
def __str__(self):
|
||||
return f"ActionGroup(name={self.name}, actions={self.actions})"
|
||||
|
||||
|
||||
class ProcessAction(BaseAction):
|
||||
"""
|
||||
|
@ -655,7 +662,7 @@ class ProcessAction(BaseAction):
|
|||
|
||||
async def _run(self, *args, **kwargs):
|
||||
if self.inject_last_result:
|
||||
last_result = self.results_context.last_result()
|
||||
last_result = self.shared_context.last_result()
|
||||
if not self._validate_pickleable(last_result):
|
||||
raise ValueError(
|
||||
f"Cannot inject last result into {self.name}: "
|
||||
|
@ -699,7 +706,7 @@ class ProcessAction(BaseAction):
|
|||
if parent:
|
||||
parent.add("".join(label))
|
||||
else:
|
||||
console.print(Tree("".join(label)))
|
||||
self.console.print(Tree("".join(label)))
|
||||
|
||||
def _validate_pickleable(self, obj: Any) -> bool:
|
||||
try:
|
||||
|
@ -708,3 +715,4 @@ class ProcessAction(BaseAction):
|
|||
return True
|
||||
except (pickle.PicklingError, TypeError):
|
||||
return False
|
||||
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# Falyx CLI Framework — (c) 2025 rtj.dev LLC — MIT Licensed
|
||||
"""context.py"""
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
@ -24,6 +26,8 @@ class ExecutionContext(BaseModel):
|
|||
extra: dict[str, Any] = Field(default_factory=dict)
|
||||
console: Console = Field(default_factory=lambda: Console(color_system="auto"))
|
||||
|
||||
shared_context: SharedContext | None = None
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def start_timer(self):
|
||||
|
@ -34,6 +38,9 @@ class ExecutionContext(BaseModel):
|
|||
self.end_time = time.perf_counter()
|
||||
self.end_wall = datetime.now()
|
||||
|
||||
def get_shared_context(self) -> SharedContext:
|
||||
return self.shared_context or SharedContext(name="default")
|
||||
|
||||
@property
|
||||
def duration(self) -> float | None:
|
||||
if self.start_time is None:
|
||||
|
@ -104,7 +111,7 @@ class ExecutionContext(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class ResultsContext(BaseModel):
|
||||
class SharedContext(BaseModel):
|
||||
name: str
|
||||
results: list[Any] = Field(default_factory=list)
|
||||
errors: list[tuple[int, Exception]] = Field(default_factory=list)
|
||||
|
@ -112,11 +119,16 @@ class ResultsContext(BaseModel):
|
|||
is_parallel: bool = False
|
||||
shared_result: Any | None = None
|
||||
|
||||
share: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def add_result(self, result: Any) -> None:
|
||||
self.results.append(result)
|
||||
|
||||
def add_error(self, index: int, error: Exception) -> None:
|
||||
self.errors.append((index, error))
|
||||
|
||||
def set_shared_result(self, result: Any) -> None:
|
||||
self.shared_result = result
|
||||
if self.is_parallel:
|
||||
|
@ -127,10 +139,16 @@ class ResultsContext(BaseModel):
|
|||
return self.shared_result
|
||||
return self.results[-1] if self.results else None
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return self.share.get(key, default)
|
||||
|
||||
def set(self, key: str, value: Any) -> None:
|
||||
self.share[key] = value
|
||||
|
||||
def __str__(self) -> str:
|
||||
parallel_label = "Parallel" if self.is_parallel else "Sequential"
|
||||
return (
|
||||
f"<{parallel_label}ResultsContext '{self.name}' | "
|
||||
f"<{parallel_label}SharedContext '{self.name}' | "
|
||||
f"Results: {self.results} | "
|
||||
f"Errors: {self.errors}>"
|
||||
)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Falyx CLI Framework — (c) 2025 rtj.dev LLC — MIT Licensed
|
||||
"""hooks.py"""
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
from falyx.context import ExecutionContext
|
||||
from falyx.exceptions import CircuitBreakerOpen
|
||||
|
@ -8,8 +9,9 @@ from falyx.themes.colors import OneColors
|
|||
from falyx.utils import logger
|
||||
|
||||
|
||||
|
||||
class ResultReporter:
|
||||
def __init__(self, formatter: callable = None):
|
||||
def __init__(self, formatter: Callable[[], str] | None = None):
|
||||
"""
|
||||
Optional result formatter. If not provided, uses repr(result).
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue