Add ProcessPoolAction, update CAP to look only at keywords correctly
This commit is contained in:
parent
fb1ffbe9f6
commit
f196e38c57
|
@ -52,7 +52,8 @@ poetry install
|
|||
import asyncio
|
||||
import random
|
||||
|
||||
from falyx import Falyx, Action, ChainedAction
|
||||
from falyx import Falyx
|
||||
from falyx.action import Action, ChainedAction
|
||||
|
||||
# A flaky async step that fails randomly
|
||||
async def flaky_step():
|
||||
|
@ -62,8 +63,8 @@ async def flaky_step():
|
|||
return "ok"
|
||||
|
||||
# Create the actions
|
||||
step1 = Action(name="step_1", action=flaky_step, retry=True)
|
||||
step2 = Action(name="step_2", action=flaky_step, retry=True)
|
||||
step1 = Action(name="step_1", action=flaky_step)
|
||||
step2 = Action(name="step_2", action=flaky_step)
|
||||
|
||||
# Chain the actions
|
||||
chain = ChainedAction(name="my_pipeline", actions=[step1, step2])
|
||||
|
@ -74,9 +75,9 @@ falyx.add_command(
|
|||
key="R",
|
||||
description="Run My Pipeline",
|
||||
action=chain,
|
||||
logging_hooks=True,
|
||||
preview_before_confirm=True,
|
||||
confirm=True,
|
||||
retry_all=True,
|
||||
)
|
||||
|
||||
# Entry point
|
||||
|
|
|
@ -1,26 +1,36 @@
|
|||
from rich.console import Console
|
||||
|
||||
from falyx import Falyx
|
||||
from falyx.action import ProcessAction
|
||||
from falyx.action import ProcessPoolAction
|
||||
from falyx.action.process_pool_action import ProcessTask
|
||||
from falyx.execution_registry import ExecutionRegistry as er
|
||||
from falyx.themes import NordColors as nc
|
||||
|
||||
console = Console()
|
||||
falyx = Falyx(title="🚀 Process Pool Demo")
|
||||
|
||||
|
||||
def generate_primes(n):
|
||||
primes = []
|
||||
for num in range(2, n):
|
||||
def generate_primes(start: int = 2, end: int = 100_000) -> list[int]:
|
||||
primes: list[int] = []
|
||||
console.print(f"Generating primes from {start} to {end}...", style=nc.YELLOW)
|
||||
for num in range(start, end):
|
||||
if all(num % p != 0 for p in primes):
|
||||
primes.append(num)
|
||||
console.print(f"Generated {len(primes)} primes up to {n}.", style=nc.GREEN)
|
||||
console.print(
|
||||
f"Generated {len(primes)} primes from {start} to {end}.", style=nc.GREEN
|
||||
)
|
||||
return primes
|
||||
|
||||
|
||||
# Will not block the event loop
|
||||
heavy_action = ProcessAction("Prime Generator", generate_primes, args=(100_000,))
|
||||
actions = [ProcessTask(task=generate_primes)]
|
||||
|
||||
falyx.add_command("R", "Generate Primes", heavy_action, spinner=True)
|
||||
# Will not block the event loop
|
||||
heavy_action = ProcessPoolAction(
|
||||
name="Prime Generator",
|
||||
actions=actions,
|
||||
)
|
||||
|
||||
falyx.add_command("R", "Generate Primes", heavy_action)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -16,6 +16,7 @@ from .io_action import BaseIOAction, ShellAction
|
|||
from .literal_input_action import LiteralInputAction
|
||||
from .menu_action import MenuAction
|
||||
from .process_action import ProcessAction
|
||||
from .process_pool_action import ProcessPoolAction
|
||||
from .prompt_menu_action import PromptMenuAction
|
||||
from .select_file_action import SelectFileAction
|
||||
from .selection_action import SelectionAction
|
||||
|
@ -40,4 +41,5 @@ __all__ = [
|
|||
"LiteralInputAction",
|
||||
"UserInputAction",
|
||||
"PromptMenuAction",
|
||||
"ProcessPoolAction",
|
||||
]
|
||||
|
|
|
@ -165,5 +165,6 @@ class ActionGroup(BaseAction, ActionListMixin):
|
|||
def __str__(self):
|
||||
return (
|
||||
f"ActionGroup(name={self.name!r}, actions={[a.name for a in self.actions]!r},"
|
||||
f" inject_last_result={self.inject_last_result})"
|
||||
f" inject_last_result={self.inject_last_result}, "
|
||||
f"inject_into={self.inject_into!r})"
|
||||
)
|
||||
|
|
|
@ -0,0 +1,166 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any, Callable
|
||||
|
||||
from rich.tree import Tree
|
||||
|
||||
from falyx.action.base import BaseAction
|
||||
from falyx.context import ExecutionContext, SharedContext
|
||||
from falyx.execution_registry import ExecutionRegistry as er
|
||||
from falyx.hook_manager import HookManager, HookType
|
||||
from falyx.logger import logger
|
||||
from falyx.parsers.utils import same_argument_definitions
|
||||
from falyx.themes import OneColors
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessTask:
|
||||
task: Callable[..., Any]
|
||||
args: tuple = ()
|
||||
kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if not callable(self.task):
|
||||
raise TypeError(f"Expected a callable task, got {type(self.task).__name__}")
|
||||
|
||||
|
||||
class ProcessPoolAction(BaseAction):
|
||||
""" """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
actions: list[ProcessTask] | None = None,
|
||||
*,
|
||||
hooks: HookManager | None = None,
|
||||
executor: ProcessPoolExecutor | None = None,
|
||||
inject_last_result: bool = False,
|
||||
inject_into: str = "last_result",
|
||||
):
|
||||
super().__init__(
|
||||
name,
|
||||
hooks=hooks,
|
||||
inject_last_result=inject_last_result,
|
||||
inject_into=inject_into,
|
||||
)
|
||||
self.executor = executor or ProcessPoolExecutor()
|
||||
self.is_retryable = True
|
||||
self.actions: list[ProcessTask] = []
|
||||
if actions:
|
||||
self.set_actions(actions)
|
||||
|
||||
def set_actions(self, actions: list[ProcessTask]) -> None:
|
||||
"""Replaces the current action list with a new one."""
|
||||
self.actions.clear()
|
||||
for action in actions:
|
||||
self.add_action(action)
|
||||
|
||||
def add_action(self, action: ProcessTask) -> None:
|
||||
if not isinstance(action, ProcessTask):
|
||||
raise TypeError(f"Expected a ProcessTask, got {type(action).__name__}")
|
||||
self.actions.append(action)
|
||||
|
||||
def get_infer_target(self) -> tuple[Callable[..., Any] | None, None]:
|
||||
arg_defs = same_argument_definitions([action.task for action in self.actions])
|
||||
if arg_defs:
|
||||
return self.actions[0].task, None
|
||||
logger.debug(
|
||||
"[%s] auto_args disabled: mismatched ProcessPoolAction arguments",
|
||||
self.name,
|
||||
)
|
||||
return None, None
|
||||
|
||||
async def _run(self, *args, **kwargs) -> Any:
|
||||
shared_context = SharedContext(name=self.name, action=self, is_parallel=True)
|
||||
if self.shared_context:
|
||||
shared_context.set_shared_result(self.shared_context.last_result())
|
||||
if self.inject_last_result and self.shared_context:
|
||||
last_result = self.shared_context.last_result()
|
||||
if not self._validate_pickleable(last_result):
|
||||
raise ValueError(
|
||||
f"Cannot inject last result into {self.name}: "
|
||||
f"last result is not pickleable."
|
||||
)
|
||||
print(kwargs)
|
||||
updated_kwargs = self._maybe_inject_last_result(kwargs)
|
||||
print(updated_kwargs)
|
||||
context = ExecutionContext(
|
||||
name=self.name,
|
||||
args=args,
|
||||
kwargs=updated_kwargs,
|
||||
action=self,
|
||||
)
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
context.start_timer()
|
||||
try:
|
||||
await self.hooks.trigger(HookType.BEFORE, context)
|
||||
futures = [
|
||||
loop.run_in_executor(
|
||||
self.executor,
|
||||
partial(
|
||||
task.task,
|
||||
*(*args, *task.args),
|
||||
**{**updated_kwargs, **task.kwargs},
|
||||
),
|
||||
)
|
||||
for task in self.actions
|
||||
]
|
||||
results = await asyncio.gather(*futures, return_exceptions=True)
|
||||
context.result = results
|
||||
await self.hooks.trigger(HookType.ON_SUCCESS, context)
|
||||
return results
|
||||
except Exception as error:
|
||||
context.exception = error
|
||||
await self.hooks.trigger(HookType.ON_ERROR, context)
|
||||
if context.result is not None:
|
||||
return context.result
|
||||
raise
|
||||
finally:
|
||||
context.stop_timer()
|
||||
await self.hooks.trigger(HookType.AFTER, context)
|
||||
await self.hooks.trigger(HookType.ON_TEARDOWN, context)
|
||||
er.record(context)
|
||||
|
||||
def _validate_pickleable(self, obj: Any) -> bool:
|
||||
try:
|
||||
import pickle
|
||||
|
||||
pickle.dumps(obj)
|
||||
return True
|
||||
except (pickle.PicklingError, TypeError):
|
||||
return False
|
||||
|
||||
async def preview(self, parent: Tree | None = None):
|
||||
label = [f"[{OneColors.DARK_YELLOW_b}]🧠 ProcessPoolAction[/] '{self.name}'"]
|
||||
if self.inject_last_result:
|
||||
label.append(f" [dim](receives '{self.inject_into}')[/dim]")
|
||||
tree = parent.add("".join(label)) if parent else Tree("".join(label))
|
||||
actions = self.actions.copy()
|
||||
random.shuffle(actions)
|
||||
for action in actions:
|
||||
label = [
|
||||
f"[{OneColors.DARK_YELLOW_b}] - {getattr(action.task, '__name__', repr(action.task))}[/] "
|
||||
f"[dim]({', '.join(map(repr, action.args))})[/]"
|
||||
]
|
||||
if action.kwargs:
|
||||
label.append(
|
||||
f" [dim]({', '.join(f'{k}={v!r}' for k, v in action.kwargs.items())})[/]"
|
||||
)
|
||||
tree.add("".join(label))
|
||||
|
||||
if not parent:
|
||||
self.console.print(tree)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"ProcessPoolAction(name={self.name!r}, "
|
||||
f"actions={[getattr(action.task, '__name__', repr(action.task)) for action in self.actions]}, "
|
||||
f"inject_last_result={self.inject_last_result}, "
|
||||
f"inject_into={self.inject_into!r})"
|
||||
)
|
|
@ -166,8 +166,8 @@ class CommandArgumentParser:
|
|||
self.help_epilogue: str = help_epilogue
|
||||
self.aliases: list[str] = aliases or []
|
||||
self._arguments: list[Argument] = []
|
||||
self._positional: list[Argument] = []
|
||||
self._keyword: list[Argument] = []
|
||||
self._positional: dict[str, Argument] = {}
|
||||
self._keyword: dict[str, Argument] = {}
|
||||
self._flag_map: dict[str, Argument] = {}
|
||||
self._dest_set: set[str] = set()
|
||||
self._add_help()
|
||||
|
@ -482,12 +482,12 @@ class CommandArgumentParser:
|
|||
)
|
||||
for flag in flags:
|
||||
self._flag_map[flag] = argument
|
||||
if not positional:
|
||||
self._keyword[flag] = argument
|
||||
self._dest_set.add(dest)
|
||||
self._arguments.append(argument)
|
||||
if positional:
|
||||
self._positional.append(argument)
|
||||
else:
|
||||
self._keyword.append(argument)
|
||||
self._positional[dest] = argument
|
||||
|
||||
def get_argument(self, dest: str) -> Argument | None:
|
||||
return next((a for a in self._arguments if a.dest == dest), None)
|
||||
|
@ -663,8 +663,8 @@ class CommandArgumentParser:
|
|||
i = 0
|
||||
while i < len(args):
|
||||
token = args[i]
|
||||
if token in self._flag_map:
|
||||
spec = self._flag_map[token]
|
||||
if token in self._keyword:
|
||||
spec = self._keyword[token]
|
||||
action = spec.action
|
||||
|
||||
if action == ArgumentAction.HELP:
|
||||
|
@ -836,7 +836,7 @@ class CommandArgumentParser:
|
|||
# Options
|
||||
# Add all keyword arguments to the options list
|
||||
options_list = []
|
||||
for arg in self._keyword:
|
||||
for arg in self._keyword.values():
|
||||
choice_text = arg.get_choice_text()
|
||||
if choice_text:
|
||||
options_list.extend([f"[{arg.flags[0]} {choice_text}]"])
|
||||
|
@ -844,7 +844,7 @@ class CommandArgumentParser:
|
|||
options_list.extend([f"[{arg.flags[0]}]"])
|
||||
|
||||
# Add positional arguments to the options list
|
||||
for arg in self._positional:
|
||||
for arg in self._positional.values():
|
||||
choice_text = arg.get_choice_text()
|
||||
if isinstance(arg.nargs, int):
|
||||
choice_text = " ".join([choice_text] * arg.nargs)
|
||||
|
@ -870,14 +870,14 @@ class CommandArgumentParser:
|
|||
if self._arguments:
|
||||
if self._positional:
|
||||
self.console.print("[bold]positional:[/bold]")
|
||||
for arg in self._positional:
|
||||
for arg in self._positional.values():
|
||||
flags = arg.get_positional_text()
|
||||
arg_line = Text(f" {flags:<30} ")
|
||||
help_text = arg.help or ""
|
||||
arg_line.append(help_text)
|
||||
self.console.print(arg_line)
|
||||
self.console.print("[bold]options:[/bold]")
|
||||
for arg in self._keyword:
|
||||
for arg in self._keyword.values():
|
||||
flags = ", ".join(arg.flags)
|
||||
flags_choice = f"{flags} {arg.get_choice_text()}"
|
||||
arg_line = Text(f" {flags_choice:<30} ")
|
||||
|
@ -906,8 +906,8 @@ class CommandArgumentParser:
|
|||
required = sum(arg.required for arg in self._arguments)
|
||||
return (
|
||||
f"CommandArgumentParser(args={len(self._arguments)}, "
|
||||
f"flags={len(self._flag_map)}, dests={len(self._dest_set)}, "
|
||||
f"required={required}, positional={positional})"
|
||||
f"flags={len(self._flag_map)}, keywords={len(self._keyword)}, "
|
||||
f"positional={positional}, required={required})"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
|
|
@ -1 +1 @@
|
|||
__version__ = "0.1.38"
|
||||
__version__ = "0.1.39"
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "falyx"
|
||||
version = "0.1.38"
|
||||
version = "0.1.39"
|
||||
description = "Reliable and introspectable async CLI action framework."
|
||||
authors = ["Roland Thomas Jr <roland@rtj.dev>"]
|
||||
license = "MIT"
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
import pytest
|
||||
|
||||
from falyx.action import Action, ActionFactoryAction, ChainedAction
|
||||
|
||||
|
||||
def make_chain(value) -> ChainedAction:
|
||||
return ChainedAction(
|
||||
"test_chain",
|
||||
[
|
||||
Action("action1", lambda: value + "_1"),
|
||||
Action("action2", lambda: value + "_2"),
|
||||
],
|
||||
return_list=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_factory_action():
|
||||
action = ActionFactoryAction(
|
||||
name="test_action", factory=make_chain, args=("test_value",)
|
||||
)
|
||||
|
||||
result = await action()
|
||||
|
||||
assert result == ["test_value_1", "test_value_2"]
|
|
@ -0,0 +1,49 @@
|
|||
import pytest
|
||||
|
||||
from falyx.exceptions import CommandArgumentError
|
||||
from falyx.parsers import CommandArgumentParser
|
||||
|
||||
|
||||
def test_str():
|
||||
"""Test the string representation of CommandArgumentParser."""
|
||||
parser = CommandArgumentParser()
|
||||
assert (
|
||||
str(parser)
|
||||
== "CommandArgumentParser(args=1, flags=2, keywords=2, positional=0, required=0)"
|
||||
)
|
||||
|
||||
parser.add_argument("test", action="store", help="Test argument")
|
||||
assert (
|
||||
str(parser)
|
||||
== "CommandArgumentParser(args=2, flags=3, keywords=2, positional=1, required=1)"
|
||||
)
|
||||
|
||||
parser.add_argument("-o", "--optional", action="store", help="Optional argument")
|
||||
assert (
|
||||
str(parser)
|
||||
== "CommandArgumentParser(args=3, flags=5, keywords=4, positional=1, required=1)"
|
||||
)
|
||||
|
||||
parser.add_argument("--flag", action="store", help="Flag argument", required=True)
|
||||
assert (
|
||||
str(parser)
|
||||
== "CommandArgumentParser(args=4, flags=6, keywords=5, positional=1, required=2)"
|
||||
)
|
||||
assert (
|
||||
repr(parser)
|
||||
== "CommandArgumentParser(args=4, flags=6, keywords=5, positional=1, required=2)"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_positional_text_with_choices():
|
||||
parser = CommandArgumentParser()
|
||||
parser.add_argument("path", choices=["a", "b"])
|
||||
args = await parser.parse_args(["a"])
|
||||
assert args["path"] == "a"
|
||||
|
||||
with pytest.raises(CommandArgumentError):
|
||||
await parser.parse_args(["c"])
|
||||
|
||||
with pytest.raises(CommandArgumentError):
|
||||
await parser.parse_args([])
|
Loading…
Reference in New Issue