1 Commits

Author SHA1 Message Date
3fd27094d4 Experiemental feature pipes 2025-04-24 19:13:52 -04:00
6 changed files with 70 additions and 248 deletions

View File

@ -11,6 +11,7 @@ This guarantees:
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import inspect
import random import random
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
@ -22,6 +23,7 @@ from rich.tree import Tree
from falyx.context import ExecutionContext, ResultsContext from falyx.context import ExecutionContext, ResultsContext
from falyx.debug import register_debug_hooks from falyx.debug import register_debug_hooks
from falyx.exceptions import FalyxError
from falyx.execution_registry import ExecutionRegistry as er from falyx.execution_registry import ExecutionRegistry as er
from falyx.hook_manager import Hook, HookManager, HookType from falyx.hook_manager import Hook, HookManager, HookType
from falyx.retry import RetryHandler, RetryPolicy from falyx.retry import RetryHandler, RetryPolicy
@ -55,7 +57,6 @@ class BaseAction(ABC):
self.results_context: ResultsContext | None = None self.results_context: ResultsContext | None = None
self.inject_last_result: bool = inject_last_result self.inject_last_result: bool = inject_last_result
self.inject_last_result_as: str = inject_last_result_as self.inject_last_result_as: str = inject_last_result_as
self.requires_injection: bool = False
if logging_hooks: if logging_hooks:
register_debug_hooks(self.hooks) register_debug_hooks(self.hooks)
@ -120,16 +121,36 @@ class BaseAction(ABC):
"""Override in subclasses that produce terminal output.""" """Override in subclasses that produce terminal output."""
pass pass
def requires_io_injection(self) -> bool:
"""Checks to see if the action requires input injection."""
return self.requires_injection
def __str__(self): def __str__(self):
return f"<{self.__class__.__name__} '{self.name}'>" return f"<{self.__class__.__name__} '{self.name}'>"
def __repr__(self): def __repr__(self):
return str(self) return str(self)
def __or__(self, other: BaseAction) -> ChainedAction:
"""Chain this action with another action."""
if not isinstance(other, BaseAction):
raise FalyxError(f"Cannot chain {type(other)} with {type(self)}")
return ChainedAction(name=f"{self.name} | {other.name}", actions=[self, other])
async def __ror__(self, other: Any):
if inspect.isawaitable(other):
print(1)
other = await other
if self.inject_last_result:
print(2)
return await self(**{self.inject_last_result_as: other})
literal_action = Action(
name=f"Input | {self.name}",
action=lambda: other,
)
chain = ChainedAction(name=f"{other} | {self.name}", actions=[literal_action, self])
print(3)
print(self.name, other)
return await chain()
class Action(BaseAction): class Action(BaseAction):
"""A simple action that runs a callable. It can be a function or a coroutine.""" """A simple action that runs a callable. It can be a function or a coroutine."""
@ -216,8 +237,7 @@ class Action(BaseAction):
class LiteralInputAction(Action): class LiteralInputAction(Action):
def __init__(self, value: Any): def __init__(self, value: Any):
async def literal(*args, **kwargs): async def literal(*args, **kwargs): return value
return value
super().__init__("Input", literal, inject_last_result=True) super().__init__("Input", literal, inject_last_result=True)
@ -257,32 +277,16 @@ class ChainedAction(BaseAction, ActionListMixin):
def __init__( def __init__(
self, self,
name: str, name: str,
actions: list[BaseAction | Any] | None = None, actions: list[BaseAction] | None = None,
hooks: HookManager | None = None, hooks: HookManager | None = None,
inject_last_result: bool = False, inject_last_result: bool = False,
inject_last_result_as: str = "last_result", inject_last_result_as: str = "last_result",
auto_inject: bool = False,
) -> None: ) -> None:
super().__init__(name, hooks, inject_last_result, inject_last_result_as) super().__init__(name, hooks, inject_last_result, inject_last_result_as)
ActionListMixin.__init__(self) ActionListMixin.__init__(self)
self.auto_inject = auto_inject
if actions: if actions:
self.set_actions(actions) self.set_actions(actions)
def _wrap_literal_if_needed(self, action: BaseAction | Any) -> BaseAction:
return LiteralInputAction(action) if not isinstance(action, BaseAction) else action
def _apply_auto_inject(self, action: BaseAction) -> None:
if self.auto_inject and not action.inject_last_result:
action.inject_last_result = True
def set_actions(self, actions: list[BaseAction | Any]):
self.actions.clear()
for action in actions:
action = self._wrap_literal_if_needed(action)
self._apply_auto_inject(action)
self.add_action(action)
async def _run(self, *args, **kwargs) -> list[Any]: async def _run(self, *args, **kwargs) -> list[Any]:
results_context = ResultsContext(name=self.name) results_context = ResultsContext(name=self.name)
if self.results_context: if self.results_context:
@ -298,22 +302,43 @@ class ChainedAction(BaseAction, ActionListMixin):
context.start_timer() context.start_timer()
try: try:
await self.hooks.trigger(HookType.BEFORE, context) await self.hooks.trigger(HookType.BEFORE, context)
last_result = self.results_context.last_result() if self.results_context else None
for index, action in enumerate(self.actions): for index, action in enumerate(self.actions):
results_context.current_index = index results_context.current_index = index
prepared = action.prepare_for_chain(results_context) prepared = action.prepare_for_chain(results_context)
last_result = results_context.last_result() run_kwargs = dict(updated_kwargs)
if self.requires_io_injection() and last_result is not None:
result = await prepared(**{prepared.inject_last_result_as: last_result}) underlying = getattr(prepared, "action", None)
if underlying:
signature = inspect.signature(underlying)
else:
signature = inspect.signature(prepared._run)
parameters = signature.parameters
if last_result is not None:
if action.inject_last_result_as in parameters:
run_kwargs[action.inject_last_result_as] = last_result
result = await prepared(*args, **run_kwargs)
elif (
len(parameters) == 1 and
not parameters.get("self")
):
result = await prepared(last_result)
else: else:
result = await prepared(*args, **updated_kwargs) result = await prepared(*args, **updated_kwargs)
else:
result = await prepared(*args, **updated_kwargs)
last_result = result
results_context.add_result(result) results_context.add_result(result)
context.extra["results"].append(result) context.extra["results"].append(result)
context.extra["rollback_stack"].append(prepared) context.extra["rollback_stack"].append(prepared)
context.result = context.extra["results"] context.result = last_result
await self.hooks.trigger(HookType.ON_SUCCESS, context) await self.hooks.trigger(HookType.ON_SUCCESS, context)
return context.result return last_result
except Exception as error: except Exception as error:
context.exception = error context.exception = error

View File

@ -30,6 +30,7 @@ class BottomBar:
) -> None: ) -> None:
self.columns = columns self.columns = columns
self.console = Console() self.console = Console()
self._items: list[Callable[[], HTML]] = []
self._named_items: dict[str, Callable[[], HTML]] = {} self._named_items: dict[str, Callable[[], HTML]] = {}
self._value_getters: dict[str, Callable[[], Any]] = CaseInsensitiveDict() self._value_getters: dict[str, Callable[[], Any]] = CaseInsensitiveDict()
self.toggle_keys: list[str] = [] self.toggle_keys: list[str] = []
@ -98,7 +99,6 @@ class BottomBar:
total: int, total: int,
fg: str = OneColors.BLACK, fg: str = OneColors.BLACK,
bg: str = OneColors.WHITE, bg: str = OneColors.WHITE,
enforce_total: bool = True,
) -> None: ) -> None:
if not callable(get_current): if not callable(get_current):
raise ValueError("`get_current` must be a callable returning int") raise ValueError("`get_current` must be a callable returning int")
@ -108,7 +108,7 @@ class BottomBar:
def render(): def render():
get_current_ = self._value_getters[name] get_current_ = self._value_getters[name]
current_value = get_current_() current_value = get_current_()
if current_value > total and enforce_total: if current_value > total:
raise ValueError( raise ValueError(
f"Current value {current_value} is greater than total value {total}" f"Current value {current_value} is greater than total value {total}"
) )
@ -169,7 +169,6 @@ class BottomBar:
bg_on: str = OneColors.GREEN, bg_on: str = OneColors.GREEN,
bg_off: str = OneColors.DARK_RED, bg_off: str = OneColors.DARK_RED,
) -> None: ) -> None:
"""Add a toggle to the bottom bar based on an option from OptionsManager."""
self.add_toggle( self.add_toggle(
key=key, key=key,
label=label, label=label,
@ -186,30 +185,15 @@ class BottomBar:
return {label: getter() for label, getter in self._value_getters.items()} return {label: getter() for label, getter in self._value_getters.items()}
def get_value(self, name: str) -> Any: def get_value(self, name: str) -> Any:
"""Get the current value of a registered item."""
if name not in self._value_getters: if name not in self._value_getters:
raise ValueError(f"No value getter registered under name: '{name}'") raise ValueError(f"No value getter registered under name: '{name}'")
return self._value_getters[name]() return self._value_getters[name]()
def remove_item(self, name: str) -> None:
"""Remove an item from the bottom bar."""
self._named_items.pop(name, None)
self._value_getters.pop(name, None)
if name in self.toggle_keys:
self.toggle_keys.remove(name)
def clear(self) -> None:
"""Clear all items from the bottom bar."""
self._value_getters.clear()
self._named_items.clear()
self.toggle_keys.clear()
def _add_named(self, name: str, render_fn: Callable[[], HTML]) -> None: def _add_named(self, name: str, render_fn: Callable[[], HTML]) -> None:
if name in self._named_items: if name in self._named_items:
raise ValueError(f"Bottom bar item '{name}' already exists") raise ValueError(f"Bottom bar item '{name}' already exists")
self._named_items[name] = render_fn self._named_items[name] = render_fn
self._items = list(self._named_items.values())
def render(self): def render(self):
"""Render the bottom bar.""" return merge_formatted_text([fn() for fn in self._items])
return merge_formatted_text([fn() for fn in self._named_items.values()])

View File

@ -113,7 +113,6 @@ class Falyx:
self.cli_args: Namespace | None = cli_args self.cli_args: Namespace | None = cli_args
self.custom_table: Callable[["Falyx"], Table] | Table | None = custom_table self.custom_table: Callable[["Falyx"], Table] | Table | None = custom_table
self.set_options(cli_args, options) self.set_options(cli_args, options)
self._session: PromptSession | None = None
def set_options( def set_options(
self, self,
@ -128,7 +127,6 @@ class Falyx:
if options and not cli_args: if options and not cli_args:
raise FalyxError("Options are set, but CLI arguments are not.") raise FalyxError("Options are set, but CLI arguments are not.")
assert isinstance(cli_args, Namespace), "CLI arguments must be a Namespace object."
if options is None: if options is None:
self.options.from_namespace(cli_args, "cli_args") self.options.from_namespace(cli_args, "cli_args")
@ -303,7 +301,6 @@ class Falyx:
"""Forces the session to be recreated on the next access.""" """Forces the session to be recreated on the next access."""
if hasattr(self, "session"): if hasattr(self, "session"):
del self.session del self.session
self._session = None
def add_help_command(self): def add_help_command(self):
"""Adds a help command to the menu if it doesn't already exist.""" """Adds a help command to the menu if it doesn't already exist."""
@ -338,7 +335,7 @@ class Falyx:
def _get_bottom_bar_render(self) -> Callable[[], Any] | str | None: def _get_bottom_bar_render(self) -> Callable[[], Any] | str | None:
"""Returns the bottom bar for the menu.""" """Returns the bottom bar for the menu."""
if isinstance(self.bottom_bar, BottomBar) and self.bottom_bar._named_items: if isinstance(self.bottom_bar, BottomBar) and self.bottom_bar._items:
return self._bottom_bar.render return self._bottom_bar.render
elif callable(self._bottom_bar): elif callable(self._bottom_bar):
return self._bottom_bar return self._bottom_bar
@ -351,8 +348,7 @@ class Falyx:
@cached_property @cached_property
def session(self) -> PromptSession: def session(self) -> PromptSession:
"""Returns the prompt session for the menu.""" """Returns the prompt session for the menu."""
if self._session is None: return PromptSession(
self._session = PromptSession(
message=self.prompt, message=self.prompt,
multiline=False, multiline=False,
completer=self._get_completer(), completer=self._get_completer(),
@ -361,7 +357,6 @@ class Falyx:
bottom_toolbar=self._get_bottom_bar_render(), bottom_toolbar=self._get_bottom_bar_render(),
key_bindings=self.key_bindings, key_bindings=self.key_bindings,
) )
return self._session
def register_all_hooks(self, hook_type: HookType, hooks: Hook | list[Hook]) -> None: def register_all_hooks(self, hook_type: HookType, hooks: Hook | list[Hook]) -> None:
"""Registers hooks for all commands in the menu and actions recursively.""" """Registers hooks for all commands in the menu and actions recursively."""
@ -750,7 +745,7 @@ class Falyx:
selected_command.retry_policy.delay = self.cli_args.retry_delay selected_command.retry_policy.delay = self.cli_args.retry_delay
if self.cli_args.retry_backoff: if self.cli_args.retry_backoff:
selected_command.retry_policy.backoff = self.cli_args.retry_backoff selected_command.retry_policy.backoff = self.cli_args.retry_backoff
#selected_command.update_retry_policy(selected_command.retry_policy) selected_command.update_retry_policy(selected_command.retry_policy)
def print_message(self, message: str | Markdown | dict[str, Any]) -> None: def print_message(self, message: str | Markdown | dict[str, Any]) -> None:
"""Prints a message to the console.""" """Prints a message to the console."""

View File

@ -19,8 +19,6 @@ class ResultReporter:
return "ResultReporter" return "ResultReporter"
async def report(self, context: ExecutionContext): async def report(self, context: ExecutionContext):
if not callable(self.formatter):
raise TypeError("formatter must be callable")
if context.result is not None: if context.result is not None:
result_text = self.formatter(context.result) result_text = self.formatter(context.result)
duration = f"{context.duration:.3f}s" if context.duration is not None else "n/a" duration = f"{context.duration:.3f}s" if context.duration is not None else "n/a"

View File

@ -1,180 +0,0 @@
"""io_action.py"""
import asyncio
import subprocess
import sys
from typing import Any
from rich.console import Console
from rich.tree import Tree
from falyx.action import BaseAction
from falyx.context import ExecutionContext
from falyx.exceptions import FalyxError
from falyx.execution_registry import ExecutionRegistry as er
from falyx.hook_manager import HookManager, HookType
from falyx.utils import logger
from falyx.themes.colors import OneColors
console = Console()
class BaseIOAction(BaseAction):
def __init__(
self,
name: str,
hooks: HookManager | None = None,
mode: str = "buffered",
logging_hooks: bool = True,
inject_last_result: bool = True,
):
super().__init__(
name=name,
hooks=hooks,
logging_hooks=logging_hooks,
inject_last_result=inject_last_result,
)
self.mode = mode
self.requires_injection = True
def from_input(self, raw: str | bytes) -> Any:
raise NotImplementedError
def to_output(self, data: Any) -> str | bytes:
raise NotImplementedError
async def _resolve_input(self, kwargs: dict[str, Any]) -> str | bytes:
last_result = kwargs.pop(self.inject_last_result_as, None)
data = await self._read_stdin()
if data:
return self.from_input(data)
if last_result is not None:
return last_result
if self.inject_last_result and self.results_context:
return self.results_context.last_result()
logger.debug("[%s] No input provided and no last result found for injection.", self.name)
raise FalyxError("No input provided and no last result to inject.")
async def __call__(self, *args, **kwargs):
context = ExecutionContext(
name=self.name,
args=args,
kwargs=kwargs,
action=self,
)
context.start_timer()
await self.hooks.trigger(HookType.BEFORE, context)
try:
if self.mode == "stream":
line_gen = await self._read_stdin_stream()
async for line in self._stream_lines(line_gen, args, kwargs):
pass
result = getattr(self, "_last_result", None)
else:
parsed_input = await self._resolve_input(kwargs)
result = await self._run(parsed_input, *args, **kwargs)
output = self.to_output(result)
await self._write_stdout(output)
context.result = result
await self.hooks.trigger(HookType.ON_SUCCESS, context)
return result
except Exception as error:
context.exception = error
await self.hooks.trigger(HookType.ON_ERROR, context)
raise
finally:
context.stop_timer()
await self.hooks.trigger(HookType.AFTER, context)
await self.hooks.trigger(HookType.ON_TEARDOWN, context)
er.record(context)
async def _read_stdin(self) -> str:
if not sys.stdin.isatty():
return await asyncio.to_thread(sys.stdin.read)
return ""
async def _read_stdin_stream(self) -> Any:
"""Returns a generator that yields lines from stdin in a background thread."""
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, lambda: iter(sys.stdin))
async def _stream_lines(self, line_gen, args, kwargs):
for line in line_gen:
parsed = self.from_input(line)
result = await self._run(parsed, *args, **kwargs)
self._last_result = result
output = self.to_output(result)
await self._write_stdout(output)
yield result
async def _write_stdout(self, data: str) -> None:
await asyncio.to_thread(sys.stdout.write, data)
await asyncio.to_thread(sys.stdout.flush)
async def _run(self, parsed_input: Any, *args, **kwargs) -> Any:
"""Subclasses should override this with actual logic."""
raise NotImplementedError("Must implement _run()")
def __str__(self):
return f"<{self.__class__.__name__} '{self.name}' IOAction>"
async def preview(self, parent: Tree | None = None):
label = [f"[{OneColors.GREEN_b}]⚙ IOAction[/] '{self.name}'"]
if self.inject_last_result:
label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]")
if parent:
parent.add("".join(label))
else:
console.print(Tree("".join(label)))
class UppercaseIO(BaseIOAction):
def from_input(self, raw: str | bytes) -> str:
if not isinstance(raw, (str, bytes)):
raise TypeError(f"{self.name} expected str or bytes input, got {type(raw).__name__}")
return raw.strip() if isinstance(raw, str) else raw.decode("utf-8").strip()
async def _run(self, parsed_input: str, *args, **kwargs) -> str:
return parsed_input.upper()
def to_output(self, data: str) -> str:
return data + "\n"
class ShellAction(BaseIOAction):
def __init__(self, name: str, command_template: str, **kwargs):
super().__init__(name=name, **kwargs)
self.command_template = command_template
def from_input(self, raw: str | bytes) -> str:
if not isinstance(raw, (str, bytes)):
raise TypeError(f"{self.name} expected str or bytes input, got {type(raw).__name__}")
return raw.strip() if isinstance(raw, str) else raw.decode("utf-8").strip()
async def _run(self, parsed_input: str) -> str:
# Replace placeholder in template, or use raw input ddas full command
command = self.command_template.format(parsed_input)
result = subprocess.run(
command, shell=True, text=True, capture_output=True
)
if result.returncode != 0:
raise RuntimeError(result.stderr.strip())
return result.stdout.strip()
def to_output(self, result: str) -> str:
return result
async def preview(self, parent: Tree | None = None):
label = [f"[{OneColors.GREEN_b}]⚙ ShellAction[/] '{self.name}'"]
if self.inject_last_result:
label.append(f" [dim](injects '{self.inject_last_result_as}')[/dim]")
if parent:
parent.add("".join(label))
else:
console.print(Tree("".join(label)))