diff --git a/falyx/action.py b/falyx/action.py index 52e3253..4e4da26 100644 --- a/falyx/action.py +++ b/falyx/action.py @@ -11,6 +11,7 @@ This guarantees: from __future__ import annotations import asyncio +import inspect import random from abc import ABC, abstractmethod from concurrent.futures import ProcessPoolExecutor @@ -22,6 +23,7 @@ from rich.tree import Tree from falyx.context import ExecutionContext, ResultsContext from falyx.debug import register_debug_hooks +from falyx.exceptions import FalyxError from falyx.execution_registry import ExecutionRegistry as er from falyx.hook_manager import Hook, HookManager, HookType from falyx.retry import RetryHandler, RetryPolicy @@ -102,12 +104,6 @@ class BaseAction(ABC): """Register a hook for all actions and sub-actions.""" self.hooks.register(hook_type, hook) - def __str__(self): - return f"<{self.__class__.__name__} '{self.name}'>" - - def __repr__(self): - return str(self) - @classmethod def enable_retries_recursively(cls, action: BaseAction, policy: RetryPolicy | None): if not policy: @@ -115,12 +111,46 @@ class BaseAction(ABC): if isinstance(action, Action): action.retry_policy = policy action.retry_policy.enabled = True - action.hooks.register("on_error", RetryHandler(policy).retry_on_error) + action.hooks.register(HookType.ON_ERROR, RetryHandler(policy).retry_on_error) if hasattr(action, "actions"): for sub in action.actions: cls.enable_retries_recursively(sub, policy) + async def _write_stdout(self, data: str) -> None: + """Override in subclasses that produce terminal output.""" + pass + + def __str__(self): + return f"<{self.__class__.__name__} '{self.name}'>" + + def __repr__(self): + return str(self) + + def __or__(self, other: BaseAction) -> ChainedAction: + """Chain this action with another action.""" + if not isinstance(other, BaseAction): + raise FalyxError(f"Cannot chain {type(other)} with {type(self)}") + return ChainedAction(name=f"{self.name} | {other.name}", actions=[self, other]) + + async def __ror__(self, other: Any): + if inspect.isawaitable(other): + print(1) + other = await other + + if self.inject_last_result: + print(2) + return await self(**{self.inject_last_result_as: other}) + + literal_action = Action( + name=f"Input | {self.name}", + action=lambda: other, + ) + + chain = ChainedAction(name=f"{other} | {self.name}", actions=[literal_action, self]) + print(3) + print(self.name, other) + return await chain() class Action(BaseAction): """A simple action that runs a callable. It can be a function or a coroutine.""" @@ -205,6 +235,12 @@ class Action(BaseAction): console.print(Tree("".join(label))) +class LiteralInputAction(Action): + def __init__(self, value: Any): + async def literal(*args, **kwargs): return value + super().__init__("Input", literal, inject_last_result=True) + + class ActionListMixin: """Mixin for managing a list of actions.""" def __init__(self) -> None: @@ -266,18 +302,43 @@ class ChainedAction(BaseAction, ActionListMixin): context.start_timer() try: await self.hooks.trigger(HookType.BEFORE, context) + last_result = self.results_context.last_result() if self.results_context else None + + for index, action in enumerate(self.actions): results_context.current_index = index prepared = action.prepare_for_chain(results_context) - result = await prepared(*args, **updated_kwargs) + run_kwargs = dict(updated_kwargs) + + underlying = getattr(prepared, "action", None) + if underlying: + signature = inspect.signature(underlying) + else: + signature = inspect.signature(prepared._run) + parameters = signature.parameters + + if last_result is not None: + if action.inject_last_result_as in parameters: + run_kwargs[action.inject_last_result_as] = last_result + result = await prepared(*args, **run_kwargs) + elif ( + len(parameters) == 1 and + not parameters.get("self") + ): + result = await prepared(last_result) + else: + result = await prepared(*args, **updated_kwargs) + else: + result = await prepared(*args, **updated_kwargs) + last_result = result results_context.add_result(result) context.extra["results"].append(result) context.extra["rollback_stack"].append(prepared) - context.result = context.extra["results"] + context.result = last_result await self.hooks.trigger(HookType.ON_SUCCESS, context) - return context.result + return last_result except Exception as error: context.exception = error @@ -302,7 +363,7 @@ class ChainedAction(BaseAction, ActionListMixin): logger.error("[%s]⚠️ Rollback failed: %s", action.name, error) async def preview(self, parent: Tree | None = None): - label = f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'" + label = [f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'"] if self.inject_last_result: label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]") tree = parent.add("".join(label)) if parent else Tree("".join(label))