3 Commits

6 changed files with 265 additions and 26 deletions

View File

@ -55,6 +55,7 @@ class BaseAction(ABC):
self.results_context: ResultsContext | None = None self.results_context: ResultsContext | None = None
self.inject_last_result: bool = inject_last_result self.inject_last_result: bool = inject_last_result
self.inject_last_result_as: str = inject_last_result_as self.inject_last_result_as: str = inject_last_result_as
self.requires_injection: bool = False
if logging_hooks: if logging_hooks:
register_debug_hooks(self.hooks) register_debug_hooks(self.hooks)
@ -102,12 +103,6 @@ class BaseAction(ABC):
"""Register a hook for all actions and sub-actions.""" """Register a hook for all actions and sub-actions."""
self.hooks.register(hook_type, hook) self.hooks.register(hook_type, hook)
def __str__(self):
return f"<{self.__class__.__name__} '{self.name}'>"
def __repr__(self):
return str(self)
@classmethod @classmethod
def enable_retries_recursively(cls, action: BaseAction, policy: RetryPolicy | None): def enable_retries_recursively(cls, action: BaseAction, policy: RetryPolicy | None):
if not policy: if not policy:
@ -115,12 +110,26 @@ class BaseAction(ABC):
if isinstance(action, Action): if isinstance(action, Action):
action.retry_policy = policy action.retry_policy = policy
action.retry_policy.enabled = True action.retry_policy.enabled = True
action.hooks.register("on_error", RetryHandler(policy).retry_on_error) action.hooks.register(HookType.ON_ERROR, RetryHandler(policy).retry_on_error)
if hasattr(action, "actions"): if hasattr(action, "actions"):
for sub in action.actions: for sub in action.actions:
cls.enable_retries_recursively(sub, policy) cls.enable_retries_recursively(sub, policy)
async def _write_stdout(self, data: str) -> None:
"""Override in subclasses that produce terminal output."""
pass
def requires_io_injection(self) -> bool:
"""Checks to see if the action requires input injection."""
return self.requires_injection
def __str__(self):
return f"<{self.__class__.__name__} '{self.name}'>"
def __repr__(self):
return str(self)
class Action(BaseAction): class Action(BaseAction):
"""A simple action that runs a callable. It can be a function or a coroutine.""" """A simple action that runs a callable. It can be a function or a coroutine."""
@ -205,6 +214,13 @@ class Action(BaseAction):
console.print(Tree("".join(label))) console.print(Tree("".join(label)))
class LiteralInputAction(Action):
def __init__(self, value: Any):
async def literal(*args, **kwargs):
return value
super().__init__("Input", literal, inject_last_result=True)
class ActionListMixin: class ActionListMixin:
"""Mixin for managing a list of actions.""" """Mixin for managing a list of actions."""
def __init__(self) -> None: def __init__(self) -> None:
@ -241,16 +257,32 @@ class ChainedAction(BaseAction, ActionListMixin):
def __init__( def __init__(
self, self,
name: str, name: str,
actions: list[BaseAction] | None = None, actions: list[BaseAction | Any] | None = None,
hooks: HookManager | None = None, hooks: HookManager | None = None,
inject_last_result: bool = False, inject_last_result: bool = False,
inject_last_result_as: str = "last_result", inject_last_result_as: str = "last_result",
auto_inject: bool = False,
) -> None: ) -> None:
super().__init__(name, hooks, inject_last_result, inject_last_result_as) super().__init__(name, hooks, inject_last_result, inject_last_result_as)
ActionListMixin.__init__(self) ActionListMixin.__init__(self)
self.auto_inject = auto_inject
if actions: if actions:
self.set_actions(actions) self.set_actions(actions)
def _wrap_literal_if_needed(self, action: BaseAction | Any) -> BaseAction:
return LiteralInputAction(action) if not isinstance(action, BaseAction) else action
def _apply_auto_inject(self, action: BaseAction) -> None:
if self.auto_inject and not action.inject_last_result:
action.inject_last_result = True
def set_actions(self, actions: list[BaseAction | Any]):
self.actions.clear()
for action in actions:
action = self._wrap_literal_if_needed(action)
self._apply_auto_inject(action)
self.add_action(action)
async def _run(self, *args, **kwargs) -> list[Any]: async def _run(self, *args, **kwargs) -> list[Any]:
results_context = ResultsContext(name=self.name) results_context = ResultsContext(name=self.name)
if self.results_context: if self.results_context:
@ -270,6 +302,10 @@ class ChainedAction(BaseAction, ActionListMixin):
for index, action in enumerate(self.actions): for index, action in enumerate(self.actions):
results_context.current_index = index results_context.current_index = index
prepared = action.prepare_for_chain(results_context) prepared = action.prepare_for_chain(results_context)
last_result = results_context.last_result()
if self.requires_io_injection() and last_result is not None:
result = await prepared(**{prepared.inject_last_result_as: last_result})
else:
result = await prepared(*args, **updated_kwargs) result = await prepared(*args, **updated_kwargs)
results_context.add_result(result) results_context.add_result(result)
context.extra["results"].append(result) context.extra["results"].append(result)
@ -302,7 +338,7 @@ class ChainedAction(BaseAction, ActionListMixin):
logger.error("[%s]⚠️ Rollback failed: %s", action.name, error) logger.error("[%s]⚠️ Rollback failed: %s", action.name, error)
async def preview(self, parent: Tree | None = None): async def preview(self, parent: Tree | None = None):
label = f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'" label = [f"[{OneColors.CYAN_b}]⛓ ChainedAction[/] '{self.name}'"]
if self.inject_last_result: if self.inject_last_result:
label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]") label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]")
tree = parent.add("".join(label)) if parent else Tree("".join(label)) tree = parent.add("".join(label)) if parent else Tree("".join(label))

View File

@ -30,7 +30,6 @@ class BottomBar:
) -> None: ) -> None:
self.columns = columns self.columns = columns
self.console = Console() self.console = Console()
self._items: list[Callable[[], HTML]] = []
self._named_items: dict[str, Callable[[], HTML]] = {} self._named_items: dict[str, Callable[[], HTML]] = {}
self._value_getters: dict[str, Callable[[], Any]] = CaseInsensitiveDict() self._value_getters: dict[str, Callable[[], Any]] = CaseInsensitiveDict()
self.toggle_keys: list[str] = [] self.toggle_keys: list[str] = []
@ -99,6 +98,7 @@ class BottomBar:
total: int, total: int,
fg: str = OneColors.BLACK, fg: str = OneColors.BLACK,
bg: str = OneColors.WHITE, bg: str = OneColors.WHITE,
enforce_total: bool = True,
) -> None: ) -> None:
if not callable(get_current): if not callable(get_current):
raise ValueError("`get_current` must be a callable returning int") raise ValueError("`get_current` must be a callable returning int")
@ -108,7 +108,7 @@ class BottomBar:
def render(): def render():
get_current_ = self._value_getters[name] get_current_ = self._value_getters[name]
current_value = get_current_() current_value = get_current_()
if current_value > total: if current_value > total and enforce_total:
raise ValueError( raise ValueError(
f"Current value {current_value} is greater than total value {total}" f"Current value {current_value} is greater than total value {total}"
) )
@ -169,6 +169,7 @@ class BottomBar:
bg_on: str = OneColors.GREEN, bg_on: str = OneColors.GREEN,
bg_off: str = OneColors.DARK_RED, bg_off: str = OneColors.DARK_RED,
) -> None: ) -> None:
"""Add a toggle to the bottom bar based on an option from OptionsManager."""
self.add_toggle( self.add_toggle(
key=key, key=key,
label=label, label=label,
@ -185,15 +186,30 @@ class BottomBar:
return {label: getter() for label, getter in self._value_getters.items()} return {label: getter() for label, getter in self._value_getters.items()}
def get_value(self, name: str) -> Any: def get_value(self, name: str) -> Any:
"""Get the current value of a registered item."""
if name not in self._value_getters: if name not in self._value_getters:
raise ValueError(f"No value getter registered under name: '{name}'") raise ValueError(f"No value getter registered under name: '{name}'")
return self._value_getters[name]() return self._value_getters[name]()
def remove_item(self, name: str) -> None:
"""Remove an item from the bottom bar."""
self._named_items.pop(name, None)
self._value_getters.pop(name, None)
if name in self.toggle_keys:
self.toggle_keys.remove(name)
def clear(self) -> None:
"""Clear all items from the bottom bar."""
self._value_getters.clear()
self._named_items.clear()
self.toggle_keys.clear()
def _add_named(self, name: str, render_fn: Callable[[], HTML]) -> None: def _add_named(self, name: str, render_fn: Callable[[], HTML]) -> None:
if name in self._named_items: if name in self._named_items:
raise ValueError(f"Bottom bar item '{name}' already exists") raise ValueError(f"Bottom bar item '{name}' already exists")
self._named_items[name] = render_fn self._named_items[name] = render_fn
self._items = list(self._named_items.values())
def render(self): def render(self):
return merge_formatted_text([fn() for fn in self._items]) """Render the bottom bar."""
return merge_formatted_text([fn() for fn in self._named_items.values()])

View File

@ -159,7 +159,7 @@ class Command(BaseModel):
elif callable(self.action): elif callable(self.action):
console.print(f"{label}") console.print(f"{label}")
console.print( console.print(
f"[{OneColors.LIGHT_RED_b}]→ Would call:[/] {self.action.__name__} " f"[{OneColors.LIGHT_RED_b}]→ Would call:[/] {self.action.__name__}"
f"[dim](args={self.args}, kwargs={self.kwargs})[/dim]" f"[dim](args={self.args}, kwargs={self.kwargs})[/dim]"
) )
else: else:

View File

@ -113,6 +113,7 @@ class Falyx:
self.cli_args: Namespace | None = cli_args self.cli_args: Namespace | None = cli_args
self.custom_table: Callable[["Falyx"], Table] | Table | None = custom_table self.custom_table: Callable[["Falyx"], Table] | Table | None = custom_table
self.set_options(cli_args, options) self.set_options(cli_args, options)
self._session: PromptSession | None = None
def set_options( def set_options(
self, self,
@ -127,6 +128,7 @@ class Falyx:
if options and not cli_args: if options and not cli_args:
raise FalyxError("Options are set, but CLI arguments are not.") raise FalyxError("Options are set, but CLI arguments are not.")
assert isinstance(cli_args, Namespace), "CLI arguments must be a Namespace object."
if options is None: if options is None:
self.options.from_namespace(cli_args, "cli_args") self.options.from_namespace(cli_args, "cli_args")
@ -301,6 +303,7 @@ class Falyx:
"""Forces the session to be recreated on the next access.""" """Forces the session to be recreated on the next access."""
if hasattr(self, "session"): if hasattr(self, "session"):
del self.session del self.session
self._session = None
def add_help_command(self): def add_help_command(self):
"""Adds a help command to the menu if it doesn't already exist.""" """Adds a help command to the menu if it doesn't already exist."""
@ -335,7 +338,7 @@ class Falyx:
def _get_bottom_bar_render(self) -> Callable[[], Any] | str | None: def _get_bottom_bar_render(self) -> Callable[[], Any] | str | None:
"""Returns the bottom bar for the menu.""" """Returns the bottom bar for the menu."""
if isinstance(self.bottom_bar, BottomBar) and self.bottom_bar._items: if isinstance(self.bottom_bar, BottomBar) and self.bottom_bar._named_items:
return self._bottom_bar.render return self._bottom_bar.render
elif callable(self._bottom_bar): elif callable(self._bottom_bar):
return self._bottom_bar return self._bottom_bar
@ -348,7 +351,8 @@ class Falyx:
@cached_property @cached_property
def session(self) -> PromptSession: def session(self) -> PromptSession:
"""Returns the prompt session for the menu.""" """Returns the prompt session for the menu."""
return PromptSession( if self._session is None:
self._session = PromptSession(
message=self.prompt, message=self.prompt,
multiline=False, multiline=False,
completer=self._get_completer(), completer=self._get_completer(),
@ -357,6 +361,7 @@ class Falyx:
bottom_toolbar=self._get_bottom_bar_render(), bottom_toolbar=self._get_bottom_bar_render(),
key_bindings=self.key_bindings, key_bindings=self.key_bindings,
) )
return self._session
def register_all_hooks(self, hook_type: HookType, hooks: Hook | list[Hook]) -> None: def register_all_hooks(self, hook_type: HookType, hooks: Hook | list[Hook]) -> None:
"""Registers hooks for all commands in the menu and actions recursively.""" """Registers hooks for all commands in the menu and actions recursively."""
@ -745,7 +750,7 @@ class Falyx:
selected_command.retry_policy.delay = self.cli_args.retry_delay selected_command.retry_policy.delay = self.cli_args.retry_delay
if self.cli_args.retry_backoff: if self.cli_args.retry_backoff:
selected_command.retry_policy.backoff = self.cli_args.retry_backoff selected_command.retry_policy.backoff = self.cli_args.retry_backoff
selected_command.update_retry_policy(selected_command.retry_policy) #selected_command.update_retry_policy(selected_command.retry_policy)
def print_message(self, message: str | Markdown | dict[str, Any]) -> None: def print_message(self, message: str | Markdown | dict[str, Any]) -> None:
"""Prints a message to the console.""" """Prints a message to the console."""

View File

@ -19,6 +19,8 @@ class ResultReporter:
return "ResultReporter" return "ResultReporter"
async def report(self, context: ExecutionContext): async def report(self, context: ExecutionContext):
if not callable(self.formatter):
raise TypeError("formatter must be callable")
if context.result is not None: if context.result is not None:
result_text = self.formatter(context.result) result_text = self.formatter(context.result)
duration = f"{context.duration:.3f}s" if context.duration is not None else "n/a" duration = f"{context.duration:.3f}s" if context.duration is not None else "n/a"

180
falyx/io_action.py Normal file
View File

@ -0,0 +1,180 @@
"""io_action.py"""
import asyncio
import subprocess
import sys
from typing import Any
from rich.console import Console
from rich.tree import Tree
from falyx.action import BaseAction
from falyx.context import ExecutionContext
from falyx.exceptions import FalyxError
from falyx.execution_registry import ExecutionRegistry as er
from falyx.hook_manager import HookManager, HookType
from falyx.utils import logger
from falyx.themes.colors import OneColors
console = Console()
class BaseIOAction(BaseAction):
def __init__(
self,
name: str,
hooks: HookManager | None = None,
mode: str = "buffered",
logging_hooks: bool = True,
inject_last_result: bool = True,
):
super().__init__(
name=name,
hooks=hooks,
logging_hooks=logging_hooks,
inject_last_result=inject_last_result,
)
self.mode = mode
self.requires_injection = True
def from_input(self, raw: str | bytes) -> Any:
raise NotImplementedError
def to_output(self, data: Any) -> str | bytes:
raise NotImplementedError
async def _resolve_input(self, kwargs: dict[str, Any]) -> str | bytes:
last_result = kwargs.pop(self.inject_last_result_as, None)
data = await self._read_stdin()
if data:
return self.from_input(data)
if last_result is not None:
return last_result
if self.inject_last_result and self.results_context:
return self.results_context.last_result()
logger.debug("[%s] No input provided and no last result found for injection.", self.name)
raise FalyxError("No input provided and no last result to inject.")
async def __call__(self, *args, **kwargs):
context = ExecutionContext(
name=self.name,
args=args,
kwargs=kwargs,
action=self,
)
context.start_timer()
await self.hooks.trigger(HookType.BEFORE, context)
try:
if self.mode == "stream":
line_gen = await self._read_stdin_stream()
async for line in self._stream_lines(line_gen, args, kwargs):
pass
result = getattr(self, "_last_result", None)
else:
parsed_input = await self._resolve_input(kwargs)
result = await self._run(parsed_input, *args, **kwargs)
output = self.to_output(result)
await self._write_stdout(output)
context.result = result
await self.hooks.trigger(HookType.ON_SUCCESS, context)
return result
except Exception as error:
context.exception = error
await self.hooks.trigger(HookType.ON_ERROR, context)
raise
finally:
context.stop_timer()
await self.hooks.trigger(HookType.AFTER, context)
await self.hooks.trigger(HookType.ON_TEARDOWN, context)
er.record(context)
async def _read_stdin(self) -> str:
if not sys.stdin.isatty():
return await asyncio.to_thread(sys.stdin.read)
return ""
async def _read_stdin_stream(self) -> Any:
"""Returns a generator that yields lines from stdin in a background thread."""
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, lambda: iter(sys.stdin))
async def _stream_lines(self, line_gen, args, kwargs):
for line in line_gen:
parsed = self.from_input(line)
result = await self._run(parsed, *args, **kwargs)
self._last_result = result
output = self.to_output(result)
await self._write_stdout(output)
yield result
async def _write_stdout(self, data: str) -> None:
await asyncio.to_thread(sys.stdout.write, data)
await asyncio.to_thread(sys.stdout.flush)
async def _run(self, parsed_input: Any, *args, **kwargs) -> Any:
"""Subclasses should override this with actual logic."""
raise NotImplementedError("Must implement _run()")
def __str__(self):
return f"<{self.__class__.__name__} '{self.name}' IOAction>"
async def preview(self, parent: Tree | None = None):
label = [f"[{OneColors.GREEN_b}]⚙ IOAction[/] '{self.name}'"]
if self.inject_last_result:
label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]")
if parent:
parent.add("".join(label))
else:
console.print(Tree("".join(label)))
class UppercaseIO(BaseIOAction):
def from_input(self, raw: str | bytes) -> str:
if not isinstance(raw, (str, bytes)):
raise TypeError(f"{self.name} expected str or bytes input, got {type(raw).__name__}")
return raw.strip() if isinstance(raw, str) else raw.decode("utf-8").strip()
async def _run(self, parsed_input: str, *args, **kwargs) -> str:
return parsed_input.upper()
def to_output(self, data: str) -> str:
return data + "\n"
class ShellAction(BaseIOAction):
def __init__(self, name: str, command_template: str, **kwargs):
super().__init__(name=name, **kwargs)
self.command_template = command_template
def from_input(self, raw: str | bytes) -> str:
if not isinstance(raw, (str, bytes)):
raise TypeError(f"{self.name} expected str or bytes input, got {type(raw).__name__}")
return raw.strip() if isinstance(raw, str) else raw.decode("utf-8").strip()
async def _run(self, parsed_input: str) -> str:
# Replace placeholder in template, or use raw input ddas full command
command = self.command_template.format(parsed_input)
result = subprocess.run(
command, shell=True, text=True, capture_output=True
)
if result.returncode != 0:
raise RuntimeError(result.stderr.strip())
return result.stdout.strip()
def to_output(self, result: str) -> str:
return result
async def preview(self, parent: Tree | None = None):
label = [f"[{OneColors.GREEN_b}]⚙ ShellAction[/] '{self.name}'"]
if self.inject_last_result:
label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]")
if parent:
parent.add("".join(label))
else:
console.print(Tree("".join(label)))