Add compatibility between BaseAction and BaseIOAction

This commit is contained in:
Roland Thomas Jr 2025-04-24 22:39:42 -04:00
parent 1fe0cd2675
commit e9fdd9cec6
Signed by: roland
GPG Key ID: 7C3C2B085A4C2872
3 changed files with 64 additions and 21 deletions

View File

@ -55,6 +55,7 @@ class BaseAction(ABC):
self.results_context: ResultsContext | None = None
self.inject_last_result: bool = inject_last_result
self.inject_last_result_as: str = inject_last_result_as
self.requires_injection: bool = False
if logging_hooks:
register_debug_hooks(self.hooks)
@ -102,12 +103,6 @@ class BaseAction(ABC):
"""Register a hook for all actions and sub-actions."""
self.hooks.register(hook_type, hook)
def __str__(self):
return f"<{self.__class__.__name__} '{self.name}'>"
def __repr__(self):
return str(self)
@classmethod
def enable_retries_recursively(cls, action: BaseAction, policy: RetryPolicy | None):
if not policy:
@ -115,12 +110,26 @@ class BaseAction(ABC):
if isinstance(action, Action):
action.retry_policy = policy
action.retry_policy.enabled = True
action.hooks.register("on_error", RetryHandler(policy).retry_on_error)
action.hooks.register(HookType.ON_ERROR, RetryHandler(policy).retry_on_error)
if hasattr(action, "actions"):
for sub in action.actions:
cls.enable_retries_recursively(sub, policy)
async def _write_stdout(self, data: str) -> None:
"""Override in subclasses that produce terminal output."""
pass
def requires_io_injection(self) -> bool:
"""Checks to see if the action requires input injection."""
return self.requires_injection
def __str__(self):
return f"<{self.__class__.__name__} '{self.name}'>"
def __repr__(self):
return str(self)
class Action(BaseAction):
"""A simple action that runs a callable. It can be a function or a coroutine."""
@ -205,6 +214,13 @@ class Action(BaseAction):
console.print(Tree("".join(label)))
class LiteralInputAction(Action):
def __init__(self, value: Any):
async def literal(*args, **kwargs):
return value
super().__init__("Input", literal, inject_last_result=True)
class ActionListMixin:
"""Mixin for managing a list of actions."""
def __init__(self) -> None:
@ -241,16 +257,32 @@ class ChainedAction(BaseAction, ActionListMixin):
def __init__(
self,
name: str,
actions: list[BaseAction] | None = None,
actions: list[BaseAction | Any] | None = None,
hooks: HookManager | None = None,
inject_last_result: bool = False,
inject_last_result_as: str = "last_result",
auto_inject: bool = False,
) -> None:
super().__init__(name, hooks, inject_last_result, inject_last_result_as)
ActionListMixin.__init__(self)
self.auto_inject = auto_inject
if actions:
self.set_actions(actions)
def _wrap_literal_if_needed(self, action: BaseAction | Any) -> BaseAction:
return LiteralInputAction(action) if not isinstance(action, BaseAction) else action
def _apply_auto_inject(self, action: BaseAction) -> None:
if self.auto_inject and not action.inject_last_result:
action.inject_last_result = True
def set_actions(self, actions: list[BaseAction | Any]):
self.actions.clear()
for action in actions:
action = self._wrap_literal_if_needed(action)
self._apply_auto_inject(action)
self.add_action(action)
async def _run(self, *args, **kwargs) -> list[Any]:
results_context = ResultsContext(name=self.name)
if self.results_context:
@ -270,6 +302,10 @@ class ChainedAction(BaseAction, ActionListMixin):
for index, action in enumerate(self.actions):
results_context.current_index = index
prepared = action.prepare_for_chain(results_context)
last_result = results_context.last_result()
if self.requires_io_injection() and last_result is not None:
result = await prepared(**{prepared.inject_last_result_as: last_result})
else:
result = await prepared(*args, **updated_kwargs)
results_context.add_result(result)
context.extra["results"].append(result)
@ -302,7 +338,7 @@ class ChainedAction(BaseAction, ActionListMixin):
logger.error("[%s]⚠️ Rollback failed: %s", action.name, error)
async def preview(self, parent: Tree | None = None):
label = f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'"
label = [f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'"]
if self.inject_last_result:
label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]")
tree = parent.add("".join(label)) if parent else Tree("".join(label))

View File

@ -113,6 +113,7 @@ class Falyx:
self.cli_args: Namespace | None = cli_args
self.custom_table: Callable[["Falyx"], Table] | Table | None = custom_table
self.set_options(cli_args, options)
self._session: PromptSession | None = None
def set_options(
self,
@ -127,6 +128,7 @@ class Falyx:
if options and not cli_args:
raise FalyxError("Options are set, but CLI arguments are not.")
assert isinstance(cli_args, Namespace), "CLI arguments must be a Namespace object."
if options is None:
self.options.from_namespace(cli_args, "cli_args")
@ -301,6 +303,7 @@ class Falyx:
"""Forces the session to be recreated on the next access."""
if hasattr(self, "session"):
del self.session
self._session = None
def add_help_command(self):
"""Adds a help command to the menu if it doesn't already exist."""
@ -348,7 +351,8 @@ class Falyx:
@cached_property
def session(self) -> PromptSession:
"""Returns the prompt session for the menu."""
return PromptSession(
if self._session is None:
self._session = PromptSession(
message=self.prompt,
multiline=False,
completer=self._get_completer(),
@ -357,6 +361,7 @@ class Falyx:
bottom_toolbar=self._get_bottom_bar_render(),
key_bindings=self.key_bindings,
)
return self._session
def register_all_hooks(self, hook_type: HookType, hooks: Hook | list[Hook]) -> None:
"""Registers hooks for all commands in the menu and actions recursively."""
@ -745,7 +750,7 @@ class Falyx:
selected_command.retry_policy.delay = self.cli_args.retry_delay
if self.cli_args.retry_backoff:
selected_command.retry_policy.backoff = self.cli_args.retry_backoff
selected_command.update_retry_policy(selected_command.retry_policy)
#selected_command.update_retry_policy(selected_command.retry_policy)
def print_message(self, message: str | Markdown | dict[str, Any]) -> None:
"""Prints a message to the console."""

View File

@ -12,6 +12,7 @@ from falyx.context import ExecutionContext
from falyx.exceptions import FalyxError
from falyx.execution_registry import ExecutionRegistry as er
from falyx.hook_manager import HookManager, HookType
from falyx.utils import logger
from falyx.themes.colors import OneColors
console = Console()
@ -33,6 +34,7 @@ class BaseIOAction(BaseAction):
inject_last_result=inject_last_result,
)
self.mode = mode
self.requires_injection = True
def from_input(self, raw: str | bytes) -> Any:
raise NotImplementedError
@ -53,6 +55,7 @@ class BaseIOAction(BaseAction):
if self.inject_last_result and self.results_context:
return self.results_context.last_result()
logger.debug("[%s] No input provided and no last result found for injection.", self.name)
raise FalyxError("No input provided and no last result to inject.")
async def __call__(self, *args, **kwargs):
@ -75,7 +78,6 @@ class BaseIOAction(BaseAction):
else:
parsed_input = await self._resolve_input(kwargs)
result = await self._run(parsed_input, *args, **kwargs)
result = await self._run(parsed_input)
output = self.to_output(result)
await self._write_stdout(output)
context.result = result