Add ProcessPoolAction, update CAP to look only at keywords correctly

This commit is contained in:
Roland Thomas Jr 2025-05-28 00:58:50 -04:00
parent fb1ffbe9f6
commit f196e38c57
Signed by: roland
GPG Key ID: 7C3C2B085A4C2872
10 changed files with 282 additions and 28 deletions

View File

@ -52,7 +52,8 @@ poetry install
import asyncio
import random
from falyx import Falyx, Action, ChainedAction
from falyx import Falyx
from falyx.action import Action, ChainedAction
# A flaky async step that fails randomly
async def flaky_step():
@ -62,8 +63,8 @@ async def flaky_step():
return "ok"
# Create the actions
step1 = Action(name="step_1", action=flaky_step, retry=True)
step2 = Action(name="step_2", action=flaky_step, retry=True)
step1 = Action(name="step_1", action=flaky_step)
step2 = Action(name="step_2", action=flaky_step)
# Chain the actions
chain = ChainedAction(name="my_pipeline", actions=[step1, step2])
@ -74,9 +75,9 @@ falyx.add_command(
key="R",
description="Run My Pipeline",
action=chain,
logging_hooks=True,
preview_before_confirm=True,
confirm=True,
retry_all=True,
)
# Entry point

View File

@ -1,26 +1,36 @@
from rich.console import Console
from falyx import Falyx
from falyx.action import ProcessAction
from falyx.action import ProcessPoolAction
from falyx.action.process_pool_action import ProcessTask
from falyx.execution_registry import ExecutionRegistry as er
from falyx.themes import NordColors as nc
console = Console()
falyx = Falyx(title="🚀 Process Pool Demo")
def generate_primes(n):
primes = []
for num in range(2, n):
def generate_primes(start: int = 2, end: int = 100_000) -> list[int]:
primes: list[int] = []
console.print(f"Generating primes from {start} to {end}...", style=nc.YELLOW)
for num in range(start, end):
if all(num % p != 0 for p in primes):
primes.append(num)
console.print(f"Generated {len(primes)} primes up to {n}.", style=nc.GREEN)
console.print(
f"Generated {len(primes)} primes from {start} to {end}.", style=nc.GREEN
)
return primes
# Will not block the event loop
heavy_action = ProcessAction("Prime Generator", generate_primes, args=(100_000,))
actions = [ProcessTask(task=generate_primes)]
falyx.add_command("R", "Generate Primes", heavy_action, spinner=True)
# Will not block the event loop
heavy_action = ProcessPoolAction(
name="Prime Generator",
actions=actions,
)
falyx.add_command("R", "Generate Primes", heavy_action)
if __name__ == "__main__":

View File

@ -16,6 +16,7 @@ from .io_action import BaseIOAction, ShellAction
from .literal_input_action import LiteralInputAction
from .menu_action import MenuAction
from .process_action import ProcessAction
from .process_pool_action import ProcessPoolAction
from .prompt_menu_action import PromptMenuAction
from .select_file_action import SelectFileAction
from .selection_action import SelectionAction
@ -40,4 +41,5 @@ __all__ = [
"LiteralInputAction",
"UserInputAction",
"PromptMenuAction",
"ProcessPoolAction",
]

View File

@ -165,5 +165,6 @@ class ActionGroup(BaseAction, ActionListMixin):
def __str__(self):
return (
f"ActionGroup(name={self.name!r}, actions={[a.name for a in self.actions]!r},"
f" inject_last_result={self.inject_last_result})"
f" inject_last_result={self.inject_last_result}, "
f"inject_into={self.inject_into!r})"
)

View File

@ -0,0 +1,166 @@
from __future__ import annotations
import asyncio
import random
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable
from rich.tree import Tree
from falyx.action.base import BaseAction
from falyx.context import ExecutionContext, SharedContext
from falyx.execution_registry import ExecutionRegistry as er
from falyx.hook_manager import HookManager, HookType
from falyx.logger import logger
from falyx.parsers.utils import same_argument_definitions
from falyx.themes import OneColors
@dataclass
class ProcessTask:
task: Callable[..., Any]
args: tuple = ()
kwargs: dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
if not callable(self.task):
raise TypeError(f"Expected a callable task, got {type(self.task).__name__}")
class ProcessPoolAction(BaseAction):
""" """
def __init__(
self,
name: str,
actions: list[ProcessTask] | None = None,
*,
hooks: HookManager | None = None,
executor: ProcessPoolExecutor | None = None,
inject_last_result: bool = False,
inject_into: str = "last_result",
):
super().__init__(
name,
hooks=hooks,
inject_last_result=inject_last_result,
inject_into=inject_into,
)
self.executor = executor or ProcessPoolExecutor()
self.is_retryable = True
self.actions: list[ProcessTask] = []
if actions:
self.set_actions(actions)
def set_actions(self, actions: list[ProcessTask]) -> None:
"""Replaces the current action list with a new one."""
self.actions.clear()
for action in actions:
self.add_action(action)
def add_action(self, action: ProcessTask) -> None:
if not isinstance(action, ProcessTask):
raise TypeError(f"Expected a ProcessTask, got {type(action).__name__}")
self.actions.append(action)
def get_infer_target(self) -> tuple[Callable[..., Any] | None, None]:
arg_defs = same_argument_definitions([action.task for action in self.actions])
if arg_defs:
return self.actions[0].task, None
logger.debug(
"[%s] auto_args disabled: mismatched ProcessPoolAction arguments",
self.name,
)
return None, None
async def _run(self, *args, **kwargs) -> Any:
shared_context = SharedContext(name=self.name, action=self, is_parallel=True)
if self.shared_context:
shared_context.set_shared_result(self.shared_context.last_result())
if self.inject_last_result and self.shared_context:
last_result = self.shared_context.last_result()
if not self._validate_pickleable(last_result):
raise ValueError(
f"Cannot inject last result into {self.name}: "
f"last result is not pickleable."
)
print(kwargs)
updated_kwargs = self._maybe_inject_last_result(kwargs)
print(updated_kwargs)
context = ExecutionContext(
name=self.name,
args=args,
kwargs=updated_kwargs,
action=self,
)
loop = asyncio.get_running_loop()
context.start_timer()
try:
await self.hooks.trigger(HookType.BEFORE, context)
futures = [
loop.run_in_executor(
self.executor,
partial(
task.task,
*(*args, *task.args),
**{**updated_kwargs, **task.kwargs},
),
)
for task in self.actions
]
results = await asyncio.gather(*futures, return_exceptions=True)
context.result = results
await self.hooks.trigger(HookType.ON_SUCCESS, context)
return results
except Exception as error:
context.exception = error
await self.hooks.trigger(HookType.ON_ERROR, context)
if context.result is not None:
return context.result
raise
finally:
context.stop_timer()
await self.hooks.trigger(HookType.AFTER, context)
await self.hooks.trigger(HookType.ON_TEARDOWN, context)
er.record(context)
def _validate_pickleable(self, obj: Any) -> bool:
try:
import pickle
pickle.dumps(obj)
return True
except (pickle.PicklingError, TypeError):
return False
async def preview(self, parent: Tree | None = None):
label = [f"[{OneColors.DARK_YELLOW_b}]🧠 ProcessPoolAction[/] '{self.name}'"]
if self.inject_last_result:
label.append(f" [dim](receives '{self.inject_into}')[/dim]")
tree = parent.add("".join(label)) if parent else Tree("".join(label))
actions = self.actions.copy()
random.shuffle(actions)
for action in actions:
label = [
f"[{OneColors.DARK_YELLOW_b}] - {getattr(action.task, '__name__', repr(action.task))}[/] "
f"[dim]({', '.join(map(repr, action.args))})[/]"
]
if action.kwargs:
label.append(
f" [dim]({', '.join(f'{k}={v!r}' for k, v in action.kwargs.items())})[/]"
)
tree.add("".join(label))
if not parent:
self.console.print(tree)
def __str__(self) -> str:
return (
f"ProcessPoolAction(name={self.name!r}, "
f"actions={[getattr(action.task, '__name__', repr(action.task)) for action in self.actions]}, "
f"inject_last_result={self.inject_last_result}, "
f"inject_into={self.inject_into!r})"
)

View File

@ -166,8 +166,8 @@ class CommandArgumentParser:
self.help_epilogue: str = help_epilogue
self.aliases: list[str] = aliases or []
self._arguments: list[Argument] = []
self._positional: list[Argument] = []
self._keyword: list[Argument] = []
self._positional: dict[str, Argument] = {}
self._keyword: dict[str, Argument] = {}
self._flag_map: dict[str, Argument] = {}
self._dest_set: set[str] = set()
self._add_help()
@ -482,12 +482,12 @@ class CommandArgumentParser:
)
for flag in flags:
self._flag_map[flag] = argument
if not positional:
self._keyword[flag] = argument
self._dest_set.add(dest)
self._arguments.append(argument)
if positional:
self._positional.append(argument)
else:
self._keyword.append(argument)
self._positional[dest] = argument
def get_argument(self, dest: str) -> Argument | None:
return next((a for a in self._arguments if a.dest == dest), None)
@ -663,8 +663,8 @@ class CommandArgumentParser:
i = 0
while i < len(args):
token = args[i]
if token in self._flag_map:
spec = self._flag_map[token]
if token in self._keyword:
spec = self._keyword[token]
action = spec.action
if action == ArgumentAction.HELP:
@ -836,7 +836,7 @@ class CommandArgumentParser:
# Options
# Add all keyword arguments to the options list
options_list = []
for arg in self._keyword:
for arg in self._keyword.values():
choice_text = arg.get_choice_text()
if choice_text:
options_list.extend([f"[{arg.flags[0]} {choice_text}]"])
@ -844,7 +844,7 @@ class CommandArgumentParser:
options_list.extend([f"[{arg.flags[0]}]"])
# Add positional arguments to the options list
for arg in self._positional:
for arg in self._positional.values():
choice_text = arg.get_choice_text()
if isinstance(arg.nargs, int):
choice_text = " ".join([choice_text] * arg.nargs)
@ -870,14 +870,14 @@ class CommandArgumentParser:
if self._arguments:
if self._positional:
self.console.print("[bold]positional:[/bold]")
for arg in self._positional:
for arg in self._positional.values():
flags = arg.get_positional_text()
arg_line = Text(f" {flags:<30} ")
help_text = arg.help or ""
arg_line.append(help_text)
self.console.print(arg_line)
self.console.print("[bold]options:[/bold]")
for arg in self._keyword:
for arg in self._keyword.values():
flags = ", ".join(arg.flags)
flags_choice = f"{flags} {arg.get_choice_text()}"
arg_line = Text(f" {flags_choice:<30} ")
@ -906,8 +906,8 @@ class CommandArgumentParser:
required = sum(arg.required for arg in self._arguments)
return (
f"CommandArgumentParser(args={len(self._arguments)}, "
f"flags={len(self._flag_map)}, dests={len(self._dest_set)}, "
f"required={required}, positional={positional})"
f"flags={len(self._flag_map)}, keywords={len(self._keyword)}, "
f"positional={positional}, required={required})"
)
def __repr__(self) -> str:

View File

@ -1 +1 @@
__version__ = "0.1.38"
__version__ = "0.1.39"

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "falyx"
version = "0.1.38"
version = "0.1.39"
description = "Reliable and introspectable async CLI action framework."
authors = ["Roland Thomas Jr <roland@rtj.dev>"]
license = "MIT"

View File

@ -0,0 +1,25 @@
import pytest
from falyx.action import Action, ActionFactoryAction, ChainedAction
def make_chain(value) -> ChainedAction:
return ChainedAction(
"test_chain",
[
Action("action1", lambda: value + "_1"),
Action("action2", lambda: value + "_2"),
],
return_list=True,
)
@pytest.mark.asyncio
async def test_action_factory_action():
action = ActionFactoryAction(
name="test_action", factory=make_chain, args=("test_value",)
)
result = await action()
assert result == ["test_value_1", "test_value_2"]

View File

@ -0,0 +1,49 @@
import pytest
from falyx.exceptions import CommandArgumentError
from falyx.parsers import CommandArgumentParser
def test_str():
"""Test the string representation of CommandArgumentParser."""
parser = CommandArgumentParser()
assert (
str(parser)
== "CommandArgumentParser(args=1, flags=2, keywords=2, positional=0, required=0)"
)
parser.add_argument("test", action="store", help="Test argument")
assert (
str(parser)
== "CommandArgumentParser(args=2, flags=3, keywords=2, positional=1, required=1)"
)
parser.add_argument("-o", "--optional", action="store", help="Optional argument")
assert (
str(parser)
== "CommandArgumentParser(args=3, flags=5, keywords=4, positional=1, required=1)"
)
parser.add_argument("--flag", action="store", help="Flag argument", required=True)
assert (
str(parser)
== "CommandArgumentParser(args=4, flags=6, keywords=5, positional=1, required=2)"
)
assert (
repr(parser)
== "CommandArgumentParser(args=4, flags=6, keywords=5, positional=1, required=2)"
)
@pytest.mark.asyncio
async def test_positional_text_with_choices():
parser = CommandArgumentParser()
parser.add_argument("path", choices=["a", "b"])
args = await parser.parse_args(["a"])
assert args["path"] == "a"
with pytest.raises(CommandArgumentError):
await parser.parse_args(["c"])
with pytest.raises(CommandArgumentError):
await parser.parse_args([])