tenseleyflow/sway / 3c0542d

Browse files

tests/unit: 40 tests covering tool_use_fidelity probe + parsers + schema check

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
3c0542d01ecd735b3adccea862820ca6fb0b4cb9
Parents
4f1cffb
Tree
f3f126e

1 changed file

StatusFile+-
A tests/unit/test_probe_tool_use_fidelity.py 420 0
tests/unit/test_probe_tool_use_fidelity.pyadded
@@ -0,0 +1,420 @@
1
+"""Tests for :mod:`dlm_sway.probes.tool_use_fidelity`."""
2
+
3
+from __future__ import annotations
4
+
5
+import json
6
+from typing import Any
7
+
8
+from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
9
+from dlm_sway.core.result import Verdict
10
+from dlm_sway.probes.base import RunContext, build_probe
11
+from dlm_sway.probes.tool_use_fidelity import (
12
+    _extract_fenced_json,
13
+    _field_disagreement,
14
+    _first_balanced_braces,
15
+    _matches_schema,
16
+    _parse_tool_call,
17
+)
18
+
19
+# ---------------------------------------------------------------------------
20
+# Schema fixtures
21
+# ---------------------------------------------------------------------------
22
+
23
+
24
+SEARCH_SPEC: dict[str, Any] = {
25
+    "name": "search_web",
26
+    "parameters": {
27
+        "type": "object",
28
+        "properties": {
29
+            "query": {"type": "string"},
30
+            "max_results": {"type": "integer"},
31
+        },
32
+        "required": ["query"],
33
+    },
34
+}
35
+
36
+CALCULATOR_SPEC: dict[str, Any] = {
37
+    "name": "calculator",
38
+    "parameters": {
39
+        "type": "object",
40
+        "properties": {
41
+            "a": {"type": "number"},
42
+            "b": {"type": "number"},
43
+        },
44
+        "required": ["a", "b"],
45
+    },
46
+}
47
+
48
+
49
+def _backend(
50
+    generations_base: dict[str, str], generations_ft: dict[str, str]
51
+) -> DummyDifferentialBackend:
52
+    """Build a dummy backend keyed by prompt for both views."""
53
+    return DummyDifferentialBackend(
54
+        base=DummyResponses(generations=generations_base),
55
+        ft=DummyResponses(generations=generations_ft),
56
+    )
57
+
58
+
59
+def _call(name: str, **args: Any) -> str:
60
+    return json.dumps({"name": name, "arguments": args})
61
+
62
+
63
+# ---------------------------------------------------------------------------
64
+# End-to-end probe behavior
65
+# ---------------------------------------------------------------------------
66
+
67
+
68
+class TestProbeBehavior:
69
+    def test_pass_when_ft_preserves_validity_and_no_hallucination(self) -> None:
70
+        """Both views produce schema-valid calls naming the gold tool."""
71
+        prompt = "Search the web for cats"
72
+        backend = _backend(
73
+            generations_base={prompt: _call("search_web", query="cats")},
74
+            generations_ft={prompt: _call("search_web", query="cats", max_results=5)},
75
+        )
76
+        probe, spec = build_probe(
77
+            {
78
+                "name": "tuf",
79
+                "kind": "tool_use_fidelity",
80
+                "cases": [
81
+                    {
82
+                        "prompt": prompt,
83
+                        "tool_spec": SEARCH_SPEC,
84
+                        "gold_tool_name": "search_web",
85
+                    }
86
+                ],
87
+            }
88
+        )
89
+        result = probe.run(spec, RunContext(backend=backend))
90
+        assert result.verdict == Verdict.PASS, result.message
91
+        assert result.evidence["json_valid_rate_ft"] == 1.0
92
+        assert result.evidence["json_valid_rate_base"] == 1.0
93
+        assert result.evidence["validity_delta"] == 0.0
94
+        assert result.evidence["hallucination_rate"] == 0.0
95
+
96
+    def test_fail_when_ft_breaks_json(self) -> None:
97
+        """ft regresses on validity → FAIL on the validity-delta floor."""
98
+        prompt = "Search the web for cats"
99
+        backend = _backend(
100
+            generations_base={prompt: _call("search_web", query="cats")},
101
+            generations_ft={prompt: "I cannot help with that request."},
102
+        )
103
+        probe, spec = build_probe(
104
+            {
105
+                "name": "tuf",
106
+                "kind": "tool_use_fidelity",
107
+                "cases": [
108
+                    {
109
+                        "prompt": prompt,
110
+                        "tool_spec": SEARCH_SPEC,
111
+                        "gold_tool_name": "search_web",
112
+                    }
113
+                ],
114
+            }
115
+        )
116
+        result = probe.run(spec, RunContext(backend=backend))
117
+        assert result.verdict == Verdict.FAIL
118
+        assert result.evidence["json_valid_rate_ft"] == 0.0
119
+        assert result.evidence["json_valid_rate_base"] == 1.0
120
+        assert result.evidence["validity_delta"] == -1.0
121
+
122
+    def test_fail_when_ft_hallucinates_tool_name(self) -> None:
123
+        """Schema-valid call but wrong tool name → fails hallucination cap."""
124
+        prompt = "Search the web for cats"
125
+        backend = _backend(
126
+            generations_base={prompt: _call("search_web", query="cats")},
127
+            # ft picks a tool that isn't in `allowed_tools`.
128
+            generations_ft={prompt: _call("delete_database", query="cats")},
129
+        )
130
+        probe, spec = build_probe(
131
+            {
132
+                "name": "tuf",
133
+                "kind": "tool_use_fidelity",
134
+                "cases": [
135
+                    {
136
+                        "prompt": prompt,
137
+                        "tool_spec": SEARCH_SPEC,
138
+                        "gold_tool_name": "search_web",
139
+                    }
140
+                ],
141
+                "allowed_tools": ["search_web"],
142
+            }
143
+        )
144
+        result = probe.run(spec, RunContext(backend=backend))
145
+        assert result.verdict == Verdict.FAIL
146
+        assert result.evidence["hallucination_rate"] == 1.0
147
+
148
+    def test_skip_when_no_cases(self) -> None:
149
+        backend = _backend({}, {})
150
+        probe, spec = build_probe({"name": "tuf", "kind": "tool_use_fidelity"})
151
+        result = probe.run(spec, RunContext(backend=backend))
152
+        assert result.verdict == Verdict.SKIP
153
+
154
+    def test_pass_with_arg_drift_under_validity_floor(self) -> None:
155
+        """Argument value drift counts as evidence but doesn't fail v1."""
156
+        prompt = "calc 2+2"
157
+        backend = _backend(
158
+            generations_base={prompt: _call("calculator", a=2, b=2)},
159
+            generations_ft={prompt: _call("calculator", a=2, b=3)},  # b drifted
160
+        )
161
+        probe, spec = build_probe(
162
+            {
163
+                "name": "tuf",
164
+                "kind": "tool_use_fidelity",
165
+                "cases": [
166
+                    {
167
+                        "prompt": prompt,
168
+                        "tool_spec": CALCULATOR_SPEC,
169
+                        "gold_tool_name": "calculator",
170
+                    }
171
+                ],
172
+            }
173
+        )
174
+        result = probe.run(spec, RunContext(backend=backend))
175
+        assert result.verdict == Verdict.PASS
176
+        # 1/2 leaf disagreements (b differs, a matches).
177
+        assert result.evidence["mean_arg_disagreement"] == 0.5
178
+
179
+    def test_allowed_tools_check_supersedes_gold(self) -> None:
180
+        """When allowed_tools is set, gold is informational; surface gate wins."""
181
+        prompt = "search"
182
+        backend = _backend(
183
+            generations_base={prompt: _call("search_web", query="x")},
184
+            # ft picks a tool that's in the allowed surface but != gold;
185
+            # the surface gate accepts it → no hallucination.
186
+            generations_ft={prompt: _call("search_web_v2", query="x")},
187
+        )
188
+        probe, spec = build_probe(
189
+            {
190
+                "name": "tuf",
191
+                "kind": "tool_use_fidelity",
192
+                "cases": [
193
+                    {
194
+                        "prompt": prompt,
195
+                        "tool_spec": {**SEARCH_SPEC, "name": "search_web_v2"},
196
+                        "gold_tool_name": "search_web",
197
+                    }
198
+                ],
199
+                "allowed_tools": ["search_web", "search_web_v2"],
200
+            }
201
+        )
202
+        result = probe.run(spec, RunContext(backend=backend))
203
+        assert result.evidence["hallucination_rate"] == 0.0
204
+
205
+    def test_fenced_json_in_generation_parses(self) -> None:
206
+        """A model that wraps the call in ```json ... ``` is still valid."""
207
+        prompt = "search for cats"
208
+        fenced = "Here you go:\n```json\n" + _call("search_web", query="cats") + "\n```"
209
+        backend = _backend(
210
+            generations_base={prompt: _call("search_web", query="cats")},
211
+            generations_ft={prompt: fenced},
212
+        )
213
+        probe, spec = build_probe(
214
+            {
215
+                "name": "tuf",
216
+                "kind": "tool_use_fidelity",
217
+                "cases": [
218
+                    {
219
+                        "prompt": prompt,
220
+                        "tool_spec": SEARCH_SPEC,
221
+                        "gold_tool_name": "search_web",
222
+                    }
223
+                ],
224
+            }
225
+        )
226
+        result = probe.run(spec, RunContext(backend=backend))
227
+        assert result.evidence["json_valid_rate_ft"] == 1.0
228
+
229
+
230
+# ---------------------------------------------------------------------------
231
+# JSON parsing helpers
232
+# ---------------------------------------------------------------------------
233
+
234
+
235
+class TestParseToolCall:
236
+    def test_whole_text_is_json(self) -> None:
237
+        out = _parse_tool_call('{"name": "x", "arguments": {}}')
238
+        assert out == {"name": "x", "arguments": {}}
239
+
240
+    def test_embedded_braces_extracted(self) -> None:
241
+        out = _parse_tool_call('I will call: {"name": "x", "arguments": {"q": 1}} now.')
242
+        assert out is not None
243
+        assert out["name"] == "x"
244
+
245
+    def test_fenced_json_block_extracted(self) -> None:
246
+        out = _parse_tool_call('Sure.\n```json\n{"name": "x", "arguments": {}}\n```\nDone.')
247
+        assert out == {"name": "x", "arguments": {}}
248
+
249
+    def test_returns_none_for_pure_prose(self) -> None:
250
+        assert _parse_tool_call("I'm sorry, I can't do that.") is None
251
+
252
+    def test_returns_none_for_non_dict_root(self) -> None:
253
+        assert _parse_tool_call("[1, 2, 3]") is None
254
+
255
+    def test_returns_none_for_unbalanced_braces(self) -> None:
256
+        assert _parse_tool_call("{not closed") is None
257
+
258
+    def test_braces_inside_string_dont_throw_off_balance(self) -> None:
259
+        # The inner `}` is inside a string literal — must not pop the depth.
260
+        out = _parse_tool_call('text {"name": "x", "arguments": {"q": "}"}}')
261
+        assert out is not None
262
+        assert out["arguments"]["q"] == "}"
263
+
264
+
265
+class TestExtractFencedJson:
266
+    def test_extracts_with_language_tag(self) -> None:
267
+        body = _extract_fenced_json('pre\n```json\n{"a": 1}\n```\npost')
268
+        assert body == '{"a": 1}'
269
+
270
+    def test_extracts_without_language_tag(self) -> None:
271
+        body = _extract_fenced_json('```\n{"a": 1}\n```')
272
+        assert body == '{"a": 1}'
273
+
274
+    def test_returns_none_when_no_fence(self) -> None:
275
+        assert _extract_fenced_json("no fences here") is None
276
+
277
+    def test_returns_none_when_unclosed(self) -> None:
278
+        assert _extract_fenced_json('```json\n{"a": 1}') is None
279
+
280
+
281
+class TestFirstBalancedBraces:
282
+    def test_simple_object(self) -> None:
283
+        assert _first_balanced_braces('text {"a": 1} more') == '{"a": 1}'
284
+
285
+    def test_nested(self) -> None:
286
+        assert _first_balanced_braces('{"a": {"b": 1}} trailing') == '{"a": {"b": 1}}'
287
+
288
+    def test_no_braces_returns_none(self) -> None:
289
+        assert _first_balanced_braces("plain text") is None
290
+
291
+    def test_unbalanced_returns_none(self) -> None:
292
+        assert _first_balanced_braces("{unclosed") is None
293
+
294
+
295
+# ---------------------------------------------------------------------------
296
+# Schema validator
297
+# ---------------------------------------------------------------------------
298
+
299
+
300
+class TestMatchesSchema:
301
+    def test_full_valid_call(self) -> None:
302
+        call = {"name": "search_web", "arguments": {"query": "cats", "max_results": 5}}
303
+        assert _matches_schema(call, SEARCH_SPEC) is True
304
+
305
+    def test_missing_name_fails(self) -> None:
306
+        call = {"arguments": {"query": "cats"}}
307
+        assert _matches_schema(call, SEARCH_SPEC) is False
308
+
309
+    def test_missing_arguments_fails(self) -> None:
310
+        call = {"name": "search_web"}
311
+        assert _matches_schema(call, SEARCH_SPEC) is False
312
+
313
+    def test_missing_required_arg_fails(self) -> None:
314
+        call = {"name": "search_web", "arguments": {"max_results": 5}}
315
+        assert _matches_schema(call, SEARCH_SPEC) is False
316
+
317
+    def test_wrong_type_fails(self) -> None:
318
+        call = {"name": "search_web", "arguments": {"query": 123}}
319
+        assert _matches_schema(call, SEARCH_SPEC) is False
320
+
321
+    def test_extra_unknown_field_allowed(self) -> None:
322
+        """Real models add benign extras; we tolerate them."""
323
+        call = {
324
+            "name": "search_web",
325
+            "arguments": {"query": "cats", "_internal_id": "abc"},
326
+        }
327
+        assert _matches_schema(call, SEARCH_SPEC) is True
328
+
329
+    def test_bool_rejected_as_integer(self) -> None:
330
+        """``isinstance(True, int)`` is True — guard against silent acceptance."""
331
+        spec = {
332
+            "name": "x",
333
+            "parameters": {
334
+                "type": "object",
335
+                "properties": {"n": {"type": "integer"}},
336
+                "required": ["n"],
337
+            },
338
+        }
339
+        assert _matches_schema({"name": "x", "arguments": {"n": True}}, spec) is False
340
+        assert _matches_schema({"name": "x", "arguments": {"n": 5}}, spec) is True
341
+
342
+    def test_number_accepts_int_and_float(self) -> None:
343
+        assert (
344
+            _matches_schema(
345
+                {"name": "calculator", "arguments": {"a": 1, "b": 2.5}}, CALCULATOR_SPEC
346
+            )
347
+            is True
348
+        )
349
+
350
+    def test_array_type_check(self) -> None:
351
+        spec = {
352
+            "name": "x",
353
+            "parameters": {
354
+                "type": "object",
355
+                "properties": {"items": {"type": "array"}},
356
+                "required": ["items"],
357
+            },
358
+        }
359
+        assert _matches_schema({"name": "x", "arguments": {"items": [1, 2]}}, spec) is True
360
+        assert _matches_schema({"name": "x", "arguments": {"items": "nope"}}, spec) is False
361
+
362
+
363
+# ---------------------------------------------------------------------------
364
+# Field disagreement
365
+# ---------------------------------------------------------------------------
366
+
367
+
368
+class TestFieldDisagreement:
369
+    def test_identical_calls_zero(self) -> None:
370
+        a = {"arguments": {"x": 1, "y": "hi"}}
371
+        b = {"arguments": {"x": 1, "y": "hi"}}
372
+        assert _field_disagreement(a, b) == 0.0
373
+
374
+    def test_one_field_drifted(self) -> None:
375
+        a = {"arguments": {"x": 1, "y": "hi"}}
376
+        b = {"arguments": {"x": 1, "y": "bye"}}
377
+        assert _field_disagreement(a, b) == 0.5
378
+
379
+    def test_missing_field_counts_as_disagreement(self) -> None:
380
+        a = {"arguments": {"x": 1, "y": "hi"}}
381
+        b = {"arguments": {"x": 1}}
382
+        assert _field_disagreement(a, b) == 0.5
383
+
384
+    def test_nested_dicts_walked(self) -> None:
385
+        a = {"arguments": {"outer": {"inner": "a"}}}
386
+        b = {"arguments": {"outer": {"inner": "b"}}}
387
+        assert _field_disagreement(a, b) == 1.0
388
+
389
+    def test_lists_treated_as_leaves(self) -> None:
390
+        a = {"arguments": {"items": [1, 2, 3]}}
391
+        b = {"arguments": {"items": [3, 2, 1]}}
392
+        # Different ordering → list leaves differ → full disagreement.
393
+        assert _field_disagreement(a, b) == 1.0
394
+
395
+    def test_empty_args_zero(self) -> None:
396
+        assert _field_disagreement({"arguments": {}}, {"arguments": {}}) == 0.0
397
+
398
+    def test_non_dict_arguments_full_disagreement(self) -> None:
399
+        assert _field_disagreement({"arguments": "oops"}, {"arguments": {}}) == 1.0
400
+
401
+    def test_present_none_distinct_from_absent(self) -> None:
402
+        """A field with value ``None`` ≠ a field that's absent."""
403
+        a = {"arguments": {"x": None}}
404
+        b = {"arguments": {}}
405
+        assert _field_disagreement(a, b) == 1.0
406
+
407
+
408
+# ---------------------------------------------------------------------------
409
+# Calibration spec — sanity check the null-adapter handoff
410
+# ---------------------------------------------------------------------------
411
+
412
+
413
+class TestCalibrateSpec:
414
+    def test_calibrate_spec_returns_two_sentinels(self) -> None:
415
+        from dlm_sway.probes.tool_use_fidelity import ToolUseFidelityProbe
416
+
417
+        cspec = ToolUseFidelityProbe.calibrate_spec(RunContext())
418
+        assert cspec is not None
419
+        assert len(cspec.cases) == 2
420
+        assert cspec.kind == "tool_use_fidelity"