| 1 | """Tests for assistant-turn helpers running on RuntimeContext.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | from dataclasses import dataclass |
| 6 | from pathlib import Path |
| 7 | from types import SimpleNamespace |
| 8 | |
| 9 | import pytest |
| 10 | |
| 11 | from loader.llm.base import CompletionResponse, Message, Role, StreamChunk, ToolCall |
| 12 | from loader.runtime.assistant_turns import AssistantTurnRequester |
| 13 | from loader.runtime.capabilities import resolve_backend_capability_profile |
| 14 | from loader.runtime.context import RuntimeContext |
| 15 | from loader.runtime.events import AgentEvent |
| 16 | from loader.runtime.permissions import ( |
| 17 | PermissionMode, |
| 18 | build_permission_policy, |
| 19 | load_permission_rules, |
| 20 | ) |
| 21 | from loader.runtime.tracing import RuntimeTracer |
| 22 | from loader.tools.base import create_default_registry |
| 23 | from tests.helpers.runtime_harness import ScriptedBackend |
| 24 | |
| 25 | |
| 26 | @dataclass |
| 27 | class FakeCompaction: |
| 28 | removed_message_count: int |
| 29 | original_input_tokens: int |
| 30 | compressed_input_tokens: int |
| 31 | |
| 32 | |
| 33 | class FakeSession: |
| 34 | def __init__( |
| 35 | self, |
| 36 | *, |
| 37 | project_root: Path, |
| 38 | request_messages: list[Message], |
| 39 | compaction: FakeCompaction | None = None, |
| 40 | ) -> None: |
| 41 | self.messages = list(request_messages) |
| 42 | self.storage_path = project_root / ".loader" / "sessions" / "fake.json" |
| 43 | self._request_messages = list(request_messages) |
| 44 | self._compaction = compaction |
| 45 | |
| 46 | def maybe_compact(self) -> FakeCompaction | None: |
| 47 | result = self._compaction |
| 48 | self._compaction = None |
| 49 | return result |
| 50 | |
| 51 | def build_request_messages(self) -> list[Message]: |
| 52 | return list(self._request_messages) |
| 53 | |
| 54 | |
| 55 | class FakeCodeFilter: |
| 56 | def __init__(self) -> None: |
| 57 | self.reset_calls = 0 |
| 58 | |
| 59 | def reset(self) -> None: |
| 60 | self.reset_calls += 1 |
| 61 | |
| 62 | |
| 63 | class FakeSafeguards: |
| 64 | def __init__( |
| 65 | self, |
| 66 | *, |
| 67 | stream_transform=str, |
| 68 | complete_transform=str, |
| 69 | steering_message: str | None = None, |
| 70 | ) -> None: |
| 71 | self.action_tracker = object() |
| 72 | self.validator = object() |
| 73 | self.code_filter = FakeCodeFilter() |
| 74 | self._stream_transform = stream_transform |
| 75 | self._complete_transform = complete_transform |
| 76 | self._steering_message = steering_message |
| 77 | |
| 78 | def filter_stream_chunk(self, content: str) -> str: |
| 79 | return self._stream_transform(content) |
| 80 | |
| 81 | def filter_complete_content(self, content: str) -> str: |
| 82 | return self._complete_transform(content) |
| 83 | |
| 84 | def should_steer(self) -> bool: |
| 85 | return self._steering_message is not None |
| 86 | |
| 87 | def get_steering_message(self) -> str | None: |
| 88 | return self._steering_message |
| 89 | |
| 90 | def record_response(self, content: str) -> None: |
| 91 | return None |
| 92 | |
| 93 | def detect_text_loop(self, content: str) -> tuple[bool, str]: |
| 94 | return False, "" |
| 95 | |
| 96 | def detect_loop(self) -> tuple[bool, str]: |
| 97 | return False, "" |
| 98 | |
| 99 | |
| 100 | def build_runtime_context( |
| 101 | *, |
| 102 | temp_dir: Path, |
| 103 | backend: ScriptedBackend, |
| 104 | session: FakeSession, |
| 105 | safeguards: FakeSafeguards, |
| 106 | supports_native_tools: bool = True, |
| 107 | ) -> tuple[RuntimeContext, list[str]]: |
| 108 | registry = create_default_registry(temp_dir) |
| 109 | registry.configure_workspace_root(temp_dir) |
| 110 | rule_status = load_permission_rules(temp_dir) |
| 111 | policy = build_permission_policy( |
| 112 | active_mode=PermissionMode.WORKSPACE_WRITE, |
| 113 | workspace_root=temp_dir, |
| 114 | tool_requirements=registry.get_tool_requirements(), |
| 115 | rules=rule_status.rules, |
| 116 | ) |
| 117 | queued_messages: list[str] = [] |
| 118 | context = RuntimeContext( |
| 119 | project_root=temp_dir, |
| 120 | backend=backend, |
| 121 | registry=registry, |
| 122 | session=session, # type: ignore[arg-type] |
| 123 | config=SimpleNamespace( |
| 124 | force_react=not supports_native_tools, |
| 125 | stream=bool(backend._streams), |
| 126 | temperature=0.2, |
| 127 | ), |
| 128 | capability_profile=resolve_backend_capability_profile(backend), |
| 129 | project_context=None, |
| 130 | permission_policy=policy, |
| 131 | permission_config_status=rule_status, |
| 132 | workflow_mode="execute", |
| 133 | safeguards=safeguards, |
| 134 | queue_steering_message_callback=queued_messages.append, |
| 135 | ) |
| 136 | return context, queued_messages |
| 137 | |
| 138 | |
| 139 | @pytest.mark.asyncio |
| 140 | async def test_assistant_turn_requester_streams_with_fake_runtime_context( |
| 141 | temp_dir: Path, |
| 142 | ) -> None: |
| 143 | pending_tool = ToolCall(id="call_read", name="read", arguments={"file_path": "README.md"}) |
| 144 | backend = ScriptedBackend( |
| 145 | streams=[[ |
| 146 | StreamChunk(content="hello", pending_tool_call=pending_tool), |
| 147 | StreamChunk( |
| 148 | content="", |
| 149 | full_content="hello", |
| 150 | tool_calls=[pending_tool], |
| 151 | is_done=True, |
| 152 | usage={"completion_tokens": 1}, |
| 153 | ), |
| 154 | ]], |
| 155 | supports_native_tools=True, |
| 156 | ) |
| 157 | session = FakeSession( |
| 158 | project_root=temp_dir, |
| 159 | request_messages=[Message(role=Role.USER, content="Read the file")], |
| 160 | compaction=FakeCompaction( |
| 161 | removed_message_count=2, |
| 162 | original_input_tokens=120, |
| 163 | compressed_input_tokens=50, |
| 164 | ), |
| 165 | ) |
| 166 | context, queued_messages = build_runtime_context( |
| 167 | temp_dir=temp_dir, |
| 168 | backend=backend, |
| 169 | session=session, |
| 170 | safeguards=FakeSafeguards(stream_transform=str.upper), |
| 171 | supports_native_tools=True, |
| 172 | ) |
| 173 | requester = AssistantTurnRequester(context, RuntimeTracer()) |
| 174 | events: list[AgentEvent] = [] |
| 175 | |
| 176 | async def emit(event: AgentEvent) -> None: |
| 177 | events.append(event) |
| 178 | |
| 179 | turn = await requester.request_turn(emit=emit, max_tokens=256) |
| 180 | |
| 181 | assert turn.content == "hello" |
| 182 | assert turn.response_content == "hello" |
| 183 | assert turn.tool_calls == [pending_tool] |
| 184 | assert turn.pending_tool_calls_seen == {"call_read"} |
| 185 | assert turn.usage == {"completion_tokens": 1} |
| 186 | assert queued_messages == [] |
| 187 | assert context.safeguards.code_filter.reset_calls == 1 |
| 188 | assert [event.type for event in events] == ["artifact", "stream", "tool_call", "stream"] |
| 189 | assert events[1].content == "HELLO" |
| 190 | assert events[2].tool_name == "read" |
| 191 | |
| 192 | |
| 193 | @pytest.mark.asyncio |
| 194 | async def test_assistant_turn_requester_completes_with_fake_runtime_context( |
| 195 | temp_dir: Path, |
| 196 | ) -> None: |
| 197 | backend = ScriptedBackend( |
| 198 | completions=[CompletionResponse(content="final answer", usage={"completion_tokens": 2})], |
| 199 | supports_native_tools=False, |
| 200 | ) |
| 201 | session = FakeSession( |
| 202 | project_root=temp_dir, |
| 203 | request_messages=[Message(role=Role.USER, content="Explain the issue")], |
| 204 | ) |
| 205 | context, queued_messages = build_runtime_context( |
| 206 | temp_dir=temp_dir, |
| 207 | backend=backend, |
| 208 | session=session, |
| 209 | safeguards=FakeSafeguards( |
| 210 | complete_transform=lambda content: f"FILTERED: {content}", |
| 211 | steering_message="Stay on the task at hand.", |
| 212 | ), |
| 213 | supports_native_tools=False, |
| 214 | ) |
| 215 | requester = AssistantTurnRequester(context, RuntimeTracer()) |
| 216 | |
| 217 | async def emit(_: AgentEvent) -> None: |
| 218 | return None |
| 219 | |
| 220 | turn = await requester.request_turn(emit=emit, max_tokens=256) |
| 221 | |
| 222 | assert turn.content == "FILTERED: final answer" |
| 223 | assert turn.response_content == "final answer" |
| 224 | assert turn.tool_calls == [] |
| 225 | assert turn.usage == {"completion_tokens": 2} |
| 226 | assert queued_messages == ["Stay on the task at hand."] |
| 227 | assert context.safeguards.code_filter.reset_calls == 1 |