Add args, kwargs to ChainedAction, ActionGroup, Add type_word_cancel and acknowledge ConfirmTypes, update ChainedAction rollback logic

This commit is contained in:
2025-07-16 18:54:03 -04:00
parent 2288015cf3
commit dc1764e752
9 changed files with 65 additions and 20 deletions

View File

@ -13,6 +13,9 @@ class Place(Enum):
SAN_FRANCISCO = "San Francisco" SAN_FRANCISCO = "San Francisco"
LONDON = "London" LONDON = "London"
def __str__(self):
return self.value
async def test_args( async def test_args(
service: str, service: str,

View File

@ -112,7 +112,16 @@ class ActionFactory(BaseAction):
tree = parent.add(label) if parent else Tree(label) tree = parent.add(label) if parent else Tree(label)
try: try:
generated = await self.factory(*self.preview_args, **self.preview_kwargs) generated = None
if self.args or self.kwargs:
try:
generated = await self.factory(*self.args, **self.kwargs)
except TypeError:
...
if not generated:
generated = await self.factory(*self.preview_args, **self.preview_kwargs)
if isinstance(generated, BaseAction): if isinstance(generated, BaseAction):
await generated.preview(parent=tree) await generated.preview(parent=tree)
else: else:

View File

@ -60,6 +60,8 @@ class ActionGroup(BaseAction, ActionListMixin):
Sequence[BaseAction | Callable[..., Any] | Callable[..., Awaitable]] | None Sequence[BaseAction | Callable[..., Any] | Callable[..., Awaitable]] | None
) = None, ) = None,
*, *,
args: tuple[Any, ...] = (),
kwargs: dict[str, Any] | None = None,
hooks: HookManager | None = None, hooks: HookManager | None = None,
inject_last_result: bool = False, inject_last_result: bool = False,
inject_into: str = "last_result", inject_into: str = "last_result",
@ -71,6 +73,8 @@ class ActionGroup(BaseAction, ActionListMixin):
inject_into=inject_into, inject_into=inject_into,
) )
ActionListMixin.__init__(self) ActionListMixin.__init__(self)
self.args = args
self.kwargs = kwargs or {}
if actions: if actions:
self.set_actions(actions) self.set_actions(actions)
@ -115,13 +119,17 @@ class ActionGroup(BaseAction, ActionListMixin):
async def _run(self, *args, **kwargs) -> list[tuple[str, Any]]: async def _run(self, *args, **kwargs) -> list[tuple[str, Any]]:
if not self.actions: if not self.actions:
raise EmptyGroupError(f"[{self.name}] No actions to execute.") raise EmptyGroupError(f"[{self.name}] No actions to execute.")
combined_args = args + self.args
combined_kwargs = {**self.kwargs, **kwargs}
shared_context = SharedContext(name=self.name, action=self, is_parallel=True) shared_context = SharedContext(name=self.name, action=self, is_parallel=True)
if self.shared_context: if self.shared_context:
shared_context.set_shared_result(self.shared_context.last_result()) shared_context.set_shared_result(self.shared_context.last_result())
updated_kwargs = self._maybe_inject_last_result(kwargs) updated_kwargs = self._maybe_inject_last_result(combined_kwargs)
context = ExecutionContext( context = ExecutionContext(
name=self.name, name=self.name,
args=args, args=combined_args,
kwargs=updated_kwargs, kwargs=updated_kwargs,
action=self, action=self,
extra={"results": [], "errors": []}, extra={"results": [], "errors": []},
@ -131,7 +139,7 @@ class ActionGroup(BaseAction, ActionListMixin):
async def run_one(action: BaseAction): async def run_one(action: BaseAction):
try: try:
prepared = action.prepare(shared_context, self.options_manager) prepared = action.prepare(shared_context, self.options_manager)
result = await prepared(*args, **updated_kwargs) result = await prepared(*combined_args, **updated_kwargs)
shared_context.add_result((action.name, result)) shared_context.add_result((action.name, result))
context.extra["results"].append((action.name, result)) context.extra["results"].append((action.name, result))
except Exception as error: except Exception as error:

View File

@ -61,7 +61,9 @@ class ConfirmType(Enum):
YES_CANCEL = "yes_cancel" YES_CANCEL = "yes_cancel"
YES_NO_CANCEL = "yes_no_cancel" YES_NO_CANCEL = "yes_no_cancel"
TYPE_WORD = "type_word" TYPE_WORD = "type_word"
TYPE_WORD_CANCEL = "type_word_cancel"
OK_CANCEL = "ok_cancel" OK_CANCEL = "ok_cancel"
ACKNOWLEDGE = "acknowledge"
@classmethod @classmethod
def choices(cls) -> list[ConfirmType]: def choices(cls) -> list[ConfirmType]:

View File

@ -54,6 +54,8 @@ class ChainedAction(BaseAction, ActionListMixin):
| None | None
) = None, ) = None,
*, *,
args: tuple[Any, ...] = (),
kwargs: dict[str, Any] | None = None,
hooks: HookManager | None = None, hooks: HookManager | None = None,
inject_last_result: bool = False, inject_last_result: bool = False,
inject_into: str = "last_result", inject_into: str = "last_result",
@ -67,6 +69,8 @@ class ChainedAction(BaseAction, ActionListMixin):
inject_into=inject_into, inject_into=inject_into,
) )
ActionListMixin.__init__(self) ActionListMixin.__init__(self)
self.args = args
self.kwargs = kwargs or {}
self.auto_inject = auto_inject self.auto_inject = auto_inject
self.return_list = return_list self.return_list = return_list
if actions: if actions:
@ -111,13 +115,16 @@ class ChainedAction(BaseAction, ActionListMixin):
if not self.actions: if not self.actions:
raise EmptyChainError(f"[{self.name}] No actions to execute.") raise EmptyChainError(f"[{self.name}] No actions to execute.")
combined_args = args + self.args
combined_kwargs = {**self.kwargs, **kwargs}
shared_context = SharedContext(name=self.name, action=self) shared_context = SharedContext(name=self.name, action=self)
if self.shared_context: if self.shared_context:
shared_context.add_result(self.shared_context.last_result()) shared_context.add_result(self.shared_context.last_result())
updated_kwargs = self._maybe_inject_last_result(kwargs) updated_kwargs = self._maybe_inject_last_result(combined_kwargs)
context = ExecutionContext( context = ExecutionContext(
name=self.name, name=self.name,
args=args, args=combined_args,
kwargs=updated_kwargs, kwargs=updated_kwargs,
action=self, action=self,
extra={"results": [], "rollback_stack": []}, extra={"results": [], "rollback_stack": []},
@ -136,7 +143,7 @@ class ChainedAction(BaseAction, ActionListMixin):
shared_context.current_index = index shared_context.current_index = index
prepared = action.prepare(shared_context, self.options_manager) prepared = action.prepare(shared_context, self.options_manager)
try: try:
result = await prepared(*args, **updated_kwargs) result = await prepared(*combined_args, **updated_kwargs)
except Exception as error: except Exception as error:
if index + 1 < len(self.actions) and isinstance( if index + 1 < len(self.actions) and isinstance(
self.actions[index + 1], FallbackAction self.actions[index + 1], FallbackAction
@ -155,10 +162,12 @@ class ChainedAction(BaseAction, ActionListMixin):
fallback._skip_in_chain = True fallback._skip_in_chain = True
else: else:
raise raise
args, updated_kwargs = self._clear_args()
shared_context.add_result(result) shared_context.add_result(result)
context.extra["results"].append(result) context.extra["results"].append(result)
context.extra["rollback_stack"].append(prepared) context.extra["rollback_stack"].append(
(prepared, combined_args, updated_kwargs)
)
combined_args, updated_kwargs = self._clear_args()
all_results = context.extra["results"] all_results = context.extra["results"]
assert ( assert (
@ -171,11 +180,11 @@ class ChainedAction(BaseAction, ActionListMixin):
logger.info("[%s] Chain broken: %s", self.name, error) logger.info("[%s] Chain broken: %s", self.name, error)
context.exception = error context.exception = error
shared_context.add_error(shared_context.current_index, error) shared_context.add_error(shared_context.current_index, error)
await self._rollback(context.extra["rollback_stack"], *args, **kwargs) await self._rollback(context.extra["rollback_stack"])
except Exception as error: except Exception as error:
context.exception = error context.exception = error
shared_context.add_error(shared_context.current_index, error) shared_context.add_error(shared_context.current_index, error)
await self._rollback(context.extra["rollback_stack"], *args, **kwargs) await self._rollback(context.extra["rollback_stack"])
await self.hooks.trigger(HookType.ON_ERROR, context) await self.hooks.trigger(HookType.ON_ERROR, context)
raise raise
finally: finally:
@ -184,7 +193,9 @@ class ChainedAction(BaseAction, ActionListMixin):
await self.hooks.trigger(HookType.ON_TEARDOWN, context) await self.hooks.trigger(HookType.ON_TEARDOWN, context)
er.record(context) er.record(context)
async def _rollback(self, rollback_stack, *args, **kwargs): async def _rollback(
self, rollback_stack: list[tuple[Action, tuple[Any, ...], dict[str, Any]]]
):
""" """
Roll back all executed actions in reverse order. Roll back all executed actions in reverse order.
@ -197,12 +208,12 @@ class ChainedAction(BaseAction, ActionListMixin):
rollback_stack (list): Actions to roll back. rollback_stack (list): Actions to roll back.
*args, **kwargs: Passed to rollback handlers. *args, **kwargs: Passed to rollback handlers.
""" """
for action in reversed(rollback_stack): for action, args, kwargs in reversed(rollback_stack):
rollback = getattr(action, "rollback", None) rollback = getattr(action, "rollback", None)
if rollback: if rollback:
try: try:
logger.warning("[%s] Rolling back...", action.name) logger.warning("[%s] Rolling back...", action.name)
await action.rollback(*args, **kwargs) await rollback(*args, **kwargs)
except Exception as error: except Exception as error:
logger.error("[%s] Rollback failed: %s", action.name, error) logger.error("[%s] Rollback failed: %s", action.name, error)

View File

@ -112,6 +112,14 @@ class ConfirmAction(BaseAction):
validator=word_validator(self.word), validator=word_validator(self.word),
) )
return answer.upper().strip() != "N" return answer.upper().strip() != "N"
case ConfirmType.TYPE_WORD_CANCEL:
answer = await self.prompt_session.prompt_async(
f"{self.message} [{self.word}] to confirm or [N/n] > ",
validator=word_validator(self.word),
)
if answer.upper().strip() == "N":
raise CancelSignal(f"Action '{self.name}' was cancelled by the user.")
return answer.upper().strip() == self.word.upper().strip()
case ConfirmType.YES_CANCEL: case ConfirmType.YES_CANCEL:
answer = await confirm_async( answer = await confirm_async(
self.message, self.message,
@ -131,6 +139,12 @@ class ConfirmAction(BaseAction):
if answer.upper() == "C": if answer.upper() == "C":
raise CancelSignal(f"Action '{self.name}' was cancelled by the user.") raise CancelSignal(f"Action '{self.name}' was cancelled by the user.")
return answer.upper() == "O" return answer.upper() == "O"
case ConfirmType.ACKNOWLEDGE:
answer = await self.prompt_session.prompt_async(
f"{self.message} [A]cknowledge > ",
validator=word_validator("A"),
)
return answer.upper().strip() == "A"
case _: case _:
raise ValueError(f"Unknown confirm_type: {self.confirm_type}") raise ValueError(f"Unknown confirm_type: {self.confirm_type}")
@ -151,7 +165,7 @@ class ConfirmAction(BaseAction):
and not should_prompt_user(confirm=True, options=self.options_manager) and not should_prompt_user(confirm=True, options=self.options_manager)
): ):
logger.debug( logger.debug(
"Skipping confirmation for action '%s' as 'confirm' is False or options manager indicates no prompt.", "Skipping confirmation for '%s' due to never_prompt or options_manager settings.",
self.name, self.name,
) )
if self.return_last_result: if self.return_last_result:
@ -189,7 +203,7 @@ class ConfirmAction(BaseAction):
tree.add(f"[bold]Message:[/] {self.message}") tree.add(f"[bold]Message:[/] {self.message}")
tree.add(f"[bold]Type:[/] {self.confirm_type.value}") tree.add(f"[bold]Type:[/] {self.confirm_type.value}")
tree.add(f"[bold]Prompt Required:[/] {'No' if self.never_prompt else 'Yes'}") tree.add(f"[bold]Prompt Required:[/] {'No' if self.never_prompt else 'Yes'}")
if self.confirm_type == ConfirmType.TYPE_WORD: if self.confirm_type in (ConfirmType.TYPE_WORD, ConfirmType.TYPE_WORD_CANCEL):
tree.add(f"[bold]Confirmation Word:[/] {self.word}") tree.add(f"[bold]Confirmation Word:[/] {self.word}")
if parent is None: if parent is None:
self.console.print(tree) self.console.print(tree)

View File

@ -91,9 +91,7 @@ class ProcessPoolAction(BaseAction):
f"Cannot inject last result into {self.name}: " f"Cannot inject last result into {self.name}: "
f"last result is not pickleable." f"last result is not pickleable."
) )
print(kwargs)
updated_kwargs = self._maybe_inject_last_result(kwargs) updated_kwargs = self._maybe_inject_last_result(kwargs)
print(updated_kwargs)
context = ExecutionContext( context = ExecutionContext(
name=self.name, name=self.name,
args=args, args=args,

View File

@ -1 +1 @@
__version__ = "0.1.61" __version__ = "0.1.62"

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "falyx" name = "falyx"
version = "0.1.61" version = "0.1.62"
description = "Reliable and introspectable async CLI action framework." description = "Reliable and introspectable async CLI action framework."
authors = ["Roland Thomas Jr <roland@rtj.dev>"] authors = ["Roland Thomas Jr <roland@rtj.dev>"]
license = "MIT" license = "MIT"