Rename ResultsContext -> SharedContext

This commit is contained in:
2025-04-30 21:45:11 -04:00
parent 80de941335
commit bc1637143c
6 changed files with 69 additions and 40 deletions

View File

@ -2,6 +2,7 @@ import asyncio
from falyx import Action, ActionGroup, ChainedAction from falyx import Action, ActionGroup, ChainedAction
# Actions can be defined as synchronous functions # Actions can be defined as synchronous functions
# Falyx will automatically convert them to async functions # Falyx will automatically convert them to async functions
def hello() -> None: def hello() -> None:

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
import random import random
from falyx import Falyx, Action, ChainedAction from falyx import Action, ChainedAction, Falyx
from falyx.utils import setup_logging from falyx.utils import setup_logging
setup_logging() setup_logging()

View File

@ -8,7 +8,7 @@ import logging
from .action import Action, ActionGroup, ChainedAction, ProcessAction from .action import Action, ActionGroup, ChainedAction, ProcessAction
from .command import Command from .command import Command
from .context import ExecutionContext, ResultsContext from .context import ExecutionContext, SharedContext
from .execution_registry import ExecutionRegistry from .execution_registry import ExecutionRegistry
from .falyx import Falyx from .falyx import Falyx
@ -24,6 +24,6 @@ __all__ = [
"Falyx", "Falyx",
"Command", "Command",
"ExecutionContext", "ExecutionContext",
"ResultsContext", "SharedContext",
"ExecutionRegistry", "ExecutionRegistry",
] ]

View File

@ -38,7 +38,7 @@ from typing import Any, Callable
from rich.console import Console from rich.console import Console
from rich.tree import Tree from rich.tree import Tree
from falyx.context import ExecutionContext, ResultsContext from falyx.context import ExecutionContext, SharedContext
from falyx.debug import register_debug_hooks from falyx.debug import register_debug_hooks
from falyx.exceptions import EmptyChainError from falyx.exceptions import EmptyChainError
from falyx.execution_registry import ExecutionRegistry as er from falyx.execution_registry import ExecutionRegistry as er
@ -47,8 +47,6 @@ from falyx.retry import RetryHandler, RetryPolicy
from falyx.themes.colors import OneColors from falyx.themes.colors import OneColors
from falyx.utils import ensure_async, logger from falyx.utils import ensure_async, logger
console = Console()
class BaseAction(ABC): class BaseAction(ABC):
""" """
@ -72,11 +70,12 @@ class BaseAction(ABC):
self.name = name self.name = name
self.hooks = hooks or HookManager() self.hooks = hooks or HookManager()
self.is_retryable: bool = False self.is_retryable: bool = False
self.results_context: ResultsContext | None = None self.shared_context: SharedContext | None = None
self.inject_last_result: bool = inject_last_result self.inject_last_result: bool = inject_last_result
self.inject_last_result_as: str = inject_last_result_as self.inject_last_result_as: str = inject_last_result_as
self._requires_injection: bool = False self._requires_injection: bool = False
self._skip_in_chain: bool = False self._skip_in_chain: bool = False
self.console = Console(color_system="auto")
if logging_hooks: if logging_hooks:
register_debug_hooks(self.hooks) register_debug_hooks(self.hooks)
@ -92,32 +91,32 @@ class BaseAction(ABC):
async def preview(self, parent: Tree | None = None): async def preview(self, parent: Tree | None = None):
raise NotImplementedError("preview must be implemented by subclasses") raise NotImplementedError("preview must be implemented by subclasses")
def set_results_context(self, results_context: ResultsContext): def set_shared_context(self, shared_context: SharedContext):
self.results_context = results_context self.shared_context = shared_context
def prepare_for_chain(self, results_context: ResultsContext) -> BaseAction: def prepare_for_chain(self, shared_context: SharedContext) -> BaseAction:
""" """
Prepare the action specifically for sequential (ChainedAction) execution. Prepare the action specifically for sequential (ChainedAction) execution.
Can be overridden for chain-specific logic. Can be overridden for chain-specific logic.
""" """
self.set_results_context(results_context) self.set_shared_context(shared_context)
return self return self
def prepare_for_group(self, results_context: ResultsContext) -> BaseAction: def prepare_for_group(self, shared_context: SharedContext) -> BaseAction:
""" """
Prepare the action specifically for parallel (ActionGroup) execution. Prepare the action specifically for parallel (ActionGroup) execution.
Can be overridden for group-specific logic. Can be overridden for group-specific logic.
""" """
self.set_results_context(results_context) self.set_shared_context(shared_context)
return self return self
def _maybe_inject_last_result(self, kwargs: dict[str, Any]) -> dict[str, Any]: def _maybe_inject_last_result(self, kwargs: dict[str, Any]) -> dict[str, Any]:
if self.inject_last_result and self.results_context: if self.inject_last_result and self.shared_context:
key = self.inject_last_result_as key = self.inject_last_result_as
if key in kwargs: if key in kwargs:
logger.warning("[%s] ⚠️ Overriding '%s' with last_result", self.name, key) logger.warning("[%s] ⚠️ Overriding '%s' with last_result", self.name, key)
kwargs = dict(kwargs) kwargs = dict(kwargs)
kwargs[key] = self.results_context.last_result() kwargs[key] = self.shared_context.last_result()
return kwargs return kwargs
def register_hooks_recursively(self, hook_type: HookType, hook: Hook): def register_hooks_recursively(self, hook_type: HookType, hook: Hook):
@ -146,7 +145,7 @@ class BaseAction(ABC):
return self._requires_injection return self._requires_injection
def __str__(self): def __str__(self):
return f"<{self.__class__.__name__} '{self.name}'>" return f"{self.__class__.__name__}('{self.name}')"
def __repr__(self): def __repr__(self):
return str(self) return str(self)
@ -261,7 +260,7 @@ class Action(BaseAction):
if parent: if parent:
parent.add("".join(label)) parent.add("".join(label))
else: else:
console.print(Tree("".join(label))) self.console.print(Tree("".join(label)))
def __str__(self): def __str__(self):
return f"Action(name={self.name}, action={self.action.__name__})" return f"Action(name={self.name}, action={self.action.__name__})"
@ -405,9 +404,9 @@ class ChainedAction(BaseAction, ActionListMixin):
if not self.actions: if not self.actions:
raise EmptyChainError(f"[{self.name}] No actions to execute.") raise EmptyChainError(f"[{self.name}] No actions to execute.")
results_context = ResultsContext(name=self.name) shared_context = SharedContext(name=self.name)
if self.results_context: if self.shared_context:
results_context.add_result(self.results_context.last_result()) shared_context.add_result(self.shared_context.last_result())
updated_kwargs = self._maybe_inject_last_result(kwargs) updated_kwargs = self._maybe_inject_last_result(kwargs)
context = ExecutionContext( context = ExecutionContext(
name=self.name, name=self.name,
@ -415,6 +414,7 @@ class ChainedAction(BaseAction, ActionListMixin):
kwargs=updated_kwargs, kwargs=updated_kwargs,
action=self, action=self,
extra={"results": [], "rollback_stack": []}, extra={"results": [], "rollback_stack": []},
shared_context=shared_context,
) )
context.start_timer() context.start_timer()
try: try:
@ -424,9 +424,9 @@ class ChainedAction(BaseAction, ActionListMixin):
if action._skip_in_chain: if action._skip_in_chain:
logger.debug("[%s] ⚠️ Skipping consumed action '%s'", self.name, action.name) logger.debug("[%s] ⚠️ Skipping consumed action '%s'", self.name, action.name)
continue continue
results_context.current_index = index shared_context.current_index = index
prepared = action.prepare_for_chain(results_context) prepared = action.prepare_for_chain(shared_context)
last_result = results_context.last_result() last_result = shared_context.last_result()
try: try:
if self.requires_io_injection() and last_result is not None: if self.requires_io_injection() and last_result is not None:
result = await prepared(**{prepared.inject_last_result_as: last_result}) result = await prepared(**{prepared.inject_last_result_as: last_result})
@ -436,14 +436,14 @@ class ChainedAction(BaseAction, ActionListMixin):
if index + 1 < len(self.actions) and isinstance(self.actions[index + 1], FallbackAction): if index + 1 < len(self.actions) and isinstance(self.actions[index + 1], FallbackAction):
logger.warning("[%s] ⚠️ Fallback triggered: %s, recovering with fallback '%s'.", logger.warning("[%s] ⚠️ Fallback triggered: %s, recovering with fallback '%s'.",
self.name, error, self.actions[index + 1].name) self.name, error, self.actions[index + 1].name)
results_context.add_result(None) shared_context.add_result(None)
context.extra["results"].append(None) context.extra["results"].append(None)
fallback = self.actions[index + 1].prepare_for_chain(results_context) fallback = self.actions[index + 1].prepare_for_chain(shared_context)
result = await fallback() result = await fallback()
fallback._skip_in_chain = True fallback._skip_in_chain = True
else: else:
raise raise
results_context.add_result(result) shared_context.add_result(result)
context.extra["results"].append(result) context.extra["results"].append(result)
context.extra["rollback_stack"].append(prepared) context.extra["rollback_stack"].append(prepared)
@ -455,7 +455,7 @@ class ChainedAction(BaseAction, ActionListMixin):
except Exception as error: except Exception as error:
context.exception = error context.exception = error
results_context.errors.append((results_context.current_index, error)) shared_context.add_error(shared_context.current_index, error)
await self._rollback(context.extra["rollback_stack"], *args, **kwargs) await self._rollback(context.extra["rollback_stack"], *args, **kwargs)
await self.hooks.trigger(HookType.ON_ERROR, context) await self.hooks.trigger(HookType.ON_ERROR, context)
raise raise
@ -495,7 +495,7 @@ class ChainedAction(BaseAction, ActionListMixin):
for action in self.actions: for action in self.actions:
await action.preview(parent=tree) await action.preview(parent=tree)
if not parent: if not parent:
console.print(tree) self.console.print(tree)
def register_hooks_recursively(self, hook_type: HookType, hook: Hook): def register_hooks_recursively(self, hook_type: HookType, hook: Hook):
"""Register a hook for all actions and sub-actions.""" """Register a hook for all actions and sub-actions."""
@ -503,6 +503,9 @@ class ChainedAction(BaseAction, ActionListMixin):
for action in self.actions: for action in self.actions:
action.register_hooks_recursively(hook_type, hook) action.register_hooks_recursively(hook_type, hook)
def __str__(self):
return f"ChainedAction(name={self.name}, actions={self.actions})"
class ActionGroup(BaseAction, ActionListMixin): class ActionGroup(BaseAction, ActionListMixin):
""" """
@ -550,9 +553,9 @@ class ActionGroup(BaseAction, ActionListMixin):
self.set_actions(actions) self.set_actions(actions)
async def _run(self, *args, **kwargs) -> list[tuple[str, Any]]: async def _run(self, *args, **kwargs) -> list[tuple[str, Any]]:
results_context = ResultsContext(name=self.name, is_parallel=True) shared_context = SharedContext(name=self.name, is_parallel=True)
if self.results_context: if self.shared_context:
results_context.set_shared_result(self.results_context.last_result()) shared_context.set_shared_result(self.shared_context.last_result())
updated_kwargs = self._maybe_inject_last_result(kwargs) updated_kwargs = self._maybe_inject_last_result(kwargs)
context = ExecutionContext( context = ExecutionContext(
name=self.name, name=self.name,
@ -560,15 +563,16 @@ class ActionGroup(BaseAction, ActionListMixin):
kwargs=updated_kwargs, kwargs=updated_kwargs,
action=self, action=self,
extra={"results": [], "errors": []}, extra={"results": [], "errors": []},
shared_context=shared_context,
) )
async def run_one(action: BaseAction): async def run_one(action: BaseAction):
try: try:
prepared = action.prepare_for_group(results_context) prepared = action.prepare_for_group(shared_context)
result = await prepared(*args, **updated_kwargs) result = await prepared(*args, **updated_kwargs)
results_context.add_result((action.name, result)) shared_context.add_result((action.name, result))
context.extra["results"].append((action.name, result)) context.extra["results"].append((action.name, result))
except Exception as error: except Exception as error:
results_context.errors.append((results_context.current_index, error)) shared_context.add_error(shared_context.current_index, error)
context.extra["errors"].append((action.name, error)) context.extra["errors"].append((action.name, error))
context.start_timer() context.start_timer()
@ -606,7 +610,7 @@ class ActionGroup(BaseAction, ActionListMixin):
random.shuffle(actions) random.shuffle(actions)
await asyncio.gather(*(action.preview(parent=tree) for action in actions)) await asyncio.gather(*(action.preview(parent=tree) for action in actions))
if not parent: if not parent:
console.print(tree) self.console.print(tree)
def register_hooks_recursively(self, hook_type: HookType, hook: Hook): def register_hooks_recursively(self, hook_type: HookType, hook: Hook):
"""Register a hook for all actions and sub-actions.""" """Register a hook for all actions and sub-actions."""
@ -614,6 +618,9 @@ class ActionGroup(BaseAction, ActionListMixin):
for action in self.actions: for action in self.actions:
action.register_hooks_recursively(hook_type, hook) action.register_hooks_recursively(hook_type, hook)
def __str__(self):
return f"ActionGroup(name={self.name}, actions={self.actions})"
class ProcessAction(BaseAction): class ProcessAction(BaseAction):
""" """
@ -655,7 +662,7 @@ class ProcessAction(BaseAction):
async def _run(self, *args, **kwargs): async def _run(self, *args, **kwargs):
if self.inject_last_result: if self.inject_last_result:
last_result = self.results_context.last_result() last_result = self.shared_context.last_result()
if not self._validate_pickleable(last_result): if not self._validate_pickleable(last_result):
raise ValueError( raise ValueError(
f"Cannot inject last result into {self.name}: " f"Cannot inject last result into {self.name}: "
@ -699,7 +706,7 @@ class ProcessAction(BaseAction):
if parent: if parent:
parent.add("".join(label)) parent.add("".join(label))
else: else:
console.print(Tree("".join(label))) self.console.print(Tree("".join(label)))
def _validate_pickleable(self, obj: Any) -> bool: def _validate_pickleable(self, obj: Any) -> bool:
try: try:
@ -708,3 +715,4 @@ class ProcessAction(BaseAction):
return True return True
except (pickle.PicklingError, TypeError): except (pickle.PicklingError, TypeError):
return False return False

View File

@ -1,5 +1,7 @@
# Falyx CLI Framework — (c) 2025 rtj.dev LLC — MIT Licensed # Falyx CLI Framework — (c) 2025 rtj.dev LLC — MIT Licensed
"""context.py""" """context.py"""
from __future__ import annotations
import time import time
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
@ -24,6 +26,8 @@ class ExecutionContext(BaseModel):
extra: dict[str, Any] = Field(default_factory=dict) extra: dict[str, Any] = Field(default_factory=dict)
console: Console = Field(default_factory=lambda: Console(color_system="auto")) console: Console = Field(default_factory=lambda: Console(color_system="auto"))
shared_context: SharedContext | None = None
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
def start_timer(self): def start_timer(self):
@ -34,6 +38,9 @@ class ExecutionContext(BaseModel):
self.end_time = time.perf_counter() self.end_time = time.perf_counter()
self.end_wall = datetime.now() self.end_wall = datetime.now()
def get_shared_context(self) -> SharedContext:
return self.shared_context or SharedContext(name="default")
@property @property
def duration(self) -> float | None: def duration(self) -> float | None:
if self.start_time is None: if self.start_time is None:
@ -104,7 +111,7 @@ class ExecutionContext(BaseModel):
) )
class ResultsContext(BaseModel): class SharedContext(BaseModel):
name: str name: str
results: list[Any] = Field(default_factory=list) results: list[Any] = Field(default_factory=list)
errors: list[tuple[int, Exception]] = Field(default_factory=list) errors: list[tuple[int, Exception]] = Field(default_factory=list)
@ -112,11 +119,16 @@ class ResultsContext(BaseModel):
is_parallel: bool = False is_parallel: bool = False
shared_result: Any | None = None shared_result: Any | None = None
share: dict[str, Any] = Field(default_factory=dict)
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
def add_result(self, result: Any) -> None: def add_result(self, result: Any) -> None:
self.results.append(result) self.results.append(result)
def add_error(self, index: int, error: Exception) -> None:
self.errors.append((index, error))
def set_shared_result(self, result: Any) -> None: def set_shared_result(self, result: Any) -> None:
self.shared_result = result self.shared_result = result
if self.is_parallel: if self.is_parallel:
@ -127,10 +139,16 @@ class ResultsContext(BaseModel):
return self.shared_result return self.shared_result
return self.results[-1] if self.results else None return self.results[-1] if self.results else None
def get(self, key: str, default: Any = None) -> Any:
return self.share.get(key, default)
def set(self, key: str, value: Any) -> None:
self.share[key] = value
def __str__(self) -> str: def __str__(self) -> str:
parallel_label = "Parallel" if self.is_parallel else "Sequential" parallel_label = "Parallel" if self.is_parallel else "Sequential"
return ( return (
f"<{parallel_label}ResultsContext '{self.name}' | " f"<{parallel_label}SharedContext '{self.name}' | "
f"Results: {self.results} | " f"Results: {self.results} | "
f"Errors: {self.errors}>" f"Errors: {self.errors}>"
) )

View File

@ -1,6 +1,7 @@
# Falyx CLI Framework — (c) 2025 rtj.dev LLC — MIT Licensed # Falyx CLI Framework — (c) 2025 rtj.dev LLC — MIT Licensed
"""hooks.py""" """hooks.py"""
import time import time
from typing import Callable
from falyx.context import ExecutionContext from falyx.context import ExecutionContext
from falyx.exceptions import CircuitBreakerOpen from falyx.exceptions import CircuitBreakerOpen
@ -8,8 +9,9 @@ from falyx.themes.colors import OneColors
from falyx.utils import logger from falyx.utils import logger
class ResultReporter: class ResultReporter:
def __init__(self, formatter: callable = None): def __init__(self, formatter: Callable[[], str] | None = None):
""" """
Optional result formatter. If not provided, uses repr(result). Optional result formatter. If not provided, uses repr(result).
""" """