falyx/falyx/action.py

477 lines
18 KiB
Python

"""action.py
Any Action or Command is callable and supports the signature:
result = thing(*args, **kwargs)
This guarantees:
- Hook lifecycle (before/after/error/teardown)
- Timing
- Consistent return values
"""
from __future__ import annotations
import asyncio
import random
from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from typing import Any, Callable
from rich.console import Console
from rich.tree import Tree
from falyx.context import ExecutionContext, ResultsContext
from falyx.debug import register_debug_hooks
from falyx.execution_registry import ExecutionRegistry as er
from falyx.hook_manager import Hook, HookManager, HookType
from falyx.retry import RetryHandler, RetryPolicy
from falyx.themes.colors import OneColors
from falyx.utils import ensure_async, logger
console = Console()
class BaseAction(ABC):
"""
Base class for actions. Actions can be simple functions or more
complex actions like `ChainedAction` or `ActionGroup`. They can also
be run independently or as part of Menu.
inject_last_result (bool): Whether to inject the previous action's result into kwargs.
inject_last_result_as (str): The name of the kwarg key to inject the result as
(default: 'last_result').
"""
def __init__(
self,
name: str,
hooks: HookManager | None = None,
inject_last_result: bool = False,
inject_last_result_as: str = "last_result",
logging_hooks: bool = False,
) -> None:
self.name = name
self.hooks = hooks or HookManager()
self.is_retryable: bool = False
self.results_context: ResultsContext | None = None
self.inject_last_result: bool = inject_last_result
self.inject_last_result_as: str = inject_last_result_as
if logging_hooks:
register_debug_hooks(self.hooks)
async def __call__(self, *args, **kwargs) -> Any:
return await self._run(*args, **kwargs)
@abstractmethod
async def _run(self, *args, **kwargs) -> Any:
raise NotImplementedError("_run must be implemented by subclasses")
@abstractmethod
async def preview(self, parent: Tree | None = None):
raise NotImplementedError("preview must be implemented by subclasses")
def set_results_context(self, results_context: ResultsContext):
self.results_context = results_context
def prepare_for_chain(self, results_context: ResultsContext) -> BaseAction:
"""
Prepare the action specifically for sequential (ChainedAction) execution.
Can be overridden for chain-specific logic.
"""
self.set_results_context(results_context)
return self
def prepare_for_group(self, results_context: ResultsContext) -> BaseAction:
"""
Prepare the action specifically for parallel (ActionGroup) execution.
Can be overridden for group-specific logic.
"""
self.set_results_context(results_context)
return self
def _maybe_inject_last_result(self, kwargs: dict[str, Any]) -> dict[str, Any]:
if self.inject_last_result and self.results_context:
key = self.inject_last_result_as
if key in kwargs:
logger.warning("[%s] ⚠️ Overriding '%s' with last_result", self.name, key)
kwargs = dict(kwargs)
kwargs[key] = self.results_context.last_result()
return kwargs
def register_hooks_recursively(self, hook_type: HookType, hook: Hook):
"""Register a hook for all actions and sub-actions."""
self.hooks.register(hook_type, hook)
def __str__(self):
return f"<{self.__class__.__name__} '{self.name}'>"
def __repr__(self):
return str(self)
@classmethod
def enable_retries_recursively(cls, action: BaseAction, policy: RetryPolicy | None):
if not policy:
policy = RetryPolicy(enabled=True)
if isinstance(action, Action):
action.retry_policy = policy
action.retry_policy.enabled = True
action.hooks.register("on_error", RetryHandler(policy).retry_on_error)
if hasattr(action, "actions"):
for sub in action.actions:
cls.enable_retries_recursively(sub, policy)
class Action(BaseAction):
"""A simple action that runs a callable. It can be a function or a coroutine."""
def __init__(
self,
name: str,
action,
rollback=None,
args: tuple[Any, ...] = (),
kwargs: dict[str, Any] | None = None,
hooks: HookManager | None = None,
inject_last_result: bool = False,
inject_last_result_as: str = "last_result",
retry: bool = False,
retry_policy: RetryPolicy | None = None,
) -> None:
super().__init__(name, hooks, inject_last_result, inject_last_result_as)
self.action = ensure_async(action)
self.rollback = rollback
self.args = args
self.kwargs = kwargs or {}
self.is_retryable = True
self.retry_policy = retry_policy or RetryPolicy()
if retry or (retry_policy and retry_policy.enabled):
self.enable_retry()
def enable_retry(self):
"""Enable retry with the existing retry policy."""
self.retry_policy.enabled = True
logger.debug(f"[Action:{self.name}] Registering retry handler")
handler = RetryHandler(self.retry_policy)
self.hooks.register(HookType.ON_ERROR, handler.retry_on_error)
def set_retry_policy(self, policy: RetryPolicy):
"""Set a new retry policy and re-register the handler."""
self.retry_policy = policy
self.enable_retry()
async def _run(self, *args, **kwargs) -> Any:
combined_args = args + self.args
combined_kwargs = self._maybe_inject_last_result({**self.kwargs, **kwargs})
context = ExecutionContext(
name=self.name,
args=combined_args,
kwargs=combined_kwargs,
action=self,
)
context.start_timer()
try:
await self.hooks.trigger(HookType.BEFORE, context)
result = await self.action(*combined_args, **combined_kwargs)
context.result = result
await self.hooks.trigger(HookType.ON_SUCCESS, context)
return context.result
except Exception as error:
context.exception = error
await self.hooks.trigger(HookType.ON_ERROR, context)
if context.result is not None:
logger.info("[%s] ✅ Recovered: %s", self.name, self.name)
return context.result
raise error
finally:
context.stop_timer()
await self.hooks.trigger(HookType.AFTER, context)
await self.hooks.trigger(HookType.ON_TEARDOWN, context)
er.record(context)
async def preview(self, parent: Tree | None = None):
label = [f"[{OneColors.GREEN_b}]⚙ Action[/] '{self.name}'"]
if self.inject_last_result:
label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]")
if self.retry_policy.enabled:
label.append(
f"\n[dim]↻ Retries:[/] {self.retry_policy.max_retries}x, "
f"delay {self.retry_policy.delay}s, backoff {self.retry_policy.backoff}x"
)
if parent:
parent.add("".join(label))
else:
console.print(Tree("".join(label)))
class ActionListMixin:
"""Mixin for managing a list of actions."""
def __init__(self) -> None:
self.actions: list[BaseAction] = []
def set_actions(self, actions: list[BaseAction]) -> None:
"""Replaces the current action list with a new one."""
self.actions.clear()
for action in actions:
self.add_action(action)
def add_action(self, action: BaseAction) -> None:
"""Adds an action to the list."""
self.actions.append(action)
def remove_action(self, name: str) -> None:
"""Removes an action by name."""
self.actions = [action for action in self.actions if action.name != name]
def has_action(self, name: str) -> bool:
"""Checks if an action with the given name exists."""
return any(action.name == name for action in self.actions)
def get_action(self, name: str) -> BaseAction | None:
"""Retrieves an action by name."""
for action in self.actions:
if action.name == name:
return action
return None
class ChainedAction(BaseAction, ActionListMixin):
"""A ChainedAction is a sequence of actions that are executed in order."""
def __init__(
self,
name: str,
actions: list[BaseAction] | None = None,
hooks: HookManager | None = None,
inject_last_result: bool = False,
inject_last_result_as: str = "last_result",
) -> None:
super().__init__(name, hooks, inject_last_result, inject_last_result_as)
ActionListMixin.__init__(self)
if actions:
self.set_actions(actions)
async def _run(self, *args, **kwargs) -> list[Any]:
results_context = ResultsContext(name=self.name)
if self.results_context:
results_context.add_result(self.results_context.last_result())
updated_kwargs = self._maybe_inject_last_result(kwargs)
context = ExecutionContext(
name=self.name,
args=args,
kwargs=updated_kwargs,
action=self,
extra={"results": [], "rollback_stack": []},
)
context.start_timer()
try:
await self.hooks.trigger(HookType.BEFORE, context)
for index, action in enumerate(self.actions):
results_context.current_index = index
prepared = action.prepare_for_chain(results_context)
result = await prepared(*args, **updated_kwargs)
results_context.add_result(result)
context.extra["results"].append(result)
context.extra["rollback_stack"].append(prepared)
context.result = context.extra["results"]
await self.hooks.trigger(HookType.ON_SUCCESS, context)
return context.result
except Exception as error:
context.exception = error
results_context.errors.append((results_context.current_index, error))
await self._rollback(context.extra["rollback_stack"], *args, **kwargs)
await self.hooks.trigger(HookType.ON_ERROR, context)
raise
finally:
context.stop_timer()
await self.hooks.trigger(HookType.AFTER, context)
await self.hooks.trigger(HookType.ON_TEARDOWN, context)
er.record(context)
async def _rollback(self, rollback_stack, *args, **kwargs):
for action in reversed(rollback_stack):
rollback = getattr(action, "rollback", None)
if rollback:
try:
logger.warning("[%s] ↩️ Rolling back...", action.name)
await action.rollback(*args, **kwargs)
except Exception as error:
logger.error("[%s]⚠️ Rollback failed: %s", action.name, error)
async def preview(self, parent: Tree | None = None):
label = f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'"
if self.inject_last_result:
label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]")
tree = parent.add("".join(label)) if parent else Tree("".join(label))
for action in self.actions:
await action.preview(parent=tree)
if not parent:
console.print(tree)
def register_hooks_recursively(self, hook_type: HookType, hook: Hook):
"""Register a hook for all actions and sub-actions."""
self.hooks.register(hook_type, hook)
for action in self.actions:
action.register_hooks_recursively(hook_type, hook)
class ActionGroup(BaseAction, ActionListMixin):
"""An ActionGroup is a collection of actions that can be run in parallel."""
def __init__(
self,
name: str,
actions: list[BaseAction] | None = None,
hooks: HookManager | None = None,
inject_last_result: bool = False,
inject_last_result_as: str = "last_result",
):
super().__init__(name, hooks, inject_last_result, inject_last_result_as)
ActionListMixin.__init__(self)
if actions:
self.set_actions(actions)
async def _run(self, *args, **kwargs) -> list[tuple[str, Any]]:
results_context = ResultsContext(name=self.name, is_parallel=True)
if self.results_context:
results_context.set_shared_result(self.results_context.last_result())
updated_kwargs = self._maybe_inject_last_result(kwargs)
context = ExecutionContext(
name=self.name,
args=args,
kwargs=updated_kwargs,
action=self,
extra={"results": [], "errors": []},
)
async def run_one(action: BaseAction):
try:
prepared = action.prepare_for_group(results_context)
result = await prepared(*args, **updated_kwargs)
results_context.add_result((action.name, result))
context.extra["results"].append((action.name, result))
except Exception as error:
results_context.errors.append((results_context.current_index, error))
context.extra["errors"].append((action.name, error))
context.start_timer()
try:
await self.hooks.trigger(HookType.BEFORE, context)
await asyncio.gather(*[run_one(a) for a in self.actions])
if context.extra["errors"]:
context.exception = Exception(
f"{len(context.extra['errors'])} action(s) failed: "
f"{' ,'.join(name for name, _ in context.extra["errors"])}"
)
await self.hooks.trigger(HookType.ON_ERROR, context)
raise context.exception
context.result = context.extra["results"]
await self.hooks.trigger(HookType.ON_SUCCESS, context)
return context.result
except Exception as error:
context.exception = error
raise
finally:
context.stop_timer()
await self.hooks.trigger(HookType.AFTER, context)
await self.hooks.trigger(HookType.ON_TEARDOWN, context)
er.record(context)
async def preview(self, parent: Tree | None = None):
label = [f"[{OneColors.MAGENTA_b}]⏩ ActionGroup (parallel)[/] '{self.name}'"]
if self.inject_last_result:
label.append(f" [dim](receives '{self.inject_last_result_as}')[/dim]")
tree = parent.add("".join(label)) if parent else Tree("".join(label))
actions = self.actions.copy()
random.shuffle(actions)
await asyncio.gather(*(action.preview(parent=tree) for action in actions))
if not parent:
console.print(tree)
def register_hooks_recursively(self, hook_type: HookType, hook: Hook):
"""Register a hook for all actions and sub-actions."""
super().register_hooks_recursively(hook_type, hook)
for action in self.actions:
action.register_hooks_recursively(hook_type, hook)
class ProcessAction(BaseAction):
"""A ProcessAction runs a function in a separate process using ProcessPoolExecutor."""
def __init__(
self,
name: str,
func: Callable[..., Any],
args: tuple = (),
kwargs: dict[str, Any] | None = None,
hooks: HookManager | None = None,
executor: ProcessPoolExecutor | None = None,
inject_last_result: bool = False,
inject_last_result_as: str = "last_result",
):
super().__init__(name, hooks, inject_last_result, inject_last_result_as)
self.func = func
self.args = args
self.kwargs = kwargs or {}
self.executor = executor or ProcessPoolExecutor()
self.is_retryable = True
async def _run(self, *args, **kwargs):
if self.inject_last_result:
last_result = self.results_context.last_result()
if not self._validate_pickleable(last_result):
raise ValueError(
f"Cannot inject last result into {self.name}: "
f"last result is not pickleable."
)
combined_args = args + self.args
combined_kwargs = self._maybe_inject_last_result({**self.kwargs, **kwargs})
context = ExecutionContext(
name=self.name,
args=combined_args,
kwargs=combined_kwargs,
action=self,
)
loop = asyncio.get_running_loop()
context.start_timer()
try:
await self.hooks.trigger(HookType.BEFORE, context)
result = await loop.run_in_executor(
self.executor, partial(self.func, *combined_args, **combined_kwargs)
)
context.result = result
await self.hooks.trigger(HookType.ON_SUCCESS, context)
return result
except Exception as error:
context.exception = error
await self.hooks.trigger(HookType.ON_ERROR, context)
if context.result is not None:
return context.result
raise
finally:
context.stop_timer()
await self.hooks.trigger(HookType.AFTER, context)
await self.hooks.trigger(HookType.ON_TEARDOWN, context)
er.record(context)
async def preview(self, parent: Tree | None = None):
label = [f"[{OneColors.DARK_YELLOW_b}]🧠 ProcessAction (new process)[/] '{self.name}'"]
if self.inject_last_result:
label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]")
if parent:
parent.add("".join(label))
else:
console.print(Tree("".join(label)))
def _validate_pickleable(self, obj: Any) -> bool:
try:
import pickle
pickle.dumps(obj)
return True
except (pickle.PicklingError, TypeError):
return False