From e9fdd9cec6a65bf30fb9a0c6fdbc3abf2ee0b9a6 Mon Sep 17 00:00:00 2001 From: Roland Thomas Date: Thu, 24 Apr 2025 22:39:42 -0400 Subject: [PATCH] Add compatibility between BaseAction and BaseIOAction --- falyx/action.py | 56 +++++++++++++++++++++++++++++++++++++--------- falyx/falyx.py | 25 ++++++++++++--------- falyx/io_action.py | 4 +++- 3 files changed, 64 insertions(+), 21 deletions(-) diff --git a/falyx/action.py b/falyx/action.py index 52e3253..0b9e7d5 100644 --- a/falyx/action.py +++ b/falyx/action.py @@ -55,6 +55,7 @@ class BaseAction(ABC): self.results_context: ResultsContext | None = None self.inject_last_result: bool = inject_last_result self.inject_last_result_as: str = inject_last_result_as + self.requires_injection: bool = False if logging_hooks: register_debug_hooks(self.hooks) @@ -102,12 +103,6 @@ class BaseAction(ABC): """Register a hook for all actions and sub-actions.""" self.hooks.register(hook_type, hook) - def __str__(self): - return f"<{self.__class__.__name__} '{self.name}'>" - - def __repr__(self): - return str(self) - @classmethod def enable_retries_recursively(cls, action: BaseAction, policy: RetryPolicy | None): if not policy: @@ -115,12 +110,26 @@ class BaseAction(ABC): if isinstance(action, Action): action.retry_policy = policy action.retry_policy.enabled = True - action.hooks.register("on_error", RetryHandler(policy).retry_on_error) + action.hooks.register(HookType.ON_ERROR, RetryHandler(policy).retry_on_error) if hasattr(action, "actions"): for sub in action.actions: cls.enable_retries_recursively(sub, policy) + async def _write_stdout(self, data: str) -> None: + """Override in subclasses that produce terminal output.""" + pass + + def requires_io_injection(self) -> bool: + """Checks to see if the action requires input injection.""" + return self.requires_injection + + def __str__(self): + return f"<{self.__class__.__name__} '{self.name}'>" + + def __repr__(self): + return str(self) + class Action(BaseAction): """A simple action that runs a callable. It can be a function or a coroutine.""" @@ -205,6 +214,13 @@ class Action(BaseAction): console.print(Tree("".join(label))) +class LiteralInputAction(Action): + def __init__(self, value: Any): + async def literal(*args, **kwargs): + return value + super().__init__("Input", literal, inject_last_result=True) + + class ActionListMixin: """Mixin for managing a list of actions.""" def __init__(self) -> None: @@ -241,16 +257,32 @@ class ChainedAction(BaseAction, ActionListMixin): def __init__( self, name: str, - actions: list[BaseAction] | None = None, + actions: list[BaseAction | Any] | None = None, hooks: HookManager | None = None, inject_last_result: bool = False, inject_last_result_as: str = "last_result", + auto_inject: bool = False, ) -> None: super().__init__(name, hooks, inject_last_result, inject_last_result_as) ActionListMixin.__init__(self) + self.auto_inject = auto_inject if actions: self.set_actions(actions) + def _wrap_literal_if_needed(self, action: BaseAction | Any) -> BaseAction: + return LiteralInputAction(action) if not isinstance(action, BaseAction) else action + + def _apply_auto_inject(self, action: BaseAction) -> None: + if self.auto_inject and not action.inject_last_result: + action.inject_last_result = True + + def set_actions(self, actions: list[BaseAction | Any]): + self.actions.clear() + for action in actions: + action = self._wrap_literal_if_needed(action) + self._apply_auto_inject(action) + self.add_action(action) + async def _run(self, *args, **kwargs) -> list[Any]: results_context = ResultsContext(name=self.name) if self.results_context: @@ -270,7 +302,11 @@ class ChainedAction(BaseAction, ActionListMixin): for index, action in enumerate(self.actions): results_context.current_index = index prepared = action.prepare_for_chain(results_context) - result = await prepared(*args, **updated_kwargs) + last_result = results_context.last_result() + if self.requires_io_injection() and last_result is not None: + result = await prepared(**{prepared.inject_last_result_as: last_result}) + else: + result = await prepared(*args, **updated_kwargs) results_context.add_result(result) context.extra["results"].append(result) context.extra["rollback_stack"].append(prepared) @@ -302,7 +338,7 @@ class ChainedAction(BaseAction, ActionListMixin): logger.error("[%s]⚠️ Rollback failed: %s", action.name, error) async def preview(self, parent: Tree | None = None): - label = f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'" + label = [f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'"] if self.inject_last_result: label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]") tree = parent.add("".join(label)) if parent else Tree("".join(label)) diff --git a/falyx/falyx.py b/falyx/falyx.py index c1f3e59..d9db7a8 100644 --- a/falyx/falyx.py +++ b/falyx/falyx.py @@ -113,6 +113,7 @@ class Falyx: self.cli_args: Namespace | None = cli_args self.custom_table: Callable[["Falyx"], Table] | Table | None = custom_table self.set_options(cli_args, options) + self._session: PromptSession | None = None def set_options( self, @@ -127,6 +128,7 @@ class Falyx: if options and not cli_args: raise FalyxError("Options are set, but CLI arguments are not.") + assert isinstance(cli_args, Namespace), "CLI arguments must be a Namespace object." if options is None: self.options.from_namespace(cli_args, "cli_args") @@ -301,6 +303,7 @@ class Falyx: """Forces the session to be recreated on the next access.""" if hasattr(self, "session"): del self.session + self._session = None def add_help_command(self): """Adds a help command to the menu if it doesn't already exist.""" @@ -348,15 +351,17 @@ class Falyx: @cached_property def session(self) -> PromptSession: """Returns the prompt session for the menu.""" - return PromptSession( - message=self.prompt, - multiline=False, - completer=self._get_completer(), - reserve_space_for_menu=1, - validator=self._get_validator(), - bottom_toolbar=self._get_bottom_bar_render(), - key_bindings=self.key_bindings, - ) + if self._session is None: + self._session = PromptSession( + message=self.prompt, + multiline=False, + completer=self._get_completer(), + reserve_space_for_menu=1, + validator=self._get_validator(), + bottom_toolbar=self._get_bottom_bar_render(), + key_bindings=self.key_bindings, + ) + return self._session def register_all_hooks(self, hook_type: HookType, hooks: Hook | list[Hook]) -> None: """Registers hooks for all commands in the menu and actions recursively.""" @@ -745,7 +750,7 @@ class Falyx: selected_command.retry_policy.delay = self.cli_args.retry_delay if self.cli_args.retry_backoff: selected_command.retry_policy.backoff = self.cli_args.retry_backoff - selected_command.update_retry_policy(selected_command.retry_policy) + #selected_command.update_retry_policy(selected_command.retry_policy) def print_message(self, message: str | Markdown | dict[str, Any]) -> None: """Prints a message to the console.""" diff --git a/falyx/io_action.py b/falyx/io_action.py index c09e7fa..a0426d6 100644 --- a/falyx/io_action.py +++ b/falyx/io_action.py @@ -12,6 +12,7 @@ from falyx.context import ExecutionContext from falyx.exceptions import FalyxError from falyx.execution_registry import ExecutionRegistry as er from falyx.hook_manager import HookManager, HookType +from falyx.utils import logger from falyx.themes.colors import OneColors console = Console() @@ -33,6 +34,7 @@ class BaseIOAction(BaseAction): inject_last_result=inject_last_result, ) self.mode = mode + self.requires_injection = True def from_input(self, raw: str | bytes) -> Any: raise NotImplementedError @@ -53,6 +55,7 @@ class BaseIOAction(BaseAction): if self.inject_last_result and self.results_context: return self.results_context.last_result() + logger.debug("[%s] No input provided and no last result found for injection.", self.name) raise FalyxError("No input provided and no last result to inject.") async def __call__(self, *args, **kwargs): @@ -75,7 +78,6 @@ class BaseIOAction(BaseAction): else: parsed_input = await self._resolve_input(kwargs) result = await self._run(parsed_input, *args, **kwargs) - result = await self._run(parsed_input) output = self.to_output(result) await self._write_stdout(output) context.result = result