Compare commits
	
		
			1 Commits
		
	
	
		
			3b2c33d28f
			...
			pipes
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 3fd27094d4 | 
| @@ -11,6 +11,7 @@ This guarantees: | ||||
| from __future__ import annotations | ||||
|  | ||||
| import asyncio | ||||
| import inspect | ||||
| import random | ||||
| from abc import ABC, abstractmethod | ||||
| from concurrent.futures import ProcessPoolExecutor | ||||
| @@ -22,6 +23,7 @@ from rich.tree import Tree | ||||
|  | ||||
| from falyx.context import ExecutionContext, ResultsContext | ||||
| from falyx.debug import register_debug_hooks | ||||
| from falyx.exceptions import FalyxError | ||||
| from falyx.execution_registry import ExecutionRegistry as er | ||||
| from falyx.hook_manager import Hook, HookManager, HookType | ||||
| from falyx.retry import RetryHandler, RetryPolicy | ||||
| @@ -102,12 +104,6 @@ class BaseAction(ABC): | ||||
|         """Register a hook for all actions and sub-actions.""" | ||||
|         self.hooks.register(hook_type, hook) | ||||
|  | ||||
|     def __str__(self): | ||||
|         return f"<{self.__class__.__name__} '{self.name}'>" | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return str(self) | ||||
|  | ||||
|     @classmethod | ||||
|     def enable_retries_recursively(cls, action: BaseAction, policy: RetryPolicy | None): | ||||
|         if not policy: | ||||
| @@ -115,12 +111,46 @@ class BaseAction(ABC): | ||||
|         if isinstance(action, Action): | ||||
|             action.retry_policy = policy | ||||
|             action.retry_policy.enabled = True | ||||
|             action.hooks.register("on_error", RetryHandler(policy).retry_on_error) | ||||
|             action.hooks.register(HookType.ON_ERROR, RetryHandler(policy).retry_on_error) | ||||
|  | ||||
|         if hasattr(action, "actions"): | ||||
|             for sub in action.actions: | ||||
|                 cls.enable_retries_recursively(sub, policy) | ||||
|  | ||||
|     async def _write_stdout(self, data: str) -> None: | ||||
|         """Override in subclasses that produce terminal output.""" | ||||
|         pass | ||||
|  | ||||
|     def __str__(self): | ||||
|         return f"<{self.__class__.__name__} '{self.name}'>" | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return str(self) | ||||
|  | ||||
|     def __or__(self, other: BaseAction) -> ChainedAction: | ||||
|         """Chain this action with another action.""" | ||||
|         if not isinstance(other, BaseAction): | ||||
|             raise FalyxError(f"Cannot chain {type(other)} with {type(self)}") | ||||
|         return ChainedAction(name=f"{self.name} | {other.name}", actions=[self, other]) | ||||
|  | ||||
|     async def __ror__(self, other: Any): | ||||
|         if inspect.isawaitable(other): | ||||
|             print(1) | ||||
|             other = await other | ||||
|  | ||||
|         if self.inject_last_result: | ||||
|             print(2) | ||||
|             return await self(**{self.inject_last_result_as: other}) | ||||
|  | ||||
|         literal_action = Action( | ||||
|             name=f"Input | {self.name}", | ||||
|             action=lambda: other, | ||||
|         ) | ||||
|  | ||||
|         chain = ChainedAction(name=f"{other} | {self.name}", actions=[literal_action, self]) | ||||
|         print(3) | ||||
|         print(self.name, other) | ||||
|         return await chain() | ||||
|  | ||||
| class Action(BaseAction): | ||||
|     """A simple action that runs a callable. It can be a function or a coroutine.""" | ||||
| @@ -205,6 +235,12 @@ class Action(BaseAction): | ||||
|             console.print(Tree("".join(label))) | ||||
|  | ||||
|  | ||||
| class LiteralInputAction(Action): | ||||
|     def __init__(self, value: Any): | ||||
|         async def literal(*args, **kwargs): return value | ||||
|         super().__init__("Input", literal, inject_last_result=True) | ||||
|  | ||||
|  | ||||
| class ActionListMixin: | ||||
|     """Mixin for managing a list of actions.""" | ||||
|     def __init__(self) -> None: | ||||
| @@ -266,18 +302,43 @@ class ChainedAction(BaseAction, ActionListMixin): | ||||
|         context.start_timer() | ||||
|         try: | ||||
|             await self.hooks.trigger(HookType.BEFORE, context) | ||||
|             last_result = self.results_context.last_result() if self.results_context else None | ||||
|  | ||||
|  | ||||
|  | ||||
|             for index, action in enumerate(self.actions): | ||||
|                 results_context.current_index = index | ||||
|                 prepared = action.prepare_for_chain(results_context) | ||||
|                 run_kwargs = dict(updated_kwargs) | ||||
|  | ||||
|                 underlying = getattr(prepared, "action", None) | ||||
|                 if underlying: | ||||
|                     signature = inspect.signature(underlying) | ||||
|                 else: | ||||
|                     signature = inspect.signature(prepared._run) | ||||
|                 parameters = signature.parameters | ||||
|  | ||||
|                 if last_result is not None: | ||||
|                     if action.inject_last_result_as in parameters: | ||||
|                         run_kwargs[action.inject_last_result_as] = last_result | ||||
|                         result = await prepared(*args, **run_kwargs) | ||||
|                     elif ( | ||||
|                         len(parameters) == 1 and | ||||
|                         not parameters.get("self") | ||||
|                     ): | ||||
|                         result = await prepared(last_result) | ||||
|                     else: | ||||
|                         result = await prepared(*args, **updated_kwargs) | ||||
|                 else: | ||||
|                     result = await prepared(*args, **updated_kwargs) | ||||
|                 last_result = result | ||||
|                 results_context.add_result(result) | ||||
|                 context.extra["results"].append(result) | ||||
|                 context.extra["rollback_stack"].append(prepared) | ||||
|  | ||||
|             context.result = context.extra["results"] | ||||
|             context.result = last_result | ||||
|             await self.hooks.trigger(HookType.ON_SUCCESS, context) | ||||
|             return context.result | ||||
|             return last_result | ||||
|  | ||||
|         except Exception as error: | ||||
|             context.exception = error | ||||
| @@ -302,7 +363,7 @@ class ChainedAction(BaseAction, ActionListMixin): | ||||
|                     logger.error("[%s]⚠️ Rollback failed: %s", action.name, error) | ||||
|  | ||||
|     async def preview(self, parent: Tree | None = None): | ||||
|         label = f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'" | ||||
|         label = [f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'"] | ||||
|         if self.inject_last_result: | ||||
|             label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]") | ||||
|         tree = parent.add("".join(label)) if parent else Tree("".join(label)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user