io-actions #1

Merged
roland merged 3 commits from io-actions into main 2025-04-30 22:26:27 -04:00
6 changed files with 69 additions and 40 deletions
Showing only changes of commit bc1637143c - Show all commits

View File

@ -2,6 +2,7 @@ import asyncio
from falyx import Action, ActionGroup, ChainedAction
# Actions can be defined as synchronous functions
# Falyx will automatically convert them to async functions
def hello() -> None:

View File

@ -1,7 +1,7 @@
import asyncio
import random
from falyx import Falyx, Action, ChainedAction
from falyx import Action, ChainedAction, Falyx
from falyx.utils import setup_logging
setup_logging()

View File

@ -8,7 +8,7 @@ import logging
from .action import Action, ActionGroup, ChainedAction, ProcessAction
from .command import Command
from .context import ExecutionContext, ResultsContext
from .context import ExecutionContext, SharedContext
from .execution_registry import ExecutionRegistry
from .falyx import Falyx
@ -24,6 +24,6 @@ __all__ = [
"Falyx",
"Command",
"ExecutionContext",
"ResultsContext",
"SharedContext",
"ExecutionRegistry",
]

View File

@ -38,7 +38,7 @@ from typing import Any, Callable
from rich.console import Console
from rich.tree import Tree
from falyx.context import ExecutionContext, ResultsContext
from falyx.context import ExecutionContext, SharedContext
from falyx.debug import register_debug_hooks
from falyx.exceptions import EmptyChainError
from falyx.execution_registry import ExecutionRegistry as er
@ -47,8 +47,6 @@ from falyx.retry import RetryHandler, RetryPolicy
from falyx.themes.colors import OneColors
from falyx.utils import ensure_async, logger
console = Console()
class BaseAction(ABC):
"""
@ -72,11 +70,12 @@ class BaseAction(ABC):
self.name = name
self.hooks = hooks or HookManager()
self.is_retryable: bool = False
self.results_context: ResultsContext | None = None
self.shared_context: SharedContext | None = None
self.inject_last_result: bool = inject_last_result
self.inject_last_result_as: str = inject_last_result_as
self._requires_injection: bool = False
self._skip_in_chain: bool = False
self.console = Console(color_system="auto")
if logging_hooks:
register_debug_hooks(self.hooks)
@ -92,32 +91,32 @@ class BaseAction(ABC):
async def preview(self, parent: Tree | None = None):
raise NotImplementedError("preview must be implemented by subclasses")
def set_results_context(self, results_context: ResultsContext):
self.results_context = results_context
def set_shared_context(self, shared_context: SharedContext):
self.shared_context = shared_context
def prepare_for_chain(self, results_context: ResultsContext) -> BaseAction:
def prepare_for_chain(self, shared_context: SharedContext) -> BaseAction:
"""
Prepare the action specifically for sequential (ChainedAction) execution.
Can be overridden for chain-specific logic.
"""
self.set_results_context(results_context)
self.set_shared_context(shared_context)
return self
def prepare_for_group(self, results_context: ResultsContext) -> BaseAction:
def prepare_for_group(self, shared_context: SharedContext) -> BaseAction:
"""
Prepare the action specifically for parallel (ActionGroup) execution.
Can be overridden for group-specific logic.
"""
self.set_results_context(results_context)
self.set_shared_context(shared_context)
return self
def _maybe_inject_last_result(self, kwargs: dict[str, Any]) -> dict[str, Any]:
if self.inject_last_result and self.results_context:
if self.inject_last_result and self.shared_context:
key = self.inject_last_result_as
if key in kwargs:
logger.warning("[%s] ⚠️ Overriding '%s' with last_result", self.name, key)
kwargs = dict(kwargs)
kwargs[key] = self.results_context.last_result()
kwargs[key] = self.shared_context.last_result()
return kwargs
def register_hooks_recursively(self, hook_type: HookType, hook: Hook):
@ -146,7 +145,7 @@ class BaseAction(ABC):
return self._requires_injection
def __str__(self):
return f"<{self.__class__.__name__} '{self.name}'>"
return f"{self.__class__.__name__}('{self.name}')"
def __repr__(self):
return str(self)
@ -261,7 +260,7 @@ class Action(BaseAction):
if parent:
parent.add("".join(label))
else:
console.print(Tree("".join(label)))
self.console.print(Tree("".join(label)))
def __str__(self):
return f"Action(name={self.name}, action={self.action.__name__})"
@ -405,9 +404,9 @@ class ChainedAction(BaseAction, ActionListMixin):
if not self.actions:
raise EmptyChainError(f"[{self.name}] No actions to execute.")
results_context = ResultsContext(name=self.name)
if self.results_context:
results_context.add_result(self.results_context.last_result())
shared_context = SharedContext(name=self.name)
if self.shared_context:
shared_context.add_result(self.shared_context.last_result())
updated_kwargs = self._maybe_inject_last_result(kwargs)
context = ExecutionContext(
name=self.name,
@ -415,6 +414,7 @@ class ChainedAction(BaseAction, ActionListMixin):
kwargs=updated_kwargs,
action=self,
extra={"results": [], "rollback_stack": []},
shared_context=shared_context,
)
context.start_timer()
try:
@ -424,9 +424,9 @@ class ChainedAction(BaseAction, ActionListMixin):
if action._skip_in_chain:
logger.debug("[%s] ⚠️ Skipping consumed action '%s'", self.name, action.name)
continue
results_context.current_index = index
prepared = action.prepare_for_chain(results_context)
last_result = results_context.last_result()
shared_context.current_index = index
prepared = action.prepare_for_chain(shared_context)
last_result = shared_context.last_result()
try:
if self.requires_io_injection() and last_result is not None:
result = await prepared(**{prepared.inject_last_result_as: last_result})
@ -436,14 +436,14 @@ class ChainedAction(BaseAction, ActionListMixin):
if index + 1 < len(self.actions) and isinstance(self.actions[index + 1], FallbackAction):
logger.warning("[%s] ⚠️ Fallback triggered: %s, recovering with fallback '%s'.",
self.name, error, self.actions[index + 1].name)
results_context.add_result(None)
shared_context.add_result(None)
context.extra["results"].append(None)
fallback = self.actions[index + 1].prepare_for_chain(results_context)
fallback = self.actions[index + 1].prepare_for_chain(shared_context)
result = await fallback()
fallback._skip_in_chain = True
else:
raise
results_context.add_result(result)
shared_context.add_result(result)
context.extra["results"].append(result)
context.extra["rollback_stack"].append(prepared)
@ -455,7 +455,7 @@ class ChainedAction(BaseAction, ActionListMixin):
except Exception as error:
context.exception = error
results_context.errors.append((results_context.current_index, error))
shared_context.add_error(shared_context.current_index, error)
await self._rollback(context.extra["rollback_stack"], *args, **kwargs)
await self.hooks.trigger(HookType.ON_ERROR, context)
raise
@ -495,7 +495,7 @@ class ChainedAction(BaseAction, ActionListMixin):
for action in self.actions:
await action.preview(parent=tree)
if not parent:
console.print(tree)
self.console.print(tree)
def register_hooks_recursively(self, hook_type: HookType, hook: Hook):
"""Register a hook for all actions and sub-actions."""
@ -503,6 +503,9 @@ class ChainedAction(BaseAction, ActionListMixin):
for action in self.actions:
action.register_hooks_recursively(hook_type, hook)
def __str__(self):
return f"ChainedAction(name={self.name}, actions={self.actions})"
class ActionGroup(BaseAction, ActionListMixin):
"""
@ -550,9 +553,9 @@ class ActionGroup(BaseAction, ActionListMixin):
self.set_actions(actions)
async def _run(self, *args, **kwargs) -> list[tuple[str, Any]]:
results_context = ResultsContext(name=self.name, is_parallel=True)
if self.results_context:
results_context.set_shared_result(self.results_context.last_result())
shared_context = SharedContext(name=self.name, is_parallel=True)
if self.shared_context:
shared_context.set_shared_result(self.shared_context.last_result())
updated_kwargs = self._maybe_inject_last_result(kwargs)
context = ExecutionContext(
name=self.name,
@ -560,15 +563,16 @@ class ActionGroup(BaseAction, ActionListMixin):
kwargs=updated_kwargs,
action=self,
extra={"results": [], "errors": []},
shared_context=shared_context,
)
async def run_one(action: BaseAction):
try:
prepared = action.prepare_for_group(results_context)
prepared = action.prepare_for_group(shared_context)
result = await prepared(*args, **updated_kwargs)
results_context.add_result((action.name, result))
shared_context.add_result((action.name, result))
context.extra["results"].append((action.name, result))
except Exception as error:
results_context.errors.append((results_context.current_index, error))
shared_context.add_error(shared_context.current_index, error)
context.extra["errors"].append((action.name, error))
context.start_timer()
@ -606,7 +610,7 @@ class ActionGroup(BaseAction, ActionListMixin):
random.shuffle(actions)
await asyncio.gather(*(action.preview(parent=tree) for action in actions))
if not parent:
console.print(tree)
self.console.print(tree)
def register_hooks_recursively(self, hook_type: HookType, hook: Hook):
"""Register a hook for all actions and sub-actions."""
@ -614,6 +618,9 @@ class ActionGroup(BaseAction, ActionListMixin):
for action in self.actions:
action.register_hooks_recursively(hook_type, hook)
def __str__(self):
return f"ActionGroup(name={self.name}, actions={self.actions})"
class ProcessAction(BaseAction):
"""
@ -655,7 +662,7 @@ class ProcessAction(BaseAction):
async def _run(self, *args, **kwargs):
if self.inject_last_result:
last_result = self.results_context.last_result()
last_result = self.shared_context.last_result()
if not self._validate_pickleable(last_result):
raise ValueError(
f"Cannot inject last result into {self.name}: "
@ -699,7 +706,7 @@ class ProcessAction(BaseAction):
if parent:
parent.add("".join(label))
else:
console.print(Tree("".join(label)))
self.console.print(Tree("".join(label)))
def _validate_pickleable(self, obj: Any) -> bool:
try:
@ -708,3 +715,4 @@ class ProcessAction(BaseAction):
return True
except (pickle.PicklingError, TypeError):
return False

View File

@ -1,5 +1,7 @@
# Falyx CLI Framework — (c) 2025 rtj.dev LLC — MIT Licensed
"""context.py"""
from __future__ import annotations
import time
from datetime import datetime
from typing import Any
@ -24,6 +26,8 @@ class ExecutionContext(BaseModel):
extra: dict[str, Any] = Field(default_factory=dict)
console: Console = Field(default_factory=lambda: Console(color_system="auto"))
shared_context: SharedContext | None = None
model_config = ConfigDict(arbitrary_types_allowed=True)
def start_timer(self):
@ -34,6 +38,9 @@ class ExecutionContext(BaseModel):
self.end_time = time.perf_counter()
self.end_wall = datetime.now()
def get_shared_context(self) -> SharedContext:
return self.shared_context or SharedContext(name="default")
@property
def duration(self) -> float | None:
if self.start_time is None:
@ -104,7 +111,7 @@ class ExecutionContext(BaseModel):
)
class ResultsContext(BaseModel):
class SharedContext(BaseModel):
name: str
results: list[Any] = Field(default_factory=list)
errors: list[tuple[int, Exception]] = Field(default_factory=list)
@ -112,11 +119,16 @@ class ResultsContext(BaseModel):
is_parallel: bool = False
shared_result: Any | None = None
share: dict[str, Any] = Field(default_factory=dict)
model_config = ConfigDict(arbitrary_types_allowed=True)
def add_result(self, result: Any) -> None:
self.results.append(result)
def add_error(self, index: int, error: Exception) -> None:
self.errors.append((index, error))
def set_shared_result(self, result: Any) -> None:
self.shared_result = result
if self.is_parallel:
@ -127,10 +139,16 @@ class ResultsContext(BaseModel):
return self.shared_result
return self.results[-1] if self.results else None
def get(self, key: str, default: Any = None) -> Any:
return self.share.get(key, default)
def set(self, key: str, value: Any) -> None:
self.share[key] = value
def __str__(self) -> str:
parallel_label = "Parallel" if self.is_parallel else "Sequential"
return (
f"<{parallel_label}ResultsContext '{self.name}' | "
f"<{parallel_label}SharedContext '{self.name}' | "
f"Results: {self.results} | "
f"Errors: {self.errors}>"
)

View File

@ -1,6 +1,7 @@
# Falyx CLI Framework — (c) 2025 rtj.dev LLC — MIT Licensed
"""hooks.py"""
import time
from typing import Callable
from falyx.context import ExecutionContext
from falyx.exceptions import CircuitBreakerOpen
@ -8,8 +9,9 @@ from falyx.themes.colors import OneColors
from falyx.utils import logger
class ResultReporter:
def __init__(self, formatter: callable = None):
def __init__(self, formatter: Callable[[], str] | None = None):
"""
Optional result formatter. If not provided, uses repr(result).
"""