Add compatibility between BaseAction and BaseIOAction
This commit is contained in:
		| @@ -55,6 +55,7 @@ class BaseAction(ABC): | ||||
|         self.results_context: ResultsContext | None = None | ||||
|         self.inject_last_result: bool = inject_last_result | ||||
|         self.inject_last_result_as: str = inject_last_result_as | ||||
|         self.requires_injection: bool = False | ||||
|  | ||||
|         if logging_hooks: | ||||
|             register_debug_hooks(self.hooks) | ||||
| @@ -102,12 +103,6 @@ class BaseAction(ABC): | ||||
|         """Register a hook for all actions and sub-actions.""" | ||||
|         self.hooks.register(hook_type, hook) | ||||
|  | ||||
|     def __str__(self): | ||||
|         return f"<{self.__class__.__name__} '{self.name}'>" | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return str(self) | ||||
|  | ||||
|     @classmethod | ||||
|     def enable_retries_recursively(cls, action: BaseAction, policy: RetryPolicy | None): | ||||
|         if not policy: | ||||
| @@ -115,12 +110,26 @@ class BaseAction(ABC): | ||||
|         if isinstance(action, Action): | ||||
|             action.retry_policy = policy | ||||
|             action.retry_policy.enabled = True | ||||
|             action.hooks.register("on_error", RetryHandler(policy).retry_on_error) | ||||
|             action.hooks.register(HookType.ON_ERROR, RetryHandler(policy).retry_on_error) | ||||
|  | ||||
|         if hasattr(action, "actions"): | ||||
|             for sub in action.actions: | ||||
|                 cls.enable_retries_recursively(sub, policy) | ||||
|  | ||||
|     async def _write_stdout(self, data: str) -> None: | ||||
|         """Override in subclasses that produce terminal output.""" | ||||
|         pass | ||||
|  | ||||
|     def requires_io_injection(self) -> bool: | ||||
|         """Checks to see if the action requires input injection.""" | ||||
|         return self.requires_injection | ||||
|  | ||||
|     def __str__(self): | ||||
|         return f"<{self.__class__.__name__} '{self.name}'>" | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return str(self) | ||||
|  | ||||
|  | ||||
| class Action(BaseAction): | ||||
|     """A simple action that runs a callable. It can be a function or a coroutine.""" | ||||
| @@ -205,6 +214,13 @@ class Action(BaseAction): | ||||
|             console.print(Tree("".join(label))) | ||||
|  | ||||
|  | ||||
| class LiteralInputAction(Action): | ||||
|     def __init__(self, value: Any): | ||||
|         async def literal(*args, **kwargs): | ||||
|             return value | ||||
|         super().__init__("Input", literal, inject_last_result=True) | ||||
|  | ||||
|  | ||||
| class ActionListMixin: | ||||
|     """Mixin for managing a list of actions.""" | ||||
|     def __init__(self) -> None: | ||||
| @@ -241,16 +257,32 @@ class ChainedAction(BaseAction, ActionListMixin): | ||||
|     def __init__( | ||||
|             self, | ||||
|             name: str, | ||||
|             actions: list[BaseAction] | None = None, | ||||
|             actions: list[BaseAction | Any] | None = None, | ||||
|             hooks: HookManager | None = None, | ||||
|             inject_last_result: bool = False, | ||||
|             inject_last_result_as: str = "last_result", | ||||
|             auto_inject: bool = False, | ||||
|     ) -> None: | ||||
|         super().__init__(name, hooks, inject_last_result, inject_last_result_as) | ||||
|         ActionListMixin.__init__(self) | ||||
|         self.auto_inject = auto_inject | ||||
|         if actions: | ||||
|             self.set_actions(actions) | ||||
|  | ||||
|     def _wrap_literal_if_needed(self, action: BaseAction | Any) -> BaseAction: | ||||
|         return LiteralInputAction(action) if not isinstance(action, BaseAction) else action | ||||
|  | ||||
|     def _apply_auto_inject(self, action: BaseAction) -> None: | ||||
|         if self.auto_inject and not action.inject_last_result: | ||||
|             action.inject_last_result = True | ||||
|  | ||||
|     def set_actions(self, actions: list[BaseAction | Any]): | ||||
|         self.actions.clear() | ||||
|         for action in actions: | ||||
|             action = self._wrap_literal_if_needed(action) | ||||
|             self._apply_auto_inject(action) | ||||
|             self.add_action(action) | ||||
|  | ||||
|     async def _run(self, *args, **kwargs) -> list[Any]: | ||||
|         results_context = ResultsContext(name=self.name) | ||||
|         if self.results_context: | ||||
| @@ -270,7 +302,11 @@ class ChainedAction(BaseAction, ActionListMixin): | ||||
|             for index, action in enumerate(self.actions): | ||||
|                 results_context.current_index = index | ||||
|                 prepared = action.prepare_for_chain(results_context) | ||||
|                 result = await prepared(*args, **updated_kwargs) | ||||
|                 last_result = results_context.last_result() | ||||
|                 if self.requires_io_injection() and last_result is not None: | ||||
|                     result = await prepared(**{prepared.inject_last_result_as: last_result}) | ||||
|                 else: | ||||
|                     result = await prepared(*args, **updated_kwargs) | ||||
|                 results_context.add_result(result) | ||||
|                 context.extra["results"].append(result) | ||||
|                 context.extra["rollback_stack"].append(prepared) | ||||
| @@ -302,7 +338,7 @@ class ChainedAction(BaseAction, ActionListMixin): | ||||
|                     logger.error("[%s]⚠️ Rollback failed: %s", action.name, error) | ||||
|  | ||||
|     async def preview(self, parent: Tree | None = None): | ||||
|         label = f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'" | ||||
|         label = [f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'"] | ||||
|         if self.inject_last_result: | ||||
|             label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]") | ||||
|         tree = parent.add("".join(label)) if parent else Tree("".join(label)) | ||||
|   | ||||
| @@ -113,6 +113,7 @@ class Falyx: | ||||
|         self.cli_args: Namespace | None = cli_args | ||||
|         self.custom_table: Callable[["Falyx"], Table] | Table | None = custom_table | ||||
|         self.set_options(cli_args, options) | ||||
|         self._session: PromptSession | None = None | ||||
|  | ||||
|     def set_options( | ||||
|             self, | ||||
| @@ -127,6 +128,7 @@ class Falyx: | ||||
|         if options and not cli_args: | ||||
|             raise FalyxError("Options are set, but CLI arguments are not.") | ||||
|  | ||||
|         assert isinstance(cli_args, Namespace), "CLI arguments must be a Namespace object." | ||||
|         if options is None: | ||||
|             self.options.from_namespace(cli_args, "cli_args") | ||||
|  | ||||
| @@ -301,6 +303,7 @@ class Falyx: | ||||
|         """Forces the session to be recreated on the next access.""" | ||||
|         if hasattr(self, "session"): | ||||
|             del self.session | ||||
|         self._session = None | ||||
|  | ||||
|     def add_help_command(self): | ||||
|         """Adds a help command to the menu if it doesn't already exist.""" | ||||
| @@ -348,15 +351,17 @@ class Falyx: | ||||
|     @cached_property | ||||
|     def session(self) -> PromptSession: | ||||
|         """Returns the prompt session for the menu.""" | ||||
|         return PromptSession( | ||||
|             message=self.prompt, | ||||
|             multiline=False, | ||||
|             completer=self._get_completer(), | ||||
|             reserve_space_for_menu=1, | ||||
|             validator=self._get_validator(), | ||||
|             bottom_toolbar=self._get_bottom_bar_render(), | ||||
|             key_bindings=self.key_bindings, | ||||
|         ) | ||||
|         if self._session is None: | ||||
|             self._session = PromptSession( | ||||
|                                 message=self.prompt, | ||||
|                                 multiline=False, | ||||
|                                 completer=self._get_completer(), | ||||
|                                 reserve_space_for_menu=1, | ||||
|                                 validator=self._get_validator(), | ||||
|                                 bottom_toolbar=self._get_bottom_bar_render(), | ||||
|                                 key_bindings=self.key_bindings, | ||||
|                             ) | ||||
|         return self._session | ||||
|  | ||||
|     def register_all_hooks(self, hook_type: HookType, hooks: Hook | list[Hook]) -> None: | ||||
|         """Registers hooks for all commands in the menu and actions recursively.""" | ||||
| @@ -745,7 +750,7 @@ class Falyx: | ||||
|                 selected_command.retry_policy.delay = self.cli_args.retry_delay | ||||
|             if self.cli_args.retry_backoff: | ||||
|                 selected_command.retry_policy.backoff = self.cli_args.retry_backoff | ||||
|             selected_command.update_retry_policy(selected_command.retry_policy) | ||||
|             #selected_command.update_retry_policy(selected_command.retry_policy) | ||||
|  | ||||
|     def print_message(self, message: str | Markdown | dict[str, Any]) -> None: | ||||
|         """Prints a message to the console.""" | ||||
|   | ||||
| @@ -12,6 +12,7 @@ from falyx.context import ExecutionContext | ||||
| from falyx.exceptions import FalyxError | ||||
| from falyx.execution_registry import ExecutionRegistry as er | ||||
| from falyx.hook_manager import HookManager, HookType | ||||
| from falyx.utils import logger | ||||
| from falyx.themes.colors import OneColors | ||||
|  | ||||
| console = Console() | ||||
| @@ -33,6 +34,7 @@ class BaseIOAction(BaseAction): | ||||
|             inject_last_result=inject_last_result, | ||||
|         ) | ||||
|         self.mode = mode | ||||
|         self.requires_injection = True | ||||
|  | ||||
|     def from_input(self, raw: str | bytes) -> Any: | ||||
|         raise NotImplementedError | ||||
| @@ -53,6 +55,7 @@ class BaseIOAction(BaseAction): | ||||
|         if self.inject_last_result and self.results_context: | ||||
|             return self.results_context.last_result() | ||||
|  | ||||
|         logger.debug("[%s] No input provided and no last result found for injection.", self.name) | ||||
|         raise FalyxError("No input provided and no last result to inject.") | ||||
|  | ||||
|     async def __call__(self, *args, **kwargs): | ||||
| @@ -75,7 +78,6 @@ class BaseIOAction(BaseAction): | ||||
|             else: | ||||
|                 parsed_input = await self._resolve_input(kwargs) | ||||
|                 result = await self._run(parsed_input, *args, **kwargs) | ||||
|                 result = await self._run(parsed_input) | ||||
|                 output = self.to_output(result) | ||||
|                 await self._write_stdout(output) | ||||
|             context.result = result | ||||
|   | ||||
		Reference in New Issue
	
	Block a user