Add ProcessPoolAction, update CAP to look only at keywords correctly
This commit is contained in:
parent
fb1ffbe9f6
commit
f196e38c57
|
@ -52,7 +52,8 @@ poetry install
|
||||||
import asyncio
|
import asyncio
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from falyx import Falyx, Action, ChainedAction
|
from falyx import Falyx
|
||||||
|
from falyx.action import Action, ChainedAction
|
||||||
|
|
||||||
# A flaky async step that fails randomly
|
# A flaky async step that fails randomly
|
||||||
async def flaky_step():
|
async def flaky_step():
|
||||||
|
@ -62,8 +63,8 @@ async def flaky_step():
|
||||||
return "ok"
|
return "ok"
|
||||||
|
|
||||||
# Create the actions
|
# Create the actions
|
||||||
step1 = Action(name="step_1", action=flaky_step, retry=True)
|
step1 = Action(name="step_1", action=flaky_step)
|
||||||
step2 = Action(name="step_2", action=flaky_step, retry=True)
|
step2 = Action(name="step_2", action=flaky_step)
|
||||||
|
|
||||||
# Chain the actions
|
# Chain the actions
|
||||||
chain = ChainedAction(name="my_pipeline", actions=[step1, step2])
|
chain = ChainedAction(name="my_pipeline", actions=[step1, step2])
|
||||||
|
@ -74,9 +75,9 @@ falyx.add_command(
|
||||||
key="R",
|
key="R",
|
||||||
description="Run My Pipeline",
|
description="Run My Pipeline",
|
||||||
action=chain,
|
action=chain,
|
||||||
logging_hooks=True,
|
|
||||||
preview_before_confirm=True,
|
preview_before_confirm=True,
|
||||||
confirm=True,
|
confirm=True,
|
||||||
|
retry_all=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Entry point
|
# Entry point
|
||||||
|
|
|
@ -1,26 +1,36 @@
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
from falyx import Falyx
|
from falyx import Falyx
|
||||||
from falyx.action import ProcessAction
|
from falyx.action import ProcessPoolAction
|
||||||
|
from falyx.action.process_pool_action import ProcessTask
|
||||||
|
from falyx.execution_registry import ExecutionRegistry as er
|
||||||
from falyx.themes import NordColors as nc
|
from falyx.themes import NordColors as nc
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
falyx = Falyx(title="🚀 Process Pool Demo")
|
falyx = Falyx(title="🚀 Process Pool Demo")
|
||||||
|
|
||||||
|
|
||||||
def generate_primes(n):
|
def generate_primes(start: int = 2, end: int = 100_000) -> list[int]:
|
||||||
primes = []
|
primes: list[int] = []
|
||||||
for num in range(2, n):
|
console.print(f"Generating primes from {start} to {end}...", style=nc.YELLOW)
|
||||||
|
for num in range(start, end):
|
||||||
if all(num % p != 0 for p in primes):
|
if all(num % p != 0 for p in primes):
|
||||||
primes.append(num)
|
primes.append(num)
|
||||||
console.print(f"Generated {len(primes)} primes up to {n}.", style=nc.GREEN)
|
console.print(
|
||||||
|
f"Generated {len(primes)} primes from {start} to {end}.", style=nc.GREEN
|
||||||
|
)
|
||||||
return primes
|
return primes
|
||||||
|
|
||||||
|
|
||||||
# Will not block the event loop
|
actions = [ProcessTask(task=generate_primes)]
|
||||||
heavy_action = ProcessAction("Prime Generator", generate_primes, args=(100_000,))
|
|
||||||
|
|
||||||
falyx.add_command("R", "Generate Primes", heavy_action, spinner=True)
|
# Will not block the event loop
|
||||||
|
heavy_action = ProcessPoolAction(
|
||||||
|
name="Prime Generator",
|
||||||
|
actions=actions,
|
||||||
|
)
|
||||||
|
|
||||||
|
falyx.add_command("R", "Generate Primes", heavy_action)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -16,6 +16,7 @@ from .io_action import BaseIOAction, ShellAction
|
||||||
from .literal_input_action import LiteralInputAction
|
from .literal_input_action import LiteralInputAction
|
||||||
from .menu_action import MenuAction
|
from .menu_action import MenuAction
|
||||||
from .process_action import ProcessAction
|
from .process_action import ProcessAction
|
||||||
|
from .process_pool_action import ProcessPoolAction
|
||||||
from .prompt_menu_action import PromptMenuAction
|
from .prompt_menu_action import PromptMenuAction
|
||||||
from .select_file_action import SelectFileAction
|
from .select_file_action import SelectFileAction
|
||||||
from .selection_action import SelectionAction
|
from .selection_action import SelectionAction
|
||||||
|
@ -40,4 +41,5 @@ __all__ = [
|
||||||
"LiteralInputAction",
|
"LiteralInputAction",
|
||||||
"UserInputAction",
|
"UserInputAction",
|
||||||
"PromptMenuAction",
|
"PromptMenuAction",
|
||||||
|
"ProcessPoolAction",
|
||||||
]
|
]
|
||||||
|
|
|
@ -165,5 +165,6 @@ class ActionGroup(BaseAction, ActionListMixin):
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return (
|
return (
|
||||||
f"ActionGroup(name={self.name!r}, actions={[a.name for a in self.actions]!r},"
|
f"ActionGroup(name={self.name!r}, actions={[a.name for a in self.actions]!r},"
|
||||||
f" inject_last_result={self.inject_last_result})"
|
f" inject_last_result={self.inject_last_result}, "
|
||||||
|
f"inject_into={self.inject_into!r})"
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,166 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
from rich.tree import Tree
|
||||||
|
|
||||||
|
from falyx.action.base import BaseAction
|
||||||
|
from falyx.context import ExecutionContext, SharedContext
|
||||||
|
from falyx.execution_registry import ExecutionRegistry as er
|
||||||
|
from falyx.hook_manager import HookManager, HookType
|
||||||
|
from falyx.logger import logger
|
||||||
|
from falyx.parsers.utils import same_argument_definitions
|
||||||
|
from falyx.themes import OneColors
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProcessTask:
|
||||||
|
task: Callable[..., Any]
|
||||||
|
args: tuple = ()
|
||||||
|
kwargs: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if not callable(self.task):
|
||||||
|
raise TypeError(f"Expected a callable task, got {type(self.task).__name__}")
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessPoolAction(BaseAction):
|
||||||
|
""" """
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
actions: list[ProcessTask] | None = None,
|
||||||
|
*,
|
||||||
|
hooks: HookManager | None = None,
|
||||||
|
executor: ProcessPoolExecutor | None = None,
|
||||||
|
inject_last_result: bool = False,
|
||||||
|
inject_into: str = "last_result",
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
name,
|
||||||
|
hooks=hooks,
|
||||||
|
inject_last_result=inject_last_result,
|
||||||
|
inject_into=inject_into,
|
||||||
|
)
|
||||||
|
self.executor = executor or ProcessPoolExecutor()
|
||||||
|
self.is_retryable = True
|
||||||
|
self.actions: list[ProcessTask] = []
|
||||||
|
if actions:
|
||||||
|
self.set_actions(actions)
|
||||||
|
|
||||||
|
def set_actions(self, actions: list[ProcessTask]) -> None:
|
||||||
|
"""Replaces the current action list with a new one."""
|
||||||
|
self.actions.clear()
|
||||||
|
for action in actions:
|
||||||
|
self.add_action(action)
|
||||||
|
|
||||||
|
def add_action(self, action: ProcessTask) -> None:
|
||||||
|
if not isinstance(action, ProcessTask):
|
||||||
|
raise TypeError(f"Expected a ProcessTask, got {type(action).__name__}")
|
||||||
|
self.actions.append(action)
|
||||||
|
|
||||||
|
def get_infer_target(self) -> tuple[Callable[..., Any] | None, None]:
|
||||||
|
arg_defs = same_argument_definitions([action.task for action in self.actions])
|
||||||
|
if arg_defs:
|
||||||
|
return self.actions[0].task, None
|
||||||
|
logger.debug(
|
||||||
|
"[%s] auto_args disabled: mismatched ProcessPoolAction arguments",
|
||||||
|
self.name,
|
||||||
|
)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
async def _run(self, *args, **kwargs) -> Any:
|
||||||
|
shared_context = SharedContext(name=self.name, action=self, is_parallel=True)
|
||||||
|
if self.shared_context:
|
||||||
|
shared_context.set_shared_result(self.shared_context.last_result())
|
||||||
|
if self.inject_last_result and self.shared_context:
|
||||||
|
last_result = self.shared_context.last_result()
|
||||||
|
if not self._validate_pickleable(last_result):
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot inject last result into {self.name}: "
|
||||||
|
f"last result is not pickleable."
|
||||||
|
)
|
||||||
|
print(kwargs)
|
||||||
|
updated_kwargs = self._maybe_inject_last_result(kwargs)
|
||||||
|
print(updated_kwargs)
|
||||||
|
context = ExecutionContext(
|
||||||
|
name=self.name,
|
||||||
|
args=args,
|
||||||
|
kwargs=updated_kwargs,
|
||||||
|
action=self,
|
||||||
|
)
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
context.start_timer()
|
||||||
|
try:
|
||||||
|
await self.hooks.trigger(HookType.BEFORE, context)
|
||||||
|
futures = [
|
||||||
|
loop.run_in_executor(
|
||||||
|
self.executor,
|
||||||
|
partial(
|
||||||
|
task.task,
|
||||||
|
*(*args, *task.args),
|
||||||
|
**{**updated_kwargs, **task.kwargs},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for task in self.actions
|
||||||
|
]
|
||||||
|
results = await asyncio.gather(*futures, return_exceptions=True)
|
||||||
|
context.result = results
|
||||||
|
await self.hooks.trigger(HookType.ON_SUCCESS, context)
|
||||||
|
return results
|
||||||
|
except Exception as error:
|
||||||
|
context.exception = error
|
||||||
|
await self.hooks.trigger(HookType.ON_ERROR, context)
|
||||||
|
if context.result is not None:
|
||||||
|
return context.result
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
context.stop_timer()
|
||||||
|
await self.hooks.trigger(HookType.AFTER, context)
|
||||||
|
await self.hooks.trigger(HookType.ON_TEARDOWN, context)
|
||||||
|
er.record(context)
|
||||||
|
|
||||||
|
def _validate_pickleable(self, obj: Any) -> bool:
|
||||||
|
try:
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
pickle.dumps(obj)
|
||||||
|
return True
|
||||||
|
except (pickle.PicklingError, TypeError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def preview(self, parent: Tree | None = None):
|
||||||
|
label = [f"[{OneColors.DARK_YELLOW_b}]🧠 ProcessPoolAction[/] '{self.name}'"]
|
||||||
|
if self.inject_last_result:
|
||||||
|
label.append(f" [dim](receives '{self.inject_into}')[/dim]")
|
||||||
|
tree = parent.add("".join(label)) if parent else Tree("".join(label))
|
||||||
|
actions = self.actions.copy()
|
||||||
|
random.shuffle(actions)
|
||||||
|
for action in actions:
|
||||||
|
label = [
|
||||||
|
f"[{OneColors.DARK_YELLOW_b}] - {getattr(action.task, '__name__', repr(action.task))}[/] "
|
||||||
|
f"[dim]({', '.join(map(repr, action.args))})[/]"
|
||||||
|
]
|
||||||
|
if action.kwargs:
|
||||||
|
label.append(
|
||||||
|
f" [dim]({', '.join(f'{k}={v!r}' for k, v in action.kwargs.items())})[/]"
|
||||||
|
)
|
||||||
|
tree.add("".join(label))
|
||||||
|
|
||||||
|
if not parent:
|
||||||
|
self.console.print(tree)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"ProcessPoolAction(name={self.name!r}, "
|
||||||
|
f"actions={[getattr(action.task, '__name__', repr(action.task)) for action in self.actions]}, "
|
||||||
|
f"inject_last_result={self.inject_last_result}, "
|
||||||
|
f"inject_into={self.inject_into!r})"
|
||||||
|
)
|
|
@ -166,8 +166,8 @@ class CommandArgumentParser:
|
||||||
self.help_epilogue: str = help_epilogue
|
self.help_epilogue: str = help_epilogue
|
||||||
self.aliases: list[str] = aliases or []
|
self.aliases: list[str] = aliases or []
|
||||||
self._arguments: list[Argument] = []
|
self._arguments: list[Argument] = []
|
||||||
self._positional: list[Argument] = []
|
self._positional: dict[str, Argument] = {}
|
||||||
self._keyword: list[Argument] = []
|
self._keyword: dict[str, Argument] = {}
|
||||||
self._flag_map: dict[str, Argument] = {}
|
self._flag_map: dict[str, Argument] = {}
|
||||||
self._dest_set: set[str] = set()
|
self._dest_set: set[str] = set()
|
||||||
self._add_help()
|
self._add_help()
|
||||||
|
@ -482,12 +482,12 @@ class CommandArgumentParser:
|
||||||
)
|
)
|
||||||
for flag in flags:
|
for flag in flags:
|
||||||
self._flag_map[flag] = argument
|
self._flag_map[flag] = argument
|
||||||
|
if not positional:
|
||||||
|
self._keyword[flag] = argument
|
||||||
self._dest_set.add(dest)
|
self._dest_set.add(dest)
|
||||||
self._arguments.append(argument)
|
self._arguments.append(argument)
|
||||||
if positional:
|
if positional:
|
||||||
self._positional.append(argument)
|
self._positional[dest] = argument
|
||||||
else:
|
|
||||||
self._keyword.append(argument)
|
|
||||||
|
|
||||||
def get_argument(self, dest: str) -> Argument | None:
|
def get_argument(self, dest: str) -> Argument | None:
|
||||||
return next((a for a in self._arguments if a.dest == dest), None)
|
return next((a for a in self._arguments if a.dest == dest), None)
|
||||||
|
@ -663,8 +663,8 @@ class CommandArgumentParser:
|
||||||
i = 0
|
i = 0
|
||||||
while i < len(args):
|
while i < len(args):
|
||||||
token = args[i]
|
token = args[i]
|
||||||
if token in self._flag_map:
|
if token in self._keyword:
|
||||||
spec = self._flag_map[token]
|
spec = self._keyword[token]
|
||||||
action = spec.action
|
action = spec.action
|
||||||
|
|
||||||
if action == ArgumentAction.HELP:
|
if action == ArgumentAction.HELP:
|
||||||
|
@ -836,7 +836,7 @@ class CommandArgumentParser:
|
||||||
# Options
|
# Options
|
||||||
# Add all keyword arguments to the options list
|
# Add all keyword arguments to the options list
|
||||||
options_list = []
|
options_list = []
|
||||||
for arg in self._keyword:
|
for arg in self._keyword.values():
|
||||||
choice_text = arg.get_choice_text()
|
choice_text = arg.get_choice_text()
|
||||||
if choice_text:
|
if choice_text:
|
||||||
options_list.extend([f"[{arg.flags[0]} {choice_text}]"])
|
options_list.extend([f"[{arg.flags[0]} {choice_text}]"])
|
||||||
|
@ -844,7 +844,7 @@ class CommandArgumentParser:
|
||||||
options_list.extend([f"[{arg.flags[0]}]"])
|
options_list.extend([f"[{arg.flags[0]}]"])
|
||||||
|
|
||||||
# Add positional arguments to the options list
|
# Add positional arguments to the options list
|
||||||
for arg in self._positional:
|
for arg in self._positional.values():
|
||||||
choice_text = arg.get_choice_text()
|
choice_text = arg.get_choice_text()
|
||||||
if isinstance(arg.nargs, int):
|
if isinstance(arg.nargs, int):
|
||||||
choice_text = " ".join([choice_text] * arg.nargs)
|
choice_text = " ".join([choice_text] * arg.nargs)
|
||||||
|
@ -870,14 +870,14 @@ class CommandArgumentParser:
|
||||||
if self._arguments:
|
if self._arguments:
|
||||||
if self._positional:
|
if self._positional:
|
||||||
self.console.print("[bold]positional:[/bold]")
|
self.console.print("[bold]positional:[/bold]")
|
||||||
for arg in self._positional:
|
for arg in self._positional.values():
|
||||||
flags = arg.get_positional_text()
|
flags = arg.get_positional_text()
|
||||||
arg_line = Text(f" {flags:<30} ")
|
arg_line = Text(f" {flags:<30} ")
|
||||||
help_text = arg.help or ""
|
help_text = arg.help or ""
|
||||||
arg_line.append(help_text)
|
arg_line.append(help_text)
|
||||||
self.console.print(arg_line)
|
self.console.print(arg_line)
|
||||||
self.console.print("[bold]options:[/bold]")
|
self.console.print("[bold]options:[/bold]")
|
||||||
for arg in self._keyword:
|
for arg in self._keyword.values():
|
||||||
flags = ", ".join(arg.flags)
|
flags = ", ".join(arg.flags)
|
||||||
flags_choice = f"{flags} {arg.get_choice_text()}"
|
flags_choice = f"{flags} {arg.get_choice_text()}"
|
||||||
arg_line = Text(f" {flags_choice:<30} ")
|
arg_line = Text(f" {flags_choice:<30} ")
|
||||||
|
@ -906,8 +906,8 @@ class CommandArgumentParser:
|
||||||
required = sum(arg.required for arg in self._arguments)
|
required = sum(arg.required for arg in self._arguments)
|
||||||
return (
|
return (
|
||||||
f"CommandArgumentParser(args={len(self._arguments)}, "
|
f"CommandArgumentParser(args={len(self._arguments)}, "
|
||||||
f"flags={len(self._flag_map)}, dests={len(self._dest_set)}, "
|
f"flags={len(self._flag_map)}, keywords={len(self._keyword)}, "
|
||||||
f"required={required}, positional={positional})"
|
f"positional={positional}, required={required})"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
__version__ = "0.1.38"
|
__version__ = "0.1.39"
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "falyx"
|
name = "falyx"
|
||||||
version = "0.1.38"
|
version = "0.1.39"
|
||||||
description = "Reliable and introspectable async CLI action framework."
|
description = "Reliable and introspectable async CLI action framework."
|
||||||
authors = ["Roland Thomas Jr <roland@rtj.dev>"]
|
authors = ["Roland Thomas Jr <roland@rtj.dev>"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from falyx.action import Action, ActionFactoryAction, ChainedAction
|
||||||
|
|
||||||
|
|
||||||
|
def make_chain(value) -> ChainedAction:
|
||||||
|
return ChainedAction(
|
||||||
|
"test_chain",
|
||||||
|
[
|
||||||
|
Action("action1", lambda: value + "_1"),
|
||||||
|
Action("action2", lambda: value + "_2"),
|
||||||
|
],
|
||||||
|
return_list=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_action_factory_action():
|
||||||
|
action = ActionFactoryAction(
|
||||||
|
name="test_action", factory=make_chain, args=("test_value",)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await action()
|
||||||
|
|
||||||
|
assert result == ["test_value_1", "test_value_2"]
|
|
@ -0,0 +1,49 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from falyx.exceptions import CommandArgumentError
|
||||||
|
from falyx.parsers import CommandArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
def test_str():
|
||||||
|
"""Test the string representation of CommandArgumentParser."""
|
||||||
|
parser = CommandArgumentParser()
|
||||||
|
assert (
|
||||||
|
str(parser)
|
||||||
|
== "CommandArgumentParser(args=1, flags=2, keywords=2, positional=0, required=0)"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("test", action="store", help="Test argument")
|
||||||
|
assert (
|
||||||
|
str(parser)
|
||||||
|
== "CommandArgumentParser(args=2, flags=3, keywords=2, positional=1, required=1)"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("-o", "--optional", action="store", help="Optional argument")
|
||||||
|
assert (
|
||||||
|
str(parser)
|
||||||
|
== "CommandArgumentParser(args=3, flags=5, keywords=4, positional=1, required=1)"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--flag", action="store", help="Flag argument", required=True)
|
||||||
|
assert (
|
||||||
|
str(parser)
|
||||||
|
== "CommandArgumentParser(args=4, flags=6, keywords=5, positional=1, required=2)"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
repr(parser)
|
||||||
|
== "CommandArgumentParser(args=4, flags=6, keywords=5, positional=1, required=2)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_positional_text_with_choices():
|
||||||
|
parser = CommandArgumentParser()
|
||||||
|
parser.add_argument("path", choices=["a", "b"])
|
||||||
|
args = await parser.parse_args(["a"])
|
||||||
|
assert args["path"] == "a"
|
||||||
|
|
||||||
|
with pytest.raises(CommandArgumentError):
|
||||||
|
await parser.parse_args(["c"])
|
||||||
|
|
||||||
|
with pytest.raises(CommandArgumentError):
|
||||||
|
await parser.parse_args([])
|
Loading…
Reference in New Issue