Update to Menu and Action

This commit is contained in:
Roland Thomas Jr 2025-03-27 22:02:16 -04:00
parent 03af4f8077
commit 7443084809
Signed by: roland
GPG Key ID: 7C3C2B085A4C2872
10 changed files with 718 additions and 340 deletions

218
menu/action.py Normal file
View File

@ -0,0 +1,218 @@
"""action.py
Any Action or Option is callable and supports the signature:
result = thing(*args, **kwargs)
This guarantees:
- Hook lifecycle (before/after/error/teardown)
- Timing
- Consistent return values
"""
from __future__ import annotations
import asyncio
import logging
import time
import inspect
from abc import ABC, abstractmethod
from typing import Optional
from hook_manager import HookManager
from menu_utils import TimingMixin, run_async
logger = logging.getLogger("menu")
class BaseAction(ABC, TimingMixin):
"""Base class for actions. They are the building blocks of the menu.
Actions can be simple functions or more complex actions like
`ChainedAction` or `ActionGroup`. They can also be run independently
or as part of a menu."""
def __init__(self, name: str, hooks: Optional[HookManager] = None):
self.name = name
self.hooks = hooks or HookManager()
self.start_time: float | None = None
self.end_time: float | None = None
self._duration: float | None = None
def __call__(self, *args, **kwargs):
context = {
"name": self.name,
"duration": None,
"args": args,
"kwargs": kwargs,
"action": self
}
self._start_timer()
try:
run_async(self.hooks.trigger("before", context))
result = self._run(*args, **kwargs)
context["result"] = result
return result
except Exception as error:
context["exception"] = error
run_async(self.hooks.trigger("on_error", context))
if "exception" not in context:
logger.info(f"✅ Recovery hook handled error for Action '{self.name}'")
return context.get("result")
raise
finally:
self._stop_timer()
context["duration"] = self.get_duration()
if "exception" not in context:
run_async(self.hooks.trigger("after", context))
run_async(self.hooks.trigger("on_teardown", context))
@abstractmethod
def _run(self, *args, **kwargs):
raise NotImplementedError("_run must be implemented by subclasses")
async def run_async(self, *args, **kwargs):
if inspect.iscoroutinefunction(self._run):
return await self._run(*args, **kwargs)
return await asyncio.to_thread(self.__call__, *args, **kwargs)
def __await__(self):
return self.run_async().__await__()
@abstractmethod
def dry_run(self):
raise NotImplementedError("dry_run must be implemented by subclasses")
def __str__(self):
return f"<{self.__class__.__name__} '{self.name}'>"
def __repr__(self):
return str(self)
class Action(BaseAction):
def __init__(self, name: str, fn, rollback=None, hooks=None):
super().__init__(name, hooks)
self.fn = fn
self.rollback = rollback
def _run(self, *args, **kwargs):
if inspect.iscoroutinefunction(self.fn):
return asyncio.run(self.fn(*args, **kwargs))
return self.fn(*args, **kwargs)
def dry_run(self):
print(f"[DRY RUN] Would run: {self.name}")
class ChainedAction(BaseAction):
def __init__(self, name: str, actions: list[BaseAction], hooks=None):
super().__init__(name, hooks)
self.actions = actions
def _run(self, *args, **kwargs):
rollback_stack = []
for action in self.actions:
try:
result = action(*args, **kwargs)
rollback_stack.append(action)
except Exception:
self._rollback(rollback_stack, *args, **kwargs)
raise
return None
def dry_run(self):
print(f"[DRY RUN] ChainedAction '{self.name}' with steps:")
for action in self.actions:
action.dry_run()
def _rollback(self, rollback_stack, *args, **kwargs):
for action in reversed(rollback_stack):
if hasattr(action, "rollback") and action.rollback:
try:
print(f"↩️ Rolling back {action.name}")
action.rollback(*args, **kwargs)
except Exception as e:
print(f"⚠️ Rollback failed for {action.name}: {e}")
class ActionGroup(BaseAction):
def __init__(self, name: str, actions: list[BaseAction], hooks=None):
super().__init__(name, hooks)
self.actions = actions
self.results = []
self.errors = []
def _run(self, *args, **kwargs):
asyncio.run(self._run_async(*args, **kwargs))
def dry_run(self):
print(f"[DRY RUN] ActionGroup '{self.name}' (parallel execution):")
for action in self.actions:
action.dry_run()
async def _run_async(self, *args, **kwargs):
async def run(action):
try:
result = await asyncio.to_thread(action, *args, **kwargs)
self.results.append((action.name, result))
except Exception as e:
self.errors.append((action.name, e))
await self.hooks.trigger("before", name=self.name)
await asyncio.gather(*[run(a) for a in self.actions])
if self.errors:
await self.hooks.trigger("on_error", name=self.name, errors=self.errors)
else:
await self.hooks.trigger("after", name=self.name, results=self.results)
await self.hooks.trigger("on_teardown", name=self.name)
# if __name__ == "__main__":
# # Example usage
# def build(): print("Build!")
# def test(): print("Test!")
# def deploy(): print("Deploy!")
# pipeline = ChainedAction("CI/CD", [
# Action("Build", build),
# Action("Test", test),
# ActionGroup("Deploy Parallel", [
# Action("Deploy A", deploy),
# Action("Deploy B", deploy)
# ])
# ])
# pipeline()
# Sample functions
def sync_hello():
time.sleep(1)
return "Hello from sync function"
async def async_hello():
await asyncio.sleep(1)
return "Hello from async function"
# Example usage
async def main():
sync_action = Action("sync_hello", sync_hello)
async_action = Action("async_hello", async_hello)
print("⏳ Awaiting sync action...")
result1 = await sync_action
print("", result1)
print("⏳ Awaiting async action...")
result2 = await async_action
print("", result2)
print(f"⏱️ sync took {sync_action.get_duration():.2f}s")
print(f"⏱️ async took {async_action.get_duration():.2f}s")
if __name__ == "__main__":
asyncio.run(main())

59
menu/bottom_bar.py Normal file
View File

@ -0,0 +1,59 @@
from prompt_toolkit.formatted_text import HTML, merge_formatted_text
from typing import Callable, Literal, Optional
from rich.console import Console
class BottomBar:
def __init__(self, columns: int = 3):
self.columns = columns
self.console = Console()
self._items: list[Callable[[], HTML]] = []
self._named_items: dict[str, Callable[[], HTML]] = {}
self._states: dict[str, any] = {}
def get_space(self) -> str:
return self.console.width // self.columns
def add_static(self, name: str, text: str) -> None:
def render():
return HTML(f"<style fg='#D8DEE9'>{text:^{self.get_space()}}</style>")
self._add_named(name, render)
def add_counter(self, name: str, label: str, current: int, total: int) -> None:
self._states[name] = (label, current, total)
def render():
l, c, t = self._states[name]
text = f"{l}: {c}/{t}"
return HTML(f"<style fg='#A3BE8C'>{text:^{self.get_space()}}</style>")
self._add_named(name, render)
def add_toggle(self, name: str, label: str, state: bool) -> None:
self._states[name] = (label, state)
def render():
l, s = self._states[name]
color = '#A3BE8C' if s else '#BF616A'
status = "ON" if s else "OFF"
text = f"{l}: {status}"
return HTML(f"<style fg='{color}'>{text:^{self.get_space()}}</style>")
self._add_named(name, render)
def update_toggle(self, name: str, state: bool) -> None:
if name in self._states:
label, _ = self._states[name]
self._states[name] = (label, state)
def update_counter(self, name: str, current: Optional[int] = None, total: Optional[int] = None) -> None:
if name in self._states:
label, c, t = self._states[name]
self._states[name] = (label, current if current is not None else c, total if total is not None else t)
def _add_named(self, name: str, render_fn: Callable[[], HTML]) -> None:
self._named_items[name] = render_fn
self._items = list(self._named_items.values())
def render(self):
return merge_formatted_text([fn() for fn in self._items])

View File

@ -11,7 +11,7 @@ console = Console()
setup_logging() setup_logging()
logger = logging.getLogger("menu") logger = logging.getLogger("menu")
def retry(max_retries=3, delay=1, backoff=2, exceptions=(Exception,), logger=None, spinner_text=None): def retry(max_retries=3, delay=1, backoff=2, exceptions=(Exception,)):
def decorator(func): def decorator(func):
is_coroutine = inspect.iscoroutinefunction(func) is_coroutine = inspect.iscoroutinefunction(func)
@ -22,8 +22,7 @@ def retry(max_retries=3, delay=1, backoff=2, exceptions=(Exception,), logger=Non
if logger: if logger:
logger.debug(f"Retrying {retries + 1}/{max_retries} for '{func.__name__}' after {current_delay}s due to '{exceptions}'.") logger.debug(f"Retrying {retries + 1}/{max_retries} for '{func.__name__}' after {current_delay}s due to '{exceptions}'.")
try: try:
with console.status(spinner_text, spinner="dots"): return await func(*args, **kwargs)
return await func(*args, **kwargs)
except exceptions as e: except exceptions as e:
if retries == max_retries: if retries == max_retries:
if logger: if logger:
@ -44,8 +43,7 @@ def retry(max_retries=3, delay=1, backoff=2, exceptions=(Exception,), logger=Non
if logger: if logger:
logger.debug(f"Retrying {retries + 1}/{max_retries} for '{func.__name__}' after {current_delay}s due to '{exceptions}'.") logger.debug(f"Retrying {retries + 1}/{max_retries} for '{func.__name__}' after {current_delay}s due to '{exceptions}'.")
try: try:
with console.status(spinner_text, spinner="dots"): return func(*args, **kwargs)
return func(*args, **kwargs)
except exceptions as e: except exceptions as e:
if retries == max_retries: if retries == max_retries:
if logger: if logger:
@ -62,7 +60,7 @@ def retry(max_retries=3, delay=1, backoff=2, exceptions=(Exception,), logger=Non
return async_wrapper if is_coroutine else sync_wrapper return async_wrapper if is_coroutine else sync_wrapper
return decorator return decorator
@retry(max_retries=10, delay=1, logger=logger, spinner_text="Trying risky thing...") @retry(max_retries=10, delay=1, spinner_text="Trying risky thing...")
def might_fail(): def might_fail():
time.sleep(4) time.sleep(4)
if random.random() < 0.6: if random.random() < 0.6:

66
menu/hook_manager.py Normal file
View File

@ -0,0 +1,66 @@
"""hook_manager.py"""
from __future__ import annotations
import inspect
import logging
from typing import (Any, Awaitable, Callable, Dict, List, Optional, TypedDict,
Union, TYPE_CHECKING)
if TYPE_CHECKING:
from action import BaseAction
from option import Option
logger = logging.getLogger("menu")
class HookContext(TypedDict, total=False):
name: str
args: tuple[Any, ...]
kwargs: dict[str, Any]
result: Any | None
exception: Exception | None
option: Option | None
action: BaseAction | None
Hook = Union[Callable[[HookContext], None], Callable[[HookContext], Awaitable[None]]]
class HookManager:
def __init__(self):
self._hooks: Dict[str, List[Hook]] = {
"before": [],
"after": [],
"on_error": [],
"on_teardown": [],
}
def register(self, hook_type: str, hook: Hook):
if hook_type not in self._hooks:
raise ValueError(f"Unsupported hook type: {hook_type}")
self._hooks[hook_type].append(hook)
def clear(self, hook_type: Optional[str] = None):
if hook_type:
self._hooks[hook_type] = []
else:
for k in self._hooks:
self._hooks[k] = []
async def trigger(self, hook_type: str, context: HookContext):
if hook_type not in self._hooks:
raise ValueError(f"Unsupported hook type: {hook_type}")
for hook in self._hooks[hook_type]:
try:
if inspect.iscoroutinefunction(hook):
await hook(context)
else:
hook(context)
except Exception as hook_error:
name = context.get("name", "<unnamed>")
logger.warning(f"⚠️ Hook '{hook.__name__}' raised an exception during '{hook_type}'"
f" for '{name}': {hook_error}")
if hook_type == "on_error":
raise context.get("exception") from hook_error

View File

@ -1,41 +1,40 @@
import time import functools
import logging import logging
import random import random
import functools import time
from menu import Menu, Option
from hook_manager import HookContext
from menu_utils import run_async
logger = logging.getLogger("menu") logger = logging.getLogger("menu")
def timing_before_hook(option: Option) -> None:
option._start_time = time.perf_counter()
def log_before(context: dict) -> None:
def timing_after_hook(option: Option) -> None: name = context.get("name", "<unnamed>")
option._end_time = time.perf_counter() option = context.get("option")
option._duration = option._end_time - option._start_time if option:
logger.info(f"🚀 Starting action '{option.description}' (key='{option.key}')")
def timing_error_hook(option: Option, _: Exception) -> None:
option._end_time = time.perf_counter()
option._duration = option._end_time - option._start_time
def log_before(option: Option) -> None:
logger.info(f"🚀 Starting action '{option.description}' (key='{option.key}')")
def log_after(option: Option) -> None:
if option._duration is not None:
logger.info(f"✅ Completed '{option.description}' (key='{option.key}') in {option._duration:.2f}s")
else: else:
logger.info(f"✅ Completed '{option.description}' (key='{option.key}')") logger.info(f"🚀 Starting action '{name}'")
def log_error(option: Option, error: Exception) -> None: def log_after(context: dict) -> None:
if option._duration is not None: name = context.get("name", "<unnamed>")
logger.error(f"❌ Error '{option.description}' (key='{option.key}') after {option._duration:.2f}s: {error}") duration = context.get("duration")
if duration is not None:
logger.info(f"✅ Completed '{name}' in {duration:.2f}s")
else: else:
logger.error(f"❌ Error '{option.description}' (key='{option.key}'): {error}") logger.info(f"✅ Completed '{name}'")
def log_error(context: dict) -> None:
name = context.get("name", "<unnamed>")
error = context.get("exception")
duration = context.get("duration")
if duration is not None:
logger.error(f"❌ Error '{name}' after {duration:.2f}s: {error}")
else:
logger.error(f"❌ Error '{name}': {error}")
class CircuitBreakerOpen(Exception): class CircuitBreakerOpen(Exception):
@ -49,23 +48,25 @@ class CircuitBreaker:
self.failures = 0 self.failures = 0
self.open_until = None self.open_until = None
def before_hook(self, option: Option): def before_hook(self, context: HookContext):
name = context.get("name", "<unnamed>")
if self.open_until: if self.open_until:
if time.time() < self.open_until: if time.time() < self.open_until:
raise CircuitBreakerOpen(f"🔴 Circuit open for '{option.description}' until {time.ctime(self.open_until)}.") raise CircuitBreakerOpen(f"🔴 Circuit open for '{name}' until {time.ctime(self.open_until)}.")
else: else:
logger.info(f"🟢 Circuit closed again for '{option.description}'.") logger.info(f"🟢 Circuit closed again for '{name}'.")
self.failures = 0 self.failures = 0
self.open_until = None self.open_until = None
def error_hook(self, option: Option, error: Exception): def error_hook(self, context: HookContext):
name = context.get("name", "<unnamed>")
self.failures += 1 self.failures += 1
logger.warning(f"⚠️ CircuitBreaker: '{option.description}' failure {self.failures}/{self.max_failures}.") logger.warning(f"⚠️ CircuitBreaker: '{name}' failure {self.failures}/{self.max_failures}.")
if self.failures >= self.max_failures: if self.failures >= self.max_failures:
self.open_until = time.time() + self.reset_timeout self.open_until = time.time() + self.reset_timeout
logger.error(f"🔴 Circuit opened for '{option.description}' until {time.ctime(self.open_until)}.") logger.error(f"🔴 Circuit opened for '{name}' until {time.ctime(self.open_until)}.")
def after_hook(self, option: Option): def after_hook(self, context: HookContext):
self.failures = 0 self.failures = 0
def is_open(self): def is_open(self):
@ -78,33 +79,52 @@ class CircuitBreaker:
class RetryHandler: class RetryHandler:
def __init__(self, max_retries=2, delay=1, backoff=2): def __init__(self, max_retries=5, delay=1, backoff=2):
self.max_retries = max_retries self.max_retries = max_retries
self.delay = delay self.delay = delay
self.backoff = backoff self.backoff = backoff
def retry_on_error(self, option: Option, error: Exception): def retry_on_error(self, context: HookContext):
name = context.get("name", "<unnamed>")
error = context.get("exception")
option = context.get("option")
action = context.get("action")
retries_done = 0 retries_done = 0
current_delay = self.delay current_delay = self.delay
last_error = error last_error = error
if not (option or action):
logger.warning(f"⚠️ RetryHandler: No Option or Action in context for '{name}'. Skipping retry.")
return
target = option or action
while retries_done < self.max_retries: while retries_done < self.max_retries:
try: try:
retries_done += 1 retries_done += 1
logger.info(f"🔄 Retrying '{option.description}' ({retries_done}/{self.max_retries}) in {current_delay}s due to '{error}'...") logger.info(f"🔄 Retrying '{name}' ({retries_done}/{self.max_retries}) in {current_delay}s due to '{last_error}'...")
time.sleep(current_delay) time.sleep(current_delay)
result = option.action() result = target(*context.get("args", ()), **context.get("kwargs", {}))
print(result) if option:
option.set_result(result) option.set_result(result)
logger.info(f"✅ Retry succeeded for '{option.description}' on attempt {retries_done}.")
option.after_action.run_hooks(option) context["result"] = result
context["duration"] = target.get_duration() or 0.0
context.pop("exception", None)
logger.info(f"✅ Retry succeeded for '{name}' on attempt {retries_done}.")
if hasattr(target, "hooks"):
run_async(target.hooks.trigger("after", context))
return return
except Exception as retry_error: except Exception as retry_error:
logger.warning(f"⚠️ Retry attempt {retries_done} for '{option.description}' failed due to '{retry_error}'.") logger.warning(f"⚠️ Retry attempt {retries_done} for '{name}' failed due to '{retry_error}'.")
last_error = retry_error last_error = retry_error
current_delay *= self.backoff current_delay *= self.backoff
logger.exception(f"'{option.description}' failed after {self.max_retries} retries.") logger.exception(f"'{name}' failed after {self.max_retries} retries.")
raise last_error raise last_error
@ -133,15 +153,13 @@ def retry(max_retries=3, delay=1, backoff=2, exceptions=(Exception,), logger=Non
def setup_hooks(menu): def setup_hooks(menu):
menu.add_before(timing_before_hook)
menu.add_after(timing_after_hook)
menu.add_on_error(timing_error_hook)
menu.add_before(log_before) menu.add_before(log_before)
menu.add_after(log_after) menu.add_after(log_after)
menu.add_on_error(log_error) menu.add_on_error(log_error)
if __name__ == "__main__": if __name__ == "__main__":
from menu import Menu
def risky_task(): def risky_task():
if random.random() > 0.1: if random.random() > 0.1:
time.sleep(1) time.sleep(1)
@ -151,9 +169,6 @@ if __name__ == "__main__":
retry_handler = RetryHandler(max_retries=30, delay=2, backoff=2) retry_handler = RetryHandler(max_retries=30, delay=2, backoff=2)
menu = Menu(never_confirm=True) menu = Menu(never_confirm=True)
menu.add_before(timing_before_hook)
menu.add_after(timing_after_hook)
menu.add_on_error(timing_error_hook)
menu.add_before(log_before) menu.add_before(log_before)
menu.add_after(log_after) menu.add_after(log_after)
menu.add_on_error(log_error) menu.add_on_error(log_error)

View File

@ -8,7 +8,7 @@ def setup_logging(
): ):
"""Set up logging configuration with separate console and file handlers.""" """Set up logging configuration with separate console and file handlers."""
root_logger = logging.getLogger() root_logger = logging.getLogger()
root_logger.setLevel(logging.WARNING) root_logger.setLevel(logging.DEBUG)
if root_logger.hasHandlers(): if root_logger.hasHandlers():
root_logger.handlers.clear() root_logger.handlers.clear()

View File

@ -12,168 +12,34 @@ formatted and visually appealing way.
This class also uses the `prompt_toolkit` library to handle This class also uses the `prompt_toolkit` library to handle
user input and create an interactive experience. user input and create an interactive experience.
""" """
import asyncio
import logging import logging
from functools import cached_property from functools import cached_property
from itertools import islice
from typing import Any, Callable from typing import Any, Callable
from prompt_toolkit import PromptSession from prompt_toolkit import PromptSession
from prompt_toolkit.completion import WordCompleter from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.formatted_text import AnyFormattedText from prompt_toolkit.formatted_text import AnyFormattedText
from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit.shortcuts import confirm from prompt_toolkit.shortcuts import confirm
from prompt_toolkit.validation import Validator from prompt_toolkit.validation import Validator
from pydantic import BaseModel, Field, field_validator, PrivateAttr
from rich import box from rich import box
from rich.console import Console from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
from rich.table import Table from rich.table import Table
from action import BaseAction
from bottom_bar import BottomBar
from colors import get_nord_theme from colors import get_nord_theme
from hook_manager import HookManager
from menu_utils import (CaseInsensitiveDict, InvalidActionError, MenuError,
NotAMenuError, OptionAlreadyExistsError, chunks, run_async)
from one_colors import OneColors from one_colors import OneColors
from option import Option
logger = logging.getLogger("menu") logger = logging.getLogger("menu")
def chunks(iterator, size):
"""Yield successive n-sized chunks from an iterator."""
iterator = iter(iterator)
while True:
chunk = list(islice(iterator, size))
if not chunk:
break
yield chunk
class MenuError(Exception):
"""Custom exception for the Menu class."""
class OptionAlreadyExistsError(MenuError):
"""Exception raised when an option with the same key already exists in the menu."""
class InvalidHookError(MenuError):
"""Exception raised when a hook is not callable."""
class InvalidActionError(MenuError):
"""Exception raised when an action is not callable."""
class NotAMenuError(MenuError):
"""Exception raised when the provided submenu is not an instance of Menu."""
class CaseInsensitiveDict(dict):
"""A case-insensitive dictionary that treats all keys as uppercase."""
def __setitem__(self, key, value):
super().__setitem__(key.upper(), value)
def __getitem__(self, key):
return super().__getitem__(key.upper())
def __contains__(self, key):
return super().__contains__(key.upper())
def get(self, key, default=None):
return super().get(key.upper(), default)
def pop(self, key, default=None):
return super().pop(key.upper(), default)
def update(self, other=None, **kwargs):
if other:
other = {k.upper(): v for k, v in other.items()}
kwargs = {k.upper(): v for k, v in kwargs.items()}
super().update(other, **kwargs)
class Hooks(BaseModel):
"""Class to manage hooks for the menu and options."""
hooks: list[Callable[["Option"], None]] | list[Callable[["Option", Exception], None]] = Field(
default_factory=list
)
@field_validator("hooks", mode="before")
@classmethod
def validate_hooks(cls, hooks):
if hooks is None:
return []
if not all(callable(hook) for hook in hooks):
raise InvalidHookError("All hooks must be callable.")
return hooks
def add_hook(self, hook: Callable[["Option"], None] | Callable[["Option", Exception], None]) -> None:
"""Add a hook to the list."""
if not callable(hook):
raise InvalidHookError("Hook must be a callable.")
if hook not in self.hooks:
self.hooks.append(hook)
def run_hooks(self, *args, **kwargs) -> None:
"""Run all hooks with the given arguments."""
for hook in self.hooks:
try:
hook(*args, **kwargs)
except Exception as hook_error:
logger.exception(f"Hook '{hook.__name__}': {hook_error}")
class Option(BaseModel):
"""Class representing an option in the menu.
Hooks must have the signature:
def hook(option: Option) -> None:
where `option` is the selected option.
Error hooks must have the signature:
def error_hook(option: Option, error: Exception) -> None:
where `option` is the selected option and `error` is the exception raised.
"""
key: str
description: str
action: Callable[[], Any] = lambda: None
color: str = OneColors.WHITE
confirm: bool = False
confirm_message: str = "Are you sure?"
spinner: bool = False
spinner_message: str = "Processing..."
spinner_type: str = "dots"
spinner_style: str = OneColors.CYAN
spinner_kwargs: dict[str, Any] = Field(default_factory=dict)
before_action: Hooks = Field(default_factory=Hooks)
after_action: Hooks = Field(default_factory=Hooks)
on_error: Hooks = Field(default_factory=Hooks)
_start_time: float | None = PrivateAttr(default=None)
_end_time: float | None = PrivateAttr(default=None)
_duration: float | None = PrivateAttr(default=None)
_result: Any | None = PrivateAttr(default=None)
def __str__(self):
return f"Option(key='{self.key}', description='{self.description}')"
def set_result(self, result: Any) -> None:
"""Set the result of the action."""
self._result = result
def get_result(self) -> Any:
"""Get the result of the action."""
return self._result
@field_validator("action")
def validate_action(cls, action):
if not callable(action):
raise InvalidActionError("Action must be a callable.")
return action
class Menu: class Menu:
"""Class to create a menu with options. """Class to create a menu with options.
@ -218,21 +84,21 @@ class Menu:
self.title: str | Markdown = title self.title: str | Markdown = title
self.prompt: str | AnyFormattedText = prompt self.prompt: str | AnyFormattedText = prompt
self.columns: int = columns self.columns: int = columns
self.bottom_bar: str | Callable[[], None] | None = bottom_bar self.bottom_bar: str | Callable[[], None] | None = bottom_bar or BottomBar(columns=columns)
self.options: dict[str, Option] = CaseInsensitiveDict() self.options: dict[str, Option] = CaseInsensitiveDict()
self.back_option: Option = self._get_back_option() self.back_option: Option = self._get_back_option()
self.console: Console = Console(color_system="truecolor", theme=get_nord_theme()) self.console: Console = Console(color_system="truecolor", theme=get_nord_theme())
self.session: PromptSession = self._get_prompt_session() #self.session: PromptSession = self._get_prompt_session()
self.welcome_message: str | Markdown = welcome_message self.welcome_message: str | Markdown = welcome_message
self.exit_message: str | Markdown = exit_message self.exit_message: str | Markdown = exit_message
self.before_action: Hooks = Hooks() self.hooks: HookManager = HookManager()
self.after_action: Hooks = Hooks()
self.on_error: Hooks = Hooks()
self.run_hooks_on_back_option: bool = run_hooks_on_back_option self.run_hooks_on_back_option: bool = run_hooks_on_back_option
self.continue_on_error_prompt: bool = continue_on_error_prompt self.continue_on_error_prompt: bool = continue_on_error_prompt
self._never_confirm: bool = never_confirm self._never_confirm: bool = never_confirm
self._verbose: bool = _verbose self._verbose: bool = _verbose
self.last_run_option: Option | None = None self.last_run_option: Option | None = None
self.key_bindings: KeyBindings = KeyBindings()
self.toggles: dict[str, str] = {}
def get_title(self) -> str: def get_title(self) -> str:
"""Returns the string title of the menu.""" """Returns the string title of the menu."""
@ -271,7 +137,8 @@ class Menu:
self.session.validator = self._get_validator() self.session.validator = self._get_validator()
self._invalidate_table_cache() self._invalidate_table_cache()
def _get_prompt_session(self) -> PromptSession: @cached_property
def session(self) -> PromptSession:
"""Returns the prompt session for the menu.""" """Returns the prompt session for the menu."""
return PromptSession( return PromptSession(
message=self.prompt, message=self.prompt,
@ -279,35 +146,50 @@ class Menu:
completer=self._get_completer(), completer=self._get_completer(),
reserve_space_for_menu=1, reserve_space_for_menu=1,
validator=self._get_validator(), validator=self._get_validator(),
bottom_toolbar=self.bottom_bar, bottom_toolbar=self.bottom_bar.render,
) )
def add_before(self, hook: Callable[["Option"], None]) -> None: def add_toggle(self, key: str, label: str, state: bool = False):
"""Adds a hook to be executed before the action of the menu.""" if key in self.options or key in self.toggles:
self.before_action.add_hook(hook) raise ValueError(f"Key '{key}' is already in use.")
def add_after(self, hook: Callable[["Option"], None]) -> None: self.toggles[key] = label
"""Adds a hook to be executed after the action of the menu.""" self.bottom_bar.add_toggle(label, label, state)
self.after_action.add_hook(hook)
def add_on_error(self, hook: Callable[["Option", Exception], None]) -> None: @self.key_bindings.add(key)
"""Adds a hook to be executed on error of the menu.""" def _(event):
self.on_error.add_hook(hook) current = self.bottom_bar._states[label][1]
self.bottom_bar.update_toggle(label, not current)
self.console.print(f"Toggled [{label}] to {'ON' if not current else 'OFF'}")
def add_counter(self, name: str, label: str, current: int, total: int):
self.bottom_bar.add_counter(name, label, current, total)
def update_counter(self, name: str, current: int | None = None, total: int | None = None):
self.bottom_bar.update_counter(name, current=current, total=total)
def update_toggle(self, name: str, state: bool):
self.bottom_bar.update_toggle(name, state)
def debug_hooks(self) -> None: def debug_hooks(self) -> None:
if not self._verbose: if not self._verbose:
return return
logger.debug(f"Menu-level before hooks: {[hook.__name__ for hook in self.before_action.hooks]}")
logger.debug(f"Menu-level after hooks: {[hook.__name__ for hook in self.after_action.hooks]}") def hook_names(hook_list):
logger.debug(f"Menu-level error hooks: {[hook.__name__ for hook in self.on_error.hooks]}") return [hook.__name__ for hook in hook_list]
logger.debug(f"Menu-level before hooks: {hook_names(self.hooks._hooks['before'])}")
logger.debug(f"Menu-level after hooks: {hook_names(self.hooks._hooks['after'])}")
logger.debug(f"Menu-level error hooks: {hook_names(self.hooks._hooks['on_error'])}")
for key, option in self.options.items(): for key, option in self.options.items():
logger.debug(f"[Option '{key}'] before: {[hook.__name__ for hook in option.before_action.hooks]}") logger.debug(f"[Option '{key}'] before: {hook_names(option.hooks._hooks['before'])}")
logger.debug(f"[Option '{key}'] after: {[hook.__name__ for hook in option.after_action.hooks]}") logger.debug(f"[Option '{key}'] after: {hook_names(option.hooks._hooks['after'])}")
logger.debug(f"[Option '{key}'] error: {[hook.__name__ for hook in option.on_error.hooks]}") logger.debug(f"[Option '{key}'] error: {hook_names(option.hooks._hooks['on_error'])}")
def _validate_option_key(self, key: str) -> None: def _validate_option_key(self, key: str) -> None:
"""Validates the option key to ensure it is unique.""" """Validates the option key to ensure it is unique."""
if key in self.options or key.upper() == self.back_option.key.upper(): if key.upper() in self.options or key.upper() == self.back_option.key.upper():
raise OptionAlreadyExistsError(f"Option with key '{key}' already exists.") raise OptionAlreadyExistsError(f"Option with key '{key}' already exists.")
def update_back_option( def update_back_option(
@ -350,7 +232,7 @@ class Menu:
self, self,
key: str, key: str,
description: str, description: str,
action: Callable[[], Any], action: BaseAction | Callable[[], Any],
color: str = OneColors.WHITE, color: str = OneColors.WHITE,
confirm: bool = False, confirm: bool = False,
confirm_message: str = "Are you sure?", confirm_message: str = "Are you sure?",
@ -358,15 +240,14 @@ class Menu:
spinner_message: str = "Processing...", spinner_message: str = "Processing...",
spinner_type: str = "dots", spinner_type: str = "dots",
spinner_style: str = OneColors.CYAN, spinner_style: str = OneColors.CYAN,
spinner_kwargs: dict[str, Any] = None, spinner_kwargs: dict[str, Any] | None = None,
before_hooks: list[Callable[[Option], None]] = None, before_hooks: list[Callable] | None = None,
after_hooks: list[Callable[[Option], None]] = None, after_hooks: list[Callable] | None = None,
error_hooks: list[Callable[[Option, Exception], None]] = None, error_hooks: list[Callable] | None = None,
) -> Option: ) -> Option:
"""Adds an option to the menu, preventing duplicates.""" """Adds an option to the menu, preventing duplicates."""
spinner_kwargs: dict[str, Any] = spinner_kwargs or {}
self._validate_option_key(key) self._validate_option_key(key)
if not spinner_kwargs:
spinner_kwargs = {}
option = Option( option = Option(
key=key, key=key,
description=description, description=description,
@ -379,10 +260,15 @@ class Menu:
spinner_type=spinner_type, spinner_type=spinner_type,
spinner_style=spinner_style, spinner_style=spinner_style,
spinner_kwargs=spinner_kwargs, spinner_kwargs=spinner_kwargs,
before_action=Hooks(hooks=before_hooks),
after_action=Hooks(hooks=after_hooks),
on_error=Hooks(hooks=error_hooks),
) )
for hook in before_hooks or []:
option.hooks.register("before", hook)
for hook in after_hooks or []:
option.hooks.register("after", hook)
for hook in error_hooks or []:
option.hooks.register("on_error", hook)
self.options[key] = option self.options[key] = option
self._refresh_session() self._refresh_session()
return option return option
@ -405,15 +291,20 @@ class Menu:
return self.back_option return self.back_option
return self.options.get(choice) return self.options.get(choice)
def _should_hooks_run(self, selected_option: Option) -> bool:
"""Determines if hooks should be run based on the selected option."""
return selected_option != self.back_option or self.run_hooks_on_back_option
def _should_run_action(self, selected_option: Option) -> bool: def _should_run_action(self, selected_option: Option) -> bool:
if selected_option.confirm and not self._never_confirm: if selected_option.confirm and not self._never_confirm:
return confirm(selected_option.confirm_message) return confirm(selected_option.confirm_message)
return True return True
def _create_context(self, selected_option: Option) -> dict[str, Any]:
"""Creates a context dictionary for the selected option."""
return {
"name": selected_option.description,
"option": selected_option,
"args": (),
"kwargs": {},
}
def _run_action_with_spinner(self, option: Option) -> Any: def _run_action_with_spinner(self, option: Option) -> Any:
"""Runs the action of the selected option with a spinner.""" """Runs the action of the selected option with a spinner."""
with self.console.status( with self.console.status(
@ -422,15 +313,13 @@ class Menu:
spinner_style=option.spinner_style, spinner_style=option.spinner_style,
**option.spinner_kwargs, **option.spinner_kwargs,
): ):
return option.action() return option()
def _handle_action_error(self, selected_option: Option, error: Exception) -> bool: def _handle_action_error(self, selected_option: Option, error: Exception) -> bool:
"""Handles errors that occur during the action of the selected option.""" """Handles errors that occur during the action of the selected option."""
logger.exception(f"Error executing '{selected_option.description}': {error}") logger.exception(f"Error executing '{selected_option.description}': {error}")
self.console.print(f"[{OneColors.DARK_RED}]An error occurred while executing " self.console.print(f"[{OneColors.DARK_RED}]An error occurred while executing "
f"{selected_option.description}:[/] {error}") f"{selected_option.description}:[/] {error}")
selected_option.on_error.run_hooks(selected_option, error)
self.on_error.run_hooks(selected_option, error)
if self.continue_on_error_prompt and not self._never_confirm: if self.continue_on_error_prompt and not self._never_confirm:
return confirm("An error occurred. Do you wish to continue?") return confirm("An error occurred. Do you wish to continue?")
if self._never_confirm: if self._never_confirm:
@ -442,25 +331,38 @@ class Menu:
choice = self.session.prompt() choice = self.session.prompt()
selected_option = self.get_option(choice) selected_option = self.get_option(choice)
self.last_run_option = selected_option self.last_run_option = selected_option
should_hooks_run = self._should_hooks_run(selected_option)
if selected_option == self.back_option:
logger.info(f"🔙 Back selected: exiting {self.get_title()}")
return False
if not self._should_run_action(selected_option): if not self._should_run_action(selected_option):
logger.info(f"[{OneColors.DARK_RED}] {selected_option.description} cancelled.") logger.info(f"[{OneColors.DARK_RED}] {selected_option.description} cancelled.")
return True return True
context = self._create_context(selected_option)
try: try:
if should_hooks_run: run_async(self.hooks.trigger("before", context))
self.before_action.run_hooks(selected_option)
selected_option.before_action.run_hooks(selected_option)
if selected_option.spinner: if selected_option.spinner:
result = self._run_action_with_spinner(selected_option) result = self._run_action_with_spinner(selected_option)
else: else:
result = selected_option.action() result = selected_option()
selected_option.set_result(result) selected_option.set_result(result)
selected_option.after_action.run_hooks(selected_option) context["result"] = result
if should_hooks_run: context["duration"] = selected_option.get_duration()
self.after_action.run_hooks(selected_option) run_async(self.hooks.trigger("after", context))
except Exception as error: except Exception as error:
context["exception"] = error
context["duration"] = selected_option.get_duration()
run_async(self.hooks.trigger("on_error", context))
if "exception" not in context:
logger.info(f"✅ Recovery hook handled error for '{selected_option.description}'")
return True
return self._handle_action_error(selected_option, error) return self._handle_action_error(selected_option, error)
return selected_option != self.back_option return True
def run_headless(self, option_key: str, never_confirm: bool | None = None) -> Any: def run_headless(self, option_key: str, never_confirm: bool | None = None) -> Any:
"""Runs the action of the selected option without displaying the menu.""" """Runs the action of the selected option without displaying the menu."""
@ -470,34 +372,45 @@ class Menu:
selected_option = self.get_option(option_key) selected_option = self.get_option(option_key)
self.last_run_option = selected_option self.last_run_option = selected_option
if not selected_option: if not selected_option:
raise MenuError(f"[Headless] Option '{option_key}' not found.") logger.info("[Headless] Back option selected. Exiting menu.")
return
logger.info(f"[Headless] 🚀 Running: '{selected_option.description}'") logger.info(f"[Headless] 🚀 Running: '{selected_option.description}'")
should_hooks_run = self._should_hooks_run(selected_option)
if not self._should_run_action(selected_option): if not self._should_run_action(selected_option):
logger.info(f"[Headless] ⛔ '{selected_option.description}' cancelled.")
raise MenuError(f"[Headless] '{selected_option.description}' cancelled by confirmation.") raise MenuError(f"[Headless] '{selected_option.description}' cancelled by confirmation.")
context = self._create_context(selected_option)
try: try:
if should_hooks_run: run_async(self.hooks.trigger("before", context))
self.before_action.run_hooks(selected_option)
selected_option.before_action.run_hooks(selected_option)
if selected_option.spinner: if selected_option.spinner:
result = self._run_action_with_spinner(selected_option) result = self._run_action_with_spinner(selected_option)
else: else:
result = selected_option.action() result = selected_option()
selected_option.set_result(result) selected_option.set_result(result)
selected_option.after_action.run_hooks(selected_option) context["result"] = result
if should_hooks_run: context["duration"] = selected_option.get_duration()
self.after_action.run_hooks(selected_option)
run_async(self.hooks.trigger("after", context))
logger.info(f"[Headless] ✅ '{selected_option.description}' complete.") logger.info(f"[Headless] ✅ '{selected_option.description}' complete.")
except (KeyboardInterrupt, EOFError): except (KeyboardInterrupt, EOFError):
raise MenuError(f"[Headless] ⚠️ '{selected_option.description}' interrupted by user.") raise MenuError(f"[Headless] ⚠️ '{selected_option.description}' interrupted by user.")
except Exception as error: except Exception as error:
context["exception"] = error
context["duration"] = selected_option.get_duration()
run_async(self.hooks.trigger("on_error", context))
if "exception" not in context:
logger.info(f"[Headless] ✅ Recovery hook handled error for '{selected_option.description}'")
return True
continue_running = self._handle_action_error(selected_option, error) continue_running = self._handle_action_error(selected_option, error)
if not continue_running: if not continue_running:
raise MenuError(f"[Headless] ❌ '{selected_option.description}' failed.") from error raise MenuError(f"[Headless] ❌ '{selected_option.description}' failed.") from error
return selected_option.get_result() return selected_option.get_result()
def run(self) -> None: def run(self) -> None:
@ -517,83 +430,3 @@ class Menu:
logger.info(f"Exiting menu: {self.get_title()}") logger.info(f"Exiting menu: {self.get_title()}")
if self.exit_message: if self.exit_message:
self.console.print(self.exit_message) self.console.print(self.exit_message)
if __name__ == "__main__":
from rich.traceback import install
from logging_utils import setup_logging
install(show_locals=True)
setup_logging()
def say_hello():
print("Hello!")
def say_goodbye():
print("Goodbye!")
def say_nested():
print("This is a nested menu!")
def my_action():
print("This is my action!")
def long_running_task():
import time
time.sleep(5)
nested_menu = Menu(
Markdown("## Nested Menu", style=OneColors.DARK_YELLOW),
columns=2,
bottom_bar="Menu within a menu",
)
nested_menu.add_option("1", "Say Nested", say_nested, color=OneColors.MAGENTA)
nested_menu.add_before(lambda opt: logger.info(f"Global BEFORE '{opt.description}'"))
nested_menu.add_after(lambda opt: logger.info(f"Global AFTER '{opt.description}'"))
nested_menu.add_option(
"2",
"Test Action",
action=my_action,
before_hooks=[lambda opt: logger.info(f"Option-specific BEFORE '{opt.description}'")],
after_hooks=[lambda opt: logger.info(f"Option-specific AFTER '{opt.description}'")],
)
def bottom_bar():
return (
f"Press Q to quit | Options available: {', '.join([f'[{key}]' for key in menu.options.keys()])}"
)
welcome_message = Markdown("# Welcome to the Menu!")
exit_message = Markdown("# Thank you for using the menu!")
menu = Menu(
Markdown("## Main Menu", style=OneColors.CYAN),
columns=3,
bottom_bar=bottom_bar,
welcome_message=welcome_message,
exit_message=exit_message,
)
menu.add_option("1", "Say Hello", say_hello, color=OneColors.GREEN)
menu.add_option("2", "Say Goodbye", say_goodbye, color=OneColors.LIGHT_RED)
menu.add_option("3", "Do Nothing", lambda: None, color=OneColors.BLUE)
menu.add_submenu("4", "Nested Menu", nested_menu, color=OneColors.MAGENTA)
menu.add_option("5", "Do Nothing", lambda: None, color=OneColors.BLUE)
menu.add_option(
"6",
"Long Running Task",
action=long_running_task,
spinner=True,
spinner_message="Working, please wait...",
spinner_type="moon",
spinner_style=OneColors.GREEN,
spinner_kwargs={"speed": 0.7},
)
menu.update_back_option("Q", "Quit", color=OneColors.DARK_RED)
try:
menu.run()
except EOFError as error:
logger.exception("EOFError: Exiting program.", exc_info=error)
print("Exiting program.")

79
menu/menu_utils.py Normal file
View File

@ -0,0 +1,79 @@
import asyncio
import time
from itertools import islice
def chunks(iterator, size):
"""Yield successive n-sized chunks from an iterator."""
iterator = iter(iterator)
while True:
chunk = list(islice(iterator, size))
if not chunk:
break
yield chunk
def run_async(coro):
"""Run an async function in a synchronous context."""
try:
_ = asyncio.get_running_loop()
return asyncio.create_task(coro)
except RuntimeError:
return asyncio.run(coro)
class TimingMixin:
def _start_timer(self):
self.start_time = time.perf_counter()
def _stop_timer(self):
self.end_time = time.perf_counter()
self._duration = self.end_time - self.start_time
def get_duration(self) -> float | None:
return getattr(self, "_duration", None)
class MenuError(Exception):
"""Custom exception for the Menu class."""
class OptionAlreadyExistsError(MenuError):
"""Exception raised when an option with the same key already exists in the menu."""
class InvalidHookError(MenuError):
"""Exception raised when a hook is not callable."""
class InvalidActionError(MenuError):
"""Exception raised when an action is not callable."""
class NotAMenuError(MenuError):
"""Exception raised when the provided submenu is not an instance of Menu."""
class CaseInsensitiveDict(dict):
"""A case-insensitive dictionary that treats all keys as uppercase."""
def __setitem__(self, key, value):
super().__setitem__(key.upper(), value)
def __getitem__(self, key):
return super().__getitem__(key.upper())
def __contains__(self, key):
return super().__contains__(key.upper())
def get(self, key, default=None):
return super().get(key.upper(), default)
def pop(self, key, default=None):
return super().pop(key.upper(), default)
def update(self, other=None, **kwargs):
if other:
other = {k.upper(): v for k, v in other.items()}
kwargs = {k.upper(): v for k, v in kwargs.items()}
super().update(other, **kwargs)

110
menu/option.py Normal file
View File

@ -0,0 +1,110 @@
"""option.py
Any Action or Option is callable and supports the signature:
result = thing(*args, **kwargs)
This guarantees:
- Hook lifecycle (before/after/error/teardown)
- Timing
- Consistent return values
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Callable
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from action import BaseAction
from colors import OneColors
from hook_manager import HookManager
from menu_utils import TimingMixin, run_async
logger = logging.getLogger("menu")
class Option(BaseModel, TimingMixin):
"""Class representing an option in the menu.
Hooks must have the signature:
def hook(option: Option) -> None:
where `option` is the selected option.
Error hooks must have the signature:
def error_hook(option: Option, error: Exception) -> None:
where `option` is the selected option and `error` is the exception raised.
"""
key: str
description: str
action: BaseAction | Callable[[], Any] = lambda: None
color: str = OneColors.WHITE
confirm: bool = False
confirm_message: str = "Are you sure?"
spinner: bool = False
spinner_message: str = "Processing..."
spinner_type: str = "dots"
spinner_style: str = OneColors.CYAN
spinner_kwargs: dict[str, Any] = Field(default_factory=dict)
hooks: "HookManager" = Field(default_factory=HookManager)
start_time: float | None = None
end_time: float | None = None
_duration: float | None = PrivateAttr(default=None)
_result: Any | None = PrivateAttr(default=None)
model_config = ConfigDict(arbitrary_types_allowed=True)
def __str__(self):
return f"Option(key='{self.key}', description='{self.description}')"
def set_result(self, result: Any) -> None:
"""Set the result of the action."""
self._result = result
def get_result(self) -> Any:
"""Get the result of the action."""
return self._result
def __call__(self, *args, **kwargs) -> Any:
context = {
"name": self.description,
"duration": None,
"args": args,
"kwargs": kwargs,
"option": self,
}
self._start_timer()
try:
run_async(self.hooks.trigger("before", context))
result = self._execute_action(*args, **kwargs)
self.set_result(result)
context["result"] = result
return result
except Exception as error:
context["exception"] = error
run_async(self.hooks.trigger("on_error", context))
if "exception" not in context:
logger.info(f"✅ Recovery hook handled error for Option '{self.key}'")
return self.get_result()
raise
finally:
self._stop_timer()
context["duration"] = self.get_duration()
if "exception" not in context:
run_async(self.hooks.trigger("after", context))
run_async(self.hooks.trigger("on_teardown", context))
def _execute_action(self, *args, **kwargs) -> Any:
if isinstance(self.action, BaseAction):
return self.action(*args, **kwargs)
return self.action()
def dry_run(self):
print(f"[DRY RUN] Option '{self.key}' would run: {self.description}")
if isinstance(self.action, BaseAction):
self.action.dry_run()
elif callable(self.action):
print(f"[DRY RUN] Action is a raw callable: {self.action.__name__}")
else:
print("[DRY RUN] Action is not callable.")

View File

@ -3,7 +3,7 @@ import time
def risky_task() -> str: def risky_task() -> str:
if random.random() > 0.25: if random.random() > 0.4:
time.sleep(1) time.sleep(1)
raise ValueError("Random failure occurred") raise ValueError("Random failure occurred")
return "Task succeeded!" return "Task succeeded!"