diff --git a/menu/action.py b/menu/action.py
new file mode 100644
index 0000000..227ae97
--- /dev/null
+++ b/menu/action.py
@@ -0,0 +1,218 @@
+"""action.py
+
+Any Action or Option is callable and supports the signature:
+ result = thing(*args, **kwargs)
+
+This guarantees:
+- Hook lifecycle (before/after/error/teardown)
+- Timing
+- Consistent return values
+"""
+from __future__ import annotations
+
+import asyncio
+import logging
+import time
+import inspect
+from abc import ABC, abstractmethod
+from typing import Optional
+
+from hook_manager import HookManager
+from menu_utils import TimingMixin, run_async
+
+
+logger = logging.getLogger("menu")
+
+
+class BaseAction(ABC, TimingMixin):
+ """Base class for actions. They are the building blocks of the menu.
+ Actions can be simple functions or more complex actions like
+ `ChainedAction` or `ActionGroup`. They can also be run independently
+ or as part of a menu."""
+ def __init__(self, name: str, hooks: Optional[HookManager] = None):
+ self.name = name
+ self.hooks = hooks or HookManager()
+ self.start_time: float | None = None
+ self.end_time: float | None = None
+ self._duration: float | None = None
+
+ def __call__(self, *args, **kwargs):
+ context = {
+ "name": self.name,
+ "duration": None,
+ "args": args,
+ "kwargs": kwargs,
+ "action": self
+ }
+ self._start_timer()
+ try:
+ run_async(self.hooks.trigger("before", context))
+ result = self._run(*args, **kwargs)
+ context["result"] = result
+ return result
+ except Exception as error:
+ context["exception"] = error
+ run_async(self.hooks.trigger("on_error", context))
+ if "exception" not in context:
+ logger.info(f"✅ Recovery hook handled error for Action '{self.name}'")
+ return context.get("result")
+ raise
+ finally:
+ self._stop_timer()
+ context["duration"] = self.get_duration()
+ if "exception" not in context:
+ run_async(self.hooks.trigger("after", context))
+ run_async(self.hooks.trigger("on_teardown", context))
+
+ @abstractmethod
+ def _run(self, *args, **kwargs):
+ raise NotImplementedError("_run must be implemented by subclasses")
+
+ async def run_async(self, *args, **kwargs):
+ if inspect.iscoroutinefunction(self._run):
+ return await self._run(*args, **kwargs)
+
+ return await asyncio.to_thread(self.__call__, *args, **kwargs)
+
+ def __await__(self):
+ return self.run_async().__await__()
+
+ @abstractmethod
+ def dry_run(self):
+ raise NotImplementedError("dry_run must be implemented by subclasses")
+
+ def __str__(self):
+ return f"<{self.__class__.__name__} '{self.name}'>"
+
+ def __repr__(self):
+ return str(self)
+
+
+class Action(BaseAction):
+ def __init__(self, name: str, fn, rollback=None, hooks=None):
+ super().__init__(name, hooks)
+ self.fn = fn
+ self.rollback = rollback
+
+ def _run(self, *args, **kwargs):
+ if inspect.iscoroutinefunction(self.fn):
+ return asyncio.run(self.fn(*args, **kwargs))
+ return self.fn(*args, **kwargs)
+
+ def dry_run(self):
+ print(f"[DRY RUN] Would run: {self.name}")
+
+
+class ChainedAction(BaseAction):
+ def __init__(self, name: str, actions: list[BaseAction], hooks=None):
+ super().__init__(name, hooks)
+ self.actions = actions
+
+ def _run(self, *args, **kwargs):
+ rollback_stack = []
+ for action in self.actions:
+ try:
+ result = action(*args, **kwargs)
+ rollback_stack.append(action)
+ except Exception:
+ self._rollback(rollback_stack, *args, **kwargs)
+ raise
+ return None
+
+ def dry_run(self):
+ print(f"[DRY RUN] ChainedAction '{self.name}' with steps:")
+ for action in self.actions:
+ action.dry_run()
+
+ def _rollback(self, rollback_stack, *args, **kwargs):
+ for action in reversed(rollback_stack):
+ if hasattr(action, "rollback") and action.rollback:
+ try:
+ print(f"↩️ Rolling back {action.name}")
+ action.rollback(*args, **kwargs)
+ except Exception as e:
+ print(f"⚠️ Rollback failed for {action.name}: {e}")
+
+
+class ActionGroup(BaseAction):
+ def __init__(self, name: str, actions: list[BaseAction], hooks=None):
+ super().__init__(name, hooks)
+ self.actions = actions
+ self.results = []
+ self.errors = []
+
+ def _run(self, *args, **kwargs):
+ asyncio.run(self._run_async(*args, **kwargs))
+
+ def dry_run(self):
+ print(f"[DRY RUN] ActionGroup '{self.name}' (parallel execution):")
+ for action in self.actions:
+ action.dry_run()
+
+ async def _run_async(self, *args, **kwargs):
+ async def run(action):
+ try:
+ result = await asyncio.to_thread(action, *args, **kwargs)
+ self.results.append((action.name, result))
+ except Exception as e:
+ self.errors.append((action.name, e))
+
+ await self.hooks.trigger("before", name=self.name)
+
+ await asyncio.gather(*[run(a) for a in self.actions])
+
+ if self.errors:
+ await self.hooks.trigger("on_error", name=self.name, errors=self.errors)
+ else:
+ await self.hooks.trigger("after", name=self.name, results=self.results)
+
+ await self.hooks.trigger("on_teardown", name=self.name)
+
+
+
+# if __name__ == "__main__":
+# # Example usage
+# def build(): print("Build!")
+# def test(): print("Test!")
+# def deploy(): print("Deploy!")
+
+
+# pipeline = ChainedAction("CI/CD", [
+# Action("Build", build),
+# Action("Test", test),
+# ActionGroup("Deploy Parallel", [
+# Action("Deploy A", deploy),
+# Action("Deploy B", deploy)
+# ])
+# ])
+
+# pipeline()
+# Sample functions
+def sync_hello():
+ time.sleep(1)
+ return "Hello from sync function"
+
+async def async_hello():
+ await asyncio.sleep(1)
+ return "Hello from async function"
+
+
+# Example usage
+async def main():
+ sync_action = Action("sync_hello", sync_hello)
+ async_action = Action("async_hello", async_hello)
+
+ print("⏳ Awaiting sync action...")
+ result1 = await sync_action
+ print("✅", result1)
+
+ print("⏳ Awaiting async action...")
+ result2 = await async_action
+ print("✅", result2)
+
+ print(f"⏱️ sync took {sync_action.get_duration():.2f}s")
+ print(f"⏱️ async took {async_action.get_duration():.2f}s")
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/menu/bottom_bar.py b/menu/bottom_bar.py
new file mode 100644
index 0000000..002e351
--- /dev/null
+++ b/menu/bottom_bar.py
@@ -0,0 +1,59 @@
+from prompt_toolkit.formatted_text import HTML, merge_formatted_text
+from typing import Callable, Literal, Optional
+from rich.console import Console
+
+
+class BottomBar:
+ def __init__(self, columns: int = 3):
+ self.columns = columns
+ self.console = Console()
+ self._items: list[Callable[[], HTML]] = []
+ self._named_items: dict[str, Callable[[], HTML]] = {}
+ self._states: dict[str, any] = {}
+
+ def get_space(self) -> str:
+ return self.console.width // self.columns
+
+ def add_static(self, name: str, text: str) -> None:
+ def render():
+ return HTML(f"")
+ self._add_named(name, render)
+
+ def add_counter(self, name: str, label: str, current: int, total: int) -> None:
+ self._states[name] = (label, current, total)
+
+ def render():
+ l, c, t = self._states[name]
+ text = f"{l}: {c}/{t}"
+ return HTML(f"")
+
+ self._add_named(name, render)
+
+ def add_toggle(self, name: str, label: str, state: bool) -> None:
+ self._states[name] = (label, state)
+
+ def render():
+ l, s = self._states[name]
+ color = '#A3BE8C' if s else '#BF616A'
+ status = "ON" if s else "OFF"
+ text = f"{l}: {status}"
+ return HTML(f"")
+
+ self._add_named(name, render)
+
+ def update_toggle(self, name: str, state: bool) -> None:
+ if name in self._states:
+ label, _ = self._states[name]
+ self._states[name] = (label, state)
+
+ def update_counter(self, name: str, current: Optional[int] = None, total: Optional[int] = None) -> None:
+ if name in self._states:
+ label, c, t = self._states[name]
+ self._states[name] = (label, current if current is not None else c, total if total is not None else t)
+
+ def _add_named(self, name: str, render_fn: Callable[[], HTML]) -> None:
+ self._named_items[name] = render_fn
+ self._items = list(self._named_items.values())
+
+ def render(self):
+ return merge_formatted_text([fn() for fn in self._items])
diff --git a/menu/callbacks.py b/menu/callbacks.py
index e7778be..f2c9abf 100644
--- a/menu/callbacks.py
+++ b/menu/callbacks.py
@@ -11,7 +11,7 @@ console = Console()
setup_logging()
logger = logging.getLogger("menu")
-def retry(max_retries=3, delay=1, backoff=2, exceptions=(Exception,), logger=None, spinner_text=None):
+def retry(max_retries=3, delay=1, backoff=2, exceptions=(Exception,)):
def decorator(func):
is_coroutine = inspect.iscoroutinefunction(func)
@@ -22,8 +22,7 @@ def retry(max_retries=3, delay=1, backoff=2, exceptions=(Exception,), logger=Non
if logger:
logger.debug(f"Retrying {retries + 1}/{max_retries} for '{func.__name__}' after {current_delay}s due to '{exceptions}'.")
try:
- with console.status(spinner_text, spinner="dots"):
- return await func(*args, **kwargs)
+ return await func(*args, **kwargs)
except exceptions as e:
if retries == max_retries:
if logger:
@@ -44,8 +43,7 @@ def retry(max_retries=3, delay=1, backoff=2, exceptions=(Exception,), logger=Non
if logger:
logger.debug(f"Retrying {retries + 1}/{max_retries} for '{func.__name__}' after {current_delay}s due to '{exceptions}'.")
try:
- with console.status(spinner_text, spinner="dots"):
- return func(*args, **kwargs)
+ return func(*args, **kwargs)
except exceptions as e:
if retries == max_retries:
if logger:
@@ -62,7 +60,7 @@ def retry(max_retries=3, delay=1, backoff=2, exceptions=(Exception,), logger=Non
return async_wrapper if is_coroutine else sync_wrapper
return decorator
-@retry(max_retries=10, delay=1, logger=logger, spinner_text="Trying risky thing...")
+@retry(max_retries=10, delay=1, spinner_text="Trying risky thing...")
def might_fail():
time.sleep(4)
if random.random() < 0.6:
diff --git a/menu/hook_manager.py b/menu/hook_manager.py
new file mode 100644
index 0000000..564cc20
--- /dev/null
+++ b/menu/hook_manager.py
@@ -0,0 +1,66 @@
+"""hook_manager.py"""
+from __future__ import annotations
+
+import inspect
+import logging
+from typing import (Any, Awaitable, Callable, Dict, List, Optional, TypedDict,
+ Union, TYPE_CHECKING)
+
+if TYPE_CHECKING:
+ from action import BaseAction
+ from option import Option
+
+
+logger = logging.getLogger("menu")
+
+
+class HookContext(TypedDict, total=False):
+ name: str
+ args: tuple[Any, ...]
+ kwargs: dict[str, Any]
+ result: Any | None
+ exception: Exception | None
+ option: Option | None
+ action: BaseAction | None
+
+
+Hook = Union[Callable[[HookContext], None], Callable[[HookContext], Awaitable[None]]]
+
+
+class HookManager:
+ def __init__(self):
+ self._hooks: Dict[str, List[Hook]] = {
+ "before": [],
+ "after": [],
+ "on_error": [],
+ "on_teardown": [],
+ }
+
+ def register(self, hook_type: str, hook: Hook):
+ if hook_type not in self._hooks:
+ raise ValueError(f"Unsupported hook type: {hook_type}")
+ self._hooks[hook_type].append(hook)
+
+ def clear(self, hook_type: Optional[str] = None):
+ if hook_type:
+ self._hooks[hook_type] = []
+ else:
+ for k in self._hooks:
+ self._hooks[k] = []
+
+ async def trigger(self, hook_type: str, context: HookContext):
+ if hook_type not in self._hooks:
+ raise ValueError(f"Unsupported hook type: {hook_type}")
+ for hook in self._hooks[hook_type]:
+ try:
+ if inspect.iscoroutinefunction(hook):
+ await hook(context)
+ else:
+ hook(context)
+ except Exception as hook_error:
+ name = context.get("name", "")
+ logger.warning(f"⚠️ Hook '{hook.__name__}' raised an exception during '{hook_type}'"
+ f" for '{name}': {hook_error}")
+
+ if hook_type == "on_error":
+ raise context.get("exception") from hook_error
diff --git a/menu/hooks.py b/menu/hooks.py
index 533feb8..6a53ab9 100644
--- a/menu/hooks.py
+++ b/menu/hooks.py
@@ -1,41 +1,40 @@
-import time
+import functools
import logging
import random
-import functools
-from menu import Menu, Option
+import time
+
+from hook_manager import HookContext
+from menu_utils import run_async
logger = logging.getLogger("menu")
-def timing_before_hook(option: Option) -> None:
- option._start_time = time.perf_counter()
-
-def timing_after_hook(option: Option) -> None:
- option._end_time = time.perf_counter()
- option._duration = option._end_time - option._start_time
-
-
-def timing_error_hook(option: Option, _: Exception) -> None:
- option._end_time = time.perf_counter()
- option._duration = option._end_time - option._start_time
-
-
-def log_before(option: Option) -> None:
- logger.info(f"🚀 Starting action '{option.description}' (key='{option.key}')")
-
-
-def log_after(option: Option) -> None:
- if option._duration is not None:
- logger.info(f"✅ Completed '{option.description}' (key='{option.key}') in {option._duration:.2f}s")
+def log_before(context: dict) -> None:
+ name = context.get("name", "")
+ option = context.get("option")
+ if option:
+ logger.info(f"🚀 Starting action '{option.description}' (key='{option.key}')")
else:
- logger.info(f"✅ Completed '{option.description}' (key='{option.key}')")
+ logger.info(f"🚀 Starting action '{name}'")
-def log_error(option: Option, error: Exception) -> None:
- if option._duration is not None:
- logger.error(f"❌ Error '{option.description}' (key='{option.key}') after {option._duration:.2f}s: {error}")
+def log_after(context: dict) -> None:
+ name = context.get("name", "")
+ duration = context.get("duration")
+ if duration is not None:
+ logger.info(f"✅ Completed '{name}' in {duration:.2f}s")
else:
- logger.error(f"❌ Error '{option.description}' (key='{option.key}'): {error}")
+ logger.info(f"✅ Completed '{name}'")
+
+
+def log_error(context: dict) -> None:
+ name = context.get("name", "")
+ error = context.get("exception")
+ duration = context.get("duration")
+ if duration is not None:
+ logger.error(f"❌ Error '{name}' after {duration:.2f}s: {error}")
+ else:
+ logger.error(f"❌ Error '{name}': {error}")
class CircuitBreakerOpen(Exception):
@@ -49,23 +48,25 @@ class CircuitBreaker:
self.failures = 0
self.open_until = None
- def before_hook(self, option: Option):
+ def before_hook(self, context: HookContext):
+ name = context.get("name", "")
if self.open_until:
if time.time() < self.open_until:
- raise CircuitBreakerOpen(f"🔴 Circuit open for '{option.description}' until {time.ctime(self.open_until)}.")
+ raise CircuitBreakerOpen(f"🔴 Circuit open for '{name}' until {time.ctime(self.open_until)}.")
else:
- logger.info(f"🟢 Circuit closed again for '{option.description}'.")
+ logger.info(f"🟢 Circuit closed again for '{name}'.")
self.failures = 0
self.open_until = None
- def error_hook(self, option: Option, error: Exception):
+ def error_hook(self, context: HookContext):
+ name = context.get("name", "")
self.failures += 1
- logger.warning(f"⚠️ CircuitBreaker: '{option.description}' failure {self.failures}/{self.max_failures}.")
+ logger.warning(f"⚠️ CircuitBreaker: '{name}' failure {self.failures}/{self.max_failures}.")
if self.failures >= self.max_failures:
self.open_until = time.time() + self.reset_timeout
- logger.error(f"🔴 Circuit opened for '{option.description}' until {time.ctime(self.open_until)}.")
+ logger.error(f"🔴 Circuit opened for '{name}' until {time.ctime(self.open_until)}.")
- def after_hook(self, option: Option):
+ def after_hook(self, context: HookContext):
self.failures = 0
def is_open(self):
@@ -78,33 +79,52 @@ class CircuitBreaker:
class RetryHandler:
- def __init__(self, max_retries=2, delay=1, backoff=2):
+ def __init__(self, max_retries=5, delay=1, backoff=2):
self.max_retries = max_retries
self.delay = delay
self.backoff = backoff
- def retry_on_error(self, option: Option, error: Exception):
+ def retry_on_error(self, context: HookContext):
+ name = context.get("name", "")
+ error = context.get("exception")
+ option = context.get("option")
+ action = context.get("action")
+
retries_done = 0
current_delay = self.delay
last_error = error
+ if not (option or action):
+ logger.warning(f"⚠️ RetryHandler: No Option or Action in context for '{name}'. Skipping retry.")
+ return
+
+ target = option or action
+
while retries_done < self.max_retries:
try:
retries_done += 1
- logger.info(f"🔄 Retrying '{option.description}' ({retries_done}/{self.max_retries}) in {current_delay}s due to '{error}'...")
+ logger.info(f"🔄 Retrying '{name}' ({retries_done}/{self.max_retries}) in {current_delay}s due to '{last_error}'...")
time.sleep(current_delay)
- result = option.action()
- print(result)
- option.set_result(result)
- logger.info(f"✅ Retry succeeded for '{option.description}' on attempt {retries_done}.")
- option.after_action.run_hooks(option)
+ result = target(*context.get("args", ()), **context.get("kwargs", {}))
+ if option:
+ option.set_result(result)
+
+ context["result"] = result
+ context["duration"] = target.get_duration() or 0.0
+ context.pop("exception", None)
+
+ logger.info(f"✅ Retry succeeded for '{name}' on attempt {retries_done}.")
+
+ if hasattr(target, "hooks"):
+ run_async(target.hooks.trigger("after", context))
+
return
except Exception as retry_error:
- logger.warning(f"⚠️ Retry attempt {retries_done} for '{option.description}' failed due to '{retry_error}'.")
+ logger.warning(f"⚠️ Retry attempt {retries_done} for '{name}' failed due to '{retry_error}'.")
last_error = retry_error
current_delay *= self.backoff
- logger.exception(f"❌ '{option.description}' failed after {self.max_retries} retries.")
+ logger.exception(f"❌ '{name}' failed after {self.max_retries} retries.")
raise last_error
@@ -133,15 +153,13 @@ def retry(max_retries=3, delay=1, backoff=2, exceptions=(Exception,), logger=Non
def setup_hooks(menu):
- menu.add_before(timing_before_hook)
- menu.add_after(timing_after_hook)
- menu.add_on_error(timing_error_hook)
menu.add_before(log_before)
menu.add_after(log_after)
menu.add_on_error(log_error)
if __name__ == "__main__":
+ from menu import Menu
def risky_task():
if random.random() > 0.1:
time.sleep(1)
@@ -151,9 +169,6 @@ if __name__ == "__main__":
retry_handler = RetryHandler(max_retries=30, delay=2, backoff=2)
menu = Menu(never_confirm=True)
- menu.add_before(timing_before_hook)
- menu.add_after(timing_after_hook)
- menu.add_on_error(timing_error_hook)
menu.add_before(log_before)
menu.add_after(log_after)
menu.add_on_error(log_error)
diff --git a/menu/logging_utils.py b/menu/logging_utils.py
index 4e7d471..da87772 100644
--- a/menu/logging_utils.py
+++ b/menu/logging_utils.py
@@ -8,7 +8,7 @@ def setup_logging(
):
"""Set up logging configuration with separate console and file handlers."""
root_logger = logging.getLogger()
- root_logger.setLevel(logging.WARNING)
+ root_logger.setLevel(logging.DEBUG)
if root_logger.hasHandlers():
root_logger.handlers.clear()
diff --git a/menu/menu.py b/menu/menu.py
index 1865c77..40101f1 100644
--- a/menu/menu.py
+++ b/menu/menu.py
@@ -12,168 +12,34 @@ formatted and visually appealing way.
This class also uses the `prompt_toolkit` library to handle
user input and create an interactive experience.
"""
-
+import asyncio
import logging
from functools import cached_property
-from itertools import islice
from typing import Any, Callable
from prompt_toolkit import PromptSession
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.formatted_text import AnyFormattedText
+from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit.shortcuts import confirm
from prompt_toolkit.validation import Validator
-from pydantic import BaseModel, Field, field_validator, PrivateAttr
from rich import box
from rich.console import Console
from rich.markdown import Markdown
from rich.table import Table
+from action import BaseAction
+from bottom_bar import BottomBar
from colors import get_nord_theme
+from hook_manager import HookManager
+from menu_utils import (CaseInsensitiveDict, InvalidActionError, MenuError,
+ NotAMenuError, OptionAlreadyExistsError, chunks, run_async)
from one_colors import OneColors
+from option import Option
logger = logging.getLogger("menu")
-def chunks(iterator, size):
- """Yield successive n-sized chunks from an iterator."""
- iterator = iter(iterator)
- while True:
- chunk = list(islice(iterator, size))
- if not chunk:
- break
- yield chunk
-
-
-class MenuError(Exception):
- """Custom exception for the Menu class."""
-
-
-class OptionAlreadyExistsError(MenuError):
- """Exception raised when an option with the same key already exists in the menu."""
-
-
-class InvalidHookError(MenuError):
- """Exception raised when a hook is not callable."""
-
-
-class InvalidActionError(MenuError):
- """Exception raised when an action is not callable."""
-
-
-class NotAMenuError(MenuError):
- """Exception raised when the provided submenu is not an instance of Menu."""
-
-
-class CaseInsensitiveDict(dict):
- """A case-insensitive dictionary that treats all keys as uppercase."""
-
- def __setitem__(self, key, value):
- super().__setitem__(key.upper(), value)
-
- def __getitem__(self, key):
- return super().__getitem__(key.upper())
-
- def __contains__(self, key):
- return super().__contains__(key.upper())
-
- def get(self, key, default=None):
- return super().get(key.upper(), default)
-
- def pop(self, key, default=None):
- return super().pop(key.upper(), default)
-
- def update(self, other=None, **kwargs):
- if other:
- other = {k.upper(): v for k, v in other.items()}
- kwargs = {k.upper(): v for k, v in kwargs.items()}
- super().update(other, **kwargs)
-
-
-class Hooks(BaseModel):
- """Class to manage hooks for the menu and options."""
-
- hooks: list[Callable[["Option"], None]] | list[Callable[["Option", Exception], None]] = Field(
- default_factory=list
- )
-
- @field_validator("hooks", mode="before")
- @classmethod
- def validate_hooks(cls, hooks):
- if hooks is None:
- return []
- if not all(callable(hook) for hook in hooks):
- raise InvalidHookError("All hooks must be callable.")
- return hooks
-
- def add_hook(self, hook: Callable[["Option"], None] | Callable[["Option", Exception], None]) -> None:
- """Add a hook to the list."""
- if not callable(hook):
- raise InvalidHookError("Hook must be a callable.")
- if hook not in self.hooks:
- self.hooks.append(hook)
-
- def run_hooks(self, *args, **kwargs) -> None:
- """Run all hooks with the given arguments."""
- for hook in self.hooks:
- try:
- hook(*args, **kwargs)
- except Exception as hook_error:
- logger.exception(f"Hook '{hook.__name__}': {hook_error}")
-
-
-class Option(BaseModel):
- """Class representing an option in the menu.
-
- Hooks must have the signature:
- def hook(option: Option) -> None:
- where `option` is the selected option.
-
- Error hooks must have the signature:
- def error_hook(option: Option, error: Exception) -> None:
- where `option` is the selected option and `error` is the exception raised.
- """
-
- key: str
- description: str
- action: Callable[[], Any] = lambda: None
- color: str = OneColors.WHITE
- confirm: bool = False
- confirm_message: str = "Are you sure?"
- spinner: bool = False
- spinner_message: str = "Processing..."
- spinner_type: str = "dots"
- spinner_style: str = OneColors.CYAN
- spinner_kwargs: dict[str, Any] = Field(default_factory=dict)
-
- before_action: Hooks = Field(default_factory=Hooks)
- after_action: Hooks = Field(default_factory=Hooks)
- on_error: Hooks = Field(default_factory=Hooks)
-
- _start_time: float | None = PrivateAttr(default=None)
- _end_time: float | None = PrivateAttr(default=None)
- _duration: float | None = PrivateAttr(default=None)
-
- _result: Any | None = PrivateAttr(default=None)
-
- def __str__(self):
- return f"Option(key='{self.key}', description='{self.description}')"
-
- def set_result(self, result: Any) -> None:
- """Set the result of the action."""
- self._result = result
-
- def get_result(self) -> Any:
- """Get the result of the action."""
- return self._result
-
- @field_validator("action")
- def validate_action(cls, action):
- if not callable(action):
- raise InvalidActionError("Action must be a callable.")
- return action
-
-
class Menu:
"""Class to create a menu with options.
@@ -218,21 +84,21 @@ class Menu:
self.title: str | Markdown = title
self.prompt: str | AnyFormattedText = prompt
self.columns: int = columns
- self.bottom_bar: str | Callable[[], None] | None = bottom_bar
+ self.bottom_bar: str | Callable[[], None] | None = bottom_bar or BottomBar(columns=columns)
self.options: dict[str, Option] = CaseInsensitiveDict()
self.back_option: Option = self._get_back_option()
self.console: Console = Console(color_system="truecolor", theme=get_nord_theme())
- self.session: PromptSession = self._get_prompt_session()
+ #self.session: PromptSession = self._get_prompt_session()
self.welcome_message: str | Markdown = welcome_message
self.exit_message: str | Markdown = exit_message
- self.before_action: Hooks = Hooks()
- self.after_action: Hooks = Hooks()
- self.on_error: Hooks = Hooks()
+ self.hooks: HookManager = HookManager()
self.run_hooks_on_back_option: bool = run_hooks_on_back_option
self.continue_on_error_prompt: bool = continue_on_error_prompt
self._never_confirm: bool = never_confirm
self._verbose: bool = _verbose
self.last_run_option: Option | None = None
+ self.key_bindings: KeyBindings = KeyBindings()
+ self.toggles: dict[str, str] = {}
def get_title(self) -> str:
"""Returns the string title of the menu."""
@@ -271,7 +137,8 @@ class Menu:
self.session.validator = self._get_validator()
self._invalidate_table_cache()
- def _get_prompt_session(self) -> PromptSession:
+ @cached_property
+ def session(self) -> PromptSession:
"""Returns the prompt session for the menu."""
return PromptSession(
message=self.prompt,
@@ -279,35 +146,50 @@ class Menu:
completer=self._get_completer(),
reserve_space_for_menu=1,
validator=self._get_validator(),
- bottom_toolbar=self.bottom_bar,
+ bottom_toolbar=self.bottom_bar.render,
)
- def add_before(self, hook: Callable[["Option"], None]) -> None:
- """Adds a hook to be executed before the action of the menu."""
- self.before_action.add_hook(hook)
+ def add_toggle(self, key: str, label: str, state: bool = False):
+ if key in self.options or key in self.toggles:
+ raise ValueError(f"Key '{key}' is already in use.")
- def add_after(self, hook: Callable[["Option"], None]) -> None:
- """Adds a hook to be executed after the action of the menu."""
- self.after_action.add_hook(hook)
+ self.toggles[key] = label
+ self.bottom_bar.add_toggle(label, label, state)
- def add_on_error(self, hook: Callable[["Option", Exception], None]) -> None:
- """Adds a hook to be executed on error of the menu."""
- self.on_error.add_hook(hook)
+ @self.key_bindings.add(key)
+ def _(event):
+ current = self.bottom_bar._states[label][1]
+ self.bottom_bar.update_toggle(label, not current)
+ self.console.print(f"Toggled [{label}] to {'ON' if not current else 'OFF'}")
+
+ def add_counter(self, name: str, label: str, current: int, total: int):
+ self.bottom_bar.add_counter(name, label, current, total)
+
+ def update_counter(self, name: str, current: int | None = None, total: int | None = None):
+ self.bottom_bar.update_counter(name, current=current, total=total)
+
+ def update_toggle(self, name: str, state: bool):
+ self.bottom_bar.update_toggle(name, state)
def debug_hooks(self) -> None:
if not self._verbose:
return
- logger.debug(f"Menu-level before hooks: {[hook.__name__ for hook in self.before_action.hooks]}")
- logger.debug(f"Menu-level after hooks: {[hook.__name__ for hook in self.after_action.hooks]}")
- logger.debug(f"Menu-level error hooks: {[hook.__name__ for hook in self.on_error.hooks]}")
+
+ def hook_names(hook_list):
+ return [hook.__name__ for hook in hook_list]
+
+ logger.debug(f"Menu-level before hooks: {hook_names(self.hooks._hooks['before'])}")
+ logger.debug(f"Menu-level after hooks: {hook_names(self.hooks._hooks['after'])}")
+ logger.debug(f"Menu-level error hooks: {hook_names(self.hooks._hooks['on_error'])}")
+
for key, option in self.options.items():
- logger.debug(f"[Option '{key}'] before: {[hook.__name__ for hook in option.before_action.hooks]}")
- logger.debug(f"[Option '{key}'] after: {[hook.__name__ for hook in option.after_action.hooks]}")
- logger.debug(f"[Option '{key}'] error: {[hook.__name__ for hook in option.on_error.hooks]}")
+ logger.debug(f"[Option '{key}'] before: {hook_names(option.hooks._hooks['before'])}")
+ logger.debug(f"[Option '{key}'] after: {hook_names(option.hooks._hooks['after'])}")
+ logger.debug(f"[Option '{key}'] error: {hook_names(option.hooks._hooks['on_error'])}")
def _validate_option_key(self, key: str) -> None:
"""Validates the option key to ensure it is unique."""
- if key in self.options or key.upper() == self.back_option.key.upper():
+ if key.upper() in self.options or key.upper() == self.back_option.key.upper():
raise OptionAlreadyExistsError(f"Option with key '{key}' already exists.")
def update_back_option(
@@ -350,7 +232,7 @@ class Menu:
self,
key: str,
description: str,
- action: Callable[[], Any],
+ action: BaseAction | Callable[[], Any],
color: str = OneColors.WHITE,
confirm: bool = False,
confirm_message: str = "Are you sure?",
@@ -358,15 +240,14 @@ class Menu:
spinner_message: str = "Processing...",
spinner_type: str = "dots",
spinner_style: str = OneColors.CYAN,
- spinner_kwargs: dict[str, Any] = None,
- before_hooks: list[Callable[[Option], None]] = None,
- after_hooks: list[Callable[[Option], None]] = None,
- error_hooks: list[Callable[[Option, Exception], None]] = None,
+ spinner_kwargs: dict[str, Any] | None = None,
+ before_hooks: list[Callable] | None = None,
+ after_hooks: list[Callable] | None = None,
+ error_hooks: list[Callable] | None = None,
) -> Option:
"""Adds an option to the menu, preventing duplicates."""
+ spinner_kwargs: dict[str, Any] = spinner_kwargs or {}
self._validate_option_key(key)
- if not spinner_kwargs:
- spinner_kwargs = {}
option = Option(
key=key,
description=description,
@@ -379,10 +260,15 @@ class Menu:
spinner_type=spinner_type,
spinner_style=spinner_style,
spinner_kwargs=spinner_kwargs,
- before_action=Hooks(hooks=before_hooks),
- after_action=Hooks(hooks=after_hooks),
- on_error=Hooks(hooks=error_hooks),
)
+
+ for hook in before_hooks or []:
+ option.hooks.register("before", hook)
+ for hook in after_hooks or []:
+ option.hooks.register("after", hook)
+ for hook in error_hooks or []:
+ option.hooks.register("on_error", hook)
+
self.options[key] = option
self._refresh_session()
return option
@@ -405,15 +291,20 @@ class Menu:
return self.back_option
return self.options.get(choice)
- def _should_hooks_run(self, selected_option: Option) -> bool:
- """Determines if hooks should be run based on the selected option."""
- return selected_option != self.back_option or self.run_hooks_on_back_option
-
def _should_run_action(self, selected_option: Option) -> bool:
if selected_option.confirm and not self._never_confirm:
return confirm(selected_option.confirm_message)
return True
+ def _create_context(self, selected_option: Option) -> dict[str, Any]:
+ """Creates a context dictionary for the selected option."""
+ return {
+ "name": selected_option.description,
+ "option": selected_option,
+ "args": (),
+ "kwargs": {},
+ }
+
def _run_action_with_spinner(self, option: Option) -> Any:
"""Runs the action of the selected option with a spinner."""
with self.console.status(
@@ -422,15 +313,13 @@ class Menu:
spinner_style=option.spinner_style,
**option.spinner_kwargs,
):
- return option.action()
+ return option()
def _handle_action_error(self, selected_option: Option, error: Exception) -> bool:
"""Handles errors that occur during the action of the selected option."""
logger.exception(f"Error executing '{selected_option.description}': {error}")
self.console.print(f"[{OneColors.DARK_RED}]An error occurred while executing "
f"{selected_option.description}:[/] {error}")
- selected_option.on_error.run_hooks(selected_option, error)
- self.on_error.run_hooks(selected_option, error)
if self.continue_on_error_prompt and not self._never_confirm:
return confirm("An error occurred. Do you wish to continue?")
if self._never_confirm:
@@ -442,25 +331,38 @@ class Menu:
choice = self.session.prompt()
selected_option = self.get_option(choice)
self.last_run_option = selected_option
- should_hooks_run = self._should_hooks_run(selected_option)
+
+ if selected_option == self.back_option:
+ logger.info(f"🔙 Back selected: exiting {self.get_title()}")
+ return False
+
if not self._should_run_action(selected_option):
logger.info(f"[{OneColors.DARK_RED}] {selected_option.description} cancelled.")
return True
+
+ context = self._create_context(selected_option)
+
try:
- if should_hooks_run:
- self.before_action.run_hooks(selected_option)
- selected_option.before_action.run_hooks(selected_option)
+ run_async(self.hooks.trigger("before", context))
+
if selected_option.spinner:
result = self._run_action_with_spinner(selected_option)
else:
- result = selected_option.action()
+ result = selected_option()
+
selected_option.set_result(result)
- selected_option.after_action.run_hooks(selected_option)
- if should_hooks_run:
- self.after_action.run_hooks(selected_option)
+ context["result"] = result
+ context["duration"] = selected_option.get_duration()
+ run_async(self.hooks.trigger("after", context))
except Exception as error:
+ context["exception"] = error
+ context["duration"] = selected_option.get_duration()
+ run_async(self.hooks.trigger("on_error", context))
+ if "exception" not in context:
+ logger.info(f"✅ Recovery hook handled error for '{selected_option.description}'")
+ return True
return self._handle_action_error(selected_option, error)
- return selected_option != self.back_option
+ return True
def run_headless(self, option_key: str, never_confirm: bool | None = None) -> Any:
"""Runs the action of the selected option without displaying the menu."""
@@ -470,34 +372,45 @@ class Menu:
selected_option = self.get_option(option_key)
self.last_run_option = selected_option
+
if not selected_option:
- raise MenuError(f"[Headless] Option '{option_key}' not found.")
+ logger.info("[Headless] Back option selected. Exiting menu.")
+ return
logger.info(f"[Headless] 🚀 Running: '{selected_option.description}'")
- should_hooks_run = self._should_hooks_run(selected_option)
+
if not self._should_run_action(selected_option):
- logger.info(f"[Headless] ⛔ '{selected_option.description}' cancelled.")
raise MenuError(f"[Headless] '{selected_option.description}' cancelled by confirmation.")
+ context = self._create_context(selected_option)
+
try:
- if should_hooks_run:
- self.before_action.run_hooks(selected_option)
- selected_option.before_action.run_hooks(selected_option)
+ run_async(self.hooks.trigger("before", context))
+
if selected_option.spinner:
result = self._run_action_with_spinner(selected_option)
else:
- result = selected_option.action()
+ result = selected_option()
+
selected_option.set_result(result)
- selected_option.after_action.run_hooks(selected_option)
- if should_hooks_run:
- self.after_action.run_hooks(selected_option)
+ context["result"] = result
+ context["duration"] = selected_option.get_duration()
+
+ run_async(self.hooks.trigger("after", context))
logger.info(f"[Headless] ✅ '{selected_option.description}' complete.")
except (KeyboardInterrupt, EOFError):
raise MenuError(f"[Headless] ⚠️ '{selected_option.description}' interrupted by user.")
except Exception as error:
+ context["exception"] = error
+ context["duration"] = selected_option.get_duration()
+ run_async(self.hooks.trigger("on_error", context))
+ if "exception" not in context:
+ logger.info(f"[Headless] ✅ Recovery hook handled error for '{selected_option.description}'")
+ return True
continue_running = self._handle_action_error(selected_option, error)
if not continue_running:
raise MenuError(f"[Headless] ❌ '{selected_option.description}' failed.") from error
+
return selected_option.get_result()
def run(self) -> None:
@@ -517,83 +430,3 @@ class Menu:
logger.info(f"Exiting menu: {self.get_title()}")
if self.exit_message:
self.console.print(self.exit_message)
-
-
-if __name__ == "__main__":
- from rich.traceback import install
- from logging_utils import setup_logging
-
- install(show_locals=True)
- setup_logging()
-
- def say_hello():
- print("Hello!")
-
- def say_goodbye():
- print("Goodbye!")
-
- def say_nested():
- print("This is a nested menu!")
-
- def my_action():
- print("This is my action!")
-
- def long_running_task():
- import time
-
- time.sleep(5)
-
- nested_menu = Menu(
- Markdown("## Nested Menu", style=OneColors.DARK_YELLOW),
- columns=2,
- bottom_bar="Menu within a menu",
- )
- nested_menu.add_option("1", "Say Nested", say_nested, color=OneColors.MAGENTA)
- nested_menu.add_before(lambda opt: logger.info(f"Global BEFORE '{opt.description}'"))
- nested_menu.add_after(lambda opt: logger.info(f"Global AFTER '{opt.description}'"))
-
- nested_menu.add_option(
- "2",
- "Test Action",
- action=my_action,
- before_hooks=[lambda opt: logger.info(f"Option-specific BEFORE '{opt.description}'")],
- after_hooks=[lambda opt: logger.info(f"Option-specific AFTER '{opt.description}'")],
- )
-
- def bottom_bar():
- return (
- f"Press Q to quit | Options available: {', '.join([f'[{key}]' for key in menu.options.keys()])}"
- )
-
- welcome_message = Markdown("# Welcome to the Menu!")
- exit_message = Markdown("# Thank you for using the menu!")
- menu = Menu(
- Markdown("## Main Menu", style=OneColors.CYAN),
- columns=3,
- bottom_bar=bottom_bar,
- welcome_message=welcome_message,
- exit_message=exit_message,
- )
- menu.add_option("1", "Say Hello", say_hello, color=OneColors.GREEN)
- menu.add_option("2", "Say Goodbye", say_goodbye, color=OneColors.LIGHT_RED)
- menu.add_option("3", "Do Nothing", lambda: None, color=OneColors.BLUE)
- menu.add_submenu("4", "Nested Menu", nested_menu, color=OneColors.MAGENTA)
- menu.add_option("5", "Do Nothing", lambda: None, color=OneColors.BLUE)
- menu.add_option(
- "6",
- "Long Running Task",
- action=long_running_task,
- spinner=True,
- spinner_message="Working, please wait...",
- spinner_type="moon",
- spinner_style=OneColors.GREEN,
- spinner_kwargs={"speed": 0.7},
- )
-
- menu.update_back_option("Q", "Quit", color=OneColors.DARK_RED)
-
- try:
- menu.run()
- except EOFError as error:
- logger.exception("EOFError: Exiting program.", exc_info=error)
- print("Exiting program.")
diff --git a/menu/menu_utils.py b/menu/menu_utils.py
new file mode 100644
index 0000000..4bd2715
--- /dev/null
+++ b/menu/menu_utils.py
@@ -0,0 +1,79 @@
+import asyncio
+import time
+from itertools import islice
+
+
+def chunks(iterator, size):
+ """Yield successive n-sized chunks from an iterator."""
+ iterator = iter(iterator)
+ while True:
+ chunk = list(islice(iterator, size))
+ if not chunk:
+ break
+ yield chunk
+
+
+def run_async(coro):
+ """Run an async function in a synchronous context."""
+ try:
+ _ = asyncio.get_running_loop()
+ return asyncio.create_task(coro)
+ except RuntimeError:
+ return asyncio.run(coro)
+
+
+class TimingMixin:
+ def _start_timer(self):
+ self.start_time = time.perf_counter()
+
+ def _stop_timer(self):
+ self.end_time = time.perf_counter()
+ self._duration = self.end_time - self.start_time
+
+ def get_duration(self) -> float | None:
+ return getattr(self, "_duration", None)
+
+
+class MenuError(Exception):
+ """Custom exception for the Menu class."""
+
+
+class OptionAlreadyExistsError(MenuError):
+ """Exception raised when an option with the same key already exists in the menu."""
+
+
+class InvalidHookError(MenuError):
+ """Exception raised when a hook is not callable."""
+
+
+class InvalidActionError(MenuError):
+ """Exception raised when an action is not callable."""
+
+
+class NotAMenuError(MenuError):
+ """Exception raised when the provided submenu is not an instance of Menu."""
+
+
+class CaseInsensitiveDict(dict):
+ """A case-insensitive dictionary that treats all keys as uppercase."""
+
+ def __setitem__(self, key, value):
+ super().__setitem__(key.upper(), value)
+
+ def __getitem__(self, key):
+ return super().__getitem__(key.upper())
+
+ def __contains__(self, key):
+ return super().__contains__(key.upper())
+
+ def get(self, key, default=None):
+ return super().get(key.upper(), default)
+
+ def pop(self, key, default=None):
+ return super().pop(key.upper(), default)
+
+ def update(self, other=None, **kwargs):
+ if other:
+ other = {k.upper(): v for k, v in other.items()}
+ kwargs = {k.upper(): v for k, v in kwargs.items()}
+ super().update(other, **kwargs)
diff --git a/menu/option.py b/menu/option.py
new file mode 100644
index 0000000..1f60196
--- /dev/null
+++ b/menu/option.py
@@ -0,0 +1,110 @@
+"""option.py
+Any Action or Option is callable and supports the signature:
+ result = thing(*args, **kwargs)
+
+This guarantees:
+- Hook lifecycle (before/after/error/teardown)
+- Timing
+- Consistent return values
+"""
+from __future__ import annotations
+
+import asyncio
+import logging
+from typing import Any, Callable
+
+from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
+
+from action import BaseAction
+from colors import OneColors
+from hook_manager import HookManager
+from menu_utils import TimingMixin, run_async
+
+logger = logging.getLogger("menu")
+
+
+class Option(BaseModel, TimingMixin):
+ """Class representing an option in the menu.
+
+ Hooks must have the signature:
+ def hook(option: Option) -> None:
+ where `option` is the selected option.
+
+ Error hooks must have the signature:
+ def error_hook(option: Option, error: Exception) -> None:
+ where `option` is the selected option and `error` is the exception raised.
+ """
+ key: str
+ description: str
+ action: BaseAction | Callable[[], Any] = lambda: None
+ color: str = OneColors.WHITE
+ confirm: bool = False
+ confirm_message: str = "Are you sure?"
+ spinner: bool = False
+ spinner_message: str = "Processing..."
+ spinner_type: str = "dots"
+ spinner_style: str = OneColors.CYAN
+ spinner_kwargs: dict[str, Any] = Field(default_factory=dict)
+
+ hooks: "HookManager" = Field(default_factory=HookManager)
+
+ start_time: float | None = None
+ end_time: float | None = None
+ _duration: float | None = PrivateAttr(default=None)
+ _result: Any | None = PrivateAttr(default=None)
+
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
+ def __str__(self):
+ return f"Option(key='{self.key}', description='{self.description}')"
+
+ def set_result(self, result: Any) -> None:
+ """Set the result of the action."""
+ self._result = result
+
+ def get_result(self) -> Any:
+ """Get the result of the action."""
+ return self._result
+
+ def __call__(self, *args, **kwargs) -> Any:
+ context = {
+ "name": self.description,
+ "duration": None,
+ "args": args,
+ "kwargs": kwargs,
+ "option": self,
+ }
+ self._start_timer()
+ try:
+ run_async(self.hooks.trigger("before", context))
+ result = self._execute_action(*args, **kwargs)
+ self.set_result(result)
+ context["result"] = result
+ return result
+ except Exception as error:
+ context["exception"] = error
+ run_async(self.hooks.trigger("on_error", context))
+ if "exception" not in context:
+ logger.info(f"✅ Recovery hook handled error for Option '{self.key}'")
+ return self.get_result()
+ raise
+ finally:
+ self._stop_timer()
+ context["duration"] = self.get_duration()
+ if "exception" not in context:
+ run_async(self.hooks.trigger("after", context))
+ run_async(self.hooks.trigger("on_teardown", context))
+
+ def _execute_action(self, *args, **kwargs) -> Any:
+ if isinstance(self.action, BaseAction):
+ return self.action(*args, **kwargs)
+ return self.action()
+
+ def dry_run(self):
+ print(f"[DRY RUN] Option '{self.key}' would run: {self.description}")
+ if isinstance(self.action, BaseAction):
+ self.action.dry_run()
+ elif callable(self.action):
+ print(f"[DRY RUN] Action is a raw callable: {self.action.__name__}")
+ else:
+ print("[DRY RUN] Action is not callable.")
\ No newline at end of file
diff --git a/menu/task.py b/menu/task.py
index 4c3b360..d71d67a 100644
--- a/menu/task.py
+++ b/menu/task.py
@@ -3,7 +3,7 @@ import time
def risky_task() -> str:
- if random.random() > 0.25:
+ if random.random() > 0.4:
time.sleep(1)
raise ValueError("Random failure occurred")
return "Task succeeded!"