Add args, kwargs to ChainedAction, ActionGroup, Add type_word_cancel and acknowledge ConfirmTypes, update ChainedAction rollback logic

This commit is contained in:
2025-07-16 18:54:03 -04:00
parent 2288015cf3
commit dc1764e752
9 changed files with 65 additions and 20 deletions

View File

@ -13,6 +13,9 @@ class Place(Enum):
SAN_FRANCISCO = "San Francisco"
LONDON = "London"
def __str__(self):
return self.value
async def test_args(
service: str,

View File

@ -112,7 +112,16 @@ class ActionFactory(BaseAction):
tree = parent.add(label) if parent else Tree(label)
try:
generated = await self.factory(*self.preview_args, **self.preview_kwargs)
generated = None
if self.args or self.kwargs:
try:
generated = await self.factory(*self.args, **self.kwargs)
except TypeError:
...
if not generated:
generated = await self.factory(*self.preview_args, **self.preview_kwargs)
if isinstance(generated, BaseAction):
await generated.preview(parent=tree)
else:

View File

@ -60,6 +60,8 @@ class ActionGroup(BaseAction, ActionListMixin):
Sequence[BaseAction | Callable[..., Any] | Callable[..., Awaitable]] | None
) = None,
*,
args: tuple[Any, ...] = (),
kwargs: dict[str, Any] | None = None,
hooks: HookManager | None = None,
inject_last_result: bool = False,
inject_into: str = "last_result",
@ -71,6 +73,8 @@ class ActionGroup(BaseAction, ActionListMixin):
inject_into=inject_into,
)
ActionListMixin.__init__(self)
self.args = args
self.kwargs = kwargs or {}
if actions:
self.set_actions(actions)
@ -115,13 +119,17 @@ class ActionGroup(BaseAction, ActionListMixin):
async def _run(self, *args, **kwargs) -> list[tuple[str, Any]]:
if not self.actions:
raise EmptyGroupError(f"[{self.name}] No actions to execute.")
combined_args = args + self.args
combined_kwargs = {**self.kwargs, **kwargs}
shared_context = SharedContext(name=self.name, action=self, is_parallel=True)
if self.shared_context:
shared_context.set_shared_result(self.shared_context.last_result())
updated_kwargs = self._maybe_inject_last_result(kwargs)
updated_kwargs = self._maybe_inject_last_result(combined_kwargs)
context = ExecutionContext(
name=self.name,
args=args,
args=combined_args,
kwargs=updated_kwargs,
action=self,
extra={"results": [], "errors": []},
@ -131,7 +139,7 @@ class ActionGroup(BaseAction, ActionListMixin):
async def run_one(action: BaseAction):
try:
prepared = action.prepare(shared_context, self.options_manager)
result = await prepared(*args, **updated_kwargs)
result = await prepared(*combined_args, **updated_kwargs)
shared_context.add_result((action.name, result))
context.extra["results"].append((action.name, result))
except Exception as error:

View File

@ -61,7 +61,9 @@ class ConfirmType(Enum):
YES_CANCEL = "yes_cancel"
YES_NO_CANCEL = "yes_no_cancel"
TYPE_WORD = "type_word"
TYPE_WORD_CANCEL = "type_word_cancel"
OK_CANCEL = "ok_cancel"
ACKNOWLEDGE = "acknowledge"
@classmethod
def choices(cls) -> list[ConfirmType]:

View File

@ -54,6 +54,8 @@ class ChainedAction(BaseAction, ActionListMixin):
| None
) = None,
*,
args: tuple[Any, ...] = (),
kwargs: dict[str, Any] | None = None,
hooks: HookManager | None = None,
inject_last_result: bool = False,
inject_into: str = "last_result",
@ -67,6 +69,8 @@ class ChainedAction(BaseAction, ActionListMixin):
inject_into=inject_into,
)
ActionListMixin.__init__(self)
self.args = args
self.kwargs = kwargs or {}
self.auto_inject = auto_inject
self.return_list = return_list
if actions:
@ -111,13 +115,16 @@ class ChainedAction(BaseAction, ActionListMixin):
if not self.actions:
raise EmptyChainError(f"[{self.name}] No actions to execute.")
combined_args = args + self.args
combined_kwargs = {**self.kwargs, **kwargs}
shared_context = SharedContext(name=self.name, action=self)
if self.shared_context:
shared_context.add_result(self.shared_context.last_result())
updated_kwargs = self._maybe_inject_last_result(kwargs)
updated_kwargs = self._maybe_inject_last_result(combined_kwargs)
context = ExecutionContext(
name=self.name,
args=args,
args=combined_args,
kwargs=updated_kwargs,
action=self,
extra={"results": [], "rollback_stack": []},
@ -136,7 +143,7 @@ class ChainedAction(BaseAction, ActionListMixin):
shared_context.current_index = index
prepared = action.prepare(shared_context, self.options_manager)
try:
result = await prepared(*args, **updated_kwargs)
result = await prepared(*combined_args, **updated_kwargs)
except Exception as error:
if index + 1 < len(self.actions) and isinstance(
self.actions[index + 1], FallbackAction
@ -155,10 +162,12 @@ class ChainedAction(BaseAction, ActionListMixin):
fallback._skip_in_chain = True
else:
raise
args, updated_kwargs = self._clear_args()
shared_context.add_result(result)
context.extra["results"].append(result)
context.extra["rollback_stack"].append(prepared)
context.extra["rollback_stack"].append(
(prepared, combined_args, updated_kwargs)
)
combined_args, updated_kwargs = self._clear_args()
all_results = context.extra["results"]
assert (
@ -171,11 +180,11 @@ class ChainedAction(BaseAction, ActionListMixin):
logger.info("[%s] Chain broken: %s", self.name, error)
context.exception = error
shared_context.add_error(shared_context.current_index, error)
await self._rollback(context.extra["rollback_stack"], *args, **kwargs)
await self._rollback(context.extra["rollback_stack"])
except Exception as error:
context.exception = error
shared_context.add_error(shared_context.current_index, error)
await self._rollback(context.extra["rollback_stack"], *args, **kwargs)
await self._rollback(context.extra["rollback_stack"])
await self.hooks.trigger(HookType.ON_ERROR, context)
raise
finally:
@ -184,7 +193,9 @@ class ChainedAction(BaseAction, ActionListMixin):
await self.hooks.trigger(HookType.ON_TEARDOWN, context)
er.record(context)
async def _rollback(self, rollback_stack, *args, **kwargs):
async def _rollback(
self, rollback_stack: list[tuple[Action, tuple[Any, ...], dict[str, Any]]]
):
"""
Roll back all executed actions in reverse order.
@ -197,12 +208,12 @@ class ChainedAction(BaseAction, ActionListMixin):
rollback_stack (list): Actions to roll back.
*args, **kwargs: Passed to rollback handlers.
"""
for action in reversed(rollback_stack):
for action, args, kwargs in reversed(rollback_stack):
rollback = getattr(action, "rollback", None)
if rollback:
try:
logger.warning("[%s] Rolling back...", action.name)
await action.rollback(*args, **kwargs)
await rollback(*args, **kwargs)
except Exception as error:
logger.error("[%s] Rollback failed: %s", action.name, error)

View File

@ -112,6 +112,14 @@ class ConfirmAction(BaseAction):
validator=word_validator(self.word),
)
return answer.upper().strip() != "N"
case ConfirmType.TYPE_WORD_CANCEL:
answer = await self.prompt_session.prompt_async(
f"{self.message} [{self.word}] to confirm or [N/n] > ",
validator=word_validator(self.word),
)
if answer.upper().strip() == "N":
raise CancelSignal(f"Action '{self.name}' was cancelled by the user.")
return answer.upper().strip() == self.word.upper().strip()
case ConfirmType.YES_CANCEL:
answer = await confirm_async(
self.message,
@ -131,6 +139,12 @@ class ConfirmAction(BaseAction):
if answer.upper() == "C":
raise CancelSignal(f"Action '{self.name}' was cancelled by the user.")
return answer.upper() == "O"
case ConfirmType.ACKNOWLEDGE:
answer = await self.prompt_session.prompt_async(
f"{self.message} [A]cknowledge > ",
validator=word_validator("A"),
)
return answer.upper().strip() == "A"
case _:
raise ValueError(f"Unknown confirm_type: {self.confirm_type}")
@ -151,7 +165,7 @@ class ConfirmAction(BaseAction):
and not should_prompt_user(confirm=True, options=self.options_manager)
):
logger.debug(
"Skipping confirmation for action '%s' as 'confirm' is False or options manager indicates no prompt.",
"Skipping confirmation for '%s' due to never_prompt or options_manager settings.",
self.name,
)
if self.return_last_result:
@ -189,7 +203,7 @@ class ConfirmAction(BaseAction):
tree.add(f"[bold]Message:[/] {self.message}")
tree.add(f"[bold]Type:[/] {self.confirm_type.value}")
tree.add(f"[bold]Prompt Required:[/] {'No' if self.never_prompt else 'Yes'}")
if self.confirm_type == ConfirmType.TYPE_WORD:
if self.confirm_type in (ConfirmType.TYPE_WORD, ConfirmType.TYPE_WORD_CANCEL):
tree.add(f"[bold]Confirmation Word:[/] {self.word}")
if parent is None:
self.console.print(tree)

View File

@ -91,9 +91,7 @@ class ProcessPoolAction(BaseAction):
f"Cannot inject last result into {self.name}: "
f"last result is not pickleable."
)
print(kwargs)
updated_kwargs = self._maybe_inject_last_result(kwargs)
print(updated_kwargs)
context = ExecutionContext(
name=self.name,
args=args,

View File

@ -1 +1 @@
__version__ = "0.1.61"
__version__ = "0.1.62"

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "falyx"
version = "0.1.61"
version = "0.1.62"
description = "Reliable and introspectable async CLI action framework."
authors = ["Roland Thomas Jr <roland@rtj.dev>"]
license = "MIT"