From 7443084809fb641f58666e3892eda4f1b7ab071c Mon Sep 17 00:00:00 2001 From: Roland Thomas Date: Thu, 27 Mar 2025 22:02:16 -0400 Subject: [PATCH] Update to Menu and Action --- menu/action.py | 218 +++++++++++++++++++++++ menu/bottom_bar.py | 59 +++++++ menu/callbacks.py | 10 +- menu/hook_manager.py | 66 +++++++ menu/hooks.py | 117 +++++++------ menu/logging_utils.py | 2 +- menu/menu.py | 395 ++++++++++++------------------------------ menu/menu_utils.py | 79 +++++++++ menu/option.py | 110 ++++++++++++ menu/task.py | 2 +- 10 files changed, 718 insertions(+), 340 deletions(-) create mode 100644 menu/action.py create mode 100644 menu/bottom_bar.py create mode 100644 menu/hook_manager.py create mode 100644 menu/menu_utils.py create mode 100644 menu/option.py diff --git a/menu/action.py b/menu/action.py new file mode 100644 index 0000000..227ae97 --- /dev/null +++ b/menu/action.py @@ -0,0 +1,218 @@ +"""action.py + +Any Action or Option is callable and supports the signature: + result = thing(*args, **kwargs) + +This guarantees: +- Hook lifecycle (before/after/error/teardown) +- Timing +- Consistent return values +""" +from __future__ import annotations + +import asyncio +import logging +import time +import inspect +from abc import ABC, abstractmethod +from typing import Optional + +from hook_manager import HookManager +from menu_utils import TimingMixin, run_async + + +logger = logging.getLogger("menu") + + +class BaseAction(ABC, TimingMixin): + """Base class for actions. They are the building blocks of the menu. + Actions can be simple functions or more complex actions like + `ChainedAction` or `ActionGroup`. They can also be run independently + or as part of a menu.""" + def __init__(self, name: str, hooks: Optional[HookManager] = None): + self.name = name + self.hooks = hooks or HookManager() + self.start_time: float | None = None + self.end_time: float | None = None + self._duration: float | None = None + + def __call__(self, *args, **kwargs): + context = { + "name": self.name, + "duration": None, + "args": args, + "kwargs": kwargs, + "action": self + } + self._start_timer() + try: + run_async(self.hooks.trigger("before", context)) + result = self._run(*args, **kwargs) + context["result"] = result + return result + except Exception as error: + context["exception"] = error + run_async(self.hooks.trigger("on_error", context)) + if "exception" not in context: + logger.info(f"✅ Recovery hook handled error for Action '{self.name}'") + return context.get("result") + raise + finally: + self._stop_timer() + context["duration"] = self.get_duration() + if "exception" not in context: + run_async(self.hooks.trigger("after", context)) + run_async(self.hooks.trigger("on_teardown", context)) + + @abstractmethod + def _run(self, *args, **kwargs): + raise NotImplementedError("_run must be implemented by subclasses") + + async def run_async(self, *args, **kwargs): + if inspect.iscoroutinefunction(self._run): + return await self._run(*args, **kwargs) + + return await asyncio.to_thread(self.__call__, *args, **kwargs) + + def __await__(self): + return self.run_async().__await__() + + @abstractmethod + def dry_run(self): + raise NotImplementedError("dry_run must be implemented by subclasses") + + def __str__(self): + return f"<{self.__class__.__name__} '{self.name}'>" + + def __repr__(self): + return str(self) + + +class Action(BaseAction): + def __init__(self, name: str, fn, rollback=None, hooks=None): + super().__init__(name, hooks) + self.fn = fn + self.rollback = rollback + + def _run(self, *args, **kwargs): + if inspect.iscoroutinefunction(self.fn): + return asyncio.run(self.fn(*args, **kwargs)) + return self.fn(*args, **kwargs) + + def dry_run(self): + print(f"[DRY RUN] Would run: {self.name}") + + +class ChainedAction(BaseAction): + def __init__(self, name: str, actions: list[BaseAction], hooks=None): + super().__init__(name, hooks) + self.actions = actions + + def _run(self, *args, **kwargs): + rollback_stack = [] + for action in self.actions: + try: + result = action(*args, **kwargs) + rollback_stack.append(action) + except Exception: + self._rollback(rollback_stack, *args, **kwargs) + raise + return None + + def dry_run(self): + print(f"[DRY RUN] ChainedAction '{self.name}' with steps:") + for action in self.actions: + action.dry_run() + + def _rollback(self, rollback_stack, *args, **kwargs): + for action in reversed(rollback_stack): + if hasattr(action, "rollback") and action.rollback: + try: + print(f"↩️ Rolling back {action.name}") + action.rollback(*args, **kwargs) + except Exception as e: + print(f"⚠️ Rollback failed for {action.name}: {e}") + + +class ActionGroup(BaseAction): + def __init__(self, name: str, actions: list[BaseAction], hooks=None): + super().__init__(name, hooks) + self.actions = actions + self.results = [] + self.errors = [] + + def _run(self, *args, **kwargs): + asyncio.run(self._run_async(*args, **kwargs)) + + def dry_run(self): + print(f"[DRY RUN] ActionGroup '{self.name}' (parallel execution):") + for action in self.actions: + action.dry_run() + + async def _run_async(self, *args, **kwargs): + async def run(action): + try: + result = await asyncio.to_thread(action, *args, **kwargs) + self.results.append((action.name, result)) + except Exception as e: + self.errors.append((action.name, e)) + + await self.hooks.trigger("before", name=self.name) + + await asyncio.gather(*[run(a) for a in self.actions]) + + if self.errors: + await self.hooks.trigger("on_error", name=self.name, errors=self.errors) + else: + await self.hooks.trigger("after", name=self.name, results=self.results) + + await self.hooks.trigger("on_teardown", name=self.name) + + + +# if __name__ == "__main__": +# # Example usage +# def build(): print("Build!") +# def test(): print("Test!") +# def deploy(): print("Deploy!") + + +# pipeline = ChainedAction("CI/CD", [ +# Action("Build", build), +# Action("Test", test), +# ActionGroup("Deploy Parallel", [ +# Action("Deploy A", deploy), +# Action("Deploy B", deploy) +# ]) +# ]) + +# pipeline() +# Sample functions +def sync_hello(): + time.sleep(1) + return "Hello from sync function" + +async def async_hello(): + await asyncio.sleep(1) + return "Hello from async function" + + +# Example usage +async def main(): + sync_action = Action("sync_hello", sync_hello) + async_action = Action("async_hello", async_hello) + + print("⏳ Awaiting sync action...") + result1 = await sync_action + print("✅", result1) + + print("⏳ Awaiting async action...") + result2 = await async_action + print("✅", result2) + + print(f"⏱️ sync took {sync_action.get_duration():.2f}s") + print(f"⏱️ async took {async_action.get_duration():.2f}s") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/menu/bottom_bar.py b/menu/bottom_bar.py new file mode 100644 index 0000000..002e351 --- /dev/null +++ b/menu/bottom_bar.py @@ -0,0 +1,59 @@ +from prompt_toolkit.formatted_text import HTML, merge_formatted_text +from typing import Callable, Literal, Optional +from rich.console import Console + + +class BottomBar: + def __init__(self, columns: int = 3): + self.columns = columns + self.console = Console() + self._items: list[Callable[[], HTML]] = [] + self._named_items: dict[str, Callable[[], HTML]] = {} + self._states: dict[str, any] = {} + + def get_space(self) -> str: + return self.console.width // self.columns + + def add_static(self, name: str, text: str) -> None: + def render(): + return HTML(f"") + self._add_named(name, render) + + def add_counter(self, name: str, label: str, current: int, total: int) -> None: + self._states[name] = (label, current, total) + + def render(): + l, c, t = self._states[name] + text = f"{l}: {c}/{t}" + return HTML(f"") + + self._add_named(name, render) + + def add_toggle(self, name: str, label: str, state: bool) -> None: + self._states[name] = (label, state) + + def render(): + l, s = self._states[name] + color = '#A3BE8C' if s else '#BF616A' + status = "ON" if s else "OFF" + text = f"{l}: {status}" + return HTML(f"") + + self._add_named(name, render) + + def update_toggle(self, name: str, state: bool) -> None: + if name in self._states: + label, _ = self._states[name] + self._states[name] = (label, state) + + def update_counter(self, name: str, current: Optional[int] = None, total: Optional[int] = None) -> None: + if name in self._states: + label, c, t = self._states[name] + self._states[name] = (label, current if current is not None else c, total if total is not None else t) + + def _add_named(self, name: str, render_fn: Callable[[], HTML]) -> None: + self._named_items[name] = render_fn + self._items = list(self._named_items.values()) + + def render(self): + return merge_formatted_text([fn() for fn in self._items]) diff --git a/menu/callbacks.py b/menu/callbacks.py index e7778be..f2c9abf 100644 --- a/menu/callbacks.py +++ b/menu/callbacks.py @@ -11,7 +11,7 @@ console = Console() setup_logging() logger = logging.getLogger("menu") -def retry(max_retries=3, delay=1, backoff=2, exceptions=(Exception,), logger=None, spinner_text=None): +def retry(max_retries=3, delay=1, backoff=2, exceptions=(Exception,)): def decorator(func): is_coroutine = inspect.iscoroutinefunction(func) @@ -22,8 +22,7 @@ def retry(max_retries=3, delay=1, backoff=2, exceptions=(Exception,), logger=Non if logger: logger.debug(f"Retrying {retries + 1}/{max_retries} for '{func.__name__}' after {current_delay}s due to '{exceptions}'.") try: - with console.status(spinner_text, spinner="dots"): - return await func(*args, **kwargs) + return await func(*args, **kwargs) except exceptions as e: if retries == max_retries: if logger: @@ -44,8 +43,7 @@ def retry(max_retries=3, delay=1, backoff=2, exceptions=(Exception,), logger=Non if logger: logger.debug(f"Retrying {retries + 1}/{max_retries} for '{func.__name__}' after {current_delay}s due to '{exceptions}'.") try: - with console.status(spinner_text, spinner="dots"): - return func(*args, **kwargs) + return func(*args, **kwargs) except exceptions as e: if retries == max_retries: if logger: @@ -62,7 +60,7 @@ def retry(max_retries=3, delay=1, backoff=2, exceptions=(Exception,), logger=Non return async_wrapper if is_coroutine else sync_wrapper return decorator -@retry(max_retries=10, delay=1, logger=logger, spinner_text="Trying risky thing...") +@retry(max_retries=10, delay=1, spinner_text="Trying risky thing...") def might_fail(): time.sleep(4) if random.random() < 0.6: diff --git a/menu/hook_manager.py b/menu/hook_manager.py new file mode 100644 index 0000000..564cc20 --- /dev/null +++ b/menu/hook_manager.py @@ -0,0 +1,66 @@ +"""hook_manager.py""" +from __future__ import annotations + +import inspect +import logging +from typing import (Any, Awaitable, Callable, Dict, List, Optional, TypedDict, + Union, TYPE_CHECKING) + +if TYPE_CHECKING: + from action import BaseAction + from option import Option + + +logger = logging.getLogger("menu") + + +class HookContext(TypedDict, total=False): + name: str + args: tuple[Any, ...] + kwargs: dict[str, Any] + result: Any | None + exception: Exception | None + option: Option | None + action: BaseAction | None + + +Hook = Union[Callable[[HookContext], None], Callable[[HookContext], Awaitable[None]]] + + +class HookManager: + def __init__(self): + self._hooks: Dict[str, List[Hook]] = { + "before": [], + "after": [], + "on_error": [], + "on_teardown": [], + } + + def register(self, hook_type: str, hook: Hook): + if hook_type not in self._hooks: + raise ValueError(f"Unsupported hook type: {hook_type}") + self._hooks[hook_type].append(hook) + + def clear(self, hook_type: Optional[str] = None): + if hook_type: + self._hooks[hook_type] = [] + else: + for k in self._hooks: + self._hooks[k] = [] + + async def trigger(self, hook_type: str, context: HookContext): + if hook_type not in self._hooks: + raise ValueError(f"Unsupported hook type: {hook_type}") + for hook in self._hooks[hook_type]: + try: + if inspect.iscoroutinefunction(hook): + await hook(context) + else: + hook(context) + except Exception as hook_error: + name = context.get("name", "") + logger.warning(f"⚠️ Hook '{hook.__name__}' raised an exception during '{hook_type}'" + f" for '{name}': {hook_error}") + + if hook_type == "on_error": + raise context.get("exception") from hook_error diff --git a/menu/hooks.py b/menu/hooks.py index 533feb8..6a53ab9 100644 --- a/menu/hooks.py +++ b/menu/hooks.py @@ -1,41 +1,40 @@ -import time +import functools import logging import random -import functools -from menu import Menu, Option +import time + +from hook_manager import HookContext +from menu_utils import run_async logger = logging.getLogger("menu") -def timing_before_hook(option: Option) -> None: - option._start_time = time.perf_counter() - -def timing_after_hook(option: Option) -> None: - option._end_time = time.perf_counter() - option._duration = option._end_time - option._start_time - - -def timing_error_hook(option: Option, _: Exception) -> None: - option._end_time = time.perf_counter() - option._duration = option._end_time - option._start_time - - -def log_before(option: Option) -> None: - logger.info(f"🚀 Starting action '{option.description}' (key='{option.key}')") - - -def log_after(option: Option) -> None: - if option._duration is not None: - logger.info(f"✅ Completed '{option.description}' (key='{option.key}') in {option._duration:.2f}s") +def log_before(context: dict) -> None: + name = context.get("name", "") + option = context.get("option") + if option: + logger.info(f"🚀 Starting action '{option.description}' (key='{option.key}')") else: - logger.info(f"✅ Completed '{option.description}' (key='{option.key}')") + logger.info(f"🚀 Starting action '{name}'") -def log_error(option: Option, error: Exception) -> None: - if option._duration is not None: - logger.error(f"❌ Error '{option.description}' (key='{option.key}') after {option._duration:.2f}s: {error}") +def log_after(context: dict) -> None: + name = context.get("name", "") + duration = context.get("duration") + if duration is not None: + logger.info(f"✅ Completed '{name}' in {duration:.2f}s") else: - logger.error(f"❌ Error '{option.description}' (key='{option.key}'): {error}") + logger.info(f"✅ Completed '{name}'") + + +def log_error(context: dict) -> None: + name = context.get("name", "") + error = context.get("exception") + duration = context.get("duration") + if duration is not None: + logger.error(f"❌ Error '{name}' after {duration:.2f}s: {error}") + else: + logger.error(f"❌ Error '{name}': {error}") class CircuitBreakerOpen(Exception): @@ -49,23 +48,25 @@ class CircuitBreaker: self.failures = 0 self.open_until = None - def before_hook(self, option: Option): + def before_hook(self, context: HookContext): + name = context.get("name", "") if self.open_until: if time.time() < self.open_until: - raise CircuitBreakerOpen(f"🔴 Circuit open for '{option.description}' until {time.ctime(self.open_until)}.") + raise CircuitBreakerOpen(f"🔴 Circuit open for '{name}' until {time.ctime(self.open_until)}.") else: - logger.info(f"🟢 Circuit closed again for '{option.description}'.") + logger.info(f"🟢 Circuit closed again for '{name}'.") self.failures = 0 self.open_until = None - def error_hook(self, option: Option, error: Exception): + def error_hook(self, context: HookContext): + name = context.get("name", "") self.failures += 1 - logger.warning(f"⚠️ CircuitBreaker: '{option.description}' failure {self.failures}/{self.max_failures}.") + logger.warning(f"⚠️ CircuitBreaker: '{name}' failure {self.failures}/{self.max_failures}.") if self.failures >= self.max_failures: self.open_until = time.time() + self.reset_timeout - logger.error(f"🔴 Circuit opened for '{option.description}' until {time.ctime(self.open_until)}.") + logger.error(f"🔴 Circuit opened for '{name}' until {time.ctime(self.open_until)}.") - def after_hook(self, option: Option): + def after_hook(self, context: HookContext): self.failures = 0 def is_open(self): @@ -78,33 +79,52 @@ class CircuitBreaker: class RetryHandler: - def __init__(self, max_retries=2, delay=1, backoff=2): + def __init__(self, max_retries=5, delay=1, backoff=2): self.max_retries = max_retries self.delay = delay self.backoff = backoff - def retry_on_error(self, option: Option, error: Exception): + def retry_on_error(self, context: HookContext): + name = context.get("name", "") + error = context.get("exception") + option = context.get("option") + action = context.get("action") + retries_done = 0 current_delay = self.delay last_error = error + if not (option or action): + logger.warning(f"⚠️ RetryHandler: No Option or Action in context for '{name}'. Skipping retry.") + return + + target = option or action + while retries_done < self.max_retries: try: retries_done += 1 - logger.info(f"🔄 Retrying '{option.description}' ({retries_done}/{self.max_retries}) in {current_delay}s due to '{error}'...") + logger.info(f"🔄 Retrying '{name}' ({retries_done}/{self.max_retries}) in {current_delay}s due to '{last_error}'...") time.sleep(current_delay) - result = option.action() - print(result) - option.set_result(result) - logger.info(f"✅ Retry succeeded for '{option.description}' on attempt {retries_done}.") - option.after_action.run_hooks(option) + result = target(*context.get("args", ()), **context.get("kwargs", {})) + if option: + option.set_result(result) + + context["result"] = result + context["duration"] = target.get_duration() or 0.0 + context.pop("exception", None) + + logger.info(f"✅ Retry succeeded for '{name}' on attempt {retries_done}.") + + if hasattr(target, "hooks"): + run_async(target.hooks.trigger("after", context)) + return except Exception as retry_error: - logger.warning(f"⚠️ Retry attempt {retries_done} for '{option.description}' failed due to '{retry_error}'.") + logger.warning(f"⚠️ Retry attempt {retries_done} for '{name}' failed due to '{retry_error}'.") last_error = retry_error current_delay *= self.backoff - logger.exception(f"❌ '{option.description}' failed after {self.max_retries} retries.") + logger.exception(f"❌ '{name}' failed after {self.max_retries} retries.") raise last_error @@ -133,15 +153,13 @@ def retry(max_retries=3, delay=1, backoff=2, exceptions=(Exception,), logger=Non def setup_hooks(menu): - menu.add_before(timing_before_hook) - menu.add_after(timing_after_hook) - menu.add_on_error(timing_error_hook) menu.add_before(log_before) menu.add_after(log_after) menu.add_on_error(log_error) if __name__ == "__main__": + from menu import Menu def risky_task(): if random.random() > 0.1: time.sleep(1) @@ -151,9 +169,6 @@ if __name__ == "__main__": retry_handler = RetryHandler(max_retries=30, delay=2, backoff=2) menu = Menu(never_confirm=True) - menu.add_before(timing_before_hook) - menu.add_after(timing_after_hook) - menu.add_on_error(timing_error_hook) menu.add_before(log_before) menu.add_after(log_after) menu.add_on_error(log_error) diff --git a/menu/logging_utils.py b/menu/logging_utils.py index 4e7d471..da87772 100644 --- a/menu/logging_utils.py +++ b/menu/logging_utils.py @@ -8,7 +8,7 @@ def setup_logging( ): """Set up logging configuration with separate console and file handlers.""" root_logger = logging.getLogger() - root_logger.setLevel(logging.WARNING) + root_logger.setLevel(logging.DEBUG) if root_logger.hasHandlers(): root_logger.handlers.clear() diff --git a/menu/menu.py b/menu/menu.py index 1865c77..40101f1 100644 --- a/menu/menu.py +++ b/menu/menu.py @@ -12,168 +12,34 @@ formatted and visually appealing way. This class also uses the `prompt_toolkit` library to handle user input and create an interactive experience. """ - +import asyncio import logging from functools import cached_property -from itertools import islice from typing import Any, Callable from prompt_toolkit import PromptSession from prompt_toolkit.completion import WordCompleter from prompt_toolkit.formatted_text import AnyFormattedText +from prompt_toolkit.key_binding import KeyBindings from prompt_toolkit.shortcuts import confirm from prompt_toolkit.validation import Validator -from pydantic import BaseModel, Field, field_validator, PrivateAttr from rich import box from rich.console import Console from rich.markdown import Markdown from rich.table import Table +from action import BaseAction +from bottom_bar import BottomBar from colors import get_nord_theme +from hook_manager import HookManager +from menu_utils import (CaseInsensitiveDict, InvalidActionError, MenuError, + NotAMenuError, OptionAlreadyExistsError, chunks, run_async) from one_colors import OneColors +from option import Option logger = logging.getLogger("menu") -def chunks(iterator, size): - """Yield successive n-sized chunks from an iterator.""" - iterator = iter(iterator) - while True: - chunk = list(islice(iterator, size)) - if not chunk: - break - yield chunk - - -class MenuError(Exception): - """Custom exception for the Menu class.""" - - -class OptionAlreadyExistsError(MenuError): - """Exception raised when an option with the same key already exists in the menu.""" - - -class InvalidHookError(MenuError): - """Exception raised when a hook is not callable.""" - - -class InvalidActionError(MenuError): - """Exception raised when an action is not callable.""" - - -class NotAMenuError(MenuError): - """Exception raised when the provided submenu is not an instance of Menu.""" - - -class CaseInsensitiveDict(dict): - """A case-insensitive dictionary that treats all keys as uppercase.""" - - def __setitem__(self, key, value): - super().__setitem__(key.upper(), value) - - def __getitem__(self, key): - return super().__getitem__(key.upper()) - - def __contains__(self, key): - return super().__contains__(key.upper()) - - def get(self, key, default=None): - return super().get(key.upper(), default) - - def pop(self, key, default=None): - return super().pop(key.upper(), default) - - def update(self, other=None, **kwargs): - if other: - other = {k.upper(): v for k, v in other.items()} - kwargs = {k.upper(): v for k, v in kwargs.items()} - super().update(other, **kwargs) - - -class Hooks(BaseModel): - """Class to manage hooks for the menu and options.""" - - hooks: list[Callable[["Option"], None]] | list[Callable[["Option", Exception], None]] = Field( - default_factory=list - ) - - @field_validator("hooks", mode="before") - @classmethod - def validate_hooks(cls, hooks): - if hooks is None: - return [] - if not all(callable(hook) for hook in hooks): - raise InvalidHookError("All hooks must be callable.") - return hooks - - def add_hook(self, hook: Callable[["Option"], None] | Callable[["Option", Exception], None]) -> None: - """Add a hook to the list.""" - if not callable(hook): - raise InvalidHookError("Hook must be a callable.") - if hook not in self.hooks: - self.hooks.append(hook) - - def run_hooks(self, *args, **kwargs) -> None: - """Run all hooks with the given arguments.""" - for hook in self.hooks: - try: - hook(*args, **kwargs) - except Exception as hook_error: - logger.exception(f"Hook '{hook.__name__}': {hook_error}") - - -class Option(BaseModel): - """Class representing an option in the menu. - - Hooks must have the signature: - def hook(option: Option) -> None: - where `option` is the selected option. - - Error hooks must have the signature: - def error_hook(option: Option, error: Exception) -> None: - where `option` is the selected option and `error` is the exception raised. - """ - - key: str - description: str - action: Callable[[], Any] = lambda: None - color: str = OneColors.WHITE - confirm: bool = False - confirm_message: str = "Are you sure?" - spinner: bool = False - spinner_message: str = "Processing..." - spinner_type: str = "dots" - spinner_style: str = OneColors.CYAN - spinner_kwargs: dict[str, Any] = Field(default_factory=dict) - - before_action: Hooks = Field(default_factory=Hooks) - after_action: Hooks = Field(default_factory=Hooks) - on_error: Hooks = Field(default_factory=Hooks) - - _start_time: float | None = PrivateAttr(default=None) - _end_time: float | None = PrivateAttr(default=None) - _duration: float | None = PrivateAttr(default=None) - - _result: Any | None = PrivateAttr(default=None) - - def __str__(self): - return f"Option(key='{self.key}', description='{self.description}')" - - def set_result(self, result: Any) -> None: - """Set the result of the action.""" - self._result = result - - def get_result(self) -> Any: - """Get the result of the action.""" - return self._result - - @field_validator("action") - def validate_action(cls, action): - if not callable(action): - raise InvalidActionError("Action must be a callable.") - return action - - class Menu: """Class to create a menu with options. @@ -218,21 +84,21 @@ class Menu: self.title: str | Markdown = title self.prompt: str | AnyFormattedText = prompt self.columns: int = columns - self.bottom_bar: str | Callable[[], None] | None = bottom_bar + self.bottom_bar: str | Callable[[], None] | None = bottom_bar or BottomBar(columns=columns) self.options: dict[str, Option] = CaseInsensitiveDict() self.back_option: Option = self._get_back_option() self.console: Console = Console(color_system="truecolor", theme=get_nord_theme()) - self.session: PromptSession = self._get_prompt_session() + #self.session: PromptSession = self._get_prompt_session() self.welcome_message: str | Markdown = welcome_message self.exit_message: str | Markdown = exit_message - self.before_action: Hooks = Hooks() - self.after_action: Hooks = Hooks() - self.on_error: Hooks = Hooks() + self.hooks: HookManager = HookManager() self.run_hooks_on_back_option: bool = run_hooks_on_back_option self.continue_on_error_prompt: bool = continue_on_error_prompt self._never_confirm: bool = never_confirm self._verbose: bool = _verbose self.last_run_option: Option | None = None + self.key_bindings: KeyBindings = KeyBindings() + self.toggles: dict[str, str] = {} def get_title(self) -> str: """Returns the string title of the menu.""" @@ -271,7 +137,8 @@ class Menu: self.session.validator = self._get_validator() self._invalidate_table_cache() - def _get_prompt_session(self) -> PromptSession: + @cached_property + def session(self) -> PromptSession: """Returns the prompt session for the menu.""" return PromptSession( message=self.prompt, @@ -279,35 +146,50 @@ class Menu: completer=self._get_completer(), reserve_space_for_menu=1, validator=self._get_validator(), - bottom_toolbar=self.bottom_bar, + bottom_toolbar=self.bottom_bar.render, ) - def add_before(self, hook: Callable[["Option"], None]) -> None: - """Adds a hook to be executed before the action of the menu.""" - self.before_action.add_hook(hook) + def add_toggle(self, key: str, label: str, state: bool = False): + if key in self.options or key in self.toggles: + raise ValueError(f"Key '{key}' is already in use.") - def add_after(self, hook: Callable[["Option"], None]) -> None: - """Adds a hook to be executed after the action of the menu.""" - self.after_action.add_hook(hook) + self.toggles[key] = label + self.bottom_bar.add_toggle(label, label, state) - def add_on_error(self, hook: Callable[["Option", Exception], None]) -> None: - """Adds a hook to be executed on error of the menu.""" - self.on_error.add_hook(hook) + @self.key_bindings.add(key) + def _(event): + current = self.bottom_bar._states[label][1] + self.bottom_bar.update_toggle(label, not current) + self.console.print(f"Toggled [{label}] to {'ON' if not current else 'OFF'}") + + def add_counter(self, name: str, label: str, current: int, total: int): + self.bottom_bar.add_counter(name, label, current, total) + + def update_counter(self, name: str, current: int | None = None, total: int | None = None): + self.bottom_bar.update_counter(name, current=current, total=total) + + def update_toggle(self, name: str, state: bool): + self.bottom_bar.update_toggle(name, state) def debug_hooks(self) -> None: if not self._verbose: return - logger.debug(f"Menu-level before hooks: {[hook.__name__ for hook in self.before_action.hooks]}") - logger.debug(f"Menu-level after hooks: {[hook.__name__ for hook in self.after_action.hooks]}") - logger.debug(f"Menu-level error hooks: {[hook.__name__ for hook in self.on_error.hooks]}") + + def hook_names(hook_list): + return [hook.__name__ for hook in hook_list] + + logger.debug(f"Menu-level before hooks: {hook_names(self.hooks._hooks['before'])}") + logger.debug(f"Menu-level after hooks: {hook_names(self.hooks._hooks['after'])}") + logger.debug(f"Menu-level error hooks: {hook_names(self.hooks._hooks['on_error'])}") + for key, option in self.options.items(): - logger.debug(f"[Option '{key}'] before: {[hook.__name__ for hook in option.before_action.hooks]}") - logger.debug(f"[Option '{key}'] after: {[hook.__name__ for hook in option.after_action.hooks]}") - logger.debug(f"[Option '{key}'] error: {[hook.__name__ for hook in option.on_error.hooks]}") + logger.debug(f"[Option '{key}'] before: {hook_names(option.hooks._hooks['before'])}") + logger.debug(f"[Option '{key}'] after: {hook_names(option.hooks._hooks['after'])}") + logger.debug(f"[Option '{key}'] error: {hook_names(option.hooks._hooks['on_error'])}") def _validate_option_key(self, key: str) -> None: """Validates the option key to ensure it is unique.""" - if key in self.options or key.upper() == self.back_option.key.upper(): + if key.upper() in self.options or key.upper() == self.back_option.key.upper(): raise OptionAlreadyExistsError(f"Option with key '{key}' already exists.") def update_back_option( @@ -350,7 +232,7 @@ class Menu: self, key: str, description: str, - action: Callable[[], Any], + action: BaseAction | Callable[[], Any], color: str = OneColors.WHITE, confirm: bool = False, confirm_message: str = "Are you sure?", @@ -358,15 +240,14 @@ class Menu: spinner_message: str = "Processing...", spinner_type: str = "dots", spinner_style: str = OneColors.CYAN, - spinner_kwargs: dict[str, Any] = None, - before_hooks: list[Callable[[Option], None]] = None, - after_hooks: list[Callable[[Option], None]] = None, - error_hooks: list[Callable[[Option, Exception], None]] = None, + spinner_kwargs: dict[str, Any] | None = None, + before_hooks: list[Callable] | None = None, + after_hooks: list[Callable] | None = None, + error_hooks: list[Callable] | None = None, ) -> Option: """Adds an option to the menu, preventing duplicates.""" + spinner_kwargs: dict[str, Any] = spinner_kwargs or {} self._validate_option_key(key) - if not spinner_kwargs: - spinner_kwargs = {} option = Option( key=key, description=description, @@ -379,10 +260,15 @@ class Menu: spinner_type=spinner_type, spinner_style=spinner_style, spinner_kwargs=spinner_kwargs, - before_action=Hooks(hooks=before_hooks), - after_action=Hooks(hooks=after_hooks), - on_error=Hooks(hooks=error_hooks), ) + + for hook in before_hooks or []: + option.hooks.register("before", hook) + for hook in after_hooks or []: + option.hooks.register("after", hook) + for hook in error_hooks or []: + option.hooks.register("on_error", hook) + self.options[key] = option self._refresh_session() return option @@ -405,15 +291,20 @@ class Menu: return self.back_option return self.options.get(choice) - def _should_hooks_run(self, selected_option: Option) -> bool: - """Determines if hooks should be run based on the selected option.""" - return selected_option != self.back_option or self.run_hooks_on_back_option - def _should_run_action(self, selected_option: Option) -> bool: if selected_option.confirm and not self._never_confirm: return confirm(selected_option.confirm_message) return True + def _create_context(self, selected_option: Option) -> dict[str, Any]: + """Creates a context dictionary for the selected option.""" + return { + "name": selected_option.description, + "option": selected_option, + "args": (), + "kwargs": {}, + } + def _run_action_with_spinner(self, option: Option) -> Any: """Runs the action of the selected option with a spinner.""" with self.console.status( @@ -422,15 +313,13 @@ class Menu: spinner_style=option.spinner_style, **option.spinner_kwargs, ): - return option.action() + return option() def _handle_action_error(self, selected_option: Option, error: Exception) -> bool: """Handles errors that occur during the action of the selected option.""" logger.exception(f"Error executing '{selected_option.description}': {error}") self.console.print(f"[{OneColors.DARK_RED}]An error occurred while executing " f"{selected_option.description}:[/] {error}") - selected_option.on_error.run_hooks(selected_option, error) - self.on_error.run_hooks(selected_option, error) if self.continue_on_error_prompt and not self._never_confirm: return confirm("An error occurred. Do you wish to continue?") if self._never_confirm: @@ -442,25 +331,38 @@ class Menu: choice = self.session.prompt() selected_option = self.get_option(choice) self.last_run_option = selected_option - should_hooks_run = self._should_hooks_run(selected_option) + + if selected_option == self.back_option: + logger.info(f"🔙 Back selected: exiting {self.get_title()}") + return False + if not self._should_run_action(selected_option): logger.info(f"[{OneColors.DARK_RED}] {selected_option.description} cancelled.") return True + + context = self._create_context(selected_option) + try: - if should_hooks_run: - self.before_action.run_hooks(selected_option) - selected_option.before_action.run_hooks(selected_option) + run_async(self.hooks.trigger("before", context)) + if selected_option.spinner: result = self._run_action_with_spinner(selected_option) else: - result = selected_option.action() + result = selected_option() + selected_option.set_result(result) - selected_option.after_action.run_hooks(selected_option) - if should_hooks_run: - self.after_action.run_hooks(selected_option) + context["result"] = result + context["duration"] = selected_option.get_duration() + run_async(self.hooks.trigger("after", context)) except Exception as error: + context["exception"] = error + context["duration"] = selected_option.get_duration() + run_async(self.hooks.trigger("on_error", context)) + if "exception" not in context: + logger.info(f"✅ Recovery hook handled error for '{selected_option.description}'") + return True return self._handle_action_error(selected_option, error) - return selected_option != self.back_option + return True def run_headless(self, option_key: str, never_confirm: bool | None = None) -> Any: """Runs the action of the selected option without displaying the menu.""" @@ -470,34 +372,45 @@ class Menu: selected_option = self.get_option(option_key) self.last_run_option = selected_option + if not selected_option: - raise MenuError(f"[Headless] Option '{option_key}' not found.") + logger.info("[Headless] Back option selected. Exiting menu.") + return logger.info(f"[Headless] 🚀 Running: '{selected_option.description}'") - should_hooks_run = self._should_hooks_run(selected_option) + if not self._should_run_action(selected_option): - logger.info(f"[Headless] ⛔ '{selected_option.description}' cancelled.") raise MenuError(f"[Headless] '{selected_option.description}' cancelled by confirmation.") + context = self._create_context(selected_option) + try: - if should_hooks_run: - self.before_action.run_hooks(selected_option) - selected_option.before_action.run_hooks(selected_option) + run_async(self.hooks.trigger("before", context)) + if selected_option.spinner: result = self._run_action_with_spinner(selected_option) else: - result = selected_option.action() + result = selected_option() + selected_option.set_result(result) - selected_option.after_action.run_hooks(selected_option) - if should_hooks_run: - self.after_action.run_hooks(selected_option) + context["result"] = result + context["duration"] = selected_option.get_duration() + + run_async(self.hooks.trigger("after", context)) logger.info(f"[Headless] ✅ '{selected_option.description}' complete.") except (KeyboardInterrupt, EOFError): raise MenuError(f"[Headless] ⚠️ '{selected_option.description}' interrupted by user.") except Exception as error: + context["exception"] = error + context["duration"] = selected_option.get_duration() + run_async(self.hooks.trigger("on_error", context)) + if "exception" not in context: + logger.info(f"[Headless] ✅ Recovery hook handled error for '{selected_option.description}'") + return True continue_running = self._handle_action_error(selected_option, error) if not continue_running: raise MenuError(f"[Headless] ❌ '{selected_option.description}' failed.") from error + return selected_option.get_result() def run(self) -> None: @@ -517,83 +430,3 @@ class Menu: logger.info(f"Exiting menu: {self.get_title()}") if self.exit_message: self.console.print(self.exit_message) - - -if __name__ == "__main__": - from rich.traceback import install - from logging_utils import setup_logging - - install(show_locals=True) - setup_logging() - - def say_hello(): - print("Hello!") - - def say_goodbye(): - print("Goodbye!") - - def say_nested(): - print("This is a nested menu!") - - def my_action(): - print("This is my action!") - - def long_running_task(): - import time - - time.sleep(5) - - nested_menu = Menu( - Markdown("## Nested Menu", style=OneColors.DARK_YELLOW), - columns=2, - bottom_bar="Menu within a menu", - ) - nested_menu.add_option("1", "Say Nested", say_nested, color=OneColors.MAGENTA) - nested_menu.add_before(lambda opt: logger.info(f"Global BEFORE '{opt.description}'")) - nested_menu.add_after(lambda opt: logger.info(f"Global AFTER '{opt.description}'")) - - nested_menu.add_option( - "2", - "Test Action", - action=my_action, - before_hooks=[lambda opt: logger.info(f"Option-specific BEFORE '{opt.description}'")], - after_hooks=[lambda opt: logger.info(f"Option-specific AFTER '{opt.description}'")], - ) - - def bottom_bar(): - return ( - f"Press Q to quit | Options available: {', '.join([f'[{key}]' for key in menu.options.keys()])}" - ) - - welcome_message = Markdown("# Welcome to the Menu!") - exit_message = Markdown("# Thank you for using the menu!") - menu = Menu( - Markdown("## Main Menu", style=OneColors.CYAN), - columns=3, - bottom_bar=bottom_bar, - welcome_message=welcome_message, - exit_message=exit_message, - ) - menu.add_option("1", "Say Hello", say_hello, color=OneColors.GREEN) - menu.add_option("2", "Say Goodbye", say_goodbye, color=OneColors.LIGHT_RED) - menu.add_option("3", "Do Nothing", lambda: None, color=OneColors.BLUE) - menu.add_submenu("4", "Nested Menu", nested_menu, color=OneColors.MAGENTA) - menu.add_option("5", "Do Nothing", lambda: None, color=OneColors.BLUE) - menu.add_option( - "6", - "Long Running Task", - action=long_running_task, - spinner=True, - spinner_message="Working, please wait...", - spinner_type="moon", - spinner_style=OneColors.GREEN, - spinner_kwargs={"speed": 0.7}, - ) - - menu.update_back_option("Q", "Quit", color=OneColors.DARK_RED) - - try: - menu.run() - except EOFError as error: - logger.exception("EOFError: Exiting program.", exc_info=error) - print("Exiting program.") diff --git a/menu/menu_utils.py b/menu/menu_utils.py new file mode 100644 index 0000000..4bd2715 --- /dev/null +++ b/menu/menu_utils.py @@ -0,0 +1,79 @@ +import asyncio +import time +from itertools import islice + + +def chunks(iterator, size): + """Yield successive n-sized chunks from an iterator.""" + iterator = iter(iterator) + while True: + chunk = list(islice(iterator, size)) + if not chunk: + break + yield chunk + + +def run_async(coro): + """Run an async function in a synchronous context.""" + try: + _ = asyncio.get_running_loop() + return asyncio.create_task(coro) + except RuntimeError: + return asyncio.run(coro) + + +class TimingMixin: + def _start_timer(self): + self.start_time = time.perf_counter() + + def _stop_timer(self): + self.end_time = time.perf_counter() + self._duration = self.end_time - self.start_time + + def get_duration(self) -> float | None: + return getattr(self, "_duration", None) + + +class MenuError(Exception): + """Custom exception for the Menu class.""" + + +class OptionAlreadyExistsError(MenuError): + """Exception raised when an option with the same key already exists in the menu.""" + + +class InvalidHookError(MenuError): + """Exception raised when a hook is not callable.""" + + +class InvalidActionError(MenuError): + """Exception raised when an action is not callable.""" + + +class NotAMenuError(MenuError): + """Exception raised when the provided submenu is not an instance of Menu.""" + + +class CaseInsensitiveDict(dict): + """A case-insensitive dictionary that treats all keys as uppercase.""" + + def __setitem__(self, key, value): + super().__setitem__(key.upper(), value) + + def __getitem__(self, key): + return super().__getitem__(key.upper()) + + def __contains__(self, key): + return super().__contains__(key.upper()) + + def get(self, key, default=None): + return super().get(key.upper(), default) + + def pop(self, key, default=None): + return super().pop(key.upper(), default) + + def update(self, other=None, **kwargs): + if other: + other = {k.upper(): v for k, v in other.items()} + kwargs = {k.upper(): v for k, v in kwargs.items()} + super().update(other, **kwargs) diff --git a/menu/option.py b/menu/option.py new file mode 100644 index 0000000..1f60196 --- /dev/null +++ b/menu/option.py @@ -0,0 +1,110 @@ +"""option.py +Any Action or Option is callable and supports the signature: + result = thing(*args, **kwargs) + +This guarantees: +- Hook lifecycle (before/after/error/teardown) +- Timing +- Consistent return values +""" +from __future__ import annotations + +import asyncio +import logging +from typing import Any, Callable + +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr + +from action import BaseAction +from colors import OneColors +from hook_manager import HookManager +from menu_utils import TimingMixin, run_async + +logger = logging.getLogger("menu") + + +class Option(BaseModel, TimingMixin): + """Class representing an option in the menu. + + Hooks must have the signature: + def hook(option: Option) -> None: + where `option` is the selected option. + + Error hooks must have the signature: + def error_hook(option: Option, error: Exception) -> None: + where `option` is the selected option and `error` is the exception raised. + """ + key: str + description: str + action: BaseAction | Callable[[], Any] = lambda: None + color: str = OneColors.WHITE + confirm: bool = False + confirm_message: str = "Are you sure?" + spinner: bool = False + spinner_message: str = "Processing..." + spinner_type: str = "dots" + spinner_style: str = OneColors.CYAN + spinner_kwargs: dict[str, Any] = Field(default_factory=dict) + + hooks: "HookManager" = Field(default_factory=HookManager) + + start_time: float | None = None + end_time: float | None = None + _duration: float | None = PrivateAttr(default=None) + _result: Any | None = PrivateAttr(default=None) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def __str__(self): + return f"Option(key='{self.key}', description='{self.description}')" + + def set_result(self, result: Any) -> None: + """Set the result of the action.""" + self._result = result + + def get_result(self) -> Any: + """Get the result of the action.""" + return self._result + + def __call__(self, *args, **kwargs) -> Any: + context = { + "name": self.description, + "duration": None, + "args": args, + "kwargs": kwargs, + "option": self, + } + self._start_timer() + try: + run_async(self.hooks.trigger("before", context)) + result = self._execute_action(*args, **kwargs) + self.set_result(result) + context["result"] = result + return result + except Exception as error: + context["exception"] = error + run_async(self.hooks.trigger("on_error", context)) + if "exception" not in context: + logger.info(f"✅ Recovery hook handled error for Option '{self.key}'") + return self.get_result() + raise + finally: + self._stop_timer() + context["duration"] = self.get_duration() + if "exception" not in context: + run_async(self.hooks.trigger("after", context)) + run_async(self.hooks.trigger("on_teardown", context)) + + def _execute_action(self, *args, **kwargs) -> Any: + if isinstance(self.action, BaseAction): + return self.action(*args, **kwargs) + return self.action() + + def dry_run(self): + print(f"[DRY RUN] Option '{self.key}' would run: {self.description}") + if isinstance(self.action, BaseAction): + self.action.dry_run() + elif callable(self.action): + print(f"[DRY RUN] Action is a raw callable: {self.action.__name__}") + else: + print("[DRY RUN] Action is not callable.") \ No newline at end of file diff --git a/menu/task.py b/menu/task.py index 4c3b360..d71d67a 100644 --- a/menu/task.py +++ b/menu/task.py @@ -3,7 +3,7 @@ import time def risky_task() -> str: - if random.random() > 0.25: + if random.random() > 0.4: time.sleep(1) raise ValueError("Random failure occurred") return "Task succeeded!"