Rename ResultsContext -> SharedContext
This commit is contained in:
		| @@ -2,6 +2,7 @@ import asyncio | ||||
|  | ||||
| from falyx import Action, ActionGroup, ChainedAction | ||||
|  | ||||
|  | ||||
| # Actions can be defined as synchronous functions | ||||
| # Falyx will automatically convert them to async functions | ||||
| def hello() -> None: | ||||
|   | ||||
| @@ -1,7 +1,7 @@ | ||||
| import asyncio | ||||
| import random | ||||
|  | ||||
| from falyx import Falyx, Action, ChainedAction | ||||
| from falyx import Action, ChainedAction, Falyx | ||||
| from falyx.utils import setup_logging | ||||
|  | ||||
| setup_logging() | ||||
|   | ||||
| @@ -8,7 +8,7 @@ import logging | ||||
|  | ||||
| from .action import Action, ActionGroup, ChainedAction, ProcessAction | ||||
| from .command import Command | ||||
| from .context import ExecutionContext, ResultsContext | ||||
| from .context import ExecutionContext, SharedContext | ||||
| from .execution_registry import ExecutionRegistry | ||||
| from .falyx import Falyx | ||||
|  | ||||
| @@ -24,6 +24,6 @@ __all__ = [ | ||||
|     "Falyx", | ||||
|     "Command", | ||||
|     "ExecutionContext", | ||||
|     "ResultsContext", | ||||
|     "SharedContext", | ||||
|     "ExecutionRegistry", | ||||
| ] | ||||
|   | ||||
| @@ -38,7 +38,7 @@ from typing import Any, Callable | ||||
| from rich.console import Console | ||||
| from rich.tree import Tree | ||||
|  | ||||
| from falyx.context import ExecutionContext, ResultsContext | ||||
| from falyx.context import ExecutionContext, SharedContext | ||||
| from falyx.debug import register_debug_hooks | ||||
| from falyx.exceptions import EmptyChainError | ||||
| from falyx.execution_registry import ExecutionRegistry as er | ||||
| @@ -47,8 +47,6 @@ from falyx.retry import RetryHandler, RetryPolicy | ||||
| from falyx.themes.colors import OneColors | ||||
| from falyx.utils import ensure_async, logger | ||||
|  | ||||
| console = Console() | ||||
|  | ||||
|  | ||||
| class BaseAction(ABC): | ||||
|     """ | ||||
| @@ -72,11 +70,12 @@ class BaseAction(ABC): | ||||
|         self.name = name | ||||
|         self.hooks = hooks or HookManager() | ||||
|         self.is_retryable: bool = False | ||||
|         self.results_context: ResultsContext | None = None | ||||
|         self.shared_context: SharedContext | None = None | ||||
|         self.inject_last_result: bool = inject_last_result | ||||
|         self.inject_last_result_as: str = inject_last_result_as | ||||
|         self._requires_injection: bool = False | ||||
|         self._skip_in_chain: bool = False | ||||
|         self.console = Console(color_system="auto") | ||||
|  | ||||
|         if logging_hooks: | ||||
|             register_debug_hooks(self.hooks) | ||||
| @@ -92,32 +91,32 @@ class BaseAction(ABC): | ||||
|     async def preview(self, parent: Tree | None = None): | ||||
|         raise NotImplementedError("preview must be implemented by subclasses") | ||||
|  | ||||
|     def set_results_context(self, results_context: ResultsContext): | ||||
|         self.results_context = results_context | ||||
|     def set_shared_context(self, shared_context: SharedContext): | ||||
|         self.shared_context = shared_context | ||||
|  | ||||
|     def prepare_for_chain(self, results_context: ResultsContext) -> BaseAction: | ||||
|     def prepare_for_chain(self, shared_context: SharedContext) -> BaseAction: | ||||
|         """ | ||||
|         Prepare the action specifically for sequential (ChainedAction) execution. | ||||
|         Can be overridden for chain-specific logic. | ||||
|         """ | ||||
|         self.set_results_context(results_context) | ||||
|         self.set_shared_context(shared_context) | ||||
|         return self | ||||
|  | ||||
|     def prepare_for_group(self, results_context: ResultsContext) -> BaseAction: | ||||
|     def prepare_for_group(self, shared_context: SharedContext) -> BaseAction: | ||||
|         """ | ||||
|         Prepare the action specifically for parallel (ActionGroup) execution. | ||||
|         Can be overridden for group-specific logic. | ||||
|         """ | ||||
|         self.set_results_context(results_context) | ||||
|         self.set_shared_context(shared_context) | ||||
|         return self | ||||
|  | ||||
|     def _maybe_inject_last_result(self, kwargs: dict[str, Any]) -> dict[str, Any]: | ||||
|         if self.inject_last_result and self.results_context: | ||||
|         if self.inject_last_result and self.shared_context: | ||||
|             key = self.inject_last_result_as | ||||
|             if key in kwargs: | ||||
|                 logger.warning("[%s] ⚠️ Overriding '%s' with last_result", self.name, key) | ||||
|             kwargs = dict(kwargs) | ||||
|             kwargs[key] = self.results_context.last_result() | ||||
|             kwargs[key] = self.shared_context.last_result() | ||||
|         return kwargs | ||||
|  | ||||
|     def register_hooks_recursively(self, hook_type: HookType, hook: Hook): | ||||
| @@ -146,7 +145,7 @@ class BaseAction(ABC): | ||||
|         return self._requires_injection | ||||
|  | ||||
|     def __str__(self): | ||||
|         return f"<{self.__class__.__name__} '{self.name}'>" | ||||
|         return f"{self.__class__.__name__}('{self.name}')" | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return str(self) | ||||
| @@ -261,7 +260,7 @@ class Action(BaseAction): | ||||
|         if parent: | ||||
|             parent.add("".join(label)) | ||||
|         else: | ||||
|             console.print(Tree("".join(label))) | ||||
|             self.console.print(Tree("".join(label))) | ||||
|  | ||||
|     def __str__(self): | ||||
|         return f"Action(name={self.name}, action={self.action.__name__})" | ||||
| @@ -405,9 +404,9 @@ class ChainedAction(BaseAction, ActionListMixin): | ||||
|         if not self.actions: | ||||
|             raise EmptyChainError(f"[{self.name}] No actions to execute.") | ||||
|  | ||||
|         results_context = ResultsContext(name=self.name) | ||||
|         if self.results_context: | ||||
|             results_context.add_result(self.results_context.last_result()) | ||||
|         shared_context = SharedContext(name=self.name) | ||||
|         if self.shared_context: | ||||
|             shared_context.add_result(self.shared_context.last_result()) | ||||
|         updated_kwargs = self._maybe_inject_last_result(kwargs) | ||||
|         context = ExecutionContext( | ||||
|             name=self.name, | ||||
| @@ -415,6 +414,7 @@ class ChainedAction(BaseAction, ActionListMixin): | ||||
|             kwargs=updated_kwargs, | ||||
|             action=self, | ||||
|             extra={"results": [], "rollback_stack": []}, | ||||
|             shared_context=shared_context, | ||||
|         ) | ||||
|         context.start_timer() | ||||
|         try: | ||||
| @@ -424,9 +424,9 @@ class ChainedAction(BaseAction, ActionListMixin): | ||||
|                 if action._skip_in_chain: | ||||
|                     logger.debug("[%s] ⚠️ Skipping consumed action '%s'", self.name, action.name) | ||||
|                     continue | ||||
|                 results_context.current_index = index | ||||
|                 prepared = action.prepare_for_chain(results_context) | ||||
|                 last_result = results_context.last_result() | ||||
|                 shared_context.current_index = index | ||||
|                 prepared = action.prepare_for_chain(shared_context) | ||||
|                 last_result = shared_context.last_result() | ||||
|                 try: | ||||
|                     if self.requires_io_injection() and last_result is not None: | ||||
|                         result = await prepared(**{prepared.inject_last_result_as: last_result}) | ||||
| @@ -436,14 +436,14 @@ class ChainedAction(BaseAction, ActionListMixin): | ||||
|                     if index + 1 < len(self.actions) and isinstance(self.actions[index + 1], FallbackAction): | ||||
|                         logger.warning("[%s] ⚠️ Fallback triggered: %s, recovering with fallback '%s'.", | ||||
|                                        self.name, error, self.actions[index + 1].name) | ||||
|                         results_context.add_result(None) | ||||
|                         shared_context.add_result(None) | ||||
|                         context.extra["results"].append(None) | ||||
|                         fallback = self.actions[index + 1].prepare_for_chain(results_context) | ||||
|                         fallback = self.actions[index + 1].prepare_for_chain(shared_context) | ||||
|                         result = await fallback() | ||||
|                         fallback._skip_in_chain = True | ||||
|                     else: | ||||
|                         raise | ||||
|                 results_context.add_result(result) | ||||
|                 shared_context.add_result(result) | ||||
|                 context.extra["results"].append(result) | ||||
|                 context.extra["rollback_stack"].append(prepared) | ||||
|  | ||||
| @@ -455,7 +455,7 @@ class ChainedAction(BaseAction, ActionListMixin): | ||||
|  | ||||
|         except Exception as error: | ||||
|             context.exception = error | ||||
|             results_context.errors.append((results_context.current_index, error)) | ||||
|             shared_context.add_error(shared_context.current_index, error) | ||||
|             await self._rollback(context.extra["rollback_stack"], *args, **kwargs) | ||||
|             await self.hooks.trigger(HookType.ON_ERROR, context) | ||||
|             raise | ||||
| @@ -495,7 +495,7 @@ class ChainedAction(BaseAction, ActionListMixin): | ||||
|         for action in self.actions: | ||||
|             await action.preview(parent=tree) | ||||
|         if not parent: | ||||
|             console.print(tree) | ||||
|             self.console.print(tree) | ||||
|  | ||||
|     def register_hooks_recursively(self, hook_type: HookType, hook: Hook): | ||||
|         """Register a hook for all actions and sub-actions.""" | ||||
| @@ -503,6 +503,9 @@ class ChainedAction(BaseAction, ActionListMixin): | ||||
|         for action in self.actions: | ||||
|             action.register_hooks_recursively(hook_type, hook) | ||||
|  | ||||
|     def __str__(self): | ||||
|         return f"ChainedAction(name={self.name}, actions={self.actions})" | ||||
|  | ||||
|  | ||||
| class ActionGroup(BaseAction, ActionListMixin): | ||||
|     """ | ||||
| @@ -550,9 +553,9 @@ class ActionGroup(BaseAction, ActionListMixin): | ||||
|             self.set_actions(actions) | ||||
|  | ||||
|     async def _run(self, *args, **kwargs) -> list[tuple[str, Any]]: | ||||
|         results_context = ResultsContext(name=self.name, is_parallel=True) | ||||
|         if self.results_context: | ||||
|             results_context.set_shared_result(self.results_context.last_result()) | ||||
|         shared_context = SharedContext(name=self.name, is_parallel=True) | ||||
|         if self.shared_context: | ||||
|             shared_context.set_shared_result(self.shared_context.last_result()) | ||||
|         updated_kwargs = self._maybe_inject_last_result(kwargs) | ||||
|         context = ExecutionContext( | ||||
|             name=self.name, | ||||
| @@ -560,15 +563,16 @@ class ActionGroup(BaseAction, ActionListMixin): | ||||
|             kwargs=updated_kwargs, | ||||
|             action=self, | ||||
|             extra={"results": [], "errors": []}, | ||||
|             shared_context=shared_context, | ||||
|         ) | ||||
|         async def run_one(action: BaseAction): | ||||
|             try: | ||||
|                 prepared = action.prepare_for_group(results_context) | ||||
|                 prepared = action.prepare_for_group(shared_context) | ||||
|                 result = await prepared(*args, **updated_kwargs) | ||||
|                 results_context.add_result((action.name, result)) | ||||
|                 shared_context.add_result((action.name, result)) | ||||
|                 context.extra["results"].append((action.name, result)) | ||||
|             except Exception as error: | ||||
|                 results_context.errors.append((results_context.current_index, error)) | ||||
|                 shared_context.add_error(shared_context.current_index, error) | ||||
|                 context.extra["errors"].append((action.name, error)) | ||||
|  | ||||
|         context.start_timer() | ||||
| @@ -606,7 +610,7 @@ class ActionGroup(BaseAction, ActionListMixin): | ||||
|         random.shuffle(actions) | ||||
|         await asyncio.gather(*(action.preview(parent=tree) for action in actions)) | ||||
|         if not parent: | ||||
|             console.print(tree) | ||||
|             self.console.print(tree) | ||||
|  | ||||
|     def register_hooks_recursively(self, hook_type: HookType, hook: Hook): | ||||
|         """Register a hook for all actions and sub-actions.""" | ||||
| @@ -614,6 +618,9 @@ class ActionGroup(BaseAction, ActionListMixin): | ||||
|         for action in self.actions: | ||||
|             action.register_hooks_recursively(hook_type, hook) | ||||
|  | ||||
|     def __str__(self): | ||||
|         return f"ActionGroup(name={self.name}, actions={self.actions})" | ||||
|  | ||||
|  | ||||
| class ProcessAction(BaseAction): | ||||
|     """ | ||||
| @@ -655,7 +662,7 @@ class ProcessAction(BaseAction): | ||||
|  | ||||
|     async def _run(self, *args, **kwargs): | ||||
|         if self.inject_last_result: | ||||
|             last_result = self.results_context.last_result() | ||||
|             last_result = self.shared_context.last_result() | ||||
|             if not self._validate_pickleable(last_result): | ||||
|                 raise ValueError( | ||||
|                     f"Cannot inject last result into {self.name}: " | ||||
| @@ -699,7 +706,7 @@ class ProcessAction(BaseAction): | ||||
|         if parent: | ||||
|             parent.add("".join(label)) | ||||
|         else: | ||||
|             console.print(Tree("".join(label))) | ||||
|             self.console.print(Tree("".join(label))) | ||||
|  | ||||
|     def _validate_pickleable(self, obj: Any) -> bool: | ||||
|         try: | ||||
| @@ -708,3 +715,4 @@ class ProcessAction(BaseAction): | ||||
|             return True | ||||
|         except (pickle.PicklingError, TypeError): | ||||
|             return False | ||||
|  | ||||
|   | ||||
| @@ -1,5 +1,7 @@ | ||||
| # Falyx CLI Framework — (c) 2025 rtj.dev LLC — MIT Licensed | ||||
| """context.py""" | ||||
| from __future__ import annotations | ||||
|  | ||||
| import time | ||||
| from datetime import datetime | ||||
| from typing import Any | ||||
| @@ -24,6 +26,8 @@ class ExecutionContext(BaseModel): | ||||
|     extra: dict[str, Any] = Field(default_factory=dict) | ||||
|     console: Console = Field(default_factory=lambda: Console(color_system="auto")) | ||||
|  | ||||
|     shared_context: SharedContext | None = None | ||||
|  | ||||
|     model_config = ConfigDict(arbitrary_types_allowed=True) | ||||
|  | ||||
|     def start_timer(self): | ||||
| @@ -34,6 +38,9 @@ class ExecutionContext(BaseModel): | ||||
|         self.end_time = time.perf_counter() | ||||
|         self.end_wall = datetime.now() | ||||
|  | ||||
|     def get_shared_context(self) -> SharedContext: | ||||
|         return self.shared_context or SharedContext(name="default") | ||||
|  | ||||
|     @property | ||||
|     def duration(self) -> float | None: | ||||
|         if self.start_time is None: | ||||
| @@ -104,7 +111,7 @@ class ExecutionContext(BaseModel): | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class ResultsContext(BaseModel): | ||||
| class SharedContext(BaseModel): | ||||
|     name: str | ||||
|     results: list[Any] = Field(default_factory=list) | ||||
|     errors: list[tuple[int, Exception]] = Field(default_factory=list) | ||||
| @@ -112,11 +119,16 @@ class ResultsContext(BaseModel): | ||||
|     is_parallel: bool = False | ||||
|     shared_result: Any | None = None | ||||
|  | ||||
|     share: dict[str, Any] = Field(default_factory=dict) | ||||
|  | ||||
|     model_config = ConfigDict(arbitrary_types_allowed=True) | ||||
|  | ||||
|     def add_result(self, result: Any) -> None: | ||||
|         self.results.append(result) | ||||
|  | ||||
|     def add_error(self, index: int, error: Exception) -> None: | ||||
|         self.errors.append((index, error)) | ||||
|  | ||||
|     def set_shared_result(self, result: Any) -> None: | ||||
|         self.shared_result = result | ||||
|         if self.is_parallel: | ||||
| @@ -127,10 +139,16 @@ class ResultsContext(BaseModel): | ||||
|             return self.shared_result | ||||
|         return self.results[-1] if self.results else None | ||||
|  | ||||
|     def get(self, key: str, default: Any = None) -> Any: | ||||
|         return self.share.get(key, default) | ||||
|  | ||||
|     def set(self, key: str, value: Any) -> None: | ||||
|         self.share[key] = value | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|         parallel_label = "Parallel" if self.is_parallel else "Sequential" | ||||
|         return ( | ||||
|             f"<{parallel_label}ResultsContext '{self.name}' | " | ||||
|             f"<{parallel_label}SharedContext '{self.name}' | " | ||||
|             f"Results: {self.results} | " | ||||
|             f"Errors: {self.errors}>" | ||||
|         ) | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| # Falyx CLI Framework — (c) 2025 rtj.dev LLC — MIT Licensed | ||||
| """hooks.py""" | ||||
| import time | ||||
| from typing import Callable | ||||
|  | ||||
| from falyx.context import ExecutionContext | ||||
| from falyx.exceptions import CircuitBreakerOpen | ||||
| @@ -8,8 +9,9 @@ from falyx.themes.colors import OneColors | ||||
| from falyx.utils import logger | ||||
|  | ||||
|  | ||||
|  | ||||
| class ResultReporter: | ||||
|     def __init__(self, formatter: callable = None): | ||||
|     def __init__(self, formatter: Callable[[], str] | None = None): | ||||
|         """ | ||||
|         Optional result formatter. If not provided, uses repr(result). | ||||
|         """ | ||||
|   | ||||
		Reference in New Issue
	
	Block a user