falyx/tests/test_actions.py

231 lines
6.7 KiB
Python

import pytest
import asyncio
import pickle
import warnings
from falyx.action import Action, ChainedAction, ActionGroup, ProcessAction
from falyx.execution_registry import ExecutionRegistry as er
from falyx.hook_manager import HookManager, HookType
from falyx.context import ExecutionContext, ResultsContext
asyncio_default_fixture_loop_scope = "function"
# --- Helpers ---
async def dummy_action(x: int = 0) -> int:
return x + 1
async def capturing_hook(context: ExecutionContext):
context.extra["hook_triggered"] = True
# --- Fixtures ---
@pytest.fixture
def sample_action():
return Action(name="increment", action=dummy_action, kwargs={"x": 5})
@pytest.fixture
def hook_manager():
hm = HookManager()
hm.register(HookType.BEFORE, capturing_hook)
return hm
@pytest.fixture(autouse=True)
def clean_registry():
er.clear()
yield
er.clear()
# --- Tests ---
@pytest.mark.asyncio
async def test_action_runs_correctly(sample_action):
result = await sample_action()
assert result == 6
@pytest.mark.asyncio
async def test_action_hook_lifecycle(hook_manager):
action = Action(
name="hooked",
action=lambda: 42,
hooks=hook_manager
)
await action()
context = er.get_latest()
assert context.name == "hooked"
assert context.extra.get("hook_triggered") is True
@pytest.mark.asyncio
async def test_chained_action_with_result_injection():
actions = [
Action(name="start", action=lambda: 1),
Action(name="add_last", action=lambda last_result: last_result + 5, inject_last_result=True),
Action(name="multiply", action=lambda last_result: last_result * 2, inject_last_result=True)
]
chain = ChainedAction(name="test_chain", actions=actions, inject_last_result=True)
result = await chain()
assert result == [1, 6, 12]
@pytest.mark.asyncio
async def test_action_group_runs_in_parallel():
actions = [
Action(name="a", action=lambda: 1),
Action(name="b", action=lambda: 2),
Action(name="c", action=lambda: 3),
]
group = ActionGroup(name="parallel", actions=actions)
result = await group()
result_dict = dict(result)
assert result_dict == {"a": 1, "b": 2, "c": 3}
@pytest.mark.asyncio
async def test_chained_action_inject_from_action():
inner_chain = ChainedAction(
name="inner_chain",
actions=[
Action(name="inner_first", action=lambda last_result: last_result + 10, inject_last_result=True),
Action(name="inner_second", action=lambda last_result: last_result + 5, inject_last_result=True),
]
)
actions = [
Action(name="first", action=lambda: 1),
Action(name="second", action=lambda last_result: last_result + 2, inject_last_result=True),
inner_chain,
]
outer_chain = ChainedAction(name="test_chain", actions=actions)
result = await outer_chain()
assert result == [1, 3, [13, 18]]
@pytest.mark.asyncio
async def test_chained_action_with_group():
group = ActionGroup(
name="group",
actions=[
Action(name="a", action=lambda last_result: last_result + 1, inject_last_result=True),
Action(name="b", action=lambda last_result: last_result + 2, inject_last_result=True),
Action(name="c", action=lambda: 3),
]
)
actions = [
Action(name="first", action=lambda: 1),
Action(name="second", action=lambda last_result: last_result + 2, inject_last_result=True),
group,
]
chain = ChainedAction(name="test_chain", actions=actions)
result = await chain()
assert result == [1, 3, [("a", 4), ("b", 5), ("c", 3)]]
@pytest.mark.asyncio
async def test_action_error_triggers_error_hook():
def fail():
raise ValueError("boom")
hooks = HookManager()
flag = {}
async def error_hook(ctx):
flag["called"] = True
hooks.register(HookType.ON_ERROR, error_hook)
action = Action(name="fail_action", action=fail, hooks=hooks)
with pytest.raises(ValueError):
await action()
assert flag.get("called") is True
@pytest.mark.asyncio
async def test_chained_action_rollback_on_failure():
rollback_called = []
async def success():
return "ok"
async def fail():
raise RuntimeError("fail")
async def rollback_fn():
rollback_called.append("rolled back")
actions = [
Action(name="ok", action=success, rollback=rollback_fn),
Action(name="fail", action=fail, rollback=rollback_fn)
]
chain = ChainedAction(name="chain", actions=actions)
with pytest.raises(RuntimeError):
await chain()
assert rollback_called == ["rolled back"]
def slow_add(x, y):
return x + y
@pytest.mark.asyncio
async def test_process_action_executes_correctly():
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
action = ProcessAction(name="proc", func=slow_add, args=(2, 3))
result = await action()
assert result == 5
unpickleable = lambda x: x + 1
@pytest.mark.asyncio
async def test_process_action_rejects_unpickleable():
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
action = ProcessAction(name="proc_fail", func=unpickleable, args=(2,))
with pytest.raises(pickle.PicklingError, match="Can't pickle"):
await action()
@pytest.mark.asyncio
async def test_register_hooks_recursively_propagates():
hook = lambda ctx: ctx.extra.update({"test_marker": True})
chain = ChainedAction(name="chain", actions=[
Action(name="a", action=lambda: 1),
Action(name="b", action=lambda: 2),
])
chain.register_hooks_recursively(HookType.BEFORE, hook)
await chain()
for ctx in er.get_by_name("a") + er.get_by_name("b"):
assert ctx.extra.get("test_marker") is True
@pytest.mark.asyncio
async def test_action_hook_recovers_error():
async def flaky():
raise ValueError("fail")
async def recovery_hook(ctx):
ctx.result = 99
ctx.exception = None
hooks = HookManager()
hooks.register(HookType.ON_ERROR, recovery_hook)
action = Action(name="recovering", action=flaky, hooks=hooks)
result = await action()
assert result == 99
@pytest.mark.asyncio
async def test_action_group_injects_last_result():
group = ActionGroup(name="group", actions=[
Action(name="g1", action=lambda last_result: last_result + 10, inject_last_result=True),
Action(name="g2", action=lambda last_result: last_result + 20, inject_last_result=True),
])
chain = ChainedAction(name="with_group", actions=[
Action(name="first", action=lambda: 5),
group,
])
result = await chain()
result_dict = dict(result[1])
assert result_dict == {"g1": 15, "g2": 25}