Python · 7359 bytes Raw Blame History
1 """Tests for assistant-turn helpers running on RuntimeContext."""
2
3 from __future__ import annotations
4
5 from dataclasses import dataclass
6 from pathlib import Path
7 from types import SimpleNamespace
8
9 import pytest
10
11 from loader.llm.base import CompletionResponse, Message, Role, StreamChunk, ToolCall
12 from loader.runtime.assistant_turns import AssistantTurnRequester
13 from loader.runtime.capabilities import resolve_backend_capability_profile
14 from loader.runtime.context import RuntimeContext
15 from loader.runtime.events import AgentEvent
16 from loader.runtime.permissions import (
17 PermissionMode,
18 build_permission_policy,
19 load_permission_rules,
20 )
21 from loader.runtime.tracing import RuntimeTracer
22 from loader.tools.base import create_default_registry
23 from tests.helpers.runtime_harness import ScriptedBackend
24
25
26 @dataclass
27 class FakeCompaction:
28 removed_message_count: int
29 original_input_tokens: int
30 compressed_input_tokens: int
31
32
33 class FakeSession:
34 def __init__(
35 self,
36 *,
37 project_root: Path,
38 request_messages: list[Message],
39 compaction: FakeCompaction | None = None,
40 ) -> None:
41 self.messages = list(request_messages)
42 self.storage_path = project_root / ".loader" / "sessions" / "fake.json"
43 self._request_messages = list(request_messages)
44 self._compaction = compaction
45
46 def maybe_compact(self) -> FakeCompaction | None:
47 result = self._compaction
48 self._compaction = None
49 return result
50
51 def build_request_messages(self) -> list[Message]:
52 return list(self._request_messages)
53
54
55 class FakeCodeFilter:
56 def __init__(self) -> None:
57 self.reset_calls = 0
58
59 def reset(self) -> None:
60 self.reset_calls += 1
61
62
63 class FakeSafeguards:
64 def __init__(
65 self,
66 *,
67 stream_transform=str,
68 complete_transform=str,
69 steering_message: str | None = None,
70 ) -> None:
71 self.action_tracker = object()
72 self.validator = object()
73 self.code_filter = FakeCodeFilter()
74 self._stream_transform = stream_transform
75 self._complete_transform = complete_transform
76 self._steering_message = steering_message
77
78 def filter_stream_chunk(self, content: str) -> str:
79 return self._stream_transform(content)
80
81 def filter_complete_content(self, content: str) -> str:
82 return self._complete_transform(content)
83
84 def should_steer(self) -> bool:
85 return self._steering_message is not None
86
87 def get_steering_message(self) -> str | None:
88 return self._steering_message
89
90 def record_response(self, content: str) -> None:
91 return None
92
93 def detect_text_loop(self, content: str) -> tuple[bool, str]:
94 return False, ""
95
96 def detect_loop(self) -> tuple[bool, str]:
97 return False, ""
98
99
100 def build_runtime_context(
101 *,
102 temp_dir: Path,
103 backend: ScriptedBackend,
104 session: FakeSession,
105 safeguards: FakeSafeguards,
106 supports_native_tools: bool = True,
107 ) -> tuple[RuntimeContext, list[str]]:
108 registry = create_default_registry(temp_dir)
109 registry.configure_workspace_root(temp_dir)
110 rule_status = load_permission_rules(temp_dir)
111 policy = build_permission_policy(
112 active_mode=PermissionMode.WORKSPACE_WRITE,
113 workspace_root=temp_dir,
114 tool_requirements=registry.get_tool_requirements(),
115 rules=rule_status.rules,
116 )
117 queued_messages: list[str] = []
118 context = RuntimeContext(
119 project_root=temp_dir,
120 backend=backend,
121 registry=registry,
122 session=session, # type: ignore[arg-type]
123 config=SimpleNamespace(
124 force_react=not supports_native_tools,
125 stream=bool(backend._streams),
126 temperature=0.2,
127 ),
128 capability_profile=resolve_backend_capability_profile(backend),
129 project_context=None,
130 permission_policy=policy,
131 permission_config_status=rule_status,
132 workflow_mode="execute",
133 safeguards=safeguards,
134 queue_steering_message_callback=queued_messages.append,
135 )
136 return context, queued_messages
137
138
139 @pytest.mark.asyncio
140 async def test_assistant_turn_requester_streams_with_fake_runtime_context(
141 temp_dir: Path,
142 ) -> None:
143 pending_tool = ToolCall(id="call_read", name="read", arguments={"file_path": "README.md"})
144 backend = ScriptedBackend(
145 streams=[[
146 StreamChunk(content="hello", pending_tool_call=pending_tool),
147 StreamChunk(
148 content="",
149 full_content="hello",
150 tool_calls=[pending_tool],
151 is_done=True,
152 usage={"completion_tokens": 1},
153 ),
154 ]],
155 supports_native_tools=True,
156 )
157 session = FakeSession(
158 project_root=temp_dir,
159 request_messages=[Message(role=Role.USER, content="Read the file")],
160 compaction=FakeCompaction(
161 removed_message_count=2,
162 original_input_tokens=120,
163 compressed_input_tokens=50,
164 ),
165 )
166 context, queued_messages = build_runtime_context(
167 temp_dir=temp_dir,
168 backend=backend,
169 session=session,
170 safeguards=FakeSafeguards(stream_transform=str.upper),
171 supports_native_tools=True,
172 )
173 requester = AssistantTurnRequester(context, RuntimeTracer())
174 events: list[AgentEvent] = []
175
176 async def emit(event: AgentEvent) -> None:
177 events.append(event)
178
179 turn = await requester.request_turn(emit=emit, max_tokens=256)
180
181 assert turn.content == "hello"
182 assert turn.response_content == "hello"
183 assert turn.tool_calls == [pending_tool]
184 assert turn.pending_tool_calls_seen == {"call_read"}
185 assert turn.usage == {"completion_tokens": 1}
186 assert queued_messages == []
187 assert context.safeguards.code_filter.reset_calls == 1
188 assert [event.type for event in events] == ["artifact", "stream", "tool_call", "stream"]
189 assert events[1].content == "HELLO"
190 assert events[2].tool_name == "read"
191
192
193 @pytest.mark.asyncio
194 async def test_assistant_turn_requester_completes_with_fake_runtime_context(
195 temp_dir: Path,
196 ) -> None:
197 backend = ScriptedBackend(
198 completions=[CompletionResponse(content="final answer", usage={"completion_tokens": 2})],
199 supports_native_tools=False,
200 )
201 session = FakeSession(
202 project_root=temp_dir,
203 request_messages=[Message(role=Role.USER, content="Explain the issue")],
204 )
205 context, queued_messages = build_runtime_context(
206 temp_dir=temp_dir,
207 backend=backend,
208 session=session,
209 safeguards=FakeSafeguards(
210 complete_transform=lambda content: f"FILTERED: {content}",
211 steering_message="Stay on the task at hand.",
212 ),
213 supports_native_tools=False,
214 )
215 requester = AssistantTurnRequester(context, RuntimeTracer())
216
217 async def emit(_: AgentEvent) -> None:
218 return None
219
220 turn = await requester.request_turn(emit=emit, max_tokens=256)
221
222 assert turn.content == "FILTERED: final answer"
223 assert turn.response_content == "final answer"
224 assert turn.tool_calls == []
225 assert turn.usage == {"completion_tokens": 2}
226 assert queued_messages == ["Stay on the task at hand."]
227 assert context.safeguards.code_filter.reset_calls == 1