tenseleyflow/loader / 825fa06

Browse files

Add direct tests for context-owned state controllers

Authored by espadonne
SHA
825fa06edd3959ddb046e65fcd76956cd85845f7
Parents
6378cd3
Tree
c946622

1 changed file

StatusFile+-
A tests/test_runtime_state_controllers.py 167 0
tests/test_runtime_state_controllers.pyadded
@@ -0,0 +1,167 @@
1
+"""Direct tests for context-owned workflow and phase controllers."""
2
+
3
+from __future__ import annotations
4
+
5
+from pathlib import Path
6
+from types import SimpleNamespace
7
+
8
+import pytest
9
+
10
+from loader.llm.base import Message
11
+from loader.runtime.context import RuntimeContext
12
+from loader.runtime.dod import DefinitionOfDoneStore
13
+from loader.runtime.events import TurnSummary
14
+from loader.runtime.permissions import (
15
+    PermissionMode,
16
+    build_permission_policy,
17
+    load_permission_rules,
18
+)
19
+from loader.runtime.phases import TurnPhase, TurnPhaseTracker
20
+from loader.runtime.tracing import RuntimeTracer
21
+from loader.runtime.workflow import WorkflowDecisionKind, WorkflowMode
22
+from loader.runtime.workflow_policy import ModeDecision
23
+from loader.runtime.workflow_state import WorkflowStateController
24
+from loader.tools.base import create_default_registry
25
+from tests.helpers.runtime_harness import ScriptedBackend
26
+
27
+
28
+class FakeCodeFilter:
29
+    def reset(self) -> None:
30
+        return None
31
+
32
+
33
+class FakeSafeguards:
34
+    def __init__(self) -> None:
35
+        self.action_tracker = object()
36
+        self.validator = object()
37
+        self.code_filter = FakeCodeFilter()
38
+
39
+    def filter_stream_chunk(self, content: str) -> str:
40
+        return content
41
+
42
+    def filter_complete_content(self, content: str) -> str:
43
+        return content
44
+
45
+    def should_steer(self) -> bool:
46
+        return False
47
+
48
+    def get_steering_message(self) -> str | None:
49
+        return None
50
+
51
+    def record_response(self, content: str) -> None:
52
+        return None
53
+
54
+
55
+class FakeSession:
56
+    def __init__(self) -> None:
57
+        self.messages: list[Message] = []
58
+        self.workflow_timeline = []
59
+        self.workflow_mode: str | None = None
60
+        self.workflow_reason_code: str | None = None
61
+        self.active_turn_phase: str | None = None
62
+        self.last_turn_transition_summary: str | None = None
63
+        self.last_turn_transition_kind: str | None = None
64
+        self.last_turn_transition_reason_code: str | None = None
65
+
66
+    def update_runtime_state(self, **kwargs: object) -> None:
67
+        for key, value in kwargs.items():
68
+            setattr(self, key, value)
69
+
70
+    def append_workflow_timeline_entry(self, entry) -> None:
71
+        self.workflow_timeline.append(entry)
72
+
73
+    def append(self, message: Message) -> None:
74
+        self.messages.append(message)
75
+
76
+
77
+def build_context(temp_dir: Path) -> tuple[RuntimeContext, list[str], FakeSession]:
78
+    registry = create_default_registry(temp_dir)
79
+    registry.configure_workspace_root(temp_dir)
80
+    rule_status = load_permission_rules(temp_dir)
81
+    policy = build_permission_policy(
82
+        active_mode=PermissionMode.WORKSPACE_WRITE,
83
+        workspace_root=temp_dir,
84
+        tool_requirements=registry.get_tool_requirements(),
85
+        rules=rule_status.rules,
86
+    )
87
+    workflow_modes: list[str] = []
88
+    session = FakeSession()
89
+    context = RuntimeContext(
90
+        project_root=temp_dir,
91
+        backend=ScriptedBackend(),
92
+        registry=registry,
93
+        session=session,  # type: ignore[arg-type]
94
+        config=SimpleNamespace(force_react=False, stream=False),
95
+        capability_profile=SimpleNamespace(supports_native_tools=True),  # type: ignore[arg-type]
96
+        project_context=None,
97
+        permission_policy=policy,
98
+        permission_config_status=rule_status,
99
+        workflow_mode="execute",
100
+        safeguards=FakeSafeguards(),
101
+        prompt_format="native",
102
+        prompt_sections=["Workflow Context"],
103
+        set_workflow_mode_callback=workflow_modes.append,
104
+    )
105
+    return context, workflow_modes, session
106
+
107
+
108
+@pytest.mark.asyncio
109
+async def test_workflow_state_controller_runs_on_runtime_context(
110
+    temp_dir: Path,
111
+) -> None:
112
+    context, workflow_modes, session = build_context(temp_dir)
113
+    controller = WorkflowStateController(
114
+        context,
115
+        dod_store=DefinitionOfDoneStore(temp_dir),
116
+    )
117
+    dod = controller.dod_store.create_or_resume("Plan the runtime context migration.")
118
+    summary = TurnSummary(final_response="")
119
+    events = []
120
+
121
+    async def emit(event) -> None:
122
+        events.append(event)
123
+
124
+    decision = ModeDecision.transition(
125
+        WorkflowMode.PLAN,
126
+        reason_code="task_is_complex",
127
+        reason_summary="task complexity favors a plan first",
128
+        decision_kind=WorkflowDecisionKind.HANDOFF,
129
+    )
130
+
131
+    await controller.set_workflow_mode(
132
+        decision,
133
+        dod=dod,
134
+        emit=emit,
135
+        summary=summary,
136
+    )
137
+
138
+    assert workflow_modes == ["plan"]
139
+    assert context.workflow_mode == "plan"
140
+    assert session.workflow_mode == "plan"
141
+    assert summary.workflow_timeline[-1].prompt_format == "native"
142
+    assert any(event.type == "workflow_mode" for event in events)
143
+
144
+
145
+@pytest.mark.asyncio
146
+async def test_turn_phase_tracker_runs_on_runtime_context(
147
+    temp_dir: Path,
148
+) -> None:
149
+    context, _workflow_modes, session = build_context(temp_dir)
150
+    tracker = TurnPhaseTracker(context, RuntimeTracer())
151
+    events = []
152
+
153
+    async def emit(event) -> None:
154
+        events.append(event)
155
+
156
+    await tracker.enter(
157
+        TurnPhase.PREPARE,
158
+        emit,
159
+        detail="Preparing the turn",
160
+        reason_code="prepare_turn",
161
+    )
162
+    tracker.clear()
163
+
164
+    assert session.last_turn_transition_summary == "start -> prepare [normal] Preparing the turn"
165
+    assert session.last_turn_transition_reason_code == "prepare_turn"
166
+    assert session.active_turn_phase is None
167
+    assert events[0].turn_phase == "prepare"