466 lines
14 KiB
Python
466 lines
14 KiB
Python
import pytest
|
|
|
|
from falyx.action import Action, ChainedAction, ActionGroup, FallbackAction
|
|
from falyx.execution_registry import ExecutionRegistry as er
|
|
from falyx.hook_manager import HookManager, HookType
|
|
from falyx.context import ExecutionContext
|
|
|
|
asyncio_default_fixture_loop_scope = "function"
|
|
|
|
# --- Helpers ---
|
|
async def capturing_hook(context: ExecutionContext):
|
|
context.extra["hook_triggered"] = True
|
|
|
|
# --- Fixtures ---
|
|
@pytest.fixture
|
|
def hook_manager():
|
|
hm = HookManager()
|
|
hm.register(HookType.BEFORE, capturing_hook)
|
|
return hm
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def clean_registry():
|
|
er.clear()
|
|
yield
|
|
er.clear()
|
|
|
|
# --- Tests ---
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_action_runs_correctly():
|
|
async def dummy_action(x: int = 0) -> int: return x + 1
|
|
sample_action = Action(name="increment", action=dummy_action, kwargs={"x": 5})
|
|
result = await sample_action()
|
|
assert result == 6
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_action_hook_lifecycle(hook_manager):
|
|
async def a1(): return 42
|
|
action = Action(
|
|
name="hooked",
|
|
action=a1,
|
|
hooks=hook_manager
|
|
)
|
|
|
|
await action()
|
|
|
|
context = er.get_latest()
|
|
assert context.name == "hooked"
|
|
assert context.extra.get("hook_triggered") is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chained_action_with_result_injection():
|
|
async def a1(): return 1
|
|
async def a2(last_result): return last_result + 5
|
|
async def a3(last_result): return last_result * 2
|
|
actions = [
|
|
Action(name="start", action=a1),
|
|
Action(name="add_last", action=a2, inject_last_result=True),
|
|
Action(name="multiply", action=a3, inject_last_result=True)
|
|
]
|
|
chain = ChainedAction(name="test_chain", actions=actions, inject_last_result=True, return_list=True)
|
|
result = await chain()
|
|
assert result == [1, 6, 12]
|
|
chain = ChainedAction(name="test_chain", actions=actions, inject_last_result=True)
|
|
result = await chain()
|
|
assert result == 12
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_action_group_runs_in_parallel():
|
|
async def a1(): return 1
|
|
async def a2(): return 2
|
|
async def a3(): return 3
|
|
actions = [
|
|
Action(name="a", action=a1),
|
|
Action(name="b", action=a2),
|
|
Action(name="c", action=a3),
|
|
]
|
|
group = ActionGroup(name="parallel", actions=actions)
|
|
result = await group()
|
|
result_dict = dict(result)
|
|
assert result_dict == {"a": 1, "b": 2, "c": 3}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chained_action_inject_from_action():
|
|
async def a1(last_result): return last_result + 10
|
|
async def a2(last_result): return last_result + 5
|
|
inner_chain = ChainedAction(
|
|
name="inner_chain",
|
|
actions=[
|
|
Action(name="inner_first", action=a1, inject_last_result=True),
|
|
Action(name="inner_second", action=a2, inject_last_result=True),
|
|
],
|
|
return_list=True,
|
|
)
|
|
async def a3(): return 1
|
|
async def a4(last_result): return last_result + 2
|
|
actions = [
|
|
Action(name="first", action=a3),
|
|
Action(name="second", action=a4, inject_last_result=True),
|
|
inner_chain,
|
|
]
|
|
outer_chain = ChainedAction(name="test_chain", actions=actions, return_list=True)
|
|
result = await outer_chain()
|
|
assert result == [1, 3, [13, 18]]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chained_action_with_group():
|
|
async def a1(last_result): return last_result + 1
|
|
async def a2(last_result): return last_result + 2
|
|
async def a3(): return 3
|
|
group = ActionGroup(
|
|
name="group",
|
|
actions=[
|
|
Action(name="a", action=a1, inject_last_result=True),
|
|
Action(name="b", action=a2, inject_last_result=True),
|
|
Action(name="c", action=a3),
|
|
]
|
|
)
|
|
async def a4(): return 1
|
|
async def a5(last_result): return last_result + 2
|
|
actions = [
|
|
Action(name="first", action=a4),
|
|
Action(name="second", action=a5, inject_last_result=True),
|
|
group,
|
|
]
|
|
chain = ChainedAction(name="test_chain", actions=actions, return_list=True)
|
|
result = await chain()
|
|
assert result == [1, 3, [("a", 4), ("b", 5), ("c", 3)]]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_action_error_triggers_error_hook():
|
|
def fail():
|
|
raise ValueError("boom")
|
|
|
|
hooks = HookManager()
|
|
flag = {}
|
|
|
|
async def error_hook(ctx):
|
|
flag["called"] = True
|
|
|
|
hooks.register(HookType.ON_ERROR, error_hook)
|
|
action = Action(name="fail_action", action=fail, hooks=hooks)
|
|
|
|
with pytest.raises(ValueError):
|
|
await action()
|
|
|
|
assert flag.get("called") is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chained_action_rollback_on_failure():
|
|
rollback_called = []
|
|
|
|
async def success():
|
|
return "ok"
|
|
|
|
async def fail():
|
|
raise RuntimeError("fail")
|
|
|
|
async def rollback_fn():
|
|
rollback_called.append("rolled back")
|
|
|
|
actions = [
|
|
Action(name="ok", action=success, rollback=rollback_fn),
|
|
Action(name="fail", action=fail, rollback=rollback_fn)
|
|
]
|
|
|
|
chain = ChainedAction(name="chain", actions=actions)
|
|
|
|
with pytest.raises(RuntimeError):
|
|
await chain()
|
|
|
|
assert rollback_called == ["rolled back"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_hooks_recursively_propagates():
|
|
def hook(context):
|
|
context.extra.update({"test_marker": True})
|
|
|
|
async def a1(): return 1
|
|
async def a2(): return 2
|
|
|
|
chain = ChainedAction(
|
|
name="chain",
|
|
actions=[
|
|
Action(name="a", action=a1),
|
|
Action(name="b", action=a2),
|
|
],
|
|
)
|
|
chain.register_hooks_recursively(HookType.BEFORE, hook)
|
|
|
|
await chain()
|
|
|
|
for ctx in er.get_by_name("a") + er.get_by_name("b"):
|
|
assert ctx.extra.get("test_marker") is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_action_hook_recovers_error():
|
|
async def flaky():
|
|
raise ValueError("fail")
|
|
|
|
async def recovery_hook(ctx):
|
|
ctx.result = 99
|
|
ctx.exception = None
|
|
|
|
hooks = HookManager()
|
|
hooks.register(HookType.ON_ERROR, recovery_hook)
|
|
action = Action(name="recovering", action=flaky, hooks=hooks)
|
|
|
|
result = await action()
|
|
assert result == 99
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_action_group_injects_last_result():
|
|
async def a1(last_result): return last_result + 10
|
|
async def a2(last_result): return last_result + 20
|
|
group = ActionGroup(name="group", actions=[
|
|
Action(name="g1", action=a1, inject_last_result=True),
|
|
Action(name="g2", action=a2, inject_last_result=True),
|
|
])
|
|
async def a3(): return 5
|
|
chain = ChainedAction(
|
|
name="with_group",
|
|
actions=[
|
|
Action(name="first", action=a3),
|
|
group,
|
|
],
|
|
return_list=True,
|
|
)
|
|
result = await chain()
|
|
result_dict = dict(result[1])
|
|
assert result_dict == {"g1": 15, "g2": 25}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_action_inject_last_result():
|
|
async def a1(): return 1
|
|
async def a2(last_result): return last_result + 1
|
|
a1 = Action(name="a1", action=a1)
|
|
a2 = Action(name="a2", action=a2, inject_last_result=True)
|
|
chain = ChainedAction(name="chain", actions=[a1, a2])
|
|
result = await chain()
|
|
assert result == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_action_inject_last_result_fail():
|
|
async def a1(): return 1
|
|
async def a2(last_result): return last_result + 1
|
|
a1 = Action(name="a1", action=a1)
|
|
a2 = Action(name="a2", action=a2)
|
|
chain = ChainedAction(name="chain", actions=[a1, a2])
|
|
|
|
with pytest.raises(TypeError) as exc_info:
|
|
await chain()
|
|
|
|
assert "last_result" in str(exc_info.value)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chained_action_auto_inject():
|
|
async def a1(): return 1
|
|
async def a2(last_result): return last_result + 2
|
|
a1 = Action(name="a1", action=a1)
|
|
a2 = Action(name="a2", action=a2)
|
|
chain = ChainedAction(name="chain", actions=[a1, a2], auto_inject=True, return_list=True)
|
|
result = await chain()
|
|
assert result == [1, 3] # a2 receives last_result=1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chained_action_no_auto_inject():
|
|
async def a1(): return 1
|
|
async def a2(): return 2
|
|
a1 = Action(name="a1", action=a1)
|
|
a2 = Action(name="a2", action=a2)
|
|
chain = ChainedAction(name="no_inject", actions=[a1, a2], auto_inject=False, return_list=True)
|
|
result = await chain()
|
|
assert result == [1, 2] # a2 does not receive 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chained_action_auto_inject_after_first():
|
|
async def a1(): return 1
|
|
async def a2(last_result): return last_result + 1
|
|
a1 = Action(name="a1", action=a1)
|
|
a2 = Action(name="a2", action=a2)
|
|
chain = ChainedAction(name="auto_inject", actions=[a1, a2], auto_inject=True)
|
|
result = await chain()
|
|
assert result == 2 # a2 receives last_result=1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chained_action_with_literal_input():
|
|
async def a1(last_result): return last_result + " world"
|
|
a1 = Action(name="a1", action=a1)
|
|
chain = ChainedAction(name="literal_inject", actions=["hello", a1], auto_inject=True)
|
|
result = await chain()
|
|
assert result == "hello world" # "hello" is injected as last_result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chained_action_manual_inject_override():
|
|
async def a1(): return 10
|
|
async def a2(last_result): return last_result * 2
|
|
a1 = Action(name="a1", action=a1)
|
|
a2 = Action(name="a2", action=a2, inject_last_result=True)
|
|
chain = ChainedAction(name="manual_override", actions=[a1, a2], auto_inject=False)
|
|
result = await chain()
|
|
assert result == 20 # Even without auto_inject, a2 still gets last_result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chained_action_with_mid_literal():
|
|
async def fetch_data():
|
|
# Imagine this is some dynamic API call
|
|
return None # Simulate failure or missing data
|
|
|
|
async def validate_data(last_result):
|
|
if last_result is None:
|
|
raise ValueError("Missing data!")
|
|
return last_result
|
|
|
|
async def enrich_data(last_result):
|
|
return f"Enriched: {last_result}"
|
|
|
|
chain = ChainedAction(
|
|
name="fallback_pipeline",
|
|
actions=[
|
|
Action(name="FetchData", action=fetch_data),
|
|
"default_value", # <-- literal fallback injected mid-chain
|
|
Action(name="ValidateData", action=validate_data),
|
|
Action(name="EnrichData", action=enrich_data),
|
|
],
|
|
auto_inject=True,
|
|
return_list=True,
|
|
)
|
|
|
|
result = await chain()
|
|
assert result == [None, "default_value", "default_value", "Enriched: default_value"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chained_action_with_mid_fallback():
|
|
async def fetch_data():
|
|
# Imagine this is some dynamic API call
|
|
return None # Simulate failure or missing data
|
|
|
|
async def validate_data(last_result):
|
|
if last_result is None:
|
|
raise ValueError("Missing data!")
|
|
return last_result
|
|
|
|
async def enrich_data(last_result):
|
|
return f"Enriched: {last_result}"
|
|
|
|
chain = ChainedAction(
|
|
name="fallback_pipeline",
|
|
actions=[
|
|
Action(name="FetchData", action=fetch_data),
|
|
FallbackAction(fallback="default_value"),
|
|
Action(name="ValidateData", action=validate_data),
|
|
Action(name="EnrichData", action=enrich_data),
|
|
],
|
|
auto_inject=True,
|
|
return_list=True,
|
|
)
|
|
|
|
result = await chain()
|
|
assert result == [None, "default_value", "default_value", "Enriched: default_value"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chained_action_with_success_mid_fallback():
|
|
async def fetch_data():
|
|
# Imagine this is some dynamic API call
|
|
return "Result" # Simulate success
|
|
|
|
async def validate_data(last_result):
|
|
if last_result is None:
|
|
raise ValueError("Missing data!")
|
|
return last_result
|
|
|
|
async def enrich_data(last_result):
|
|
return f"Enriched: {last_result}"
|
|
|
|
chain = ChainedAction(
|
|
name="fallback_pipeline",
|
|
actions=[
|
|
Action(name="FetchData", action=fetch_data),
|
|
FallbackAction(fallback="default_value"),
|
|
Action(name="ValidateData", action=validate_data),
|
|
Action(name="EnrichData", action=enrich_data),
|
|
],
|
|
auto_inject=True,
|
|
return_list=True,
|
|
)
|
|
|
|
result = await chain()
|
|
assert result == ["Result", "Result", "Result", "Enriched: Result"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_action_group_partial_failure():
|
|
async def succeed(): return "ok"
|
|
async def fail(): raise ValueError("oops")
|
|
|
|
group = ActionGroup(name="partial_group", actions=[
|
|
Action(name="succeed_action", action=succeed),
|
|
Action(name="fail_action", action=fail),
|
|
])
|
|
|
|
with pytest.raises(Exception) as exc_info:
|
|
await group()
|
|
|
|
assert er.get_by_name("succeed_action")[0].result == "ok"
|
|
assert er.get_by_name("fail_action")[0].exception is not None
|
|
assert "fail_action" in str(exc_info.value)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chained_action_with_nested_group():
|
|
async def g1(last_result): return last_result + "10"
|
|
async def g2(last_result): return last_result + "20"
|
|
group = ActionGroup(
|
|
name="nested_group",
|
|
actions=[
|
|
Action(name="g1", action=g1, inject_last_result=True),
|
|
Action(name="g2", action=g2, inject_last_result=True),
|
|
],
|
|
)
|
|
|
|
chain = ChainedAction(
|
|
name="chain_with_group",
|
|
actions=[
|
|
"start",
|
|
group,
|
|
],
|
|
auto_inject=True,
|
|
return_list=True,
|
|
)
|
|
|
|
result = await chain()
|
|
# "start" -> group both receive "start" as last_result
|
|
assert result[0] == "start"
|
|
assert dict(result[1]) == {"g1": "start10", "g2": "start20"} # Assuming string concatenation for example
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chained_action_double_fallback():
|
|
async def fetch_data(last_result=None):
|
|
raise ValueError("No data!") # Simulate failure
|
|
|
|
async def validate_data(last_result):
|
|
if last_result is None:
|
|
raise ValueError("No data!")
|
|
return last_result
|
|
|
|
async def enrich(last_result):
|
|
return f"Enriched: {last_result}"
|
|
|
|
chain = ChainedAction(
|
|
name="fallback_chain",
|
|
actions=[
|
|
Action(name="Fetch", action=fetch_data),
|
|
FallbackAction(fallback="default1"),
|
|
Action(name="Validate", action=validate_data),
|
|
Action(name="Fetch", action=fetch_data),
|
|
FallbackAction(fallback="default2"),
|
|
Action(name="Enrich", action=enrich),
|
|
],
|
|
auto_inject=True,
|
|
return_list=True,
|
|
)
|
|
|
|
result = await chain()
|
|
assert result == [None, "default1", "default1", None, "default2", "Enriched: default2"]
|
|
|