Experiemental feature pipes
This commit is contained in:
parent
6c72e22415
commit
3fd27094d4
|
@ -11,6 +11,7 @@ This guarantees:
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import inspect
|
||||||
import random
|
import random
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import ProcessPoolExecutor
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
@ -22,6 +23,7 @@ from rich.tree import Tree
|
||||||
|
|
||||||
from falyx.context import ExecutionContext, ResultsContext
|
from falyx.context import ExecutionContext, ResultsContext
|
||||||
from falyx.debug import register_debug_hooks
|
from falyx.debug import register_debug_hooks
|
||||||
|
from falyx.exceptions import FalyxError
|
||||||
from falyx.execution_registry import ExecutionRegistry as er
|
from falyx.execution_registry import ExecutionRegistry as er
|
||||||
from falyx.hook_manager import Hook, HookManager, HookType
|
from falyx.hook_manager import Hook, HookManager, HookType
|
||||||
from falyx.retry import RetryHandler, RetryPolicy
|
from falyx.retry import RetryHandler, RetryPolicy
|
||||||
|
@ -102,12 +104,6 @@ class BaseAction(ABC):
|
||||||
"""Register a hook for all actions and sub-actions."""
|
"""Register a hook for all actions and sub-actions."""
|
||||||
self.hooks.register(hook_type, hook)
|
self.hooks.register(hook_type, hook)
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return f"<{self.__class__.__name__} '{self.name}'>"
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return str(self)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enable_retries_recursively(cls, action: BaseAction, policy: RetryPolicy | None):
|
def enable_retries_recursively(cls, action: BaseAction, policy: RetryPolicy | None):
|
||||||
if not policy:
|
if not policy:
|
||||||
|
@ -115,12 +111,46 @@ class BaseAction(ABC):
|
||||||
if isinstance(action, Action):
|
if isinstance(action, Action):
|
||||||
action.retry_policy = policy
|
action.retry_policy = policy
|
||||||
action.retry_policy.enabled = True
|
action.retry_policy.enabled = True
|
||||||
action.hooks.register("on_error", RetryHandler(policy).retry_on_error)
|
action.hooks.register(HookType.ON_ERROR, RetryHandler(policy).retry_on_error)
|
||||||
|
|
||||||
if hasattr(action, "actions"):
|
if hasattr(action, "actions"):
|
||||||
for sub in action.actions:
|
for sub in action.actions:
|
||||||
cls.enable_retries_recursively(sub, policy)
|
cls.enable_retries_recursively(sub, policy)
|
||||||
|
|
||||||
|
async def _write_stdout(self, data: str) -> None:
|
||||||
|
"""Override in subclasses that produce terminal output."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"<{self.__class__.__name__} '{self.name}'>"
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self)
|
||||||
|
|
||||||
|
def __or__(self, other: BaseAction) -> ChainedAction:
|
||||||
|
"""Chain this action with another action."""
|
||||||
|
if not isinstance(other, BaseAction):
|
||||||
|
raise FalyxError(f"Cannot chain {type(other)} with {type(self)}")
|
||||||
|
return ChainedAction(name=f"{self.name} | {other.name}", actions=[self, other])
|
||||||
|
|
||||||
|
async def __ror__(self, other: Any):
|
||||||
|
if inspect.isawaitable(other):
|
||||||
|
print(1)
|
||||||
|
other = await other
|
||||||
|
|
||||||
|
if self.inject_last_result:
|
||||||
|
print(2)
|
||||||
|
return await self(**{self.inject_last_result_as: other})
|
||||||
|
|
||||||
|
literal_action = Action(
|
||||||
|
name=f"Input | {self.name}",
|
||||||
|
action=lambda: other,
|
||||||
|
)
|
||||||
|
|
||||||
|
chain = ChainedAction(name=f"{other} | {self.name}", actions=[literal_action, self])
|
||||||
|
print(3)
|
||||||
|
print(self.name, other)
|
||||||
|
return await chain()
|
||||||
|
|
||||||
class Action(BaseAction):
|
class Action(BaseAction):
|
||||||
"""A simple action that runs a callable. It can be a function or a coroutine."""
|
"""A simple action that runs a callable. It can be a function or a coroutine."""
|
||||||
|
@ -205,6 +235,12 @@ class Action(BaseAction):
|
||||||
console.print(Tree("".join(label)))
|
console.print(Tree("".join(label)))
|
||||||
|
|
||||||
|
|
||||||
|
class LiteralInputAction(Action):
|
||||||
|
def __init__(self, value: Any):
|
||||||
|
async def literal(*args, **kwargs): return value
|
||||||
|
super().__init__("Input", literal, inject_last_result=True)
|
||||||
|
|
||||||
|
|
||||||
class ActionListMixin:
|
class ActionListMixin:
|
||||||
"""Mixin for managing a list of actions."""
|
"""Mixin for managing a list of actions."""
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
@ -266,18 +302,43 @@ class ChainedAction(BaseAction, ActionListMixin):
|
||||||
context.start_timer()
|
context.start_timer()
|
||||||
try:
|
try:
|
||||||
await self.hooks.trigger(HookType.BEFORE, context)
|
await self.hooks.trigger(HookType.BEFORE, context)
|
||||||
|
last_result = self.results_context.last_result() if self.results_context else None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for index, action in enumerate(self.actions):
|
for index, action in enumerate(self.actions):
|
||||||
results_context.current_index = index
|
results_context.current_index = index
|
||||||
prepared = action.prepare_for_chain(results_context)
|
prepared = action.prepare_for_chain(results_context)
|
||||||
result = await prepared(*args, **updated_kwargs)
|
run_kwargs = dict(updated_kwargs)
|
||||||
|
|
||||||
|
underlying = getattr(prepared, "action", None)
|
||||||
|
if underlying:
|
||||||
|
signature = inspect.signature(underlying)
|
||||||
|
else:
|
||||||
|
signature = inspect.signature(prepared._run)
|
||||||
|
parameters = signature.parameters
|
||||||
|
|
||||||
|
if last_result is not None:
|
||||||
|
if action.inject_last_result_as in parameters:
|
||||||
|
run_kwargs[action.inject_last_result_as] = last_result
|
||||||
|
result = await prepared(*args, **run_kwargs)
|
||||||
|
elif (
|
||||||
|
len(parameters) == 1 and
|
||||||
|
not parameters.get("self")
|
||||||
|
):
|
||||||
|
result = await prepared(last_result)
|
||||||
|
else:
|
||||||
|
result = await prepared(*args, **updated_kwargs)
|
||||||
|
else:
|
||||||
|
result = await prepared(*args, **updated_kwargs)
|
||||||
|
last_result = result
|
||||||
results_context.add_result(result)
|
results_context.add_result(result)
|
||||||
context.extra["results"].append(result)
|
context.extra["results"].append(result)
|
||||||
context.extra["rollback_stack"].append(prepared)
|
context.extra["rollback_stack"].append(prepared)
|
||||||
|
|
||||||
context.result = context.extra["results"]
|
context.result = last_result
|
||||||
await self.hooks.trigger(HookType.ON_SUCCESS, context)
|
await self.hooks.trigger(HookType.ON_SUCCESS, context)
|
||||||
return context.result
|
return last_result
|
||||||
|
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
context.exception = error
|
context.exception = error
|
||||||
|
@ -302,7 +363,7 @@ class ChainedAction(BaseAction, ActionListMixin):
|
||||||
logger.error("[%s]⚠️ Rollback failed: %s", action.name, error)
|
logger.error("[%s]⚠️ Rollback failed: %s", action.name, error)
|
||||||
|
|
||||||
async def preview(self, parent: Tree | None = None):
|
async def preview(self, parent: Tree | None = None):
|
||||||
label = f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'"
|
label = [f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'"]
|
||||||
if self.inject_last_result:
|
if self.inject_last_result:
|
||||||
label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]")
|
label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]")
|
||||||
tree = parent.add("".join(label)) if parent else Tree("".join(label))
|
tree = parent.add("".join(label)) if parent else Tree("".join(label))
|
||||||
|
|
Loading…
Reference in New Issue