Compare commits
1 Commits
Author | SHA1 | Date |
---|---|---|
|
3fd27094d4 |
|
@ -11,6 +11,7 @@ This guarantees:
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
@ -22,6 +23,7 @@ from rich.tree import Tree
|
|||
|
||||
from falyx.context import ExecutionContext, ResultsContext
|
||||
from falyx.debug import register_debug_hooks
|
||||
from falyx.exceptions import FalyxError
|
||||
from falyx.execution_registry import ExecutionRegistry as er
|
||||
from falyx.hook_manager import Hook, HookManager, HookType
|
||||
from falyx.retry import RetryHandler, RetryPolicy
|
||||
|
@ -102,12 +104,6 @@ class BaseAction(ABC):
|
|||
"""Register a hook for all actions and sub-actions."""
|
||||
self.hooks.register(hook_type, hook)
|
||||
|
||||
def __str__(self):
|
||||
return f"<{self.__class__.__name__} '{self.name}'>"
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
@classmethod
|
||||
def enable_retries_recursively(cls, action: BaseAction, policy: RetryPolicy | None):
|
||||
if not policy:
|
||||
|
@ -115,12 +111,46 @@ class BaseAction(ABC):
|
|||
if isinstance(action, Action):
|
||||
action.retry_policy = policy
|
||||
action.retry_policy.enabled = True
|
||||
action.hooks.register("on_error", RetryHandler(policy).retry_on_error)
|
||||
action.hooks.register(HookType.ON_ERROR, RetryHandler(policy).retry_on_error)
|
||||
|
||||
if hasattr(action, "actions"):
|
||||
for sub in action.actions:
|
||||
cls.enable_retries_recursively(sub, policy)
|
||||
|
||||
async def _write_stdout(self, data: str) -> None:
|
||||
"""Override in subclasses that produce terminal output."""
|
||||
pass
|
||||
|
||||
def __str__(self):
|
||||
return f"<{self.__class__.__name__} '{self.name}'>"
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def __or__(self, other: BaseAction) -> ChainedAction:
|
||||
"""Chain this action with another action."""
|
||||
if not isinstance(other, BaseAction):
|
||||
raise FalyxError(f"Cannot chain {type(other)} with {type(self)}")
|
||||
return ChainedAction(name=f"{self.name} | {other.name}", actions=[self, other])
|
||||
|
||||
async def __ror__(self, other: Any):
|
||||
if inspect.isawaitable(other):
|
||||
print(1)
|
||||
other = await other
|
||||
|
||||
if self.inject_last_result:
|
||||
print(2)
|
||||
return await self(**{self.inject_last_result_as: other})
|
||||
|
||||
literal_action = Action(
|
||||
name=f"Input | {self.name}",
|
||||
action=lambda: other,
|
||||
)
|
||||
|
||||
chain = ChainedAction(name=f"{other} | {self.name}", actions=[literal_action, self])
|
||||
print(3)
|
||||
print(self.name, other)
|
||||
return await chain()
|
||||
|
||||
class Action(BaseAction):
|
||||
"""A simple action that runs a callable. It can be a function or a coroutine."""
|
||||
|
@ -205,6 +235,12 @@ class Action(BaseAction):
|
|||
console.print(Tree("".join(label)))
|
||||
|
||||
|
||||
class LiteralInputAction(Action):
|
||||
def __init__(self, value: Any):
|
||||
async def literal(*args, **kwargs): return value
|
||||
super().__init__("Input", literal, inject_last_result=True)
|
||||
|
||||
|
||||
class ActionListMixin:
|
||||
"""Mixin for managing a list of actions."""
|
||||
def __init__(self) -> None:
|
||||
|
@ -266,18 +302,43 @@ class ChainedAction(BaseAction, ActionListMixin):
|
|||
context.start_timer()
|
||||
try:
|
||||
await self.hooks.trigger(HookType.BEFORE, context)
|
||||
last_result = self.results_context.last_result() if self.results_context else None
|
||||
|
||||
|
||||
|
||||
for index, action in enumerate(self.actions):
|
||||
results_context.current_index = index
|
||||
prepared = action.prepare_for_chain(results_context)
|
||||
result = await prepared(*args, **updated_kwargs)
|
||||
run_kwargs = dict(updated_kwargs)
|
||||
|
||||
underlying = getattr(prepared, "action", None)
|
||||
if underlying:
|
||||
signature = inspect.signature(underlying)
|
||||
else:
|
||||
signature = inspect.signature(prepared._run)
|
||||
parameters = signature.parameters
|
||||
|
||||
if last_result is not None:
|
||||
if action.inject_last_result_as in parameters:
|
||||
run_kwargs[action.inject_last_result_as] = last_result
|
||||
result = await prepared(*args, **run_kwargs)
|
||||
elif (
|
||||
len(parameters) == 1 and
|
||||
not parameters.get("self")
|
||||
):
|
||||
result = await prepared(last_result)
|
||||
else:
|
||||
result = await prepared(*args, **updated_kwargs)
|
||||
else:
|
||||
result = await prepared(*args, **updated_kwargs)
|
||||
last_result = result
|
||||
results_context.add_result(result)
|
||||
context.extra["results"].append(result)
|
||||
context.extra["rollback_stack"].append(prepared)
|
||||
|
||||
context.result = context.extra["results"]
|
||||
context.result = last_result
|
||||
await self.hooks.trigger(HookType.ON_SUCCESS, context)
|
||||
return context.result
|
||||
return last_result
|
||||
|
||||
except Exception as error:
|
||||
context.exception = error
|
||||
|
@ -302,7 +363,7 @@ class ChainedAction(BaseAction, ActionListMixin):
|
|||
logger.error("[%s]⚠️ Rollback failed: %s", action.name, error)
|
||||
|
||||
async def preview(self, parent: Tree | None = None):
|
||||
label = f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'"
|
||||
label = [f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'"]
|
||||
if self.inject_last_result:
|
||||
label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]")
|
||||
tree = parent.add("".join(label)) if parent else Tree("".join(label))
|
||||
|
|
Loading…
Reference in New Issue