| """ |
| Tests for the Workflow Engine core components. |
| """ |
|
|
| import pytest |
| import asyncio |
| from typing import Dict, Any |
|
|
| from app.engine.state import WorkflowState, StateManager |
| from app.engine.node import Node, NodeType, node, create_node_from_function |
| from app.engine.graph import Graph, END |
| from app.engine.executor import Executor, ExecutionStatus, execute_graph |
|
|
|
|
| |
| |
| |
|
|
| class TestWorkflowState: |
| """Tests for WorkflowState.""" |
| |
| def test_create_empty_state(self): |
| """Test creating an empty state.""" |
| state = WorkflowState() |
| assert state.data == {} |
| assert state.iteration == 0 |
| assert state.visited_nodes == [] |
| |
| def test_create_state_with_data(self): |
| """Test creating state with initial data.""" |
| state = WorkflowState(data={"key": "value"}) |
| assert state.get("key") == "value" |
| assert state.get("missing") is None |
| assert state.get("missing", "default") == "default" |
| |
| def test_state_immutability(self): |
| """Test that state updates return new instances.""" |
| state1 = WorkflowState(data={"a": 1}) |
| state2 = state1.set("b", 2) |
| |
| assert state1.get("b") is None |
| assert state2.get("b") == 2 |
| assert state1 is not state2 |
| |
| def test_state_update_multiple(self): |
| """Test updating multiple values at once.""" |
| state = WorkflowState(data={"a": 1}) |
| new_state = state.update({"b": 2, "c": 3}) |
| |
| assert new_state.get("a") == 1 |
| assert new_state.get("b") == 2 |
| assert new_state.get("c") == 3 |
| |
| def test_state_mark_visited(self): |
| """Test marking nodes as visited.""" |
| state = WorkflowState() |
| state = state.mark_visited("node1") |
| state = state.mark_visited("node2") |
| |
| assert "node1" in state.visited_nodes |
| assert "node2" in state.visited_nodes |
| assert state.current_node == "node2" |
| |
| def test_state_to_from_dict(self): |
| """Test serialization and deserialization.""" |
| state = WorkflowState(data={"test": 123}) |
| state_dict = state.to_dict() |
| |
| assert "data" in state_dict |
| assert state_dict["data"]["test"] == 123 |
| |
| restored = WorkflowState.from_dict(state_dict) |
| assert restored.get("test") == 123 |
|
|
|
|
| class TestStateManager: |
| """Tests for StateManager.""" |
| |
| def test_initialize(self): |
| """Test state manager initialization.""" |
| manager = StateManager() |
| state = manager.initialize({"input": "test"}) |
| |
| assert manager.current_state is not None |
| assert manager.current_state.get("input") == "test" |
| assert manager.current_state.started_at is not None |
| |
| def test_update_and_history(self): |
| """Test state updates create history.""" |
| manager = StateManager() |
| state = manager.initialize({"count": 0}) |
| |
| new_state = state.set("count", 1) |
| manager.update(new_state, "node1") |
| |
| assert len(manager.history) == 1 |
| assert manager.history[0].node_name == "node1" |
| assert manager.current_state.get("count") == 1 |
|
|
|
|
| |
| |
| |
|
|
| class TestNode: |
| """Tests for Node class.""" |
| |
| def test_create_node(self): |
| """Test creating a node.""" |
| def handler(state): |
| return state |
| |
| n = Node(name="test_node", handler=handler) |
| |
| assert n.name == "test_node" |
| assert n.handler == handler |
| assert n.node_type == NodeType.STANDARD |
| |
| def test_node_validation(self): |
| """Test node validation.""" |
| with pytest.raises(ValueError, match="name cannot be empty"): |
| Node(name="", handler=lambda x: x) |
| |
| with pytest.raises(ValueError, match="must be callable"): |
| Node(name="test", handler="not a function") |
| |
| @pytest.mark.asyncio |
| async def test_sync_node_execution(self): |
| """Test executing a sync node.""" |
| def handler(state): |
| state["processed"] = True |
| return state |
| |
| n = Node(name="test", handler=handler) |
| result = await n.execute({"input": "data"}) |
| |
| assert result["processed"] is True |
| assert result["input"] == "data" |
| |
| @pytest.mark.asyncio |
| async def test_async_node_execution(self): |
| """Test executing an async node.""" |
| async def async_handler(state): |
| await asyncio.sleep(0.01) |
| state["async_processed"] = True |
| return state |
| |
| n = Node(name="async_test", handler=async_handler) |
| assert n.is_async is True |
| |
| result = await n.execute({"input": "data"}) |
| assert result["async_processed"] is True |
| |
| def test_node_decorator(self): |
| """Test the @node decorator.""" |
| @node(name="decorated_node", description="A test node") |
| def my_handler(state): |
| return state |
| |
| assert hasattr(my_handler, "_node_metadata") |
| assert my_handler._node_metadata["name"] == "decorated_node" |
|
|
|
|
| |
| |
| |
|
|
| class TestGraph: |
| """Tests for Graph class.""" |
| |
| def test_create_graph(self): |
| """Test creating a graph.""" |
| graph = Graph(name="Test Graph") |
| assert graph.name == "Test Graph" |
| assert len(graph.nodes) == 0 |
| |
| def test_add_nodes(self): |
| """Test adding nodes to a graph.""" |
| graph = Graph() |
| graph.add_node("node1", handler=lambda s: s) |
| graph.add_node("node2", handler=lambda s: s) |
| |
| assert "node1" in graph.nodes |
| assert "node2" in graph.nodes |
| assert graph.entry_point == "node1" |
| |
| def test_add_edges(self): |
| """Test adding edges.""" |
| graph = Graph() |
| graph.add_node("a", handler=lambda s: s) |
| graph.add_node("b", handler=lambda s: s) |
| graph.add_edge("a", "b") |
| |
| assert graph.edges["a"] == "b" |
| |
| def test_add_edge_to_end(self): |
| """Test adding edge to END.""" |
| graph = Graph() |
| graph.add_node("a", handler=lambda s: s) |
| graph.add_edge("a", END) |
| |
| assert graph.edges["a"] == END |
| |
| def test_invalid_edge(self): |
| """Test adding invalid edges raises error.""" |
| graph = Graph() |
| graph.add_node("a", handler=lambda s: s) |
| |
| with pytest.raises(ValueError, match="not found"): |
| graph.add_edge("a", "nonexistent") |
| |
| def test_conditional_edge(self): |
| """Test conditional edges.""" |
| graph = Graph() |
| graph.add_node("check", handler=lambda s: s) |
| graph.add_node("yes", handler=lambda s: s) |
| graph.add_node("no", handler=lambda s: s) |
| |
| def condition(state): |
| return "yes" if state.get("value") else "no" |
| |
| graph.add_conditional_edge("check", condition, {"yes": "yes", "no": "no"}) |
| |
| |
| assert graph.get_next_node("check", {"value": True}) == "yes" |
| assert graph.get_next_node("check", {"value": False}) == "no" |
| |
| def test_graph_validation(self): |
| """Test graph validation.""" |
| graph = Graph() |
| |
| |
| errors = graph.validate() |
| assert len(errors) > 0 |
| |
| |
| graph.add_node("start", handler=lambda s: s) |
| graph.add_edge("start", END) |
| |
| errors = graph.validate() |
| assert len(errors) == 0 |
| |
| def test_mermaid_generation(self): |
| """Test Mermaid diagram generation.""" |
| graph = Graph() |
| graph.add_node("a", handler=lambda s: s) |
| graph.add_node("b", handler=lambda s: s) |
| graph.add_edge("a", "b") |
| graph.add_edge("b", END) |
| |
| mermaid = graph.to_mermaid() |
| |
| assert "graph TD" in mermaid |
| assert "a" in mermaid |
| assert "b" in mermaid |
|
|
|
|
| |
| |
| |
|
|
| class TestExecutor: |
| """Tests for the Executor.""" |
| |
| @pytest.mark.asyncio |
| async def test_simple_execution(self): |
| """Test executing a simple graph.""" |
| graph = Graph() |
| graph.add_node("double", handler=lambda s: {**s, "value": s["value"] * 2}) |
| graph.add_edge("double", END) |
| |
| result = await execute_graph(graph, {"value": 5}) |
| |
| assert result.status == ExecutionStatus.COMPLETED |
| assert result.final_state["value"] == 10 |
| |
| @pytest.mark.asyncio |
| async def test_multi_node_execution(self): |
| """Test executing multiple nodes.""" |
| graph = Graph() |
| graph.add_node("add1", handler=lambda s: {**s, "value": s["value"] + 1}) |
| graph.add_node("add2", handler=lambda s: {**s, "value": s["value"] + 2}) |
| graph.add_edge("add1", "add2") |
| graph.add_edge("add2", END) |
| |
| result = await execute_graph(graph, {"value": 0}) |
| |
| assert result.status == ExecutionStatus.COMPLETED |
| assert result.final_state["value"] == 3 |
| assert len(result.execution_log) == 2 |
| |
| @pytest.mark.asyncio |
| async def test_conditional_execution(self): |
| """Test conditional branching.""" |
| graph = Graph() |
| graph.add_node("start", handler=lambda s: s) |
| graph.add_node("high", handler=lambda s: {**s, "path": "high"}) |
| graph.add_node("low", handler=lambda s: {**s, "path": "low"}) |
| |
| def route(state): |
| return "high" if state["value"] > 5 else "low" |
| |
| graph.add_conditional_edge("start", route, {"high": "high", "low": "low"}) |
| graph.add_edge("high", END) |
| graph.add_edge("low", END) |
| |
| |
| result = await execute_graph(graph, {"value": 10}) |
| assert result.final_state["path"] == "high" |
| |
| |
| result = await execute_graph(graph, {"value": 2}) |
| assert result.final_state["path"] == "low" |
| |
| @pytest.mark.asyncio |
| async def test_loop_execution(self): |
| """Test looping execution.""" |
| graph = Graph(max_iterations=10) |
| |
| def increment(state): |
| return {**state, "count": state["count"] + 1} |
| |
| def check_count(state): |
| return "done" if state["count"] >= 3 else "continue" |
| |
| graph.add_node("increment", handler=increment) |
| graph.add_conditional_edge("increment", check_count, {"done": END, "continue": "increment"}) |
| |
| result = await execute_graph(graph, {"count": 0}) |
| |
| assert result.status == ExecutionStatus.COMPLETED |
| assert result.final_state["count"] == 3 |
| |
| @pytest.mark.asyncio |
| async def test_max_iterations(self): |
| """Test max iterations limit.""" |
| graph = Graph(max_iterations=3) |
| |
| |
| graph.add_node("loop", handler=lambda s: s) |
| graph.add_conditional_edge("loop", lambda s: "continue", {"continue": "loop"}) |
| |
| result = await execute_graph(graph, {}) |
| |
| assert result.status == ExecutionStatus.FAILED |
| assert "Max iterations" in result.error |
| |
| @pytest.mark.asyncio |
| async def test_error_handling(self): |
| """Test error handling during execution.""" |
| def failing_handler(state): |
| raise ValueError("Intentional error") |
| |
| graph = Graph() |
| graph.add_node("fail", handler=failing_handler) |
| |
| result = await execute_graph(graph, {}) |
| |
| assert result.status == ExecutionStatus.FAILED |
| assert "Intentional error" in result.error |
| |
| @pytest.mark.asyncio |
| async def test_execution_log(self): |
| """Test that execution log is properly generated.""" |
| graph = Graph() |
| graph.add_node("step1", handler=lambda s: s) |
| graph.add_node("step2", handler=lambda s: s) |
| graph.add_edge("step1", "step2") |
| graph.add_edge("step2", END) |
| |
| result = await execute_graph(graph, {}) |
| |
| assert len(result.execution_log) == 2 |
| assert result.execution_log[0].node == "step1" |
| assert result.execution_log[1].node == "step2" |
| assert all(s.duration_ms > 0 for s in result.execution_log) |
|
|
|
|
| |
| |
| |
|
|
| class TestCodeReviewWorkflow: |
| """Integration tests for the Code Review workflow.""" |
| |
| @pytest.mark.asyncio |
| async def test_code_review_workflow(self): |
| """Test the full code review workflow.""" |
| from app.workflows.code_review import create_code_review_workflow |
| |
| sample_code = ''' |
| def hello(): |
| """Says hello.""" |
| print("Hello, World!") |
| |
| def add(a, b): |
| return a + b |
| ''' |
| |
| workflow = create_code_review_workflow(max_iterations=3, quality_threshold=5.0) |
| result = await execute_graph(workflow, { |
| "code": sample_code, |
| "quality_threshold": 5.0, |
| }) |
| |
| assert result.status == ExecutionStatus.COMPLETED |
| assert "functions" in result.final_state |
| assert "quality_score" in result.final_state |
| assert len(result.execution_log) > 0 |
|
|
|
|
| if __name__ == "__main__": |
| pytest.main([__file__, "-v"]) |
|
|