Add compatibility between BaseAction and BaseIOAction

This commit is contained in:
Roland Thomas Jr 2025-04-24 22:39:42 -04:00
parent 1fe0cd2675
commit e9fdd9cec6
Signed by: roland
GPG Key ID: 7C3C2B085A4C2872
3 changed files with 64 additions and 21 deletions

View File

@ -55,6 +55,7 @@ class BaseAction(ABC):
self.results_context: ResultsContext | None = None self.results_context: ResultsContext | None = None
self.inject_last_result: bool = inject_last_result self.inject_last_result: bool = inject_last_result
self.inject_last_result_as: str = inject_last_result_as self.inject_last_result_as: str = inject_last_result_as
self.requires_injection: bool = False
if logging_hooks: if logging_hooks:
register_debug_hooks(self.hooks) register_debug_hooks(self.hooks)
@ -102,12 +103,6 @@ class BaseAction(ABC):
"""Register a hook for all actions and sub-actions.""" """Register a hook for all actions and sub-actions."""
self.hooks.register(hook_type, hook) self.hooks.register(hook_type, hook)
def __str__(self):
return f"<{self.__class__.__name__} '{self.name}'>"
def __repr__(self):
return str(self)
@classmethod @classmethod
def enable_retries_recursively(cls, action: BaseAction, policy: RetryPolicy | None): def enable_retries_recursively(cls, action: BaseAction, policy: RetryPolicy | None):
if not policy: if not policy:
@ -115,12 +110,26 @@ class BaseAction(ABC):
if isinstance(action, Action): if isinstance(action, Action):
action.retry_policy = policy action.retry_policy = policy
action.retry_policy.enabled = True action.retry_policy.enabled = True
action.hooks.register("on_error", RetryHandler(policy).retry_on_error) action.hooks.register(HookType.ON_ERROR, RetryHandler(policy).retry_on_error)
if hasattr(action, "actions"): if hasattr(action, "actions"):
for sub in action.actions: for sub in action.actions:
cls.enable_retries_recursively(sub, policy) cls.enable_retries_recursively(sub, policy)
async def _write_stdout(self, data: str) -> None:
"""Override in subclasses that produce terminal output."""
pass
def requires_io_injection(self) -> bool:
"""Checks to see if the action requires input injection."""
return self.requires_injection
def __str__(self):
return f"<{self.__class__.__name__} '{self.name}'>"
def __repr__(self):
return str(self)
class Action(BaseAction): class Action(BaseAction):
"""A simple action that runs a callable. It can be a function or a coroutine.""" """A simple action that runs a callable. It can be a function or a coroutine."""
@ -205,6 +214,13 @@ class Action(BaseAction):
console.print(Tree("".join(label))) console.print(Tree("".join(label)))
class LiteralInputAction(Action):
def __init__(self, value: Any):
async def literal(*args, **kwargs):
return value
super().__init__("Input", literal, inject_last_result=True)
class ActionListMixin: class ActionListMixin:
"""Mixin for managing a list of actions.""" """Mixin for managing a list of actions."""
def __init__(self) -> None: def __init__(self) -> None:
@ -241,16 +257,32 @@ class ChainedAction(BaseAction, ActionListMixin):
def __init__( def __init__(
self, self,
name: str, name: str,
actions: list[BaseAction] | None = None, actions: list[BaseAction | Any] | None = None,
hooks: HookManager | None = None, hooks: HookManager | None = None,
inject_last_result: bool = False, inject_last_result: bool = False,
inject_last_result_as: str = "last_result", inject_last_result_as: str = "last_result",
auto_inject: bool = False,
) -> None: ) -> None:
super().__init__(name, hooks, inject_last_result, inject_last_result_as) super().__init__(name, hooks, inject_last_result, inject_last_result_as)
ActionListMixin.__init__(self) ActionListMixin.__init__(self)
self.auto_inject = auto_inject
if actions: if actions:
self.set_actions(actions) self.set_actions(actions)
def _wrap_literal_if_needed(self, action: BaseAction | Any) -> BaseAction:
return LiteralInputAction(action) if not isinstance(action, BaseAction) else action
def _apply_auto_inject(self, action: BaseAction) -> None:
if self.auto_inject and not action.inject_last_result:
action.inject_last_result = True
def set_actions(self, actions: list[BaseAction | Any]):
self.actions.clear()
for action in actions:
action = self._wrap_literal_if_needed(action)
self._apply_auto_inject(action)
self.add_action(action)
async def _run(self, *args, **kwargs) -> list[Any]: async def _run(self, *args, **kwargs) -> list[Any]:
results_context = ResultsContext(name=self.name) results_context = ResultsContext(name=self.name)
if self.results_context: if self.results_context:
@ -270,7 +302,11 @@ class ChainedAction(BaseAction, ActionListMixin):
for index, action in enumerate(self.actions): for index, action in enumerate(self.actions):
results_context.current_index = index results_context.current_index = index
prepared = action.prepare_for_chain(results_context) prepared = action.prepare_for_chain(results_context)
result = await prepared(*args, **updated_kwargs) last_result = results_context.last_result()
if self.requires_io_injection() and last_result is not None:
result = await prepared(**{prepared.inject_last_result_as: last_result})
else:
result = await prepared(*args, **updated_kwargs)
results_context.add_result(result) results_context.add_result(result)
context.extra["results"].append(result) context.extra["results"].append(result)
context.extra["rollback_stack"].append(prepared) context.extra["rollback_stack"].append(prepared)
@ -302,7 +338,7 @@ class ChainedAction(BaseAction, ActionListMixin):
logger.error("[%s]⚠️ Rollback failed: %s", action.name, error) logger.error("[%s]⚠️ Rollback failed: %s", action.name, error)
async def preview(self, parent: Tree | None = None): async def preview(self, parent: Tree | None = None):
label = f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'" label = [f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'"]
if self.inject_last_result: if self.inject_last_result:
label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]") label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]")
tree = parent.add("".join(label)) if parent else Tree("".join(label)) tree = parent.add("".join(label)) if parent else Tree("".join(label))

View File

@ -113,6 +113,7 @@ class Falyx:
self.cli_args: Namespace | None = cli_args self.cli_args: Namespace | None = cli_args
self.custom_table: Callable[["Falyx"], Table] | Table | None = custom_table self.custom_table: Callable[["Falyx"], Table] | Table | None = custom_table
self.set_options(cli_args, options) self.set_options(cli_args, options)
self._session: PromptSession | None = None
def set_options( def set_options(
self, self,
@ -127,6 +128,7 @@ class Falyx:
if options and not cli_args: if options and not cli_args:
raise FalyxError("Options are set, but CLI arguments are not.") raise FalyxError("Options are set, but CLI arguments are not.")
assert isinstance(cli_args, Namespace), "CLI arguments must be a Namespace object."
if options is None: if options is None:
self.options.from_namespace(cli_args, "cli_args") self.options.from_namespace(cli_args, "cli_args")
@ -301,6 +303,7 @@ class Falyx:
"""Forces the session to be recreated on the next access.""" """Forces the session to be recreated on the next access."""
if hasattr(self, "session"): if hasattr(self, "session"):
del self.session del self.session
self._session = None
def add_help_command(self): def add_help_command(self):
"""Adds a help command to the menu if it doesn't already exist.""" """Adds a help command to the menu if it doesn't already exist."""
@ -348,15 +351,17 @@ class Falyx:
@cached_property @cached_property
def session(self) -> PromptSession: def session(self) -> PromptSession:
"""Returns the prompt session for the menu.""" """Returns the prompt session for the menu."""
return PromptSession( if self._session is None:
message=self.prompt, self._session = PromptSession(
multiline=False, message=self.prompt,
completer=self._get_completer(), multiline=False,
reserve_space_for_menu=1, completer=self._get_completer(),
validator=self._get_validator(), reserve_space_for_menu=1,
bottom_toolbar=self._get_bottom_bar_render(), validator=self._get_validator(),
key_bindings=self.key_bindings, bottom_toolbar=self._get_bottom_bar_render(),
) key_bindings=self.key_bindings,
)
return self._session
def register_all_hooks(self, hook_type: HookType, hooks: Hook | list[Hook]) -> None: def register_all_hooks(self, hook_type: HookType, hooks: Hook | list[Hook]) -> None:
"""Registers hooks for all commands in the menu and actions recursively.""" """Registers hooks for all commands in the menu and actions recursively."""
@ -745,7 +750,7 @@ class Falyx:
selected_command.retry_policy.delay = self.cli_args.retry_delay selected_command.retry_policy.delay = self.cli_args.retry_delay
if self.cli_args.retry_backoff: if self.cli_args.retry_backoff:
selected_command.retry_policy.backoff = self.cli_args.retry_backoff selected_command.retry_policy.backoff = self.cli_args.retry_backoff
selected_command.update_retry_policy(selected_command.retry_policy) #selected_command.update_retry_policy(selected_command.retry_policy)
def print_message(self, message: str | Markdown | dict[str, Any]) -> None: def print_message(self, message: str | Markdown | dict[str, Any]) -> None:
"""Prints a message to the console.""" """Prints a message to the console."""

View File

@ -12,6 +12,7 @@ from falyx.context import ExecutionContext
from falyx.exceptions import FalyxError from falyx.exceptions import FalyxError
from falyx.execution_registry import ExecutionRegistry as er from falyx.execution_registry import ExecutionRegistry as er
from falyx.hook_manager import HookManager, HookType from falyx.hook_manager import HookManager, HookType
from falyx.utils import logger
from falyx.themes.colors import OneColors from falyx.themes.colors import OneColors
console = Console() console = Console()
@ -33,6 +34,7 @@ class BaseIOAction(BaseAction):
inject_last_result=inject_last_result, inject_last_result=inject_last_result,
) )
self.mode = mode self.mode = mode
self.requires_injection = True
def from_input(self, raw: str | bytes) -> Any: def from_input(self, raw: str | bytes) -> Any:
raise NotImplementedError raise NotImplementedError
@ -53,6 +55,7 @@ class BaseIOAction(BaseAction):
if self.inject_last_result and self.results_context: if self.inject_last_result and self.results_context:
return self.results_context.last_result() return self.results_context.last_result()
logger.debug("[%s] No input provided and no last result found for injection.", self.name)
raise FalyxError("No input provided and no last result to inject.") raise FalyxError("No input provided and no last result to inject.")
async def __call__(self, *args, **kwargs): async def __call__(self, *args, **kwargs):
@ -75,7 +78,6 @@ class BaseIOAction(BaseAction):
else: else:
parsed_input = await self._resolve_input(kwargs) parsed_input = await self._resolve_input(kwargs)
result = await self._run(parsed_input, *args, **kwargs) result = await self._run(parsed_input, *args, **kwargs)
result = await self._run(parsed_input)
output = self.to_output(result) output = self.to_output(result)
await self._write_stdout(output) await self._write_stdout(output)
context.result = result context.result = result