diff --git a/examples/action_example.py b/examples/action_example.py index cd4666b..644043e 100644 --- a/examples/action_example.py +++ b/examples/action_example.py @@ -2,6 +2,7 @@ import asyncio from falyx import Action, ActionGroup, ChainedAction + # Actions can be defined as synchronous functions # Falyx will automatically convert them to async functions def hello() -> None: diff --git a/examples/simple.py b/examples/simple.py index 3b7c754..0ccb5e3 100644 --- a/examples/simple.py +++ b/examples/simple.py @@ -1,7 +1,7 @@ import asyncio import random -from falyx import Falyx, Action, ChainedAction +from falyx import Action, ChainedAction, Falyx from falyx.utils import setup_logging setup_logging() diff --git a/falyx/__init__.py b/falyx/__init__.py index 5afa428..2ebe73f 100644 --- a/falyx/__init__.py +++ b/falyx/__init__.py @@ -8,7 +8,7 @@ import logging from .action import Action, ActionGroup, ChainedAction, ProcessAction from .command import Command -from .context import ExecutionContext, ResultsContext +from .context import ExecutionContext, SharedContext from .execution_registry import ExecutionRegistry from .falyx import Falyx @@ -24,6 +24,6 @@ __all__ = [ "Falyx", "Command", "ExecutionContext", - "ResultsContext", + "SharedContext", "ExecutionRegistry", ] diff --git a/falyx/action.py b/falyx/action.py index a082b56..d050e99 100644 --- a/falyx/action.py +++ b/falyx/action.py @@ -38,7 +38,7 @@ from typing import Any, Callable from rich.console import Console from rich.tree import Tree -from falyx.context import ExecutionContext, ResultsContext +from falyx.context import ExecutionContext, SharedContext from falyx.debug import register_debug_hooks from falyx.exceptions import EmptyChainError from falyx.execution_registry import ExecutionRegistry as er @@ -47,8 +47,6 @@ from falyx.retry import RetryHandler, RetryPolicy from falyx.themes.colors import OneColors from falyx.utils import ensure_async, logger -console = Console() - class BaseAction(ABC): """ @@ -72,11 +70,12 @@ class BaseAction(ABC): self.name = name self.hooks = hooks or HookManager() self.is_retryable: bool = False - self.results_context: ResultsContext | None = None + self.shared_context: SharedContext | None = None self.inject_last_result: bool = inject_last_result self.inject_last_result_as: str = inject_last_result_as self._requires_injection: bool = False self._skip_in_chain: bool = False + self.console = Console(color_system="auto") if logging_hooks: register_debug_hooks(self.hooks) @@ -92,32 +91,32 @@ class BaseAction(ABC): async def preview(self, parent: Tree | None = None): raise NotImplementedError("preview must be implemented by subclasses") - def set_results_context(self, results_context: ResultsContext): - self.results_context = results_context + def set_shared_context(self, shared_context: SharedContext): + self.shared_context = shared_context - def prepare_for_chain(self, results_context: ResultsContext) -> BaseAction: + def prepare_for_chain(self, shared_context: SharedContext) -> BaseAction: """ Prepare the action specifically for sequential (ChainedAction) execution. Can be overridden for chain-specific logic. """ - self.set_results_context(results_context) + self.set_shared_context(shared_context) return self - def prepare_for_group(self, results_context: ResultsContext) -> BaseAction: + def prepare_for_group(self, shared_context: SharedContext) -> BaseAction: """ Prepare the action specifically for parallel (ActionGroup) execution. Can be overridden for group-specific logic. """ - self.set_results_context(results_context) + self.set_shared_context(shared_context) return self def _maybe_inject_last_result(self, kwargs: dict[str, Any]) -> dict[str, Any]: - if self.inject_last_result and self.results_context: + if self.inject_last_result and self.shared_context: key = self.inject_last_result_as if key in kwargs: logger.warning("[%s] ⚠️ Overriding '%s' with last_result", self.name, key) kwargs = dict(kwargs) - kwargs[key] = self.results_context.last_result() + kwargs[key] = self.shared_context.last_result() return kwargs def register_hooks_recursively(self, hook_type: HookType, hook: Hook): @@ -146,7 +145,7 @@ class BaseAction(ABC): return self._requires_injection def __str__(self): - return f"<{self.__class__.__name__} '{self.name}'>" + return f"{self.__class__.__name__}('{self.name}')" def __repr__(self): return str(self) @@ -261,7 +260,7 @@ class Action(BaseAction): if parent: parent.add("".join(label)) else: - console.print(Tree("".join(label))) + self.console.print(Tree("".join(label))) def __str__(self): return f"Action(name={self.name}, action={self.action.__name__})" @@ -405,9 +404,9 @@ class ChainedAction(BaseAction, ActionListMixin): if not self.actions: raise EmptyChainError(f"[{self.name}] No actions to execute.") - results_context = ResultsContext(name=self.name) - if self.results_context: - results_context.add_result(self.results_context.last_result()) + shared_context = SharedContext(name=self.name) + if self.shared_context: + shared_context.add_result(self.shared_context.last_result()) updated_kwargs = self._maybe_inject_last_result(kwargs) context = ExecutionContext( name=self.name, @@ -415,6 +414,7 @@ class ChainedAction(BaseAction, ActionListMixin): kwargs=updated_kwargs, action=self, extra={"results": [], "rollback_stack": []}, + shared_context=shared_context, ) context.start_timer() try: @@ -424,9 +424,9 @@ class ChainedAction(BaseAction, ActionListMixin): if action._skip_in_chain: logger.debug("[%s] ⚠️ Skipping consumed action '%s'", self.name, action.name) continue - results_context.current_index = index - prepared = action.prepare_for_chain(results_context) - last_result = results_context.last_result() + shared_context.current_index = index + prepared = action.prepare_for_chain(shared_context) + last_result = shared_context.last_result() try: if self.requires_io_injection() and last_result is not None: result = await prepared(**{prepared.inject_last_result_as: last_result}) @@ -436,14 +436,14 @@ class ChainedAction(BaseAction, ActionListMixin): if index + 1 < len(self.actions) and isinstance(self.actions[index + 1], FallbackAction): logger.warning("[%s] ⚠️ Fallback triggered: %s, recovering with fallback '%s'.", self.name, error, self.actions[index + 1].name) - results_context.add_result(None) + shared_context.add_result(None) context.extra["results"].append(None) - fallback = self.actions[index + 1].prepare_for_chain(results_context) + fallback = self.actions[index + 1].prepare_for_chain(shared_context) result = await fallback() fallback._skip_in_chain = True else: raise - results_context.add_result(result) + shared_context.add_result(result) context.extra["results"].append(result) context.extra["rollback_stack"].append(prepared) @@ -455,7 +455,7 @@ class ChainedAction(BaseAction, ActionListMixin): except Exception as error: context.exception = error - results_context.errors.append((results_context.current_index, error)) + shared_context.add_error(shared_context.current_index, error) await self._rollback(context.extra["rollback_stack"], *args, **kwargs) await self.hooks.trigger(HookType.ON_ERROR, context) raise @@ -495,7 +495,7 @@ class ChainedAction(BaseAction, ActionListMixin): for action in self.actions: await action.preview(parent=tree) if not parent: - console.print(tree) + self.console.print(tree) def register_hooks_recursively(self, hook_type: HookType, hook: Hook): """Register a hook for all actions and sub-actions.""" @@ -503,6 +503,9 @@ class ChainedAction(BaseAction, ActionListMixin): for action in self.actions: action.register_hooks_recursively(hook_type, hook) + def __str__(self): + return f"ChainedAction(name={self.name}, actions={self.actions})" + class ActionGroup(BaseAction, ActionListMixin): """ @@ -550,9 +553,9 @@ class ActionGroup(BaseAction, ActionListMixin): self.set_actions(actions) async def _run(self, *args, **kwargs) -> list[tuple[str, Any]]: - results_context = ResultsContext(name=self.name, is_parallel=True) - if self.results_context: - results_context.set_shared_result(self.results_context.last_result()) + shared_context = SharedContext(name=self.name, is_parallel=True) + if self.shared_context: + shared_context.set_shared_result(self.shared_context.last_result()) updated_kwargs = self._maybe_inject_last_result(kwargs) context = ExecutionContext( name=self.name, @@ -560,15 +563,16 @@ class ActionGroup(BaseAction, ActionListMixin): kwargs=updated_kwargs, action=self, extra={"results": [], "errors": []}, + shared_context=shared_context, ) async def run_one(action: BaseAction): try: - prepared = action.prepare_for_group(results_context) + prepared = action.prepare_for_group(shared_context) result = await prepared(*args, **updated_kwargs) - results_context.add_result((action.name, result)) + shared_context.add_result((action.name, result)) context.extra["results"].append((action.name, result)) except Exception as error: - results_context.errors.append((results_context.current_index, error)) + shared_context.add_error(shared_context.current_index, error) context.extra["errors"].append((action.name, error)) context.start_timer() @@ -606,7 +610,7 @@ class ActionGroup(BaseAction, ActionListMixin): random.shuffle(actions) await asyncio.gather(*(action.preview(parent=tree) for action in actions)) if not parent: - console.print(tree) + self.console.print(tree) def register_hooks_recursively(self, hook_type: HookType, hook: Hook): """Register a hook for all actions and sub-actions.""" @@ -614,6 +618,9 @@ class ActionGroup(BaseAction, ActionListMixin): for action in self.actions: action.register_hooks_recursively(hook_type, hook) + def __str__(self): + return f"ActionGroup(name={self.name}, actions={self.actions})" + class ProcessAction(BaseAction): """ @@ -655,7 +662,7 @@ class ProcessAction(BaseAction): async def _run(self, *args, **kwargs): if self.inject_last_result: - last_result = self.results_context.last_result() + last_result = self.shared_context.last_result() if not self._validate_pickleable(last_result): raise ValueError( f"Cannot inject last result into {self.name}: " @@ -699,7 +706,7 @@ class ProcessAction(BaseAction): if parent: parent.add("".join(label)) else: - console.print(Tree("".join(label))) + self.console.print(Tree("".join(label))) def _validate_pickleable(self, obj: Any) -> bool: try: @@ -708,3 +715,4 @@ class ProcessAction(BaseAction): return True except (pickle.PicklingError, TypeError): return False + diff --git a/falyx/context.py b/falyx/context.py index f8de893..623bac0 100644 --- a/falyx/context.py +++ b/falyx/context.py @@ -1,5 +1,7 @@ # Falyx CLI Framework — (c) 2025 rtj.dev LLC — MIT Licensed """context.py""" +from __future__ import annotations + import time from datetime import datetime from typing import Any @@ -24,6 +26,8 @@ class ExecutionContext(BaseModel): extra: dict[str, Any] = Field(default_factory=dict) console: Console = Field(default_factory=lambda: Console(color_system="auto")) + shared_context: SharedContext | None = None + model_config = ConfigDict(arbitrary_types_allowed=True) def start_timer(self): @@ -34,6 +38,9 @@ class ExecutionContext(BaseModel): self.end_time = time.perf_counter() self.end_wall = datetime.now() + def get_shared_context(self) -> SharedContext: + return self.shared_context or SharedContext(name="default") + @property def duration(self) -> float | None: if self.start_time is None: @@ -104,7 +111,7 @@ class ExecutionContext(BaseModel): ) -class ResultsContext(BaseModel): +class SharedContext(BaseModel): name: str results: list[Any] = Field(default_factory=list) errors: list[tuple[int, Exception]] = Field(default_factory=list) @@ -112,11 +119,16 @@ class ResultsContext(BaseModel): is_parallel: bool = False shared_result: Any | None = None + share: dict[str, Any] = Field(default_factory=dict) + model_config = ConfigDict(arbitrary_types_allowed=True) def add_result(self, result: Any) -> None: self.results.append(result) + def add_error(self, index: int, error: Exception) -> None: + self.errors.append((index, error)) + def set_shared_result(self, result: Any) -> None: self.shared_result = result if self.is_parallel: @@ -127,10 +139,16 @@ class ResultsContext(BaseModel): return self.shared_result return self.results[-1] if self.results else None + def get(self, key: str, default: Any = None) -> Any: + return self.share.get(key, default) + + def set(self, key: str, value: Any) -> None: + self.share[key] = value + def __str__(self) -> str: parallel_label = "Parallel" if self.is_parallel else "Sequential" return ( - f"<{parallel_label}ResultsContext '{self.name}' | " + f"<{parallel_label}SharedContext '{self.name}' | " f"Results: {self.results} | " f"Errors: {self.errors}>" ) diff --git a/falyx/hooks.py b/falyx/hooks.py index 0923718..7a7bc99 100644 --- a/falyx/hooks.py +++ b/falyx/hooks.py @@ -1,6 +1,7 @@ # Falyx CLI Framework — (c) 2025 rtj.dev LLC — MIT Licensed """hooks.py""" import time +from typing import Callable from falyx.context import ExecutionContext from falyx.exceptions import CircuitBreakerOpen @@ -8,8 +9,9 @@ from falyx.themes.colors import OneColors from falyx.utils import logger + class ResultReporter: - def __init__(self, formatter: callable = None): + def __init__(self, formatter: Callable[[], str] | None = None): """ Optional result formatter. If not provided, uses repr(result). """