Compare commits

...

1 Commits
main ... pipes

Author SHA1 Message Date
Roland Thomas Jr 3fd27094d4
Experiemental feature pipes 2025-04-24 19:13:52 -04:00
1 changed files with 72 additions and 11 deletions

View File

@ -11,6 +11,7 @@ This guarantees:
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import inspect
import random import random
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
@ -22,6 +23,7 @@ from rich.tree import Tree
from falyx.context import ExecutionContext, ResultsContext from falyx.context import ExecutionContext, ResultsContext
from falyx.debug import register_debug_hooks from falyx.debug import register_debug_hooks
from falyx.exceptions import FalyxError
from falyx.execution_registry import ExecutionRegistry as er from falyx.execution_registry import ExecutionRegistry as er
from falyx.hook_manager import Hook, HookManager, HookType from falyx.hook_manager import Hook, HookManager, HookType
from falyx.retry import RetryHandler, RetryPolicy from falyx.retry import RetryHandler, RetryPolicy
@ -102,12 +104,6 @@ class BaseAction(ABC):
"""Register a hook for all actions and sub-actions.""" """Register a hook for all actions and sub-actions."""
self.hooks.register(hook_type, hook) self.hooks.register(hook_type, hook)
def __str__(self):
return f"<{self.__class__.__name__} '{self.name}'>"
def __repr__(self):
return str(self)
@classmethod @classmethod
def enable_retries_recursively(cls, action: BaseAction, policy: RetryPolicy | None): def enable_retries_recursively(cls, action: BaseAction, policy: RetryPolicy | None):
if not policy: if not policy:
@ -115,12 +111,46 @@ class BaseAction(ABC):
if isinstance(action, Action): if isinstance(action, Action):
action.retry_policy = policy action.retry_policy = policy
action.retry_policy.enabled = True action.retry_policy.enabled = True
action.hooks.register("on_error", RetryHandler(policy).retry_on_error) action.hooks.register(HookType.ON_ERROR, RetryHandler(policy).retry_on_error)
if hasattr(action, "actions"): if hasattr(action, "actions"):
for sub in action.actions: for sub in action.actions:
cls.enable_retries_recursively(sub, policy) cls.enable_retries_recursively(sub, policy)
async def _write_stdout(self, data: str) -> None:
"""Override in subclasses that produce terminal output."""
pass
def __str__(self):
return f"<{self.__class__.__name__} '{self.name}'>"
def __repr__(self):
return str(self)
def __or__(self, other: BaseAction) -> ChainedAction:
"""Chain this action with another action."""
if not isinstance(other, BaseAction):
raise FalyxError(f"Cannot chain {type(other)} with {type(self)}")
return ChainedAction(name=f"{self.name} | {other.name}", actions=[self, other])
async def __ror__(self, other: Any):
if inspect.isawaitable(other):
print(1)
other = await other
if self.inject_last_result:
print(2)
return await self(**{self.inject_last_result_as: other})
literal_action = Action(
name=f"Input | {self.name}",
action=lambda: other,
)
chain = ChainedAction(name=f"{other} | {self.name}", actions=[literal_action, self])
print(3)
print(self.name, other)
return await chain()
class Action(BaseAction): class Action(BaseAction):
"""A simple action that runs a callable. It can be a function or a coroutine.""" """A simple action that runs a callable. It can be a function or a coroutine."""
@ -205,6 +235,12 @@ class Action(BaseAction):
console.print(Tree("".join(label))) console.print(Tree("".join(label)))
class LiteralInputAction(Action):
def __init__(self, value: Any):
async def literal(*args, **kwargs): return value
super().__init__("Input", literal, inject_last_result=True)
class ActionListMixin: class ActionListMixin:
"""Mixin for managing a list of actions.""" """Mixin for managing a list of actions."""
def __init__(self) -> None: def __init__(self) -> None:
@ -266,18 +302,43 @@ class ChainedAction(BaseAction, ActionListMixin):
context.start_timer() context.start_timer()
try: try:
await self.hooks.trigger(HookType.BEFORE, context) await self.hooks.trigger(HookType.BEFORE, context)
last_result = self.results_context.last_result() if self.results_context else None
for index, action in enumerate(self.actions): for index, action in enumerate(self.actions):
results_context.current_index = index results_context.current_index = index
prepared = action.prepare_for_chain(results_context) prepared = action.prepare_for_chain(results_context)
run_kwargs = dict(updated_kwargs)
underlying = getattr(prepared, "action", None)
if underlying:
signature = inspect.signature(underlying)
else:
signature = inspect.signature(prepared._run)
parameters = signature.parameters
if last_result is not None:
if action.inject_last_result_as in parameters:
run_kwargs[action.inject_last_result_as] = last_result
result = await prepared(*args, **run_kwargs)
elif (
len(parameters) == 1 and
not parameters.get("self")
):
result = await prepared(last_result)
else:
result = await prepared(*args, **updated_kwargs) result = await prepared(*args, **updated_kwargs)
else:
result = await prepared(*args, **updated_kwargs)
last_result = result
results_context.add_result(result) results_context.add_result(result)
context.extra["results"].append(result) context.extra["results"].append(result)
context.extra["rollback_stack"].append(prepared) context.extra["rollback_stack"].append(prepared)
context.result = context.extra["results"] context.result = last_result
await self.hooks.trigger(HookType.ON_SUCCESS, context) await self.hooks.trigger(HookType.ON_SUCCESS, context)
return context.result return last_result
except Exception as error: except Exception as error:
context.exception = error context.exception = error
@ -302,7 +363,7 @@ class ChainedAction(BaseAction, ActionListMixin):
logger.error("[%s]⚠️ Rollback failed: %s", action.name, error) logger.error("[%s]⚠️ Rollback failed: %s", action.name, error)
async def preview(self, parent: Tree | None = None): async def preview(self, parent: Tree | None = None):
label = f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'" label = [f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'"]
if self.inject_last_result: if self.inject_last_result:
label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]") label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]")
tree = parent.add("".join(label)) if parent else Tree("".join(label)) tree = parent.add("".join(label)) if parent else Tree("".join(label))