Python · 6330 bytes Raw Blame History
1 """Mismatch severity table for `dlm.lock` validation (Sprint 15).
2
3 Decides for each field how loudly a drift between the recorded lock
4 and the current runtime should surface:
5
6 - `ALLOW` — drift is fine; `dlm_sha256` is the obvious one (editing
7 the `.dlm` is the point).
8 - `WARN` — drift deserves a printed message but doesn't block the
9 run. `--strict-lock` upgrades all WARNs to ERROR.
10 - `ERROR` — drift blocks the run unless the operator explicitly
11 accepts via `--update-lock` (overwrite + continue) or
12 `--ignore-lock` (continue without touching the lock).
13
14 Each rule is a pure function from the prior & current `DlmLock` to
15 `(severity, message) | None`. The validator drives them all; adding
16 a new field only requires adding a rule here.
17 """
18
19 from __future__ import annotations
20
21 from collections.abc import Callable
22 from enum import Enum
23 from typing import Final
24
25 from packaging.version import InvalidVersion, Version
26
27 from dlm.lock.schema import DlmLock
28
29
30 class Severity(Enum):
31 """Three-state mismatch outcome."""
32
33 ALLOW = "allow"
34 WARN = "warn"
35 ERROR = "error"
36
37
38 # Rule signature: takes (prior, current), returns a mismatch description
39 # or `None` when the field hasn't drifted.
40 Rule = Callable[[DlmLock, DlmLock], tuple[Severity, str] | None]
41
42
43 # --- version parsing helper -------------------------------------------------
44
45
46 def _major(version_str: str) -> int | None:
47 """Extract the major version; return None on unparseable input.
48
49 Falling back to `None` is deliberate — non-semver packages (like
50 llama.cpp's `b8816` tag) should treat any change as a minor drift,
51 not a major-version error that blocks the run.
52 """
53 try:
54 return Version(version_str).major
55 except InvalidVersion:
56 return None
57
58
59 # --- per-field rules --------------------------------------------------------
60
61
62 def _rule_dlm_sha(prior: DlmLock, current: DlmLock) -> tuple[Severity, str] | None:
63 if prior.dlm_sha256 == current.dlm_sha256:
64 return None
65 # Editing the .dlm is the whole point of `dlm train` — never block.
66 return (
67 Severity.ALLOW,
68 f"dlm_sha256 changed ({prior.dlm_sha256[:12]}… → {current.dlm_sha256[:12]}…)",
69 )
70
71
72 def _rule_base_revision(prior: DlmLock, current: DlmLock) -> tuple[Severity, str] | None:
73 if prior.base_model_revision == current.base_model_revision:
74 return None
75 return (
76 Severity.ERROR,
77 f"base_model_revision changed ({prior.base_model_revision} → "
78 f"{current.base_model_revision}); re-run with --update-lock to accept",
79 )
80
81
82 def _rule_hardware_tier(prior: DlmLock, current: DlmLock) -> tuple[Severity, str] | None:
83 if prior.hardware_tier == current.hardware_tier:
84 return None
85 return (
86 Severity.WARN,
87 f"hardware_tier changed ({prior.hardware_tier}{current.hardware_tier}); "
88 "re-plan recommended",
89 )
90
91
92 def _rule_determinism_class(prior: DlmLock, current: DlmLock) -> tuple[Severity, str] | None:
93 if prior.determinism_class == current.determinism_class:
94 return None
95 return (
96 Severity.WARN,
97 f"determinism_class changed ({prior.determinism_class}{current.determinism_class})",
98 )
99
100
101 def _rule_determinism_flags(prior: DlmLock, current: DlmLock) -> tuple[Severity, str] | None:
102 if prior.determinism_flags == current.determinism_flags:
103 return None
104 return (Severity.WARN, "determinism_flags changed")
105
106
107 def _rule_torch_version(prior: DlmLock, current: DlmLock) -> tuple[Severity, str] | None:
108 prior_v = prior.pinned_versions.get("torch")
109 current_v = current.pinned_versions.get("torch")
110 if prior_v is None or current_v is None or prior_v == current_v:
111 return None
112 prior_major = _major(prior_v)
113 current_major = _major(current_v)
114 if prior_major is not None and current_major is not None and prior_major != current_major:
115 return (
116 Severity.ERROR,
117 f"torch major-version mismatch ({prior_v}{current_v})",
118 )
119 return (Severity.WARN, f"torch minor-version drift ({prior_v}{current_v})")
120
121
122 def _rule_bitsandbytes_any(prior: DlmLock, current: DlmLock) -> tuple[Severity, str] | None:
123 prior_v = prior.pinned_versions.get("bitsandbytes")
124 current_v = current.pinned_versions.get("bitsandbytes")
125 if prior_v == current_v:
126 return None
127 # Any bnb drift is a strong warning — QLoRA correctness is unusually
128 # sensitive to bnb kernels.
129 return (
130 Severity.WARN,
131 f"bitsandbytes changed ({prior_v!r}{current_v!r}); QLoRA kernels are version-sensitive",
132 )
133
134
135 def _rule_minor_peers(prior: DlmLock, current: DlmLock) -> list[tuple[Severity, str]]:
136 """WARN on drift for transformers / peft / trl / accelerate / llama_cpp."""
137 keys = ("transformers", "peft", "trl", "accelerate", "llama_cpp")
138 mismatches: list[tuple[Severity, str]] = []
139 for key in keys:
140 prior_v = prior.pinned_versions.get(key)
141 current_v = current.pinned_versions.get(key)
142 if prior_v == current_v:
143 continue
144 mismatches.append(
145 (Severity.WARN, f"{key} changed ({prior_v!r}{current_v!r})"),
146 )
147 return mismatches
148
149
150 # --- driver -----------------------------------------------------------------
151
152
153 DEFAULT_RULES: Final[tuple[Rule, ...]] = (
154 _rule_dlm_sha,
155 _rule_base_revision,
156 _rule_hardware_tier,
157 _rule_determinism_class,
158 _rule_determinism_flags,
159 _rule_torch_version,
160 _rule_bitsandbytes_any,
161 )
162
163
164 def classify_mismatches(
165 prior: DlmLock,
166 current: DlmLock,
167 *,
168 strict: bool = False,
169 ) -> list[tuple[Severity, str]]:
170 """Return every field-level mismatch between `prior` and `current`.
171
172 `strict=True` upgrades `WARN` results to `ERROR`. `ALLOW` is never
173 upgraded — allowed drift is allowed, even under strict mode
174 (otherwise `--strict-lock` would block every retrain).
175 """
176 results: list[tuple[Severity, str]] = []
177 for rule in DEFAULT_RULES:
178 outcome = rule(prior, current)
179 if outcome is not None:
180 results.append(outcome)
181 results.extend(_rule_minor_peers(prior, current))
182 if strict:
183 results = [(Severity.ERROR if sev is Severity.WARN else sev, msg) for sev, msg in results]
184 return results