| 1 |
"""Mismatch severity table for `dlm.lock` validation (Sprint 15). |
| 2 |
|
| 3 |
Decides for each field how loudly a drift between the recorded lock |
| 4 |
and the current runtime should surface: |
| 5 |
|
| 6 |
- `ALLOW` — drift is fine; `dlm_sha256` is the obvious one (editing |
| 7 |
the `.dlm` is the point). |
| 8 |
- `WARN` — drift deserves a printed message but doesn't block the |
| 9 |
run. `--strict-lock` upgrades all WARNs to ERROR. |
| 10 |
- `ERROR` — drift blocks the run unless the operator explicitly |
| 11 |
accepts via `--update-lock` (overwrite + continue) or |
| 12 |
`--ignore-lock` (continue without touching the lock). |
| 13 |
|
| 14 |
Each rule is a pure function from the prior & current `DlmLock` to |
| 15 |
`(severity, message) | None`. The validator drives them all; adding |
| 16 |
a new field only requires adding a rule here. |
| 17 |
""" |
| 18 |
|
| 19 |
from __future__ import annotations |
| 20 |
|
| 21 |
from collections.abc import Callable |
| 22 |
from enum import Enum |
| 23 |
from typing import Final |
| 24 |
|
| 25 |
from packaging.version import InvalidVersion, Version |
| 26 |
|
| 27 |
from dlm.lock.schema import DlmLock |
| 28 |
|
| 29 |
|
| 30 |
class Severity(Enum): |
| 31 |
"""Three-state mismatch outcome.""" |
| 32 |
|
| 33 |
ALLOW = "allow" |
| 34 |
WARN = "warn" |
| 35 |
ERROR = "error" |
| 36 |
|
| 37 |
|
| 38 |
# Rule signature: takes (prior, current), returns a mismatch description |
| 39 |
# or `None` when the field hasn't drifted. |
| 40 |
Rule = Callable[[DlmLock, DlmLock], tuple[Severity, str] | None] |
| 41 |
|
| 42 |
|
| 43 |
# --- version parsing helper ------------------------------------------------- |
| 44 |
|
| 45 |
|
| 46 |
def _major(version_str: str) -> int | None: |
| 47 |
"""Extract the major version; return None on unparseable input. |
| 48 |
|
| 49 |
Falling back to `None` is deliberate — non-semver packages (like |
| 50 |
llama.cpp's `b8816` tag) should treat any change as a minor drift, |
| 51 |
not a major-version error that blocks the run. |
| 52 |
""" |
| 53 |
try: |
| 54 |
return Version(version_str).major |
| 55 |
except InvalidVersion: |
| 56 |
return None |
| 57 |
|
| 58 |
|
| 59 |
# --- per-field rules -------------------------------------------------------- |
| 60 |
|
| 61 |
|
| 62 |
def _rule_dlm_sha(prior: DlmLock, current: DlmLock) -> tuple[Severity, str] | None: |
| 63 |
if prior.dlm_sha256 == current.dlm_sha256: |
| 64 |
return None |
| 65 |
# Editing the .dlm is the whole point of `dlm train` — never block. |
| 66 |
return ( |
| 67 |
Severity.ALLOW, |
| 68 |
f"dlm_sha256 changed ({prior.dlm_sha256[:12]}… → {current.dlm_sha256[:12]}…)", |
| 69 |
) |
| 70 |
|
| 71 |
|
| 72 |
def _rule_base_revision(prior: DlmLock, current: DlmLock) -> tuple[Severity, str] | None: |
| 73 |
if prior.base_model_revision == current.base_model_revision: |
| 74 |
return None |
| 75 |
return ( |
| 76 |
Severity.ERROR, |
| 77 |
f"base_model_revision changed ({prior.base_model_revision} → " |
| 78 |
f"{current.base_model_revision}); re-run with --update-lock to accept", |
| 79 |
) |
| 80 |
|
| 81 |
|
| 82 |
def _rule_hardware_tier(prior: DlmLock, current: DlmLock) -> tuple[Severity, str] | None: |
| 83 |
if prior.hardware_tier == current.hardware_tier: |
| 84 |
return None |
| 85 |
return ( |
| 86 |
Severity.WARN, |
| 87 |
f"hardware_tier changed ({prior.hardware_tier} → {current.hardware_tier}); " |
| 88 |
"re-plan recommended", |
| 89 |
) |
| 90 |
|
| 91 |
|
| 92 |
def _rule_determinism_class(prior: DlmLock, current: DlmLock) -> tuple[Severity, str] | None: |
| 93 |
if prior.determinism_class == current.determinism_class: |
| 94 |
return None |
| 95 |
return ( |
| 96 |
Severity.WARN, |
| 97 |
f"determinism_class changed ({prior.determinism_class} → {current.determinism_class})", |
| 98 |
) |
| 99 |
|
| 100 |
|
| 101 |
def _rule_determinism_flags(prior: DlmLock, current: DlmLock) -> tuple[Severity, str] | None: |
| 102 |
if prior.determinism_flags == current.determinism_flags: |
| 103 |
return None |
| 104 |
return (Severity.WARN, "determinism_flags changed") |
| 105 |
|
| 106 |
|
| 107 |
def _rule_torch_version(prior: DlmLock, current: DlmLock) -> tuple[Severity, str] | None: |
| 108 |
prior_v = prior.pinned_versions.get("torch") |
| 109 |
current_v = current.pinned_versions.get("torch") |
| 110 |
if prior_v is None or current_v is None or prior_v == current_v: |
| 111 |
return None |
| 112 |
prior_major = _major(prior_v) |
| 113 |
current_major = _major(current_v) |
| 114 |
if prior_major is not None and current_major is not None and prior_major != current_major: |
| 115 |
return ( |
| 116 |
Severity.ERROR, |
| 117 |
f"torch major-version mismatch ({prior_v} → {current_v})", |
| 118 |
) |
| 119 |
return (Severity.WARN, f"torch minor-version drift ({prior_v} → {current_v})") |
| 120 |
|
| 121 |
|
| 122 |
def _rule_bitsandbytes_any(prior: DlmLock, current: DlmLock) -> tuple[Severity, str] | None: |
| 123 |
prior_v = prior.pinned_versions.get("bitsandbytes") |
| 124 |
current_v = current.pinned_versions.get("bitsandbytes") |
| 125 |
if prior_v == current_v: |
| 126 |
return None |
| 127 |
# Any bnb drift is a strong warning — QLoRA correctness is unusually |
| 128 |
# sensitive to bnb kernels. |
| 129 |
return ( |
| 130 |
Severity.WARN, |
| 131 |
f"bitsandbytes changed ({prior_v!r} → {current_v!r}); QLoRA kernels are version-sensitive", |
| 132 |
) |
| 133 |
|
| 134 |
|
| 135 |
def _rule_minor_peers(prior: DlmLock, current: DlmLock) -> list[tuple[Severity, str]]: |
| 136 |
"""WARN on drift for transformers / peft / trl / accelerate / llama_cpp.""" |
| 137 |
keys = ("transformers", "peft", "trl", "accelerate", "llama_cpp") |
| 138 |
mismatches: list[tuple[Severity, str]] = [] |
| 139 |
for key in keys: |
| 140 |
prior_v = prior.pinned_versions.get(key) |
| 141 |
current_v = current.pinned_versions.get(key) |
| 142 |
if prior_v == current_v: |
| 143 |
continue |
| 144 |
mismatches.append( |
| 145 |
(Severity.WARN, f"{key} changed ({prior_v!r} → {current_v!r})"), |
| 146 |
) |
| 147 |
return mismatches |
| 148 |
|
| 149 |
|
| 150 |
# --- driver ----------------------------------------------------------------- |
| 151 |
|
| 152 |
|
| 153 |
DEFAULT_RULES: Final[tuple[Rule, ...]] = ( |
| 154 |
_rule_dlm_sha, |
| 155 |
_rule_base_revision, |
| 156 |
_rule_hardware_tier, |
| 157 |
_rule_determinism_class, |
| 158 |
_rule_determinism_flags, |
| 159 |
_rule_torch_version, |
| 160 |
_rule_bitsandbytes_any, |
| 161 |
) |
| 162 |
|
| 163 |
|
| 164 |
def classify_mismatches( |
| 165 |
prior: DlmLock, |
| 166 |
current: DlmLock, |
| 167 |
*, |
| 168 |
strict: bool = False, |
| 169 |
) -> list[tuple[Severity, str]]: |
| 170 |
"""Return every field-level mismatch between `prior` and `current`. |
| 171 |
|
| 172 |
`strict=True` upgrades `WARN` results to `ERROR`. `ALLOW` is never |
| 173 |
upgraded — allowed drift is allowed, even under strict mode |
| 174 |
(otherwise `--strict-lock` would block every retrain). |
| 175 |
""" |
| 176 |
results: list[tuple[Severity, str]] = [] |
| 177 |
for rule in DEFAULT_RULES: |
| 178 |
outcome = rule(prior, current) |
| 179 |
if outcome is not None: |
| 180 |
results.append(outcome) |
| 181 |
results.extend(_rule_minor_peers(prior, current)) |
| 182 |
if strict: |
| 183 |
results = [(Severity.ERROR if sev is Severity.WARN else sev, msg) for sev, msg in results] |
| 184 |
return results |