3 Commits

6 changed files with 248 additions and 70 deletions

View File

@ -11,7 +11,6 @@ This guarantees:
from __future__ import annotations
import asyncio
import inspect
import random
from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor
@ -23,7 +22,6 @@ from rich.tree import Tree
from falyx.context import ExecutionContext, ResultsContext
from falyx.debug import register_debug_hooks
from falyx.exceptions import FalyxError
from falyx.execution_registry import ExecutionRegistry as er
from falyx.hook_manager import Hook, HookManager, HookType
from falyx.retry import RetryHandler, RetryPolicy
@ -57,6 +55,7 @@ class BaseAction(ABC):
self.results_context: ResultsContext | None = None
self.inject_last_result: bool = inject_last_result
self.inject_last_result_as: str = inject_last_result_as
self.requires_injection: bool = False
if logging_hooks:
register_debug_hooks(self.hooks)
@ -121,36 +120,16 @@ class BaseAction(ABC):
"""Override in subclasses that produce terminal output."""
pass
def requires_io_injection(self) -> bool:
"""Checks to see if the action requires input injection."""
return self.requires_injection
def __str__(self):
return f"<{self.__class__.__name__} '{self.name}'>"
def __repr__(self):
return str(self)
def __or__(self, other: BaseAction) -> ChainedAction:
"""Chain this action with another action."""
if not isinstance(other, BaseAction):
raise FalyxError(f"Cannot chain {type(other)} with {type(self)}")
return ChainedAction(name=f"{self.name} | {other.name}", actions=[self, other])
async def __ror__(self, other: Any):
if inspect.isawaitable(other):
print(1)
other = await other
if self.inject_last_result:
print(2)
return await self(**{self.inject_last_result_as: other})
literal_action = Action(
name=f"Input | {self.name}",
action=lambda: other,
)
chain = ChainedAction(name=f"{other} | {self.name}", actions=[literal_action, self])
print(3)
print(self.name, other)
return await chain()
class Action(BaseAction):
"""A simple action that runs a callable. It can be a function or a coroutine."""
@ -237,7 +216,8 @@ class Action(BaseAction):
class LiteralInputAction(Action):
def __init__(self, value: Any):
async def literal(*args, **kwargs): return value
async def literal(*args, **kwargs):
return value
super().__init__("Input", literal, inject_last_result=True)
@ -277,16 +257,32 @@ class ChainedAction(BaseAction, ActionListMixin):
def __init__(
self,
name: str,
actions: list[BaseAction] | None = None,
actions: list[BaseAction | Any] | None = None,
hooks: HookManager | None = None,
inject_last_result: bool = False,
inject_last_result_as: str = "last_result",
auto_inject: bool = False,
) -> None:
super().__init__(name, hooks, inject_last_result, inject_last_result_as)
ActionListMixin.__init__(self)
self.auto_inject = auto_inject
if actions:
self.set_actions(actions)
def _wrap_literal_if_needed(self, action: BaseAction | Any) -> BaseAction:
return LiteralInputAction(action) if not isinstance(action, BaseAction) else action
def _apply_auto_inject(self, action: BaseAction) -> None:
if self.auto_inject and not action.inject_last_result:
action.inject_last_result = True
def set_actions(self, actions: list[BaseAction | Any]):
self.actions.clear()
for action in actions:
action = self._wrap_literal_if_needed(action)
self._apply_auto_inject(action)
self.add_action(action)
async def _run(self, *args, **kwargs) -> list[Any]:
results_context = ResultsContext(name=self.name)
if self.results_context:
@ -302,43 +298,22 @@ class ChainedAction(BaseAction, ActionListMixin):
context.start_timer()
try:
await self.hooks.trigger(HookType.BEFORE, context)
last_result = self.results_context.last_result() if self.results_context else None
for index, action in enumerate(self.actions):
results_context.current_index = index
prepared = action.prepare_for_chain(results_context)
run_kwargs = dict(updated_kwargs)
underlying = getattr(prepared, "action", None)
if underlying:
signature = inspect.signature(underlying)
else:
signature = inspect.signature(prepared._run)
parameters = signature.parameters
if last_result is not None:
if action.inject_last_result_as in parameters:
run_kwargs[action.inject_last_result_as] = last_result
result = await prepared(*args, **run_kwargs)
elif (
len(parameters) == 1 and
not parameters.get("self")
):
result = await prepared(last_result)
last_result = results_context.last_result()
if self.requires_io_injection() and last_result is not None:
result = await prepared(**{prepared.inject_last_result_as: last_result})
else:
result = await prepared(*args, **updated_kwargs)
else:
result = await prepared(*args, **updated_kwargs)
last_result = result
results_context.add_result(result)
context.extra["results"].append(result)
context.extra["rollback_stack"].append(prepared)
context.result = last_result
context.result = context.extra["results"]
await self.hooks.trigger(HookType.ON_SUCCESS, context)
return last_result
return context.result
except Exception as error:
context.exception = error

View File

@ -30,7 +30,6 @@ class BottomBar:
) -> None:
self.columns = columns
self.console = Console()
self._items: list[Callable[[], HTML]] = []
self._named_items: dict[str, Callable[[], HTML]] = {}
self._value_getters: dict[str, Callable[[], Any]] = CaseInsensitiveDict()
self.toggle_keys: list[str] = []
@ -99,6 +98,7 @@ class BottomBar:
total: int,
fg: str = OneColors.BLACK,
bg: str = OneColors.WHITE,
enforce_total: bool = True,
) -> None:
if not callable(get_current):
raise ValueError("`get_current` must be a callable returning int")
@ -108,7 +108,7 @@ class BottomBar:
def render():
get_current_ = self._value_getters[name]
current_value = get_current_()
if current_value > total:
if current_value > total and enforce_total:
raise ValueError(
f"Current value {current_value} is greater than total value {total}"
)
@ -169,6 +169,7 @@ class BottomBar:
bg_on: str = OneColors.GREEN,
bg_off: str = OneColors.DARK_RED,
) -> None:
"""Add a toggle to the bottom bar based on an option from OptionsManager."""
self.add_toggle(
key=key,
label=label,
@ -185,15 +186,30 @@ class BottomBar:
return {label: getter() for label, getter in self._value_getters.items()}
def get_value(self, name: str) -> Any:
"""Get the current value of a registered item."""
if name not in self._value_getters:
raise ValueError(f"No value getter registered under name: '{name}'")
return self._value_getters[name]()
def remove_item(self, name: str) -> None:
"""Remove an item from the bottom bar."""
self._named_items.pop(name, None)
self._value_getters.pop(name, None)
if name in self.toggle_keys:
self.toggle_keys.remove(name)
def clear(self) -> None:
"""Clear all items from the bottom bar."""
self._value_getters.clear()
self._named_items.clear()
self.toggle_keys.clear()
def _add_named(self, name: str, render_fn: Callable[[], HTML]) -> None:
if name in self._named_items:
raise ValueError(f"Bottom bar item '{name}' already exists")
self._named_items[name] = render_fn
self._items = list(self._named_items.values())
def render(self):
return merge_formatted_text([fn() for fn in self._items])
"""Render the bottom bar."""
return merge_formatted_text([fn() for fn in self._named_items.values()])

View File

@ -159,7 +159,7 @@ class Command(BaseModel):
elif callable(self.action):
console.print(f"{label}")
console.print(
f"[{OneColors.LIGHT_RED_b}]→ Would call:[/] {self.action.__name__} "
f"[{OneColors.LIGHT_RED_b}]→ Would call:[/] {self.action.__name__}"
f"[dim](args={self.args}, kwargs={self.kwargs})[/dim]"
)
else:

View File

@ -113,6 +113,7 @@ class Falyx:
self.cli_args: Namespace | None = cli_args
self.custom_table: Callable[["Falyx"], Table] | Table | None = custom_table
self.set_options(cli_args, options)
self._session: PromptSession | None = None
def set_options(
self,
@ -127,6 +128,7 @@ class Falyx:
if options and not cli_args:
raise FalyxError("Options are set, but CLI arguments are not.")
assert isinstance(cli_args, Namespace), "CLI arguments must be a Namespace object."
if options is None:
self.options.from_namespace(cli_args, "cli_args")
@ -301,6 +303,7 @@ class Falyx:
"""Forces the session to be recreated on the next access."""
if hasattr(self, "session"):
del self.session
self._session = None
def add_help_command(self):
"""Adds a help command to the menu if it doesn't already exist."""
@ -335,7 +338,7 @@ class Falyx:
def _get_bottom_bar_render(self) -> Callable[[], Any] | str | None:
"""Returns the bottom bar for the menu."""
if isinstance(self.bottom_bar, BottomBar) and self.bottom_bar._items:
if isinstance(self.bottom_bar, BottomBar) and self.bottom_bar._named_items:
return self._bottom_bar.render
elif callable(self._bottom_bar):
return self._bottom_bar
@ -348,7 +351,8 @@ class Falyx:
@cached_property
def session(self) -> PromptSession:
"""Returns the prompt session for the menu."""
return PromptSession(
if self._session is None:
self._session = PromptSession(
message=self.prompt,
multiline=False,
completer=self._get_completer(),
@ -357,6 +361,7 @@ class Falyx:
bottom_toolbar=self._get_bottom_bar_render(),
key_bindings=self.key_bindings,
)
return self._session
def register_all_hooks(self, hook_type: HookType, hooks: Hook | list[Hook]) -> None:
"""Registers hooks for all commands in the menu and actions recursively."""
@ -745,7 +750,7 @@ class Falyx:
selected_command.retry_policy.delay = self.cli_args.retry_delay
if self.cli_args.retry_backoff:
selected_command.retry_policy.backoff = self.cli_args.retry_backoff
selected_command.update_retry_policy(selected_command.retry_policy)
#selected_command.update_retry_policy(selected_command.retry_policy)
def print_message(self, message: str | Markdown | dict[str, Any]) -> None:
"""Prints a message to the console."""

View File

@ -19,6 +19,8 @@ class ResultReporter:
return "ResultReporter"
async def report(self, context: ExecutionContext):
if not callable(self.formatter):
raise TypeError("formatter must be callable")
if context.result is not None:
result_text = self.formatter(context.result)
duration = f"{context.duration:.3f}s" if context.duration is not None else "n/a"

180
falyx/io_action.py Normal file
View File

@ -0,0 +1,180 @@
"""io_action.py"""
import asyncio
import subprocess
import sys
from typing import Any
from rich.console import Console
from rich.tree import Tree
from falyx.action import BaseAction
from falyx.context import ExecutionContext
from falyx.exceptions import FalyxError
from falyx.execution_registry import ExecutionRegistry as er
from falyx.hook_manager import HookManager, HookType
from falyx.utils import logger
from falyx.themes.colors import OneColors
console = Console()
class BaseIOAction(BaseAction):
def __init__(
self,
name: str,
hooks: HookManager | None = None,
mode: str = "buffered",
logging_hooks: bool = True,
inject_last_result: bool = True,
):
super().__init__(
name=name,
hooks=hooks,
logging_hooks=logging_hooks,
inject_last_result=inject_last_result,
)
self.mode = mode
self.requires_injection = True
def from_input(self, raw: str | bytes) -> Any:
raise NotImplementedError
def to_output(self, data: Any) -> str | bytes:
raise NotImplementedError
async def _resolve_input(self, kwargs: dict[str, Any]) -> str | bytes:
last_result = kwargs.pop(self.inject_last_result_as, None)
data = await self._read_stdin()
if data:
return self.from_input(data)
if last_result is not None:
return last_result
if self.inject_last_result and self.results_context:
return self.results_context.last_result()
logger.debug("[%s] No input provided and no last result found for injection.", self.name)
raise FalyxError("No input provided and no last result to inject.")
async def __call__(self, *args, **kwargs):
context = ExecutionContext(
name=self.name,
args=args,
kwargs=kwargs,
action=self,
)
context.start_timer()
await self.hooks.trigger(HookType.BEFORE, context)
try:
if self.mode == "stream":
line_gen = await self._read_stdin_stream()
async for line in self._stream_lines(line_gen, args, kwargs):
pass
result = getattr(self, "_last_result", None)
else:
parsed_input = await self._resolve_input(kwargs)
result = await self._run(parsed_input, *args, **kwargs)
output = self.to_output(result)
await self._write_stdout(output)
context.result = result
await self.hooks.trigger(HookType.ON_SUCCESS, context)
return result
except Exception as error:
context.exception = error
await self.hooks.trigger(HookType.ON_ERROR, context)
raise
finally:
context.stop_timer()
await self.hooks.trigger(HookType.AFTER, context)
await self.hooks.trigger(HookType.ON_TEARDOWN, context)
er.record(context)
async def _read_stdin(self) -> str:
if not sys.stdin.isatty():
return await asyncio.to_thread(sys.stdin.read)
return ""
async def _read_stdin_stream(self) -> Any:
"""Returns a generator that yields lines from stdin in a background thread."""
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, lambda: iter(sys.stdin))
async def _stream_lines(self, line_gen, args, kwargs):
for line in line_gen:
parsed = self.from_input(line)
result = await self._run(parsed, *args, **kwargs)
self._last_result = result
output = self.to_output(result)
await self._write_stdout(output)
yield result
async def _write_stdout(self, data: str) -> None:
await asyncio.to_thread(sys.stdout.write, data)
await asyncio.to_thread(sys.stdout.flush)
async def _run(self, parsed_input: Any, *args, **kwargs) -> Any:
"""Subclasses should override this with actual logic."""
raise NotImplementedError("Must implement _run()")
def __str__(self):
return f"<{self.__class__.__name__} '{self.name}' IOAction>"
async def preview(self, parent: Tree | None = None):
label = [f"[{OneColors.GREEN_b}]⚙ IOAction[/] '{self.name}'"]
if self.inject_last_result:
label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]")
if parent:
parent.add("".join(label))
else:
console.print(Tree("".join(label)))
class UppercaseIO(BaseIOAction):
def from_input(self, raw: str | bytes) -> str:
if not isinstance(raw, (str, bytes)):
raise TypeError(f"{self.name} expected str or bytes input, got {type(raw).__name__}")
return raw.strip() if isinstance(raw, str) else raw.decode("utf-8").strip()
async def _run(self, parsed_input: str, *args, **kwargs) -> str:
return parsed_input.upper()
def to_output(self, data: str) -> str:
return data + "\n"
class ShellAction(BaseIOAction):
def __init__(self, name: str, command_template: str, **kwargs):
super().__init__(name=name, **kwargs)
self.command_template = command_template
def from_input(self, raw: str | bytes) -> str:
if not isinstance(raw, (str, bytes)):
raise TypeError(f"{self.name} expected str or bytes input, got {type(raw).__name__}")
return raw.strip() if isinstance(raw, str) else raw.decode("utf-8").strip()
async def _run(self, parsed_input: str) -> str:
# Replace placeholder in template, or use raw input ddas full command
command = self.command_template.format(parsed_input)
result = subprocess.run(
command, shell=True, text=True, capture_output=True
)
if result.returncode != 0:
raise RuntimeError(result.stderr.strip())
return result.stdout.strip()
def to_output(self, result: str) -> str:
return result
async def preview(self, parent: Tree | None = None):
label = [f"[{OneColors.GREEN_b}]⚙ ShellAction[/] '{self.name}'"]
if self.inject_last_result:
label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]")
if parent:
parent.add("".join(label))
else:
console.print(Tree("".join(label)))