Add args, kwargs to ChainedAction, ActionGroup, Add type_word_cancel and acknowledge ConfirmTypes, update ChainedAction rollback logic
This commit is contained in:
		| @@ -13,6 +13,9 @@ class Place(Enum): | ||||
|     SAN_FRANCISCO = "San Francisco" | ||||
|     LONDON = "London" | ||||
|  | ||||
|     def __str__(self): | ||||
|         return self.value | ||||
|  | ||||
|  | ||||
| async def test_args( | ||||
|     service: str, | ||||
|   | ||||
| @@ -112,7 +112,16 @@ class ActionFactory(BaseAction): | ||||
|         tree = parent.add(label) if parent else Tree(label) | ||||
|  | ||||
|         try: | ||||
|             generated = await self.factory(*self.preview_args, **self.preview_kwargs) | ||||
|             generated = None | ||||
|             if self.args or self.kwargs: | ||||
|                 try: | ||||
|                     generated = await self.factory(*self.args, **self.kwargs) | ||||
|                 except TypeError: | ||||
|                     ... | ||||
|  | ||||
|             if not generated: | ||||
|                 generated = await self.factory(*self.preview_args, **self.preview_kwargs) | ||||
|  | ||||
|             if isinstance(generated, BaseAction): | ||||
|                 await generated.preview(parent=tree) | ||||
|             else: | ||||
|   | ||||
| @@ -60,6 +60,8 @@ class ActionGroup(BaseAction, ActionListMixin): | ||||
|             Sequence[BaseAction | Callable[..., Any] | Callable[..., Awaitable]] | None | ||||
|         ) = None, | ||||
|         *, | ||||
|         args: tuple[Any, ...] = (), | ||||
|         kwargs: dict[str, Any] | None = None, | ||||
|         hooks: HookManager | None = None, | ||||
|         inject_last_result: bool = False, | ||||
|         inject_into: str = "last_result", | ||||
| @@ -71,6 +73,8 @@ class ActionGroup(BaseAction, ActionListMixin): | ||||
|             inject_into=inject_into, | ||||
|         ) | ||||
|         ActionListMixin.__init__(self) | ||||
|         self.args = args | ||||
|         self.kwargs = kwargs or {} | ||||
|         if actions: | ||||
|             self.set_actions(actions) | ||||
|  | ||||
| @@ -115,13 +119,17 @@ class ActionGroup(BaseAction, ActionListMixin): | ||||
|     async def _run(self, *args, **kwargs) -> list[tuple[str, Any]]: | ||||
|         if not self.actions: | ||||
|             raise EmptyGroupError(f"[{self.name}] No actions to execute.") | ||||
|  | ||||
|         combined_args = args + self.args | ||||
|         combined_kwargs = {**self.kwargs, **kwargs} | ||||
|  | ||||
|         shared_context = SharedContext(name=self.name, action=self, is_parallel=True) | ||||
|         if self.shared_context: | ||||
|             shared_context.set_shared_result(self.shared_context.last_result()) | ||||
|         updated_kwargs = self._maybe_inject_last_result(kwargs) | ||||
|         updated_kwargs = self._maybe_inject_last_result(combined_kwargs) | ||||
|         context = ExecutionContext( | ||||
|             name=self.name, | ||||
|             args=args, | ||||
|             args=combined_args, | ||||
|             kwargs=updated_kwargs, | ||||
|             action=self, | ||||
|             extra={"results": [], "errors": []}, | ||||
| @@ -131,7 +139,7 @@ class ActionGroup(BaseAction, ActionListMixin): | ||||
|         async def run_one(action: BaseAction): | ||||
|             try: | ||||
|                 prepared = action.prepare(shared_context, self.options_manager) | ||||
|                 result = await prepared(*args, **updated_kwargs) | ||||
|                 result = await prepared(*combined_args, **updated_kwargs) | ||||
|                 shared_context.add_result((action.name, result)) | ||||
|                 context.extra["results"].append((action.name, result)) | ||||
|             except Exception as error: | ||||
|   | ||||
| @@ -61,7 +61,9 @@ class ConfirmType(Enum): | ||||
|     YES_CANCEL = "yes_cancel" | ||||
|     YES_NO_CANCEL = "yes_no_cancel" | ||||
|     TYPE_WORD = "type_word" | ||||
|     TYPE_WORD_CANCEL = "type_word_cancel" | ||||
|     OK_CANCEL = "ok_cancel" | ||||
|     ACKNOWLEDGE = "acknowledge" | ||||
|  | ||||
|     @classmethod | ||||
|     def choices(cls) -> list[ConfirmType]: | ||||
|   | ||||
| @@ -54,6 +54,8 @@ class ChainedAction(BaseAction, ActionListMixin): | ||||
|             | None | ||||
|         ) = None, | ||||
|         *, | ||||
|         args: tuple[Any, ...] = (), | ||||
|         kwargs: dict[str, Any] | None = None, | ||||
|         hooks: HookManager | None = None, | ||||
|         inject_last_result: bool = False, | ||||
|         inject_into: str = "last_result", | ||||
| @@ -67,6 +69,8 @@ class ChainedAction(BaseAction, ActionListMixin): | ||||
|             inject_into=inject_into, | ||||
|         ) | ||||
|         ActionListMixin.__init__(self) | ||||
|         self.args = args | ||||
|         self.kwargs = kwargs or {} | ||||
|         self.auto_inject = auto_inject | ||||
|         self.return_list = return_list | ||||
|         if actions: | ||||
| @@ -111,13 +115,16 @@ class ChainedAction(BaseAction, ActionListMixin): | ||||
|         if not self.actions: | ||||
|             raise EmptyChainError(f"[{self.name}] No actions to execute.") | ||||
|  | ||||
|         combined_args = args + self.args | ||||
|         combined_kwargs = {**self.kwargs, **kwargs} | ||||
|  | ||||
|         shared_context = SharedContext(name=self.name, action=self) | ||||
|         if self.shared_context: | ||||
|             shared_context.add_result(self.shared_context.last_result()) | ||||
|         updated_kwargs = self._maybe_inject_last_result(kwargs) | ||||
|         updated_kwargs = self._maybe_inject_last_result(combined_kwargs) | ||||
|         context = ExecutionContext( | ||||
|             name=self.name, | ||||
|             args=args, | ||||
|             args=combined_args, | ||||
|             kwargs=updated_kwargs, | ||||
|             action=self, | ||||
|             extra={"results": [], "rollback_stack": []}, | ||||
| @@ -136,7 +143,7 @@ class ChainedAction(BaseAction, ActionListMixin): | ||||
|                 shared_context.current_index = index | ||||
|                 prepared = action.prepare(shared_context, self.options_manager) | ||||
|                 try: | ||||
|                     result = await prepared(*args, **updated_kwargs) | ||||
|                     result = await prepared(*combined_args, **updated_kwargs) | ||||
|                 except Exception as error: | ||||
|                     if index + 1 < len(self.actions) and isinstance( | ||||
|                         self.actions[index + 1], FallbackAction | ||||
| @@ -155,10 +162,12 @@ class ChainedAction(BaseAction, ActionListMixin): | ||||
|                         fallback._skip_in_chain = True | ||||
|                     else: | ||||
|                         raise | ||||
|                 args, updated_kwargs = self._clear_args() | ||||
|                 shared_context.add_result(result) | ||||
|                 context.extra["results"].append(result) | ||||
|                 context.extra["rollback_stack"].append(prepared) | ||||
|                 context.extra["rollback_stack"].append( | ||||
|                     (prepared, combined_args, updated_kwargs) | ||||
|                 ) | ||||
|                 combined_args, updated_kwargs = self._clear_args() | ||||
|  | ||||
|             all_results = context.extra["results"] | ||||
|             assert ( | ||||
| @@ -171,11 +180,11 @@ class ChainedAction(BaseAction, ActionListMixin): | ||||
|             logger.info("[%s] Chain broken: %s", self.name, error) | ||||
|             context.exception = error | ||||
|             shared_context.add_error(shared_context.current_index, error) | ||||
|             await self._rollback(context.extra["rollback_stack"], *args, **kwargs) | ||||
|             await self._rollback(context.extra["rollback_stack"]) | ||||
|         except Exception as error: | ||||
|             context.exception = error | ||||
|             shared_context.add_error(shared_context.current_index, error) | ||||
|             await self._rollback(context.extra["rollback_stack"], *args, **kwargs) | ||||
|             await self._rollback(context.extra["rollback_stack"]) | ||||
|             await self.hooks.trigger(HookType.ON_ERROR, context) | ||||
|             raise | ||||
|         finally: | ||||
| @@ -184,7 +193,9 @@ class ChainedAction(BaseAction, ActionListMixin): | ||||
|             await self.hooks.trigger(HookType.ON_TEARDOWN, context) | ||||
|             er.record(context) | ||||
|  | ||||
|     async def _rollback(self, rollback_stack, *args, **kwargs): | ||||
|     async def _rollback( | ||||
|         self, rollback_stack: list[tuple[Action, tuple[Any, ...], dict[str, Any]]] | ||||
|     ): | ||||
|         """ | ||||
|         Roll back all executed actions in reverse order. | ||||
|  | ||||
| @@ -197,12 +208,12 @@ class ChainedAction(BaseAction, ActionListMixin): | ||||
|             rollback_stack (list): Actions to roll back. | ||||
|             *args, **kwargs: Passed to rollback handlers. | ||||
|         """ | ||||
|         for action in reversed(rollback_stack): | ||||
|         for action, args, kwargs in reversed(rollback_stack): | ||||
|             rollback = getattr(action, "rollback", None) | ||||
|             if rollback: | ||||
|                 try: | ||||
|                     logger.warning("[%s] Rolling back...", action.name) | ||||
|                     await action.rollback(*args, **kwargs) | ||||
|                     await rollback(*args, **kwargs) | ||||
|                 except Exception as error: | ||||
|                     logger.error("[%s] Rollback failed: %s", action.name, error) | ||||
|  | ||||
|   | ||||
| @@ -112,6 +112,14 @@ class ConfirmAction(BaseAction): | ||||
|                     validator=word_validator(self.word), | ||||
|                 ) | ||||
|                 return answer.upper().strip() != "N" | ||||
|             case ConfirmType.TYPE_WORD_CANCEL: | ||||
|                 answer = await self.prompt_session.prompt_async( | ||||
|                     f"❓ {self.message} [{self.word}] to confirm or [N/n] > ", | ||||
|                     validator=word_validator(self.word), | ||||
|                 ) | ||||
|                 if answer.upper().strip() == "N": | ||||
|                     raise CancelSignal(f"Action '{self.name}' was cancelled by the user.") | ||||
|                 return answer.upper().strip() == self.word.upper().strip() | ||||
|             case ConfirmType.YES_CANCEL: | ||||
|                 answer = await confirm_async( | ||||
|                     self.message, | ||||
| @@ -131,6 +139,12 @@ class ConfirmAction(BaseAction): | ||||
|                 if answer.upper() == "C": | ||||
|                     raise CancelSignal(f"Action '{self.name}' was cancelled by the user.") | ||||
|                 return answer.upper() == "O" | ||||
|             case ConfirmType.ACKNOWLEDGE: | ||||
|                 answer = await self.prompt_session.prompt_async( | ||||
|                     f"❓ {self.message} [A]cknowledge > ", | ||||
|                     validator=word_validator("A"), | ||||
|                 ) | ||||
|                 return answer.upper().strip() == "A" | ||||
|             case _: | ||||
|                 raise ValueError(f"Unknown confirm_type: {self.confirm_type}") | ||||
|  | ||||
| @@ -151,7 +165,7 @@ class ConfirmAction(BaseAction): | ||||
|                 and not should_prompt_user(confirm=True, options=self.options_manager) | ||||
|             ): | ||||
|                 logger.debug( | ||||
|                     "Skipping confirmation for action '%s' as 'confirm' is False or options manager indicates no prompt.", | ||||
|                     "Skipping confirmation for '%s' due to never_prompt or options_manager settings.", | ||||
|                     self.name, | ||||
|                 ) | ||||
|                 if self.return_last_result: | ||||
| @@ -189,7 +203,7 @@ class ConfirmAction(BaseAction): | ||||
|         tree.add(f"[bold]Message:[/] {self.message}") | ||||
|         tree.add(f"[bold]Type:[/] {self.confirm_type.value}") | ||||
|         tree.add(f"[bold]Prompt Required:[/] {'No' if self.never_prompt else 'Yes'}") | ||||
|         if self.confirm_type == ConfirmType.TYPE_WORD: | ||||
|         if self.confirm_type in (ConfirmType.TYPE_WORD, ConfirmType.TYPE_WORD_CANCEL): | ||||
|             tree.add(f"[bold]Confirmation Word:[/] {self.word}") | ||||
|         if parent is None: | ||||
|             self.console.print(tree) | ||||
|   | ||||
| @@ -91,9 +91,7 @@ class ProcessPoolAction(BaseAction): | ||||
|                     f"Cannot inject last result into {self.name}: " | ||||
|                     f"last result is not pickleable." | ||||
|                 ) | ||||
|         print(kwargs) | ||||
|         updated_kwargs = self._maybe_inject_last_result(kwargs) | ||||
|         print(updated_kwargs) | ||||
|         context = ExecutionContext( | ||||
|             name=self.name, | ||||
|             args=args, | ||||
|   | ||||
| @@ -1 +1 @@ | ||||
| __version__ = "0.1.61" | ||||
| __version__ = "0.1.62" | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| [tool.poetry] | ||||
| name = "falyx" | ||||
| version = "0.1.61" | ||||
| version = "0.1.62" | ||||
| description = "Reliable and introspectable async CLI action framework." | ||||
| authors = ["Roland Thomas Jr <roland@rtj.dev>"] | ||||
| license = "MIT" | ||||
|   | ||||
		Reference in New Issue
	
	Block a user