777 lines
30 KiB
Python
777 lines
30 KiB
Python
# Falyx CLI Framework — (c) 2025 rtj.dev LLC — MIT Licensed
|
|
"""action.py
|
|
|
|
Core action system for Falyx.
|
|
|
|
This module defines the building blocks for executable actions and workflows,
|
|
providing a structured way to compose, execute, recover, and manage sequences of operations.
|
|
|
|
All actions are callable and follow a unified signature:
|
|
result = action(*args, **kwargs)
|
|
|
|
Core guarantees:
|
|
- Full hook lifecycle support (before, on_success, on_error, after, on_teardown).
|
|
- Consistent timing and execution context tracking for each run.
|
|
- Unified, predictable result handling and error propagation.
|
|
- Optional last_result injection to enable flexible, data-driven workflows.
|
|
- Built-in support for retries, rollbacks, parallel groups, chaining, and fallback recovery.
|
|
|
|
Key components:
|
|
- Action: wraps a function or coroutine into a standard executable unit.
|
|
- ChainedAction: runs actions sequentially, optionally injecting last results.
|
|
- ActionGroup: runs actions in parallel and gathers results.
|
|
- ProcessAction: executes CPU-bound functions in a separate process.
|
|
- LiteralInputAction: injects static values into workflows.
|
|
- FallbackAction: gracefully recovers from failures or missing data.
|
|
|
|
This design promotes clean, fault-tolerant, modular CLI and automation systems.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import random
|
|
from abc import ABC, abstractmethod
|
|
from concurrent.futures import ProcessPoolExecutor
|
|
from functools import cached_property, partial
|
|
from typing import Any, Callable
|
|
|
|
from rich.console import Console
|
|
from rich.tree import Tree
|
|
|
|
from falyx.context import ExecutionContext, SharedContext
|
|
from falyx.debug import register_debug_hooks
|
|
from falyx.exceptions import EmptyChainError
|
|
from falyx.execution_registry import ExecutionRegistry as er
|
|
from falyx.hook_manager import Hook, HookManager, HookType
|
|
from falyx.options_manager import OptionsManager
|
|
from falyx.retry import RetryHandler, RetryPolicy
|
|
from falyx.themes.colors import OneColors
|
|
from falyx.utils import ensure_async, logger
|
|
|
|
|
|
class BaseAction(ABC):
|
|
"""
|
|
Base class for actions. Actions can be simple functions or more
|
|
complex actions like `ChainedAction` or `ActionGroup`. They can also
|
|
be run independently or as part of Falyx.
|
|
|
|
inject_last_result (bool): Whether to inject the previous action's result into kwargs.
|
|
inject_last_result_as (str): The name of the kwarg key to inject the result as
|
|
(default: 'last_result').
|
|
_requires_injection (bool): Whether the action requires input injection.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
hooks: HookManager | None = None,
|
|
inject_last_result: bool = False,
|
|
inject_last_result_as: str = "last_result",
|
|
never_prompt: bool = False,
|
|
logging_hooks: bool = False,
|
|
) -> None:
|
|
self.name = name
|
|
self.hooks = hooks or HookManager()
|
|
self.is_retryable: bool = False
|
|
self.shared_context: SharedContext | None = None
|
|
self.inject_last_result: bool = inject_last_result
|
|
self.inject_last_result_as: str = inject_last_result_as
|
|
self._never_prompt: bool = never_prompt
|
|
self._requires_injection: bool = False
|
|
self._skip_in_chain: bool = False
|
|
self.console = Console(color_system="auto")
|
|
self.options_manager: OptionsManager | None = None
|
|
|
|
if logging_hooks:
|
|
register_debug_hooks(self.hooks)
|
|
|
|
async def __call__(self, *args, **kwargs) -> Any:
|
|
return await self._run(*args, **kwargs)
|
|
|
|
@abstractmethod
|
|
async def _run(self, *args, **kwargs) -> Any:
|
|
raise NotImplementedError("_run must be implemented by subclasses")
|
|
|
|
@abstractmethod
|
|
async def preview(self, parent: Tree | None = None):
|
|
raise NotImplementedError("preview must be implemented by subclasses")
|
|
|
|
def set_options_manager(self, options_manager: OptionsManager) -> None:
|
|
self.options_manager = options_manager
|
|
|
|
def set_shared_context(self, shared_context: SharedContext) -> None:
|
|
self.shared_context = shared_context
|
|
|
|
def get_option(self, option_name: str, default: Any = None) -> Any:
|
|
"""Resolve an option from the OptionsManager if present, otherwise use the fallback."""
|
|
if self.options_manager:
|
|
return self.options_manager.get(option_name, default)
|
|
return default
|
|
|
|
@property
|
|
def last_result(self) -> Any:
|
|
"""Return the last result from the shared context."""
|
|
if self.shared_context:
|
|
return self.shared_context.last_result()
|
|
return None
|
|
|
|
@property
|
|
def never_prompt(self) -> bool:
|
|
return self.get_option("never_prompt", self._never_prompt)
|
|
|
|
def prepare(
|
|
self, shared_context: SharedContext, options_manager: OptionsManager | None = None
|
|
) -> BaseAction:
|
|
"""
|
|
Prepare the action specifically for sequential (ChainedAction) execution.
|
|
Can be overridden for chain-specific logic.
|
|
"""
|
|
self.set_shared_context(shared_context)
|
|
if options_manager:
|
|
self.set_options_manager(options_manager)
|
|
return self
|
|
|
|
def _maybe_inject_last_result(self, kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
if self.inject_last_result and self.shared_context:
|
|
key = self.inject_last_result_as
|
|
if key in kwargs:
|
|
logger.warning("[%s] ⚠️ Overriding '%s' with last_result", self.name, key)
|
|
kwargs = dict(kwargs)
|
|
kwargs[key] = self.shared_context.last_result()
|
|
return kwargs
|
|
|
|
def register_hooks_recursively(self, hook_type: HookType, hook: Hook):
|
|
"""Register a hook for all actions and sub-actions."""
|
|
self.hooks.register(hook_type, hook)
|
|
|
|
async def _write_stdout(self, data: str) -> None:
|
|
"""Override in subclasses that produce terminal output."""
|
|
|
|
def requires_io_injection(self) -> bool:
|
|
"""Checks to see if the action requires input injection."""
|
|
return self._requires_injection
|
|
|
|
def __repr__(self) -> str:
|
|
return str(self)
|
|
|
|
|
|
class Action(BaseAction):
|
|
"""
|
|
Action wraps a simple function or coroutine into a standard executable unit.
|
|
|
|
It supports:
|
|
- Optional retry logic.
|
|
- Hook lifecycle (before, success, error, after, teardown).
|
|
- Last result injection for chaining.
|
|
- Optional rollback handlers for undo logic.
|
|
|
|
Args:
|
|
name (str): Name of the action.
|
|
action (Callable): The function or coroutine to execute.
|
|
rollback (Callable, optional): Rollback function to undo the action.
|
|
args (tuple, optional): Static positional arguments.
|
|
kwargs (dict, optional): Static keyword arguments.
|
|
hooks (HookManager, optional): Hook manager for lifecycle events.
|
|
inject_last_result (bool, optional): Enable last_result injection.
|
|
inject_last_result_as (str, optional): Name of injected key.
|
|
retry (bool, optional): Enable retry logic.
|
|
retry_policy (RetryPolicy, optional): Retry settings.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
action: Callable[..., Any],
|
|
rollback: Callable[..., Any] | None = None,
|
|
args: tuple[Any, ...] = (),
|
|
kwargs: dict[str, Any] | None = None,
|
|
hooks: HookManager | None = None,
|
|
inject_last_result: bool = False,
|
|
inject_last_result_as: str = "last_result",
|
|
retry: bool = False,
|
|
retry_policy: RetryPolicy | None = None,
|
|
) -> None:
|
|
super().__init__(name, hooks, inject_last_result, inject_last_result_as)
|
|
self.action = action
|
|
self.rollback = rollback
|
|
self.args = args
|
|
self.kwargs = kwargs or {}
|
|
self.is_retryable = True
|
|
self.retry_policy = retry_policy or RetryPolicy()
|
|
if retry or (retry_policy and retry_policy.enabled):
|
|
self.enable_retry()
|
|
|
|
@property
|
|
def action(self) -> Callable[..., Any]:
|
|
return self._action
|
|
|
|
@action.setter
|
|
def action(self, value: Callable[..., Any]):
|
|
self._action = ensure_async(value)
|
|
|
|
@property
|
|
def rollback(self) -> Callable[..., Any] | None:
|
|
return self._rollback
|
|
|
|
@rollback.setter
|
|
def rollback(self, value: Callable[..., Any] | None):
|
|
if value is None:
|
|
self._rollback = None
|
|
else:
|
|
self._rollback = ensure_async(value)
|
|
|
|
def enable_retry(self):
|
|
"""Enable retry with the existing retry policy."""
|
|
self.retry_policy.enable_policy()
|
|
logger.debug("[%s] Registering retry handler", self.name)
|
|
handler = RetryHandler(self.retry_policy)
|
|
self.hooks.register(HookType.ON_ERROR, handler.retry_on_error)
|
|
|
|
def set_retry_policy(self, policy: RetryPolicy):
|
|
"""Set a new retry policy and re-register the handler."""
|
|
self.retry_policy = policy
|
|
if policy.enabled:
|
|
self.enable_retry()
|
|
|
|
async def _run(self, *args, **kwargs) -> Any:
|
|
combined_args = args + self.args
|
|
combined_kwargs = self._maybe_inject_last_result({**self.kwargs, **kwargs})
|
|
|
|
context = ExecutionContext(
|
|
name=self.name,
|
|
args=combined_args,
|
|
kwargs=combined_kwargs,
|
|
action=self,
|
|
)
|
|
|
|
context.start_timer()
|
|
try:
|
|
await self.hooks.trigger(HookType.BEFORE, context)
|
|
result = await self.action(*combined_args, **combined_kwargs)
|
|
context.result = result
|
|
await self.hooks.trigger(HookType.ON_SUCCESS, context)
|
|
return context.result
|
|
except Exception as error:
|
|
context.exception = error
|
|
await self.hooks.trigger(HookType.ON_ERROR, context)
|
|
if context.result is not None:
|
|
logger.info("[%s] ✅ Recovered: %s", self.name, self.name)
|
|
return context.result
|
|
raise error
|
|
finally:
|
|
context.stop_timer()
|
|
await self.hooks.trigger(HookType.AFTER, context)
|
|
await self.hooks.trigger(HookType.ON_TEARDOWN, context)
|
|
er.record(context)
|
|
|
|
async def preview(self, parent: Tree | None = None):
|
|
label = [f"[{OneColors.GREEN_b}]⚙ Action[/] '{self.name}'"]
|
|
if self.inject_last_result:
|
|
label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]")
|
|
if self.retry_policy.enabled:
|
|
label.append(
|
|
f"\n[dim]↻ Retries:[/] {self.retry_policy.max_retries}x, "
|
|
f"delay {self.retry_policy.delay}s, backoff {self.retry_policy.backoff}x"
|
|
)
|
|
|
|
if parent:
|
|
parent.add("".join(label))
|
|
else:
|
|
self.console.print(Tree("".join(label)))
|
|
|
|
def __str__(self):
|
|
return (
|
|
f"Action(name={self.name!r}, action={getattr(self._action, '__name__', repr(self._action))}, "
|
|
f"args={self.args!r}, kwargs={self.kwargs!r}, retry={self.retry_policy.enabled})"
|
|
)
|
|
|
|
|
|
class LiteralInputAction(Action):
|
|
"""
|
|
LiteralInputAction injects a static value into a ChainedAction.
|
|
|
|
This allows embedding hardcoded values mid-pipeline, useful when:
|
|
- Providing default or fallback inputs.
|
|
- Starting a pipeline with a fixed input.
|
|
- Supplying missing context manually.
|
|
|
|
Args:
|
|
value (Any): The static value to inject.
|
|
"""
|
|
|
|
def __init__(self, value: Any):
|
|
self._value = value
|
|
|
|
async def literal(*args, **kwargs):
|
|
return value
|
|
|
|
super().__init__("Input", literal)
|
|
|
|
@cached_property
|
|
def value(self) -> Any:
|
|
"""Return the literal value."""
|
|
return self._value
|
|
|
|
def __str__(self) -> str:
|
|
return f"LiteralInputAction(value={self.value!r})"
|
|
|
|
|
|
class FallbackAction(Action):
|
|
"""
|
|
FallbackAction provides a default value if the previous action failed or returned None.
|
|
|
|
It injects the last result and checks:
|
|
- If last_result is not None, it passes it through unchanged.
|
|
- If last_result is None (e.g., due to failure), it replaces it with a fallback value.
|
|
|
|
Used in ChainedAction pipelines to gracefully recover from errors or missing data.
|
|
When activated, it consumes the preceding error and allows the chain to continue normally.
|
|
|
|
Args:
|
|
fallback (Any): The fallback value to use if last_result is None.
|
|
"""
|
|
|
|
def __init__(self, fallback: Any):
|
|
self._fallback = fallback
|
|
|
|
async def _fallback_logic(last_result):
|
|
return last_result if last_result is not None else fallback
|
|
|
|
super().__init__(name="Fallback", action=_fallback_logic, inject_last_result=True)
|
|
|
|
@cached_property
|
|
def fallback(self) -> Any:
|
|
"""Return the fallback value."""
|
|
return self._fallback
|
|
|
|
def __str__(self) -> str:
|
|
return f"FallbackAction(fallback={self.fallback!r})"
|
|
|
|
|
|
class ActionListMixin:
|
|
"""Mixin for managing a list of actions."""
|
|
|
|
def __init__(self) -> None:
|
|
self.actions: list[BaseAction] = []
|
|
|
|
def set_actions(self, actions: list[BaseAction]) -> None:
|
|
"""Replaces the current action list with a new one."""
|
|
self.actions.clear()
|
|
for action in actions:
|
|
self.add_action(action)
|
|
|
|
def add_action(self, action: BaseAction) -> None:
|
|
"""Adds an action to the list."""
|
|
self.actions.append(action)
|
|
|
|
def remove_action(self, name: str) -> None:
|
|
"""Removes an action by name."""
|
|
self.actions = [action for action in self.actions if action.name != name]
|
|
|
|
def has_action(self, name: str) -> bool:
|
|
"""Checks if an action with the given name exists."""
|
|
return any(action.name == name for action in self.actions)
|
|
|
|
def get_action(self, name: str) -> BaseAction | None:
|
|
"""Retrieves an action by name."""
|
|
for action in self.actions:
|
|
if action.name == name:
|
|
return action
|
|
return None
|
|
|
|
|
|
class ChainedAction(BaseAction, ActionListMixin):
|
|
"""
|
|
ChainedAction executes a sequence of actions one after another.
|
|
|
|
Features:
|
|
- Supports optional automatic last_result injection (auto_inject).
|
|
- Recovers from intermediate errors using FallbackAction if present.
|
|
- Rolls back all previously executed actions if a failure occurs.
|
|
- Handles literal values with LiteralInputAction.
|
|
|
|
Best used for defining robust, ordered workflows where each step can depend on previous results.
|
|
|
|
Args:
|
|
name (str): Name of the chain.
|
|
actions (list): List of actions or literals to execute.
|
|
hooks (HookManager, optional): Hooks for lifecycle events.
|
|
inject_last_result (bool, optional): Whether to inject last results into kwargs by default.
|
|
inject_last_result_as (str, optional): Key name for injection.
|
|
auto_inject (bool, optional): Auto-enable injection for subsequent actions.
|
|
return_list (bool, optional): Whether to return a list of all results. False returns the last result.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
actions: list[BaseAction | Any] | None = None,
|
|
hooks: HookManager | None = None,
|
|
inject_last_result: bool = False,
|
|
inject_last_result_as: str = "last_result",
|
|
auto_inject: bool = False,
|
|
return_list: bool = False,
|
|
) -> None:
|
|
super().__init__(name, hooks, inject_last_result, inject_last_result_as)
|
|
ActionListMixin.__init__(self)
|
|
self.auto_inject = auto_inject
|
|
self.return_list = return_list
|
|
if actions:
|
|
self.set_actions(actions)
|
|
|
|
def _wrap_literal_if_needed(self, action: BaseAction | Any) -> BaseAction:
|
|
return (
|
|
LiteralInputAction(action) if not isinstance(action, BaseAction) else action
|
|
)
|
|
|
|
def add_action(self, action: BaseAction | Any) -> None:
|
|
action = self._wrap_literal_if_needed(action)
|
|
if self.actions and self.auto_inject and not action.inject_last_result:
|
|
action.inject_last_result = True
|
|
super().add_action(action)
|
|
|
|
async def _run(self, *args, **kwargs) -> list[Any]:
|
|
if not self.actions:
|
|
raise EmptyChainError(f"[{self.name}] No actions to execute.")
|
|
|
|
shared_context = SharedContext(name=self.name)
|
|
if self.shared_context:
|
|
shared_context.add_result(self.shared_context.last_result())
|
|
updated_kwargs = self._maybe_inject_last_result(kwargs)
|
|
context = ExecutionContext(
|
|
name=self.name,
|
|
args=args,
|
|
kwargs=updated_kwargs,
|
|
action=self,
|
|
extra={"results": [], "rollback_stack": []},
|
|
shared_context=shared_context,
|
|
)
|
|
context.start_timer()
|
|
try:
|
|
await self.hooks.trigger(HookType.BEFORE, context)
|
|
|
|
for index, action in enumerate(self.actions):
|
|
if action._skip_in_chain:
|
|
logger.debug(
|
|
"[%s] ⚠️ Skipping consumed action '%s'", self.name, action.name
|
|
)
|
|
continue
|
|
shared_context.current_index = index
|
|
prepared = action.prepare(shared_context, self.options_manager)
|
|
last_result = shared_context.last_result()
|
|
try:
|
|
if self.requires_io_injection() and last_result is not None:
|
|
result = await prepared(
|
|
**{prepared.inject_last_result_as: last_result}
|
|
)
|
|
else:
|
|
result = await prepared(*args, **updated_kwargs)
|
|
except Exception as error:
|
|
if index + 1 < len(self.actions) and isinstance(
|
|
self.actions[index + 1], FallbackAction
|
|
):
|
|
logger.warning(
|
|
"[%s] ⚠️ Fallback triggered: %s, recovering with fallback '%s'.",
|
|
self.name,
|
|
error,
|
|
self.actions[index + 1].name,
|
|
)
|
|
shared_context.add_result(None)
|
|
context.extra["results"].append(None)
|
|
fallback = self.actions[index + 1].prepare(shared_context)
|
|
result = await fallback()
|
|
fallback._skip_in_chain = True
|
|
else:
|
|
raise
|
|
shared_context.add_result(result)
|
|
context.extra["results"].append(result)
|
|
context.extra["rollback_stack"].append(prepared)
|
|
|
|
all_results = context.extra["results"]
|
|
assert (
|
|
all_results
|
|
), f"[{self.name}] No results captured. Something seriously went wrong."
|
|
context.result = all_results if self.return_list else all_results[-1]
|
|
await self.hooks.trigger(HookType.ON_SUCCESS, context)
|
|
return context.result
|
|
|
|
except Exception as error:
|
|
context.exception = error
|
|
shared_context.add_error(shared_context.current_index, error)
|
|
await self._rollback(context.extra["rollback_stack"], *args, **kwargs)
|
|
await self.hooks.trigger(HookType.ON_ERROR, context)
|
|
raise
|
|
finally:
|
|
context.stop_timer()
|
|
await self.hooks.trigger(HookType.AFTER, context)
|
|
await self.hooks.trigger(HookType.ON_TEARDOWN, context)
|
|
er.record(context)
|
|
|
|
async def _rollback(self, rollback_stack, *args, **kwargs):
|
|
"""
|
|
Roll back all executed actions in reverse order.
|
|
|
|
Rollbacks run even if a fallback recovered from failure,
|
|
ensuring consistent undo of all side effects.
|
|
|
|
Actions without rollback handlers are skipped.
|
|
|
|
Args:
|
|
rollback_stack (list): Actions to roll back.
|
|
*args, **kwargs: Passed to rollback handlers.
|
|
"""
|
|
for action in reversed(rollback_stack):
|
|
rollback = getattr(action, "rollback", None)
|
|
if rollback:
|
|
try:
|
|
logger.warning("[%s] ↩️ Rolling back...", action.name)
|
|
await action.rollback(*args, **kwargs)
|
|
except Exception as error:
|
|
logger.error("[%s] ⚠️ Rollback failed: %s", action.name, error)
|
|
|
|
async def preview(self, parent: Tree | None = None):
|
|
label = [f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'"]
|
|
if self.inject_last_result:
|
|
label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]")
|
|
tree = parent.add("".join(label)) if parent else Tree("".join(label))
|
|
for action in self.actions:
|
|
await action.preview(parent=tree)
|
|
if not parent:
|
|
self.console.print(tree)
|
|
|
|
def register_hooks_recursively(self, hook_type: HookType, hook: Hook):
|
|
"""Register a hook for all actions and sub-actions."""
|
|
self.hooks.register(hook_type, hook)
|
|
for action in self.actions:
|
|
action.register_hooks_recursively(hook_type, hook)
|
|
|
|
def __str__(self):
|
|
return (
|
|
f"ChainedAction(name={self.name!r}, actions={[a.name for a in self.actions]!r}, "
|
|
f"auto_inject={self.auto_inject}, return_list={self.return_list})"
|
|
)
|
|
|
|
|
|
class ActionGroup(BaseAction, ActionListMixin):
|
|
"""
|
|
ActionGroup executes multiple actions concurrently in parallel.
|
|
|
|
It is ideal for independent tasks that can be safely run simultaneously,
|
|
improving overall throughput and responsiveness of workflows.
|
|
|
|
Core features:
|
|
- Parallel execution of all contained actions.
|
|
- Shared last_result injection across all actions if configured.
|
|
- Aggregated collection of individual results as (name, result) pairs.
|
|
- Hook lifecycle support (before, on_success, on_error, after, on_teardown).
|
|
- Error aggregation: captures all action errors and reports them together.
|
|
|
|
Behavior:
|
|
- If any action fails, the group collects the errors but continues executing
|
|
other actions without interruption.
|
|
- After all actions complete, ActionGroup raises a single exception summarizing
|
|
all failures, or returns all results if successful.
|
|
|
|
Best used for:
|
|
- Batch processing multiple independent tasks.
|
|
- Reducing latency for workflows with parallelizable steps.
|
|
- Isolating errors while maximizing successful execution.
|
|
|
|
Args:
|
|
name (str): Name of the chain.
|
|
actions (list): List of actions or literals to execute.
|
|
hooks (HookManager, optional): Hooks for lifecycle events.
|
|
inject_last_result (bool, optional): Whether to inject last results into kwargs by default.
|
|
inject_last_result_as (str, optional): Key name for injection.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
actions: list[BaseAction] | None = None,
|
|
hooks: HookManager | None = None,
|
|
inject_last_result: bool = False,
|
|
inject_last_result_as: str = "last_result",
|
|
):
|
|
super().__init__(name, hooks, inject_last_result, inject_last_result_as)
|
|
ActionListMixin.__init__(self)
|
|
if actions:
|
|
self.set_actions(actions)
|
|
|
|
async def _run(self, *args, **kwargs) -> list[tuple[str, Any]]:
|
|
shared_context = SharedContext(name=self.name, is_parallel=True)
|
|
if self.shared_context:
|
|
shared_context.set_shared_result(self.shared_context.last_result())
|
|
updated_kwargs = self._maybe_inject_last_result(kwargs)
|
|
context = ExecutionContext(
|
|
name=self.name,
|
|
args=args,
|
|
kwargs=updated_kwargs,
|
|
action=self,
|
|
extra={"results": [], "errors": []},
|
|
shared_context=shared_context,
|
|
)
|
|
|
|
async def run_one(action: BaseAction):
|
|
try:
|
|
prepared = action.prepare(shared_context, self.options_manager)
|
|
result = await prepared(*args, **updated_kwargs)
|
|
shared_context.add_result((action.name, result))
|
|
context.extra["results"].append((action.name, result))
|
|
except Exception as error:
|
|
shared_context.add_error(shared_context.current_index, error)
|
|
context.extra["errors"].append((action.name, error))
|
|
|
|
context.start_timer()
|
|
try:
|
|
await self.hooks.trigger(HookType.BEFORE, context)
|
|
await asyncio.gather(*[run_one(a) for a in self.actions])
|
|
|
|
if context.extra["errors"]:
|
|
context.exception = Exception(
|
|
f"{len(context.extra['errors'])} action(s) failed: "
|
|
f"{' ,'.join(name for name, _ in context.extra["errors"])}"
|
|
)
|
|
await self.hooks.trigger(HookType.ON_ERROR, context)
|
|
raise context.exception
|
|
|
|
context.result = context.extra["results"]
|
|
await self.hooks.trigger(HookType.ON_SUCCESS, context)
|
|
return context.result
|
|
|
|
except Exception as error:
|
|
context.exception = error
|
|
raise
|
|
finally:
|
|
context.stop_timer()
|
|
await self.hooks.trigger(HookType.AFTER, context)
|
|
await self.hooks.trigger(HookType.ON_TEARDOWN, context)
|
|
er.record(context)
|
|
|
|
async def preview(self, parent: Tree | None = None):
|
|
label = [f"[{OneColors.MAGENTA_b}]⏩ ActionGroup (parallel)[/] '{self.name}'"]
|
|
if self.inject_last_result:
|
|
label.append(f" [dim](receives '{self.inject_last_result_as}')[/dim]")
|
|
tree = parent.add("".join(label)) if parent else Tree("".join(label))
|
|
actions = self.actions.copy()
|
|
random.shuffle(actions)
|
|
await asyncio.gather(*(action.preview(parent=tree) for action in actions))
|
|
if not parent:
|
|
self.console.print(tree)
|
|
|
|
def register_hooks_recursively(self, hook_type: HookType, hook: Hook):
|
|
"""Register a hook for all actions and sub-actions."""
|
|
super().register_hooks_recursively(hook_type, hook)
|
|
for action in self.actions:
|
|
action.register_hooks_recursively(hook_type, hook)
|
|
|
|
def __str__(self):
|
|
return (
|
|
f"ActionGroup(name={self.name!r}, actions={[a.name for a in self.actions]!r}, "
|
|
f"inject_last_result={self.inject_last_result})"
|
|
)
|
|
|
|
|
|
class ProcessAction(BaseAction):
|
|
"""
|
|
ProcessAction runs a function in a separate process using ProcessPoolExecutor.
|
|
|
|
Features:
|
|
- Executes CPU-bound or blocking tasks without blocking the main event loop.
|
|
- Supports last_result injection into the subprocess.
|
|
- Validates that last_result is pickleable when injection is enabled.
|
|
|
|
Args:
|
|
name (str): Name of the action.
|
|
func (Callable): Function to execute in a new process.
|
|
args (tuple, optional): Positional arguments.
|
|
kwargs (dict, optional): Keyword arguments.
|
|
hooks (HookManager, optional): Hook manager for lifecycle events.
|
|
executor (ProcessPoolExecutor, optional): Custom executor if desired.
|
|
inject_last_result (bool, optional): Inject last result into the function.
|
|
inject_last_result_as (str, optional): Name of the injected key.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
func: Callable[..., Any],
|
|
args: tuple = (),
|
|
kwargs: dict[str, Any] | None = None,
|
|
hooks: HookManager | None = None,
|
|
executor: ProcessPoolExecutor | None = None,
|
|
inject_last_result: bool = False,
|
|
inject_last_result_as: str = "last_result",
|
|
):
|
|
super().__init__(name, hooks, inject_last_result, inject_last_result_as)
|
|
self.func = func
|
|
self.args = args
|
|
self.kwargs = kwargs or {}
|
|
self.executor = executor or ProcessPoolExecutor()
|
|
self.is_retryable = True
|
|
|
|
async def _run(self, *args, **kwargs):
|
|
if self.inject_last_result:
|
|
last_result = self.shared_context.last_result()
|
|
if not self._validate_pickleable(last_result):
|
|
raise ValueError(
|
|
f"Cannot inject last result into {self.name}: "
|
|
f"last result is not pickleable."
|
|
)
|
|
combined_args = args + self.args
|
|
combined_kwargs = self._maybe_inject_last_result({**self.kwargs, **kwargs})
|
|
context = ExecutionContext(
|
|
name=self.name,
|
|
args=combined_args,
|
|
kwargs=combined_kwargs,
|
|
action=self,
|
|
)
|
|
loop = asyncio.get_running_loop()
|
|
|
|
context.start_timer()
|
|
try:
|
|
await self.hooks.trigger(HookType.BEFORE, context)
|
|
result = await loop.run_in_executor(
|
|
self.executor, partial(self.func, *combined_args, **combined_kwargs)
|
|
)
|
|
context.result = result
|
|
await self.hooks.trigger(HookType.ON_SUCCESS, context)
|
|
return result
|
|
except Exception as error:
|
|
context.exception = error
|
|
await self.hooks.trigger(HookType.ON_ERROR, context)
|
|
if context.result is not None:
|
|
return context.result
|
|
raise
|
|
finally:
|
|
context.stop_timer()
|
|
await self.hooks.trigger(HookType.AFTER, context)
|
|
await self.hooks.trigger(HookType.ON_TEARDOWN, context)
|
|
er.record(context)
|
|
|
|
async def preview(self, parent: Tree | None = None):
|
|
label = [
|
|
f"[{OneColors.DARK_YELLOW_b}]🧠 ProcessAction (new process)[/] '{self.name}'"
|
|
]
|
|
if self.inject_last_result:
|
|
label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]")
|
|
if parent:
|
|
parent.add("".join(label))
|
|
else:
|
|
self.console.print(Tree("".join(label)))
|
|
|
|
def _validate_pickleable(self, obj: Any) -> bool:
|
|
try:
|
|
import pickle
|
|
|
|
pickle.dumps(obj)
|
|
return True
|
|
except (pickle.PicklingError, TypeError):
|
|
return False
|
|
|
|
def __str__(self) -> str:
|
|
return (
|
|
f"ProcessAction(name={self.name!r}, func={getattr(self.func, '__name__', repr(self.func))}, "
|
|
f"args={self.args!r}, kwargs={self.kwargs!r})"
|
|
)
|