Python · 15810 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.probes.tool_use_fidelity`."""
2
3 from __future__ import annotations
4
5 import json
6 from typing import Any
7
8 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
9 from dlm_sway.core.result import Verdict
10 from dlm_sway.probes.base import RunContext, build_probe
11 from dlm_sway.probes.tool_use_fidelity import (
12 _extract_fenced_json,
13 _field_disagreement,
14 _first_balanced_braces,
15 _matches_schema,
16 _parse_tool_call,
17 )
18
19 # ---------------------------------------------------------------------------
20 # Schema fixtures
21 # ---------------------------------------------------------------------------
22
23
24 SEARCH_SPEC: dict[str, Any] = {
25 "name": "search_web",
26 "parameters": {
27 "type": "object",
28 "properties": {
29 "query": {"type": "string"},
30 "max_results": {"type": "integer"},
31 },
32 "required": ["query"],
33 },
34 }
35
36 CALCULATOR_SPEC: dict[str, Any] = {
37 "name": "calculator",
38 "parameters": {
39 "type": "object",
40 "properties": {
41 "a": {"type": "number"},
42 "b": {"type": "number"},
43 },
44 "required": ["a", "b"],
45 },
46 }
47
48
49 def _backend(
50 generations_base: dict[str, str], generations_ft: dict[str, str]
51 ) -> DummyDifferentialBackend:
52 """Build a dummy backend keyed by prompt for both views."""
53 return DummyDifferentialBackend(
54 base=DummyResponses(generations=generations_base),
55 ft=DummyResponses(generations=generations_ft),
56 )
57
58
59 def _call(name: str, **args: Any) -> str:
60 return json.dumps({"name": name, "arguments": args})
61
62
63 # ---------------------------------------------------------------------------
64 # End-to-end probe behavior
65 # ---------------------------------------------------------------------------
66
67
68 class TestProbeBehavior:
69 def test_pass_when_ft_preserves_validity_and_no_hallucination(self) -> None:
70 """Both views produce schema-valid calls naming the gold tool."""
71 prompt = "Search the web for cats"
72 backend = _backend(
73 generations_base={prompt: _call("search_web", query="cats")},
74 generations_ft={prompt: _call("search_web", query="cats", max_results=5)},
75 )
76 probe, spec = build_probe(
77 {
78 "name": "tuf",
79 "kind": "tool_use_fidelity",
80 "cases": [
81 {
82 "prompt": prompt,
83 "tool_spec": SEARCH_SPEC,
84 "gold_tool_name": "search_web",
85 }
86 ],
87 }
88 )
89 result = probe.run(spec, RunContext(backend=backend))
90 assert result.verdict == Verdict.PASS, result.message
91 assert result.evidence["json_valid_rate_ft"] == 1.0
92 assert result.evidence["json_valid_rate_base"] == 1.0
93 assert result.evidence["validity_delta"] == 0.0
94 assert result.evidence["hallucination_rate"] == 0.0
95
96 def test_fail_when_ft_breaks_json(self) -> None:
97 """ft regresses on validity → FAIL on the validity-delta floor."""
98 prompt = "Search the web for cats"
99 backend = _backend(
100 generations_base={prompt: _call("search_web", query="cats")},
101 generations_ft={prompt: "I cannot help with that request."},
102 )
103 probe, spec = build_probe(
104 {
105 "name": "tuf",
106 "kind": "tool_use_fidelity",
107 "cases": [
108 {
109 "prompt": prompt,
110 "tool_spec": SEARCH_SPEC,
111 "gold_tool_name": "search_web",
112 }
113 ],
114 }
115 )
116 result = probe.run(spec, RunContext(backend=backend))
117 assert result.verdict == Verdict.FAIL
118 assert result.evidence["json_valid_rate_ft"] == 0.0
119 assert result.evidence["json_valid_rate_base"] == 1.0
120 assert result.evidence["validity_delta"] == -1.0
121
122 def test_fail_when_ft_hallucinates_tool_name(self) -> None:
123 """Schema-valid call but wrong tool name → fails hallucination cap."""
124 prompt = "Search the web for cats"
125 backend = _backend(
126 generations_base={prompt: _call("search_web", query="cats")},
127 # ft picks a tool that isn't in `allowed_tools`.
128 generations_ft={prompt: _call("delete_database", query="cats")},
129 )
130 probe, spec = build_probe(
131 {
132 "name": "tuf",
133 "kind": "tool_use_fidelity",
134 "cases": [
135 {
136 "prompt": prompt,
137 "tool_spec": SEARCH_SPEC,
138 "gold_tool_name": "search_web",
139 }
140 ],
141 "allowed_tools": ["search_web"],
142 }
143 )
144 result = probe.run(spec, RunContext(backend=backend))
145 assert result.verdict == Verdict.FAIL
146 assert result.evidence["hallucination_rate"] == 1.0
147
148 def test_skip_when_no_cases(self) -> None:
149 backend = _backend({}, {})
150 probe, spec = build_probe({"name": "tuf", "kind": "tool_use_fidelity"})
151 result = probe.run(spec, RunContext(backend=backend))
152 assert result.verdict == Verdict.SKIP
153
154 def test_pass_with_arg_drift_under_validity_floor(self) -> None:
155 """Argument value drift counts as evidence but doesn't fail v1."""
156 prompt = "calc 2+2"
157 backend = _backend(
158 generations_base={prompt: _call("calculator", a=2, b=2)},
159 generations_ft={prompt: _call("calculator", a=2, b=3)}, # b drifted
160 )
161 probe, spec = build_probe(
162 {
163 "name": "tuf",
164 "kind": "tool_use_fidelity",
165 "cases": [
166 {
167 "prompt": prompt,
168 "tool_spec": CALCULATOR_SPEC,
169 "gold_tool_name": "calculator",
170 }
171 ],
172 }
173 )
174 result = probe.run(spec, RunContext(backend=backend))
175 assert result.verdict == Verdict.PASS
176 # 1/2 leaf disagreements (b differs, a matches).
177 assert result.evidence["mean_arg_disagreement"] == 0.5
178
179 def test_allowed_tools_check_supersedes_gold(self) -> None:
180 """When allowed_tools is set, gold is informational; surface gate wins."""
181 prompt = "search"
182 backend = _backend(
183 generations_base={prompt: _call("search_web", query="x")},
184 # ft picks a tool that's in the allowed surface but != gold;
185 # the surface gate accepts it → no hallucination.
186 generations_ft={prompt: _call("search_web_v2", query="x")},
187 )
188 probe, spec = build_probe(
189 {
190 "name": "tuf",
191 "kind": "tool_use_fidelity",
192 "cases": [
193 {
194 "prompt": prompt,
195 "tool_spec": {**SEARCH_SPEC, "name": "search_web_v2"},
196 "gold_tool_name": "search_web",
197 }
198 ],
199 "allowed_tools": ["search_web", "search_web_v2"],
200 }
201 )
202 result = probe.run(spec, RunContext(backend=backend))
203 assert result.evidence["hallucination_rate"] == 0.0
204
205 def test_fenced_json_in_generation_parses(self) -> None:
206 """A model that wraps the call in ```json ... ``` is still valid."""
207 prompt = "search for cats"
208 fenced = "Here you go:\n```json\n" + _call("search_web", query="cats") + "\n```"
209 backend = _backend(
210 generations_base={prompt: _call("search_web", query="cats")},
211 generations_ft={prompt: fenced},
212 )
213 probe, spec = build_probe(
214 {
215 "name": "tuf",
216 "kind": "tool_use_fidelity",
217 "cases": [
218 {
219 "prompt": prompt,
220 "tool_spec": SEARCH_SPEC,
221 "gold_tool_name": "search_web",
222 }
223 ],
224 }
225 )
226 result = probe.run(spec, RunContext(backend=backend))
227 assert result.evidence["json_valid_rate_ft"] == 1.0
228
229
230 # ---------------------------------------------------------------------------
231 # JSON parsing helpers
232 # ---------------------------------------------------------------------------
233
234
235 class TestParseToolCall:
236 def test_whole_text_is_json(self) -> None:
237 out = _parse_tool_call('{"name": "x", "arguments": {}}')
238 assert out == {"name": "x", "arguments": {}}
239
240 def test_embedded_braces_extracted(self) -> None:
241 out = _parse_tool_call('I will call: {"name": "x", "arguments": {"q": 1}} now.')
242 assert out is not None
243 assert out["name"] == "x"
244
245 def test_fenced_json_block_extracted(self) -> None:
246 out = _parse_tool_call('Sure.\n```json\n{"name": "x", "arguments": {}}\n```\nDone.')
247 assert out == {"name": "x", "arguments": {}}
248
249 def test_returns_none_for_pure_prose(self) -> None:
250 assert _parse_tool_call("I'm sorry, I can't do that.") is None
251
252 def test_returns_none_for_non_dict_root(self) -> None:
253 assert _parse_tool_call("[1, 2, 3]") is None
254
255 def test_returns_none_for_unbalanced_braces(self) -> None:
256 assert _parse_tool_call("{not closed") is None
257
258 def test_braces_inside_string_dont_throw_off_balance(self) -> None:
259 # The inner `}` is inside a string literal — must not pop the depth.
260 out = _parse_tool_call('text {"name": "x", "arguments": {"q": "}"}}')
261 assert out is not None
262 assert out["arguments"]["q"] == "}"
263
264
265 class TestExtractFencedJson:
266 def test_extracts_with_language_tag(self) -> None:
267 body = _extract_fenced_json('pre\n```json\n{"a": 1}\n```\npost')
268 assert body == '{"a": 1}'
269
270 def test_extracts_without_language_tag(self) -> None:
271 body = _extract_fenced_json('```\n{"a": 1}\n```')
272 assert body == '{"a": 1}'
273
274 def test_returns_none_when_no_fence(self) -> None:
275 assert _extract_fenced_json("no fences here") is None
276
277 def test_returns_none_when_unclosed(self) -> None:
278 assert _extract_fenced_json('```json\n{"a": 1}') is None
279
280
281 class TestFirstBalancedBraces:
282 def test_simple_object(self) -> None:
283 assert _first_balanced_braces('text {"a": 1} more') == '{"a": 1}'
284
285 def test_nested(self) -> None:
286 assert _first_balanced_braces('{"a": {"b": 1}} trailing') == '{"a": {"b": 1}}'
287
288 def test_no_braces_returns_none(self) -> None:
289 assert _first_balanced_braces("plain text") is None
290
291 def test_unbalanced_returns_none(self) -> None:
292 assert _first_balanced_braces("{unclosed") is None
293
294
295 # ---------------------------------------------------------------------------
296 # Schema validator
297 # ---------------------------------------------------------------------------
298
299
300 class TestMatchesSchema:
301 def test_full_valid_call(self) -> None:
302 call = {"name": "search_web", "arguments": {"query": "cats", "max_results": 5}}
303 assert _matches_schema(call, SEARCH_SPEC) is True
304
305 def test_missing_name_fails(self) -> None:
306 call = {"arguments": {"query": "cats"}}
307 assert _matches_schema(call, SEARCH_SPEC) is False
308
309 def test_missing_arguments_fails(self) -> None:
310 call = {"name": "search_web"}
311 assert _matches_schema(call, SEARCH_SPEC) is False
312
313 def test_missing_required_arg_fails(self) -> None:
314 call = {"name": "search_web", "arguments": {"max_results": 5}}
315 assert _matches_schema(call, SEARCH_SPEC) is False
316
317 def test_wrong_type_fails(self) -> None:
318 call = {"name": "search_web", "arguments": {"query": 123}}
319 assert _matches_schema(call, SEARCH_SPEC) is False
320
321 def test_extra_unknown_field_allowed(self) -> None:
322 """Real models add benign extras; we tolerate them."""
323 call = {
324 "name": "search_web",
325 "arguments": {"query": "cats", "_internal_id": "abc"},
326 }
327 assert _matches_schema(call, SEARCH_SPEC) is True
328
329 def test_bool_rejected_as_integer(self) -> None:
330 """``isinstance(True, int)`` is True — guard against silent acceptance."""
331 spec = {
332 "name": "x",
333 "parameters": {
334 "type": "object",
335 "properties": {"n": {"type": "integer"}},
336 "required": ["n"],
337 },
338 }
339 assert _matches_schema({"name": "x", "arguments": {"n": True}}, spec) is False
340 assert _matches_schema({"name": "x", "arguments": {"n": 5}}, spec) is True
341
342 def test_number_accepts_int_and_float(self) -> None:
343 assert (
344 _matches_schema(
345 {"name": "calculator", "arguments": {"a": 1, "b": 2.5}}, CALCULATOR_SPEC
346 )
347 is True
348 )
349
350 def test_array_type_check(self) -> None:
351 spec = {
352 "name": "x",
353 "parameters": {
354 "type": "object",
355 "properties": {"items": {"type": "array"}},
356 "required": ["items"],
357 },
358 }
359 assert _matches_schema({"name": "x", "arguments": {"items": [1, 2]}}, spec) is True
360 assert _matches_schema({"name": "x", "arguments": {"items": "nope"}}, spec) is False
361
362
363 # ---------------------------------------------------------------------------
364 # Field disagreement
365 # ---------------------------------------------------------------------------
366
367
368 class TestFieldDisagreement:
369 def test_identical_calls_zero(self) -> None:
370 a = {"arguments": {"x": 1, "y": "hi"}}
371 b = {"arguments": {"x": 1, "y": "hi"}}
372 assert _field_disagreement(a, b) == 0.0
373
374 def test_one_field_drifted(self) -> None:
375 a = {"arguments": {"x": 1, "y": "hi"}}
376 b = {"arguments": {"x": 1, "y": "bye"}}
377 assert _field_disagreement(a, b) == 0.5
378
379 def test_missing_field_counts_as_disagreement(self) -> None:
380 a = {"arguments": {"x": 1, "y": "hi"}}
381 b = {"arguments": {"x": 1}}
382 assert _field_disagreement(a, b) == 0.5
383
384 def test_nested_dicts_walked(self) -> None:
385 a = {"arguments": {"outer": {"inner": "a"}}}
386 b = {"arguments": {"outer": {"inner": "b"}}}
387 assert _field_disagreement(a, b) == 1.0
388
389 def test_lists_treated_as_leaves(self) -> None:
390 a = {"arguments": {"items": [1, 2, 3]}}
391 b = {"arguments": {"items": [3, 2, 1]}}
392 # Different ordering → list leaves differ → full disagreement.
393 assert _field_disagreement(a, b) == 1.0
394
395 def test_empty_args_zero(self) -> None:
396 assert _field_disagreement({"arguments": {}}, {"arguments": {}}) == 0.0
397
398 def test_non_dict_arguments_full_disagreement(self) -> None:
399 assert _field_disagreement({"arguments": "oops"}, {"arguments": {}}) == 1.0
400
401 def test_present_none_distinct_from_absent(self) -> None:
402 """A field with value ``None`` ≠ a field that's absent."""
403 a = {"arguments": {"x": None}}
404 b = {"arguments": {}}
405 assert _field_disagreement(a, b) == 1.0
406
407
408 # ---------------------------------------------------------------------------
409 # Calibration spec — sanity check the null-adapter handoff
410 # ---------------------------------------------------------------------------
411
412
413 class TestCalibrateSpec:
414 def test_calibrate_spec_returns_two_sentinels(self) -> None:
415 from dlm_sway.probes.tool_use_fidelity import ToolUseFidelityProbe
416
417 cspec = ToolUseFidelityProbe.calibrate_spec(RunContext())
418 assert cspec is not None
419 assert len(cspec.cases) == 2
420 assert cspec.kind == "tool_use_fidelity"