Python · 22676 bytes Raw Blame History
1 """Tool-use fidelity — does the adapter preserve the base's tool-call format?
2
3 LoRA fine-tunes increasingly target tool-use behavior: function-calling
4 JSON schemas, MCP tool plans, code-execution gating. An adapter that
5 accidentally degrades JSON-schema validity or starts hallucinating tool
6 names while acing every other probe is a silent production failure.
7
8 For each ``(prompt, tool_spec, gold_tool_name)`` case the probe greedy-
9 generates from both views and computes three independent signals:
10
11 - **JSON-schema validity delta** — ``ft_valid_rate − base_valid_rate``.
12 Negative values mean the adapter degraded the base's tool-call
13 formatting; positive values mean it improved (or the base was
14 already tool-clueless).
15 - **Argument-field disagreement rate** — over cases where *both* views
16 produced a schema-valid call, the leaf-field disagreement rate
17 between ``ft_call.arguments`` and ``base_call.arguments``. Catches
18 numeric / string drift inside an otherwise-well-formed call.
19 (We deliberately ship leaf-field equality rather than per-token KL
20 on argument values: per-token KL requires alignment between two
21 decoded strings, which is the same hard problem we explicitly
22 defer past v1.)
23 - **Tool-name hallucination rate** — over schema-valid ft calls, the
24 fraction whose ``name`` is not in ``allowed_tools`` (when the user
25 declared a tool surface) or differs from ``gold_tool_name`` (when
26 no surface is declared).
27
28 Composite verdict logic:
29
30 - The pass criterion is compound — validity delta above the floor AND
31 hallucination rate below the cap. Argument disagreement is
32 informational on the v1 surface.
33 - ``json_valid_rate_ft`` is the metric that's z-scored against the
34 null-adapter baseline. A null adapter should produce essentially no
35 schema-valid calls, so ``z >= assert_z_gte`` is the principled
36 "the adapter actually preserved tool-call structure" claim.
37
38 JSON-schema check is deliberately minimal: this v1 implementation
39 validates ``required`` membership and per-field type tags from a tiny
40 OpenAI-flavored subset (``string``, ``number``, ``integer``, ``boolean``,
41 ``object``, ``array``). We don't pull in the ``jsonschema`` package —
42 the OpenAI-style spec is small enough to validate directly, and core
43 sway dependencies stay lean.
44 """
45
46 from __future__ import annotations
47
48 import json
49 import statistics
50 from typing import Any, Literal
51
52 from pydantic import BaseModel, ConfigDict, Field
53
54 from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize
55 from dlm_sway.core.stats import bootstrap_ci
56 from dlm_sway.probes._zscore import (
57 no_calibration_note,
58 score_from_z,
59 verdict_from_z,
60 z_score,
61 z_scores_by_rank,
62 )
63 from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
64 from dlm_sway.probes.null_adapter import get_null_stats, get_null_stats_by_rank
65
66
67 class ToolUseCase(BaseModel):
68 """One ``(prompt, tool_spec, gold_tool_name)`` case.
69
70 ``tool_spec`` follows the OpenAI function-calling shape::
71
72 {
73 "name": "search_web",
74 "description": "...",
75 "parameters": {
76 "type": "object",
77 "properties": {
78 "query": {"type": "string"},
79 "max_results": {"type": "integer"}
80 },
81 "required": ["query"]
82 }
83 }
84
85 ``gold_tool_name`` is the tool the case expects ft to call. It's
86 used by the hallucination check when no broader ``allowed_tools``
87 list is declared on the spec.
88 """
89
90 model_config = ConfigDict(extra="forbid", frozen=True)
91
92 prompt: str
93 tool_spec: dict[str, Any]
94 gold_tool_name: str
95 max_new_tokens: int = 256
96
97
98 class ToolUseFidelitySpec(ProbeSpec):
99 """Spec for ``kind: tool_use_fidelity``."""
100
101 kind: Literal["tool_use_fidelity"] = "tool_use_fidelity"
102 cases: list[ToolUseCase] = Field(default_factory=list, min_length=0)
103 """Inline cases. Empty list → probe SKIPs (the .dlm autogen path
104 leaves this empty unless a tool-use template seeded the doc)."""
105 allowed_tools: list[str] | None = None
106 """Optional tool-surface declaration. When set, hallucination is
107 ``ft.name not in allowed_tools``. When ``None``, hallucination is
108 ``ft.name != case.gold_tool_name`` per-case."""
109 assert_validity_delta_gte: float = -0.05
110 """Pass criterion on JSON-schema validity. Default tolerates a 5pp
111 drop from base — anything worse is an adapter regression."""
112 assert_hallucination_lte: float = 0.10
113 """Pass criterion on tool-name hallucination. Default 10% — above
114 that the adapter is actively inventing tools."""
115 assert_z_gte: float = 3.0
116 """Z-score pass criterion on ``json_valid_rate_ft`` against the
117 null-adapter baseline. The principled signal — preferred over
118 the raw thresholds when null calibration ran."""
119
120
121 class ToolUseFidelityProbe(Probe):
122 """The "did the LoRA preserve tool-call format?" probe."""
123
124 kind = "tool_use_fidelity"
125 spec_cls = ToolUseFidelitySpec
126 category = "attribution"
127
128 @classmethod
129 def calibrate_spec(cls, ctx: RunContext) -> ToolUseFidelitySpec | None:
130 """Two trivial sentinel cases for null calibration.
131
132 A null (random-noise) adapter should produce essentially no
133 schema-valid output here, so ``json_valid_rate`` clusters
134 tightly near zero — exactly what we want as the denominator
135 when z-scoring the real adapter's validity rate.
136 """
137 del ctx
138 return ToolUseFidelitySpec(
139 name="_calibration",
140 kind="tool_use_fidelity",
141 cases=[
142 ToolUseCase(
143 prompt="Search the web for the capital of France.",
144 tool_spec={
145 "name": "search_web",
146 "parameters": {
147 "type": "object",
148 "properties": {"query": {"type": "string"}},
149 "required": ["query"],
150 },
151 },
152 gold_tool_name="search_web",
153 ),
154 ToolUseCase(
155 prompt="Add 2 and 2.",
156 tool_spec={
157 "name": "calculator",
158 "parameters": {
159 "type": "object",
160 "properties": {
161 "a": {"type": "number"},
162 "b": {"type": "number"},
163 },
164 "required": ["a", "b"],
165 },
166 },
167 gold_tool_name="calculator",
168 ),
169 ],
170 )
171
172 def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
173 assert isinstance(spec, ToolUseFidelitySpec)
174 if not spec.cases:
175 return ProbeResult(
176 name=spec.name,
177 kind=spec.kind,
178 verdict=Verdict.SKIP,
179 score=None,
180 message="no tool-use cases (inline 'cases' was empty)",
181 )
182
183 base_valid: list[bool] = []
184 ft_valid: list[bool] = []
185 ft_calls: list[dict[str, Any] | None] = []
186 base_calls: list[dict[str, Any] | None] = []
187
188 # Greedy decode each case under both views. We open one base
189 # context for the whole sweep and one ft context — same shape
190 # as preference_flip's per-triple loop, but at the case level
191 # so a single backend toggle covers all N cases.
192 with ctx.require_backend.as_base() as base_view:
193 base_outputs = [
194 base_view.generate(c.prompt, max_new_tokens=c.max_new_tokens) for c in spec.cases
195 ]
196 with ctx.require_backend.as_finetuned() as ft_view:
197 ft_outputs = [
198 ft_view.generate(c.prompt, max_new_tokens=c.max_new_tokens) for c in spec.cases
199 ]
200
201 for case, b_out, f_out in zip(spec.cases, base_outputs, ft_outputs, strict=True):
202 b_call = _parse_tool_call(b_out)
203 f_call = _parse_tool_call(f_out)
204 base_calls.append(b_call)
205 ft_calls.append(f_call)
206 base_valid.append(b_call is not None and _matches_schema(b_call, case.tool_spec))
207 ft_valid.append(f_call is not None and _matches_schema(f_call, case.tool_spec))
208
209 n = len(spec.cases)
210 base_valid_rate = sum(base_valid) / n
211 ft_valid_rate = sum(ft_valid) / n
212 validity_delta = ft_valid_rate - base_valid_rate
213
214 # Argument-disagreement: only computed where BOTH views produced
215 # a schema-valid call. Otherwise the comparison is between a
216 # call and a non-call, which the probe scores via the validity
217 # path instead.
218 agreed_pairs = [
219 (bc, fc)
220 for bc, fc, bv, fv in zip(base_calls, ft_calls, base_valid, ft_valid, strict=True)
221 if bv and fv and bc is not None and fc is not None
222 ]
223 if agreed_pairs:
224 disagreement_per_pair = [_field_disagreement(bc, fc) for bc, fc in agreed_pairs]
225 mean_arg_disagreement = statistics.fmean(disagreement_per_pair)
226 else:
227 mean_arg_disagreement = 0.0
228
229 # Hallucination: over schema-valid ft calls, what fraction call
230 # the wrong tool? Denominator excludes invalid calls so
231 # validity-failure double-counts don't pollute the metric.
232 ft_valid_calls = [
233 (case, fc)
234 for case, fc, fv in zip(spec.cases, ft_calls, ft_valid, strict=True)
235 if fv and fc is not None
236 ]
237 if ft_valid_calls:
238 hallucinated = [
239 not _name_allowed(fc.get("name", ""), case.gold_tool_name, spec.allowed_tools)
240 for case, fc in ft_valid_calls
241 ]
242 hallucination_rate = sum(hallucinated) / len(ft_valid_calls)
243 else:
244 hallucination_rate = 0.0
245
246 # Z-score the ft validity rate against the null baseline. CIs
247 # are over the per-case validity flags so they reflect sampling
248 # noise on the proportion estimate.
249 stats = get_null_stats(ctx, spec.kind)
250 z = z_score(ft_valid_rate, stats)
251 z_by_rank = z_scores_by_rank(ft_valid_rate, get_null_stats_by_rank(ctx, spec.kind), sign=+1)
252 ci_95 = bootstrap_ci([1.0 if v else 0.0 for v in ft_valid], seed=ctx.seed)
253
254 # Compound pass criterion: every gate must hold. The z-score
255 # path subsumes the validity-delta gate when null calibration
256 # ran (a 3σ-significant validity rate is by construction a
257 # rate above noise), but we apply both so a probe with z=4σ
258 # but a -20pp validity delta still flags as a regression.
259 validity_pass = validity_delta >= spec.assert_validity_delta_gte
260 hallucination_pass = hallucination_rate <= spec.assert_hallucination_lte
261 verdict_z = verdict_from_z(z, spec.assert_z_gte)
262
263 if verdict_z is not None:
264 z_path_pass = verdict_z == Verdict.PASS
265 verdict = (
266 Verdict.PASS
267 if (z_path_pass and validity_pass and hallucination_pass)
268 else Verdict.FAIL
269 )
270 base_msg = (
271 f"validity ft={ft_valid_rate:.0%} (base={base_valid_rate:.0%}, "
272 f"Δ={validity_delta:+.0%}), hallucination={hallucination_rate:.0%}, "
273 f"z={z:+.2f}σ vs null"
274 )
275 else:
276 verdict = Verdict.PASS if (validity_pass and hallucination_pass) else Verdict.FAIL
277 base_msg = (
278 f"validity ft={ft_valid_rate:.0%} (base={base_valid_rate:.0%}, "
279 f"Δ={validity_delta:+.0%}), hallucination={hallucination_rate:.0%} "
280 f"{no_calibration_note(spec.kind)}"
281 )
282
283 # Composite score blends the three signals. Weights tuned so
284 # that perfect validity preservation + zero hallucination
285 # → 1.0; full hallucination or full validity collapse → 0.0.
286 validity_factor = max(0.0, min(1.0, 1.0 + validity_delta))
287 hallucination_factor = max(0.0, 1.0 - hallucination_rate)
288 score_raw = validity_factor * hallucination_factor
289 # Z-score boost only applies when calibration is available
290 # — scaling toward score_from_z(z) so a strongly-significant
291 # adapter scores higher than a marginally-significant one
292 # even at the same raw rates.
293 z_score_val = score_from_z(z) if z is not None else None
294 score = 0.7 * score_raw + 0.3 * z_score_val if z_score_val is not None else score_raw
295
296 return safe_finalize(
297 name=spec.name,
298 kind=spec.kind,
299 verdict=verdict,
300 score=score,
301 raw=ft_valid_rate,
302 z_score=z,
303 base_value=base_valid_rate,
304 ft_value=ft_valid_rate,
305 evidence={
306 "json_valid_rate_base": base_valid_rate,
307 "json_valid_rate_ft": ft_valid_rate,
308 "validity_delta": validity_delta,
309 "mean_arg_disagreement": mean_arg_disagreement,
310 "hallucination_rate": hallucination_rate,
311 "num_cases": n,
312 "num_arg_pairs_compared": len(agreed_pairs),
313 "z_by_rank": z_by_rank,
314 "raw_ci_95": list(ci_95) if ci_95 is not None else None,
315 "weight": spec.weight,
316 },
317 message=base_msg,
318 ci_95=ci_95,
319 )
320
321
322 # ---------------------------------------------------------------------------
323 # Tool-call parsing + minimal schema validator
324 # ---------------------------------------------------------------------------
325
326
327 def _parse_tool_call(text: str) -> dict[str, Any] | None:
328 """Extract ``{"name": ..., "arguments": {...}}`` from a generation.
329
330 Tries three increasingly forgiving strategies:
331
332 1. The whole text parses as JSON and is a dict.
333 2. The text contains a ``{...}`` substring that parses as JSON.
334 3. The text contains a fenced code block (``\\`\\`\\`json ... \\`\\`\\```)
335 whose body parses as JSON.
336
337 Returns ``None`` when no parse succeeds OR when the parsed object
338 isn't a dict — both base and ft models often wander prose without
339 producing structured calls; that's exactly the failure mode the
340 probe scores.
341
342 The returned dict isn't validated yet — callers that care about
343 schema conformance run :func:`_matches_schema` separately so the
344 "valid JSON, wrong shape" failure mode is observable downstream.
345 """
346 text = text.strip()
347 # Strategy 1.
348 try:
349 obj = json.loads(text)
350 except (ValueError, TypeError):
351 obj = None
352 if isinstance(obj, dict):
353 return obj
354
355 # Strategy 3 — fenced ```json ... ``` block first because its
356 # boundaries are unambiguous; embedded ``{...}`` heuristics in
357 # strategy 2 sometimes match a JSON-looking fragment inside a
358 # fenced block's body.
359 fenced = _extract_fenced_json(text)
360 if fenced is not None:
361 try:
362 obj = json.loads(fenced)
363 except (ValueError, TypeError):
364 obj = None
365 if isinstance(obj, dict):
366 return obj
367
368 # Strategy 2 — first balanced ``{...}`` substring.
369 candidate = _first_balanced_braces(text)
370 if candidate is not None:
371 try:
372 obj = json.loads(candidate)
373 except (ValueError, TypeError):
374 obj = None
375 if isinstance(obj, dict):
376 return obj
377
378 return None
379
380
381 def _extract_fenced_json(text: str) -> str | None:
382 """Return the body of a ```json ... ``` block, if any."""
383 marker = "```"
384 start = text.find(marker)
385 if start < 0:
386 return None
387 # Skip the opening fence (and an optional ``json`` language tag).
388 body_start = start + len(marker)
389 if text[body_start : body_start + 4].lower() == "json":
390 body_start += 4
391 # Strip a single trailing newline after the language tag.
392 if body_start < len(text) and text[body_start] == "\n":
393 body_start += 1
394 end = text.find(marker, body_start)
395 if end < 0:
396 return None
397 return text[body_start:end].strip()
398
399
400 def _first_balanced_braces(text: str) -> str | None:
401 """Return the first balanced ``{...}`` substring, or ``None``.
402
403 Tracks string-literal context so a brace inside a string doesn't
404 throw off the depth counter. Quadratic worst case is fine — the
405 inputs are model generations capped at a few hundred tokens.
406 """
407 depth = 0
408 in_string = False
409 escape = False
410 start = -1
411 for i, ch in enumerate(text):
412 if in_string:
413 if escape:
414 escape = False
415 elif ch == "\\":
416 escape = True
417 elif ch == '"':
418 in_string = False
419 continue
420 if ch == '"':
421 in_string = True
422 continue
423 if ch == "{":
424 if depth == 0:
425 start = i
426 depth += 1
427 elif ch == "}":
428 depth -= 1
429 if depth == 0 and start >= 0:
430 return text[start : i + 1]
431 if depth < 0:
432 # Unbalanced — reset and keep scanning.
433 depth = 0
434 start = -1
435 return None
436
437
438 # Type tags we recognize. Matches the OpenAI function-call subset; we
439 # accept ``int`` for ``"integer"`` and ``int|float`` for ``"number"``.
440 _TYPE_CHECKS: dict[str, tuple[type, ...]] = {
441 "string": (str,),
442 "integer": (int,),
443 "number": (int, float),
444 "boolean": (bool,),
445 "object": (dict,),
446 "array": (list,),
447 }
448
449
450 def _matches_schema(call: dict[str, Any], tool_spec: dict[str, Any]) -> bool:
451 """Light OpenAI-flavored JSON-schema check.
452
453 Validates:
454
455 - The call has a ``name`` field (string) — the model is supposed to
456 identify the tool it's calling.
457 - The call has an ``arguments`` field (dict) — the OpenAI shape.
458 - Every name in ``parameters.required`` appears in ``arguments``.
459 - Every field declared in ``parameters.properties`` whose value is
460 present has a type that matches the property's ``type`` tag.
461
462 Doesn't validate: enums, ``minLength``/``maxLength``, nested
463 schemas, ``oneOf``/``anyOf``. Those land in a follow-up sprint
464 when users push tool specs that need them — the v1 surface is
465 deliberately the OpenAI-tutorial shape.
466
467 Booleans-as-integers caveat: Python's ``isinstance(True, int)``
468 is True, so we explicitly reject bools when the schema asks for
469 integer/number. Otherwise an adapter that emits ``true`` would
470 silently pass an "integer" check.
471 """
472 if not isinstance(call.get("name"), str):
473 return False
474 args = call.get("arguments")
475 if not isinstance(args, dict):
476 return False
477
478 parameters = tool_spec.get("parameters") or {}
479 properties = parameters.get("properties") or {}
480 required = parameters.get("required") or []
481
482 for r in required:
483 if r not in args:
484 return False
485
486 for field_name, field_value in args.items():
487 prop = properties.get(field_name)
488 if not prop:
489 # Unknown field — OpenAI's actual function-call rules
490 # forbid this, but we permit it because real models often
491 # add benign extra fields. Validity stays True.
492 continue
493 type_tag = prop.get("type")
494 if type_tag is None:
495 continue
496 allowed = _TYPE_CHECKS.get(type_tag)
497 if allowed is None:
498 continue
499 if isinstance(field_value, bool) and bool not in allowed:
500 return False
501 if not isinstance(field_value, allowed):
502 return False
503 return True
504
505
506 def _name_allowed(name: str, gold: str, allowed_tools: list[str] | None) -> bool:
507 """Return True iff the model's tool ``name`` is acceptable.
508
509 Rules:
510
511 - ``allowed_tools`` declared → ``name`` must be in the list.
512 - Otherwise → ``name`` must equal ``gold``.
513
514 The split exists because some users want a per-case strict gold
515 (calibrator is checking "did the model pick the *right* tool")
516 while others want a surface gate ("did the model stay inside the
517 declared tool list").
518 """
519 if allowed_tools is not None:
520 return name in allowed_tools
521 return name == gold
522
523
524 def _field_disagreement(base_call: dict[str, Any], ft_call: dict[str, Any]) -> float:
525 """Leaf-field disagreement rate between two parsed tool calls.
526
527 Walks the union of leaf paths in ``base_call.arguments`` and
528 ``ft_call.arguments``; counts a disagreement for each leaf where
529 the values differ (or the path exists on only one side). Rate is
530 ``disagreements / total_leaves``; returns ``0.0`` when both
531 arguments dicts are empty (a vacuously-perfect agreement).
532
533 Intentionally simple: no semantic equality (``"2.0"`` ≠ ``2.0``
534 by design — drift in numeric *type* counts as drift). Probes
535 that want softer comparison (numeric tolerance, embedding
536 distance on strings) build on this in a follow-up sprint.
537 """
538 base_args = base_call.get("arguments") or {}
539 ft_args = ft_call.get("arguments") or {}
540 if not isinstance(base_args, dict) or not isinstance(ft_args, dict):
541 # One side wasn't a real arguments dict; treat as full disagreement.
542 return 1.0
543 base_leaves = dict(_walk_leaves(base_args))
544 ft_leaves = dict(_walk_leaves(ft_args))
545 all_paths = set(base_leaves) | set(ft_leaves)
546 if not all_paths:
547 return 0.0
548 disagreements = sum(
549 1 for p in all_paths if base_leaves.get(p, _MISSING) != ft_leaves.get(p, _MISSING)
550 )
551 return disagreements / len(all_paths)
552
553
554 _MISSING = object()
555 """Sentinel for "field not present" in :func:`_field_disagreement`. A
556 leaf with value ``None`` should NOT compare equal to a leaf that's
557 absent — using a private singleton makes the distinction explicit
558 where ``dict.get(k)`` would conflate the two."""
559
560
561 def _walk_leaves(obj: Any, path: tuple[str, ...] = ()) -> Any:
562 """Yield ``(path, leaf_value)`` pairs for nested dicts.
563
564 Lists are leaves themselves (not recursed into) — list-element
565 drift would otherwise dominate the disagreement metric in cases
566 where the model legitimately returns the same set in a different
567 order, which we don't want to flag as fidelity loss.
568 """
569 if isinstance(obj, dict):
570 for k, v in obj.items():
571 yield from _walk_leaves(v, path + (str(k),))
572 else:
573 yield path, obj