Compare commits

...

1 Commits
main ... pipes

Author SHA1 Message Date
Roland Thomas Jr 3fd27094d4
Experiemental feature pipes 2025-04-24 19:13:52 -04:00
1 changed files with 72 additions and 11 deletions

View File

@ -11,6 +11,7 @@ This guarantees:
from __future__ import annotations
import asyncio
import inspect
import random
from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor
@ -22,6 +23,7 @@ from rich.tree import Tree
from falyx.context import ExecutionContext, ResultsContext
from falyx.debug import register_debug_hooks
from falyx.exceptions import FalyxError
from falyx.execution_registry import ExecutionRegistry as er
from falyx.hook_manager import Hook, HookManager, HookType
from falyx.retry import RetryHandler, RetryPolicy
@ -102,12 +104,6 @@ class BaseAction(ABC):
"""Register a hook for all actions and sub-actions."""
self.hooks.register(hook_type, hook)
def __str__(self):
return f"<{self.__class__.__name__} '{self.name}'>"
def __repr__(self):
return str(self)
@classmethod
def enable_retries_recursively(cls, action: BaseAction, policy: RetryPolicy | None):
if not policy:
@ -115,12 +111,46 @@ class BaseAction(ABC):
if isinstance(action, Action):
action.retry_policy = policy
action.retry_policy.enabled = True
action.hooks.register("on_error", RetryHandler(policy).retry_on_error)
action.hooks.register(HookType.ON_ERROR, RetryHandler(policy).retry_on_error)
if hasattr(action, "actions"):
for sub in action.actions:
cls.enable_retries_recursively(sub, policy)
async def _write_stdout(self, data: str) -> None:
"""Override in subclasses that produce terminal output."""
pass
def __str__(self):
return f"<{self.__class__.__name__} '{self.name}'>"
def __repr__(self):
return str(self)
def __or__(self, other: BaseAction) -> ChainedAction:
"""Chain this action with another action."""
if not isinstance(other, BaseAction):
raise FalyxError(f"Cannot chain {type(other)} with {type(self)}")
return ChainedAction(name=f"{self.name} | {other.name}", actions=[self, other])
async def __ror__(self, other: Any):
if inspect.isawaitable(other):
print(1)
other = await other
if self.inject_last_result:
print(2)
return await self(**{self.inject_last_result_as: other})
literal_action = Action(
name=f"Input | {self.name}",
action=lambda: other,
)
chain = ChainedAction(name=f"{other} | {self.name}", actions=[literal_action, self])
print(3)
print(self.name, other)
return await chain()
class Action(BaseAction):
"""A simple action that runs a callable. It can be a function or a coroutine."""
@ -205,6 +235,12 @@ class Action(BaseAction):
console.print(Tree("".join(label)))
class LiteralInputAction(Action):
def __init__(self, value: Any):
async def literal(*args, **kwargs): return value
super().__init__("Input", literal, inject_last_result=True)
class ActionListMixin:
"""Mixin for managing a list of actions."""
def __init__(self) -> None:
@ -266,18 +302,43 @@ class ChainedAction(BaseAction, ActionListMixin):
context.start_timer()
try:
await self.hooks.trigger(HookType.BEFORE, context)
last_result = self.results_context.last_result() if self.results_context else None
for index, action in enumerate(self.actions):
results_context.current_index = index
prepared = action.prepare_for_chain(results_context)
result = await prepared(*args, **updated_kwargs)
run_kwargs = dict(updated_kwargs)
underlying = getattr(prepared, "action", None)
if underlying:
signature = inspect.signature(underlying)
else:
signature = inspect.signature(prepared._run)
parameters = signature.parameters
if last_result is not None:
if action.inject_last_result_as in parameters:
run_kwargs[action.inject_last_result_as] = last_result
result = await prepared(*args, **run_kwargs)
elif (
len(parameters) == 1 and
not parameters.get("self")
):
result = await prepared(last_result)
else:
result = await prepared(*args, **updated_kwargs)
else:
result = await prepared(*args, **updated_kwargs)
last_result = result
results_context.add_result(result)
context.extra["results"].append(result)
context.extra["rollback_stack"].append(prepared)
context.result = context.extra["results"]
context.result = last_result
await self.hooks.trigger(HookType.ON_SUCCESS, context)
return context.result
return last_result
except Exception as error:
context.exception = error
@ -302,7 +363,7 @@ class ChainedAction(BaseAction, ActionListMixin):
logger.error("[%s]⚠️ Rollback failed: %s", action.name, error)
async def preview(self, parent: Tree | None = None):
label = f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'"
label = [f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'"]
if self.inject_last_result:
label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]")
tree = parent.add("".join(label)) if parent else Tree("".join(label))