style: ruff format pass + embed_warmup ARG002 ignore for HF callback protocol
- SHA
b6deabb764cf7506bb56284fad6d257aa76ce586- Parents
-
19f5a6c - Tree
557a338
b6deabb
b6deabb764cf7506bb56284fad6d257aa76ce58619f5a6c
557a338pyproject.tomlmodified@@ -137,6 +137,10 @@ ignore = [ | ||
| 137 | 137 | # will take so `--help` reflects the shipping surface — even though the |
| 138 | 138 | # stub body discards them. |
| 139 | 139 | "src/dlm/cli/commands.py" = ["ARG001"] |
| 140 | +# HuggingFace Trainer callbacks MUST accept `args`/`state`/`control` | |
| 141 | +# positionally even when the implementation only reads some of them — | |
| 142 | +# HF dispatches them by position. ARG002 for these wrappers is noise. | |
| 143 | +"src/dlm/train/cpt/embed_warmup.py" = ["ARG002"] | |
| 140 | 144 | |
| 141 | 145 | [tool.ruff.format] |
| 142 | 146 | quote-style = "double" |
src/dlm/base_models/probes.pymodified@@ -378,9 +378,7 @@ def probe_pretokenizer_hash( | ||
| 378 | 378 | # --- aggregate --------------------------------------------------------------- |
| 379 | 379 | |
| 380 | 380 | |
| 381 | -def run_all( | |
| 382 | - spec: BaseModelSpec, *, skip_export_probes: bool = False | |
| 383 | -) -> ProbeReport: | |
| 381 | +def run_all(spec: BaseModelSpec, *, skip_export_probes: bool = False) -> ProbeReport: | |
| 384 | 382 | """Run every probe; aggregate into a `ProbeReport`. |
| 385 | 383 | |
| 386 | 384 | `GatedModelError` from an individual probe propagates immediately — |
src/dlm/cli/commands.pymodified@@ -139,9 +139,7 @@ def init_cmd( | ||
| 139 | 139 | from dlm.templates import TemplateError, apply_template |
| 140 | 140 | |
| 141 | 141 | try: |
| 142 | - applied_result = apply_template( | |
| 143 | - template, path, force=force, accept_license=True | |
| 144 | - ) | |
| 142 | + applied_result = apply_template(template, path, force=force, accept_license=True) | |
| 145 | 143 | except TemplateError as exc: |
| 146 | 144 | console.print(f"[red]init:[/red] {exc}") |
| 147 | 145 | raise typer.Exit(code=1) from exc |
@@ -1494,9 +1492,7 @@ def metrics_cmd( | ||
| 1494 | 1492 | parsed = parse_file(path) |
| 1495 | 1493 | store = for_dlm(parsed.frontmatter.dlm_id) |
| 1496 | 1494 | |
| 1497 | - runs = recent_runs( | |
| 1498 | - store.root, limit=limit, phase=phase, since=since_delta, run_id=run_id | |
| 1499 | - ) | |
| 1495 | + runs = recent_runs(store.root, limit=limit, phase=phase, since=since_delta, run_id=run_id) | |
| 1500 | 1496 | |
| 1501 | 1497 | if run_id is not None: |
| 1502 | 1498 | # Drill-down: show this run's steps + evals. |
@@ -1520,9 +1516,7 @@ def metrics_cmd( | ||
| 1520 | 1516 | writer.writerow(["step", "loss", "lr", "grad_norm", "val_loss"]) |
| 1521 | 1517 | eval_by_step = {e.step: e.val_loss for e in evals} |
| 1522 | 1518 | for s in steps: |
| 1523 | - writer.writerow( | |
| 1524 | - [s.step, s.loss, s.lr, s.grad_norm, eval_by_step.get(s.step)] | |
| 1525 | - ) | |
| 1519 | + writer.writerow([s.step, s.loss, s.lr, s.grad_norm, eval_by_step.get(s.step)]) | |
| 1526 | 1520 | return |
| 1527 | 1521 | console.print( |
| 1528 | 1522 | f"[green]run_id={run.run_id}[/green] phase={run.phase} " |
@@ -1545,9 +1539,7 @@ def metrics_cmd( | ||
| 1545 | 1539 | writer = csv.writer(sys.stdout) |
| 1546 | 1540 | writer.writerow(["run_id", "phase", "seed", "status", "started_at", "ended_at"]) |
| 1547 | 1541 | for r in runs: |
| 1548 | - writer.writerow( | |
| 1549 | - [r.run_id, r.phase, r.seed, r.status, r.started_at, r.ended_at] | |
| 1550 | - ) | |
| 1542 | + writer.writerow([r.run_id, r.phase, r.seed, r.status, r.started_at, r.ended_at]) | |
| 1551 | 1543 | return |
| 1552 | 1544 | |
| 1553 | 1545 | if not runs: |
@@ -1690,8 +1682,8 @@ def show_cmd( | ||
| 1690 | 1682 | raise typer.Exit(code=1) from exc |
| 1691 | 1683 | |
| 1692 | 1684 | store = for_dlm(parsed.frontmatter.dlm_id) |
| 1693 | - training_sources, discovered_configs = ( | |
| 1694 | - _summarize_training_sources_and_discovered(parsed, path.resolve().parent) | |
| 1685 | + training_sources, discovered_configs = _summarize_training_sources_and_discovered( | |
| 1686 | + parsed, path.resolve().parent | |
| 1695 | 1687 | ) |
| 1696 | 1688 | # Store may not exist yet (no `dlm train` run). Treat that as an |
| 1697 | 1689 | # informational state rather than an error — useful after `dlm init`. |
@@ -1830,9 +1822,7 @@ def _human_size(n: int) -> str: | ||
| 1830 | 1822 | return f"{n} PB" |
| 1831 | 1823 | |
| 1832 | 1824 | |
| 1833 | -def _summarize_training_sources( | |
| 1834 | - parsed: object, base_path: Path | |
| 1835 | -) -> list[dict[str, object]] | None: | |
| 1825 | +def _summarize_training_sources(parsed: object, base_path: Path) -> list[dict[str, object]] | None: | |
| 1836 | 1826 | """Best-effort resolution of `training.sources` for `dlm show`. |
| 1837 | 1827 | |
| 1838 | 1828 | Returns None when the frontmatter declares no directives; returns |
@@ -1905,9 +1895,7 @@ def _summarize_training_sources_and_discovered( | ||
| 1905 | 1895 | "has_ignore": bool(dc.ignore_rules), |
| 1906 | 1896 | "include": list(dc.config.include) if dc.config else [], |
| 1907 | 1897 | "exclude": list(dc.config.exclude) if dc.config else [], |
| 1908 | - "exclude_defaults": ( | |
| 1909 | - dc.config.exclude_defaults if dc.config else True | |
| 1910 | - ), | |
| 1898 | + "exclude_defaults": (dc.config.exclude_defaults if dc.config else True), | |
| 1911 | 1899 | "metadata": dict(dc.config.metadata) if dc.config else {}, |
| 1912 | 1900 | "ignore_rules": len(dc.ignore_rules), |
| 1913 | 1901 | } |
@@ -1915,9 +1903,7 @@ def _summarize_training_sources_and_discovered( | ||
| 1915 | 1903 | return records, discovered_records |
| 1916 | 1904 | |
| 1917 | 1905 | |
| 1918 | -def _summarize_training_cache( | |
| 1919 | - cache_dir: Path, store_root: Path | |
| 1920 | -) -> dict[str, object] | None: | |
| 1906 | +def _summarize_training_cache(cache_dir: Path, store_root: Path) -> dict[str, object] | None: | |
| 1921 | 1907 | """Return a JSON-friendly snapshot of the tokenized-section cache. |
| 1922 | 1908 | |
| 1923 | 1909 | None when the cache dir doesn't exist (store never trained with |
@@ -1956,9 +1942,7 @@ def _render_training_cache_text(console: object, snap: dict[str, object]) -> Non | ||
| 1956 | 1942 | console.print(f" last hit rate: {float(rate):.1%}") |
| 1957 | 1943 | |
| 1958 | 1944 | |
| 1959 | -def _render_training_sources_text( | |
| 1960 | - console: object, records: list[dict[str, object]] | |
| 1961 | -) -> None: | |
| 1945 | +def _render_training_sources_text(console: object, records: list[dict[str, object]]) -> None: | |
| 1962 | 1946 | from rich.console import Console |
| 1963 | 1947 | |
| 1964 | 1948 | assert isinstance(console, Console) |
@@ -2110,9 +2094,7 @@ def push_cmd( | ||
| 2110 | 2094 | bool, |
| 2111 | 2095 | typer.Option("--sign", help="Sign the pack with minisign before upload."), |
| 2112 | 2096 | ] = False, |
| 2113 | - include_exports: Annotated[ | |
| 2114 | - bool, typer.Option("--include-exports") | |
| 2115 | - ] = False, | |
| 2097 | + include_exports: Annotated[bool, typer.Option("--include-exports")] = False, | |
| 2116 | 2098 | include_base: Annotated[bool, typer.Option("--include-base")] = False, |
| 2117 | 2099 | include_logs: Annotated[bool, typer.Option("--include-logs")] = False, |
| 2118 | 2100 | licensee: Annotated[ |
@@ -2149,10 +2131,7 @@ def push_cmd( | ||
| 2149 | 2131 | raise typer.Exit(code=1) from exc |
| 2150 | 2132 | |
| 2151 | 2133 | size_mb = result.bytes_sent / (1024 * 1024) |
| 2152 | - console.print( | |
| 2153 | - f"[green]pushed:[/green] {result.destination} " | |
| 2154 | - f"({size_mb:.2f} MB)" | |
| 2155 | - ) | |
| 2134 | + console.print(f"[green]pushed:[/green] {result.destination} ({size_mb:.2f} MB)") | |
| 2156 | 2135 | if result.sink_kind.value == "hf": |
| 2157 | 2136 | console.print(f"[dim]install:[/dim] dlm pull {result.destination}") |
| 2158 | 2137 | if result.detail: |
@@ -2197,10 +2176,7 @@ def pull_cmd( | ||
| 2197 | 2176 | raise typer.Exit(code=1) from exc |
| 2198 | 2177 | |
| 2199 | 2178 | size_mb = result.bytes_received / (1024 * 1024) |
| 2200 | - console.print( | |
| 2201 | - f"[green]pulled:[/green] {result.source} → {result.dlm_path} " | |
| 2202 | - f"({size_mb:.2f} MB)" | |
| 2203 | - ) | |
| 2179 | + console.print(f"[green]pulled:[/green] {result.source} → {result.dlm_path} ({size_mb:.2f} MB)") | |
| 2204 | 2180 | |
| 2205 | 2181 | status = result.verification.status |
| 2206 | 2182 | if status == VerifyStatus.VERIFIED: |
@@ -2265,8 +2241,7 @@ def serve_cmd( | ||
| 2265 | 2241 | store = for_dlm(dlm_id) |
| 2266 | 2242 | if not store.manifest.exists(): |
| 2267 | 2243 | console.print( |
| 2268 | - f"[red]serve:[/red] no training state for {dlm_id} — run " | |
| 2269 | - "[bold]dlm train[/bold] first." | |
| 2244 | + f"[red]serve:[/red] no training state for {dlm_id} — run [bold]dlm train[/bold] first." | |
| 2270 | 2245 | ) |
| 2271 | 2246 | raise typer.Exit(code=1) |
| 2272 | 2247 | |
@@ -2387,8 +2362,7 @@ def cache_prune_cmd( | ||
| 2387 | 2362 | seconds = _parse_duration(older_than) |
| 2388 | 2363 | if seconds is None: |
| 2389 | 2364 | console.print( |
| 2390 | - f"[red]cache:[/red] invalid --older-than {older_than!r} " | |
| 2391 | - "(expected e.g. 30d, 12h, 45m)" | |
| 2365 | + f"[red]cache:[/red] invalid --older-than {older_than!r} (expected e.g. 30d, 12h, 45m)" | |
| 2392 | 2366 | ) |
| 2393 | 2367 | raise typer.Exit(code=2) |
| 2394 | 2368 | |
src/dlm/cli/scaffold.pymodified@@ -83,9 +83,7 @@ def scaffold_train_target( | ||
| 83 | 83 | if not target.exists(): |
| 84 | 84 | raise ScaffoldError(f"target does not exist: {target}", path=target) |
| 85 | 85 | if not target.is_dir(): |
| 86 | - raise ScaffoldError( | |
| 87 | - f"scaffold expects a directory, got file: {target}", path=target | |
| 88 | - ) | |
| 86 | + raise ScaffoldError(f"scaffold expects a directory, got file: {target}", path=target) | |
| 89 | 87 | |
| 90 | 88 | dlm_dir = target / _SCAFFOLD_DIR |
| 91 | 89 | existing = sorted(dlm_dir.glob("*.dlm")) if dlm_dir.is_dir() else [] |
@@ -105,13 +103,9 @@ def scaffold_train_target( | ||
| 105 | 103 | return ScaffoldResult(dlm_path=named_match, scaffolded=False, dlm_id=dlm_id) |
| 106 | 104 | if name_is_default and len(existing) == 1: |
| 107 | 105 | dlm_id = _dlm_id_from_file(existing[0]) |
| 108 | - return ScaffoldResult( | |
| 109 | - dlm_path=existing[0], scaffolded=False, dlm_id=dlm_id | |
| 110 | - ) | |
| 106 | + return ScaffoldResult(dlm_path=existing[0], scaffolded=False, dlm_id=dlm_id) | |
| 111 | 107 | if name_is_default and len(existing) > 1: |
| 112 | - listing = "\n".join( | |
| 113 | - f" dlm train {target} --name {c.stem}" for c in existing | |
| 114 | - ) | |
| 108 | + listing = "\n".join(f" dlm train {target} --name {c.stem}" for c in existing) | |
| 115 | 109 | raise ScaffoldError( |
| 116 | 110 | f"multiple .dlm files found under {target / _SCAFFOLD_DIR}; " |
| 117 | 111 | f"pass --name to pick one:\n{listing}", |
@@ -128,11 +122,7 @@ def scaffold_train_target( | ||
| 128 | 122 | ) |
| 129 | 123 | |
| 130 | 124 | dlm_path = dlm_dir / f"{name}.dlm" |
| 131 | - existing_id = ( | |
| 132 | - _dlm_id_from_file(dlm_path) | |
| 133 | - if rescaffold and dlm_path.is_file() | |
| 134 | - else None | |
| 135 | - ) | |
| 125 | + existing_id = _dlm_id_from_file(dlm_path) if rescaffold and dlm_path.is_file() else None | |
| 136 | 126 | |
| 137 | 127 | dlm_id = existing_id or mint_ulid() |
| 138 | 128 | dlm_dir.mkdir(parents=True, exist_ok=True) |
@@ -146,9 +136,7 @@ def scaffold_train_target( | ||
| 146 | 136 | policy=policy, |
| 147 | 137 | target=target, |
| 148 | 138 | ) |
| 149 | - _LOG.info( | |
| 150 | - "scaffold: wrote %s (dlm_id=%s, base=%s)", dlm_path, dlm_id, base | |
| 151 | - ) | |
| 139 | + _LOG.info("scaffold: wrote %s (dlm_id=%s, base=%s)", dlm_path, dlm_id, base) | |
| 152 | 140 | return ScaffoldResult(dlm_path=dlm_path, scaffolded=True, dlm_id=dlm_id) |
| 153 | 141 | |
| 154 | 142 | |
@@ -205,17 +193,14 @@ def _write_scaffold( | ||
| 205 | 193 | [ |
| 206 | 194 | "---", |
| 207 | 195 | "", |
| 208 | - "# Auto-scaffolded by `dlm train`. Edit the frontmatter above " | |
| 209 | - "to refine training.", | |
| 196 | + "# Auto-scaffolded by `dlm train`. Edit the frontmatter above to refine training.", | |
| 210 | 197 | "", |
| 211 | 198 | ] |
| 212 | 199 | ) |
| 213 | 200 | atomic_write_text(dlm_path, "\n".join(lines)) |
| 214 | 201 | |
| 215 | 202 | |
| 216 | -def _build_include_globs( | |
| 217 | - include: tuple[str, ...], *, recursive: bool | |
| 218 | -) -> tuple[str, ...]: | |
| 203 | +def _build_include_globs(include: tuple[str, ...], *, recursive: bool) -> tuple[str, ...]: | |
| 219 | 204 | """Map `--include` flags + `--recursive` to frontmatter globs. |
| 220 | 205 | |
| 221 | 206 | Empty `--include` + `--recursive` → `["**/*"]`: train on every |
src/dlm/data/formatter.pymodified@@ -57,7 +57,11 @@ def make_formatting_func(tokenizer: PreTrainedTokenizerBase) -> FormattingFunc: | ||
| 57 | 57 | if not isinstance(text, str): |
| 58 | 58 | raise DataFormatError(f"`text` field must be str, got {type(text).__name__}") |
| 59 | 59 | return text |
| 60 | - if row.get("prompt") is not None and row.get("chosen") is not None and row.get("rejected") is not None: | |
| 60 | + if ( | |
| 61 | + row.get("prompt") is not None | |
| 62 | + and row.get("chosen") is not None | |
| 63 | + and row.get("rejected") is not None | |
| 64 | + ): | |
| 61 | 65 | raise DataFormatError( |
| 62 | 66 | "preference rows (prompt/chosen/rejected) must be routed to DPOTrainer, " |
| 63 | 67 | "not SFTTrainer's formatting_func" |
src/dlm/directives/cache.pymodified@@ -132,9 +132,7 @@ class TokenizedCache: | ||
| 132 | 132 | # ---- Open / construct -------------------------------------------- |
| 133 | 133 | |
| 134 | 134 | @classmethod |
| 135 | - def open( | |
| 136 | - cls, root: Path, *, max_bytes: int = _DEFAULT_MAX_BYTES | |
| 137 | - ) -> TokenizedCache: | |
| 135 | + def open(cls, root: Path, *, max_bytes: int = _DEFAULT_MAX_BYTES) -> TokenizedCache: | |
| 138 | 136 | """Open (or create) a cache at `root`. |
| 139 | 137 | |
| 140 | 138 | Creates the directory layout idempotently. Missing manifest → |
@@ -343,17 +341,15 @@ class TokenizedCache: | ||
| 343 | 341 | mid-put fallback. |
| 344 | 342 | """ |
| 345 | 343 | cutoff = time.time() - older_than_seconds |
| 346 | - stale_keys = [ | |
| 347 | - e.key_str | |
| 348 | - for e in self._manifest.values() | |
| 349 | - if e.last_access_ts < cutoff | |
| 350 | - ] | |
| 344 | + stale_keys = [e.key_str for e in self._manifest.values() if e.last_access_ts < cutoff] | |
| 351 | 345 | for key_str in stale_keys: |
| 352 | 346 | entry = self._manifest[key_str] |
| 353 | 347 | self._entry_path(entry).unlink(missing_ok=True) |
| 354 | 348 | del self._manifest[key_str] |
| 355 | 349 | if stale_keys: |
| 356 | - _LOG.info("cache: pruned %d entries older than %ds", len(stale_keys), older_than_seconds) | |
| 350 | + _LOG.info( | |
| 351 | + "cache: pruned %d entries older than %ds", len(stale_keys), older_than_seconds | |
| 352 | + ) | |
| 357 | 353 | return len(stale_keys) |
| 358 | 354 | |
| 359 | 355 | def clear(self) -> int: |
src/dlm/directives/cache_key.pymodified@@ -50,9 +50,7 @@ class CacheKey: | ||
| 50 | 50 | birthday threshold. The full sha is persisted in the manifest |
| 51 | 51 | for verification if a collision ever occurs in practice. |
| 52 | 52 | """ |
| 53 | - return ( | |
| 54 | - f"{self.section_id}.{self.tokenizer_sha[:12]}.seq{self.sequence_len}.npz" | |
| 55 | - ) | |
| 53 | + return f"{self.section_id}.{self.tokenizer_sha[:12]}.seq{self.sequence_len}.npz" | |
| 56 | 54 | |
| 57 | 55 | def shard(self) -> str: |
| 58 | 56 | """First 2 hex chars of section_id — the directory shard.""" |
src/dlm/directives/discovery.pymodified@@ -76,11 +76,7 @@ def discover_configs(root: Path) -> tuple[DiscoveredConfig, ...]: | ||
| 76 | 76 | anchor = dlm_dir.parent |
| 77 | 77 | config = _load_training_yaml(dlm_dir / _CONFIG_FILENAME) |
| 78 | 78 | ignore_rules = _load_ignore(dlm_dir / _IGNORE_FILENAME) |
| 79 | - discovered.append( | |
| 80 | - DiscoveredConfig( | |
| 81 | - anchor=anchor, config=config, ignore_rules=ignore_rules | |
| 82 | - ) | |
| 83 | - ) | |
| 79 | + discovered.append(DiscoveredConfig(anchor=anchor, config=config, ignore_rules=ignore_rules)) | |
| 84 | 80 | |
| 85 | 81 | discovered.sort(key=lambda d: len(d.anchor.as_posix())) |
| 86 | 82 | return tuple(discovered) |
@@ -121,9 +117,7 @@ def _load_training_yaml(path: Path) -> DlmTrainingConfig | None: | ||
| 121 | 117 | try: |
| 122 | 118 | return DlmTrainingConfig.model_validate(raw) |
| 123 | 119 | except ValidationError as exc: |
| 124 | - _LOG.warning( | |
| 125 | - "discovery: %s: schema violation (%s); skipping config", path, exc | |
| 126 | - ) | |
| 120 | + _LOG.warning("discovery: %s: schema violation (%s); skipping config", path, exc) | |
| 127 | 121 | return None |
| 128 | 122 | |
| 129 | 123 | |
src/dlm/directives/expand.pymodified@@ -100,9 +100,7 @@ def expand_sources(parsed: ParsedDlm, *, base_path: Path) -> ExpandResult: | ||
| 100 | 100 | if not directives: |
| 101 | 101 | return ExpandResult(sections=(), provenance=(), discovered=()) |
| 102 | 102 | |
| 103 | - effective_base = ( | |
| 104 | - base_path.parent if base_path.name == ".dlm" else base_path | |
| 105 | - ) | |
| 103 | + effective_base = base_path.parent if base_path.name == ".dlm" else base_path | |
| 106 | 104 | strict = training.sources_policy == "strict" |
| 107 | 105 | sections: list[Section] = [] |
| 108 | 106 | provenance: list[SourceProvenance] = [] |
@@ -159,10 +157,7 @@ def _expand_one( | ||
| 159 | 157 | header_root = resolved_root if resolved_root.is_dir() else resolved_root.parent |
| 160 | 158 | |
| 161 | 159 | for file_path in _iter_candidates(resolved_root): |
| 162 | - if ( | |
| 163 | - directive.max_files is not None | |
| 164 | - and len(sections) >= directive.max_files | |
| 165 | - ): | |
| 160 | + if directive.max_files is not None and len(sections) >= directive.max_files: | |
| 166 | 161 | _LOG.info( |
| 167 | 162 | "directive: hit max_files=%d for %s; truncating deterministically", |
| 168 | 163 | directive.max_files, |
@@ -190,10 +185,7 @@ def _expand_one( | ||
| 190 | 185 | _LOG.warning("directive: stat failed for %s: %s; skipping", file_path, exc) |
| 191 | 186 | continue |
| 192 | 187 | |
| 193 | - if ( | |
| 194 | - directive.max_bytes_per_file is not None | |
| 195 | - and size > directive.max_bytes_per_file | |
| 196 | - ): | |
| 188 | + if directive.max_bytes_per_file is not None and size > directive.max_bytes_per_file: | |
| 197 | 189 | _LOG.info( |
| 198 | 190 | "directive: %s (%d bytes) exceeds max_bytes_per_file=%d; skipping", |
| 199 | 191 | file_path, |
@@ -210,9 +202,7 @@ def _expand_one( | ||
| 210 | 202 | continue |
| 211 | 203 | |
| 212 | 204 | if is_probably_binary(raw): |
| 213 | - _LOG.info( | |
| 214 | - "directive: %s looks binary (NUL in first KiB); skipping", file_path | |
| 215 | - ) | |
| 205 | + _LOG.info("directive: %s looks binary (NUL in first KiB); skipping", file_path) | |
| 216 | 206 | skipped_binary += 1 |
| 217 | 207 | continue |
| 218 | 208 | |
@@ -225,9 +215,7 @@ def _expand_one( | ||
| 225 | 215 | |
| 226 | 216 | relpath = file_path.relative_to(header_root).as_posix() |
| 227 | 217 | content = f"# source: {relpath}\n\n{text}" |
| 228 | - sections.append( | |
| 229 | - Section(type=SectionType.PROSE, content=content, tags=effective.tags) | |
| 230 | - ) | |
| 218 | + sections.append(Section(type=SectionType.PROSE, content=content, tags=effective.tags)) | |
| 231 | 219 | total_bytes += len(raw) |
| 232 | 220 | |
| 233 | 221 | return sections, SourceProvenance( |
src/dlm/directives/merge.pymodified@@ -52,11 +52,7 @@ def ancestors_of( | ||
| 52 | 52 | """Return DiscoveredConfigs whose anchor is an ancestor of file_path, |
| 53 | 53 | sorted shallowest → deepest. Includes the direct-parent anchor.""" |
| 54 | 54 | abs_file = file_path.resolve() |
| 55 | - result = [ | |
| 56 | - d | |
| 57 | - for d in discovered | |
| 58 | - if _is_ancestor(d.anchor.resolve(), abs_file) | |
| 59 | - ] | |
| 55 | + result = [d for d in discovered if _is_ancestor(d.anchor.resolve(), abs_file)] | |
| 60 | 56 | result.sort(key=lambda d: len(d.anchor.as_posix())) |
| 61 | 57 | return tuple(result) |
| 62 | 58 | |
src/dlm/directives/safety.pymodified@@ -158,9 +158,7 @@ def enumerate_matching_files( | ||
| 158 | 158 | yield candidate |
| 159 | 159 | |
| 160 | 160 | |
| 161 | -def _matches_filters( | |
| 162 | - rel_path: str, include: Iterable[str], exclude: Iterable[str] | |
| 163 | -) -> bool: | |
| 161 | +def _matches_filters(rel_path: str, include: Iterable[str], exclude: Iterable[str]) -> bool: | |
| 164 | 162 | """Match rel_path against include (any) and exclude (none).""" |
| 165 | 163 | if any(_compile_glob(pat).fullmatch(rel_path) for pat in exclude): |
| 166 | 164 | return False |
src/dlm/doc/parser.pymodified@@ -210,9 +210,7 @@ def _tokenize_body(body: str, *, body_start_line: int, path: Path | None) -> lis | ||
| 210 | 210 | match = _FENCE_RE.match(line) |
| 211 | 211 | if match: |
| 212 | 212 | fence_name = match.group(1) |
| 213 | - fence_type, fence_adapter = _resolve_fence_type( | |
| 214 | - fence_name, source_line, path | |
| 215 | - ) | |
| 213 | + fence_type, fence_adapter = _resolve_fence_type(fence_name, source_line, path) | |
| 216 | 214 | flush() |
| 217 | 215 | current_type = fence_type |
| 218 | 216 | current_adapter = fence_adapter |
@@ -233,9 +231,7 @@ def _tokenize_body(body: str, *, body_start_line: int, path: Path | None) -> lis | ||
| 233 | 231 | return sections |
| 234 | 232 | |
| 235 | 233 | |
| 236 | -def _resolve_fence_type( | |
| 237 | - name: str, line: int, path: Path | None | |
| 238 | -) -> tuple[SectionType, str | None]: | |
| 234 | +def _resolve_fence_type(name: str, line: int, path: Path | None) -> tuple[SectionType, str | None]: | |
| 239 | 235 | """Map a fence name to `(SectionType, adapter_name|None)` or raise. |
| 240 | 236 | |
| 241 | 237 | Multi-adapter fences carry a `#<adapter>` suffix; the adapter part is |
@@ -267,8 +263,7 @@ def _resolve_fence_type( | ||
| 267 | 263 | section_type = SectionType(base) |
| 268 | 264 | except ValueError as exc: |
| 269 | 265 | raise FenceError( |
| 270 | - f"unknown section fence '::{name}::'; valid types are " | |
| 271 | - f"{[t.value for t in SectionType]}", | |
| 266 | + f"unknown section fence '::{name}::'; valid types are {[t.value for t in SectionType]}", | |
| 272 | 267 | path=path, |
| 273 | 268 | line=line, |
| 274 | 269 | col=1, |
src/dlm/doc/schema.pymodified@@ -69,9 +69,7 @@ class PreferenceConfig(BaseModel): | ||
| 69 | 69 | |
| 70 | 70 | enabled: bool = False |
| 71 | 71 | method: Literal["dpo", "orpo"] = "dpo" |
| 72 | - hyperparams: PreferenceHyperparams = Field( | |
| 73 | - default_factory=lambda: PreferenceHyperparams() | |
| 74 | - ) | |
| 72 | + hyperparams: PreferenceHyperparams = Field(default_factory=lambda: PreferenceHyperparams()) | |
| 75 | 73 | # DPO-only fields — ignored for ORPO but kept on the config so a |
| 76 | 74 | # user switching methods doesn't have to delete them. |
| 77 | 75 | loss_type: Literal["sigmoid", "hinge", "ipo"] = "sigmoid" |
@@ -247,11 +245,7 @@ class TrainingConfig(BaseModel): | ||
| 247 | 245 | "target_modules": "auto", |
| 248 | 246 | "learning_rate": 2e-4, |
| 249 | 247 | } |
| 250 | - drift = [ | |
| 251 | - key | |
| 252 | - for key, default in flat_defaults.items() | |
| 253 | - if getattr(self, key) != default | |
| 254 | - ] | |
| 248 | + drift = [key for key, default in flat_defaults.items() if getattr(self, key) != default] | |
| 255 | 249 | if drift: |
| 256 | 250 | raise ValueError( |
| 257 | 251 | "training.adapters is declared; flat per-adapter fields " |
src/dlm/doc/serializer.pymodified@@ -104,8 +104,10 @@ def _emit_nested_mapping(model: BaseModel, *, indent: int) -> list[str]: | ||
| 104 | 104 | lines.append(f"{pad}{field_name}:") |
| 105 | 105 | lines.extend(nested) |
| 106 | 106 | continue |
| 107 | - if isinstance(value, dict) and value and all( | |
| 108 | - isinstance(v, BaseModel) for v in value.values() | |
| 107 | + if ( | |
| 108 | + isinstance(value, dict) | |
| 109 | + and value | |
| 110 | + and all(isinstance(v, BaseModel) for v in value.values()) | |
| 109 | 111 | ): |
| 110 | 112 | # `dict[str, BaseModel]` (e.g. training.adapters) — emit |
| 111 | 113 | # each entry as a nested mapping. The key is the dict |
src/dlm/export/runner.pymodified@@ -186,9 +186,7 @@ def run_export( | ||
| 186 | 186 | if not adapter_path.exists(): |
| 187 | 187 | from dlm.export.errors import ExportError |
| 188 | 188 | |
| 189 | - raise ExportError( | |
| 190 | - f"adapter_path_override {adapter_path} does not exist" | |
| 191 | - ) | |
| 189 | + raise ExportError(f"adapter_path_override {adapter_path} does not exist") | |
| 192 | 190 | elif adapter_name is None: |
| 193 | 191 | resolved = store.resolve_current_adapter() |
| 194 | 192 | pointer = store.adapter_current_pointer |
@@ -196,8 +194,7 @@ def run_export( | ||
| 196 | 194 | from dlm.export.errors import ExportError |
| 197 | 195 | |
| 198 | 196 | raise ExportError( |
| 199 | - f"no current adapter under {pointer}; " | |
| 200 | - "run `dlm train` before exporting." | |
| 197 | + f"no current adapter under {pointer}; run `dlm train` before exporting." | |
| 201 | 198 | ) |
| 202 | 199 | adapter_path = resolved |
| 203 | 200 | else: |
src/dlm/export/weighted_merge.pymodified@@ -77,47 +77,37 @@ def parse_mix_spec(spec_str: str) -> list[MixEntry]: | ||
| 77 | 77 | for piece in raw.split(","): |
| 78 | 78 | token = piece.strip() |
| 79 | 79 | if not token: |
| 80 | - raise InvalidMixSpecError( | |
| 81 | - f"--adapter-mix: empty entry in spec {spec_str!r}" | |
| 82 | - ) | |
| 80 | + raise InvalidMixSpecError(f"--adapter-mix: empty entry in spec {spec_str!r}") | |
| 83 | 81 | if ":" not in token: |
| 84 | 82 | raise InvalidMixSpecError( |
| 85 | - f"--adapter-mix: entry {token!r} is missing a weight " | |
| 86 | - "(shape: `name:weight`)" | |
| 83 | + f"--adapter-mix: entry {token!r} is missing a weight (shape: `name:weight`)" | |
| 87 | 84 | ) |
| 88 | 85 | name, _, weight_str = token.rpartition(":") |
| 89 | 86 | name = name.strip() |
| 90 | 87 | weight_str = weight_str.strip() |
| 91 | 88 | if not _NAME_RE.fullmatch(name): |
| 92 | 89 | raise InvalidMixSpecError( |
| 93 | - f"--adapter-mix: adapter name {name!r} is not valid " | |
| 94 | - f"(must match {_NAME_RE.pattern})" | |
| 90 | + f"--adapter-mix: adapter name {name!r} is not valid (must match {_NAME_RE.pattern})" | |
| 95 | 91 | ) |
| 96 | 92 | if name in seen: |
| 97 | - raise InvalidMixSpecError( | |
| 98 | - f"--adapter-mix: adapter {name!r} appears twice" | |
| 99 | - ) | |
| 93 | + raise InvalidMixSpecError(f"--adapter-mix: adapter {name!r} appears twice") | |
| 100 | 94 | seen.add(name) |
| 101 | 95 | try: |
| 102 | 96 | weight = float(weight_str) |
| 103 | 97 | except ValueError as exc: |
| 104 | 98 | raise InvalidMixSpecError( |
| 105 | - f"--adapter-mix: weight {weight_str!r} for adapter " | |
| 106 | - f"{name!r} is not a number" | |
| 99 | + f"--adapter-mix: weight {weight_str!r} for adapter {name!r} is not a number" | |
| 107 | 100 | ) from exc |
| 108 | 101 | if weight < 0: |
| 109 | 102 | raise InvalidMixSpecError( |
| 110 | - f"--adapter-mix: weight {weight} for adapter {name!r} " | |
| 111 | - "is negative (must be >= 0)" | |
| 103 | + f"--adapter-mix: weight {weight} for adapter {name!r} is negative (must be >= 0)" | |
| 112 | 104 | ) |
| 113 | 105 | entries.append(MixEntry(name=name, weight=weight)) |
| 114 | 106 | |
| 115 | 107 | return entries |
| 116 | 108 | |
| 117 | 109 | |
| 118 | -def validate_mix_against_declared( | |
| 119 | - entries: list[MixEntry], declared: set[str] | |
| 120 | -) -> None: | |
| 110 | +def validate_mix_against_declared(entries: list[MixEntry], declared: set[str]) -> None: | |
| 121 | 111 | """Refuse mix entries that reference adapters not in `training.adapters`. |
| 122 | 112 | |
| 123 | 113 | Single source of error messaging so the CLI and the runner both |
@@ -170,9 +160,7 @@ def build_weighted_merged( # pragma: no cover - heavy path | ||
| 170 | 160 | |
| 171 | 161 | first = entries[0] |
| 172 | 162 | first_path = _resolve_or_raise(store, first.name) |
| 173 | - model = PeftModel.from_pretrained( | |
| 174 | - base_model, str(first_path), adapter_name=first.name | |
| 175 | - ) | |
| 163 | + model = PeftModel.from_pretrained(base_model, str(first_path), adapter_name=first.name) | |
| 176 | 164 | for extra in entries[1:]: |
| 177 | 165 | path = _resolve_or_raise(store, extra.name) |
| 178 | 166 | model.load_adapter(str(path), adapter_name=extra.name) |
@@ -208,9 +196,7 @@ def resolve_first_source_path(store: StorePath, entries: list[MixEntry]) -> Path | ||
| 208 | 196 | single-valued), so any source is interchangeable — we pick the first. |
| 209 | 197 | """ |
| 210 | 198 | if not entries: |
| 211 | - raise InvalidMixSpecError( | |
| 212 | - "resolve_first_source_path: empty mix" | |
| 213 | - ) | |
| 199 | + raise InvalidMixSpecError("resolve_first_source_path: empty mix") | |
| 214 | 200 | return _resolve_or_raise(store, entries[0].name) |
| 215 | 201 | |
| 216 | 202 | |
@@ -257,9 +243,7 @@ def save_merged_to_tmp( # pragma: no cover - heavy path | ||
| 257 | 243 | import shutil |
| 258 | 244 | |
| 259 | 245 | tmp_dir.mkdir(parents=True, exist_ok=True) |
| 260 | - merged_model.save_pretrained( | |
| 261 | - str(tmp_dir), selected_adapters=[_MERGED_ADAPTER_NAME] | |
| 262 | - ) | |
| 246 | + merged_model.save_pretrained(str(tmp_dir), selected_adapters=[_MERGED_ADAPTER_NAME]) | |
| 263 | 247 | |
| 264 | 248 | # PEFT nests under the adapter name; that's where run_export |
| 265 | 249 | # expects to find adapter_config.json + safetensors. |
src/dlm/hardware/capabilities.pymodified@@ -156,9 +156,7 @@ def _get_unified_memory_gb(backend: Backend) -> float | None: | ||
| 156 | 156 | return psutil.virtual_memory().total / (1024**3) |
| 157 | 157 | |
| 158 | 158 | |
| 159 | -def _supports_bf16( | |
| 160 | - backend: Backend, sm: tuple[int, int] | None, rocm_arch: str | None | |
| 161 | -) -> bool: | |
| 159 | +def _supports_bf16(backend: Backend, sm: tuple[int, int] | None, rocm_arch: str | None) -> bool: | |
| 162 | 160 | if backend == Backend.CUDA: |
| 163 | 161 | return sm is not None and sm >= (8, 0) |
| 164 | 162 | if backend == Backend.ROCM: |
src/dlm/hardware/plan.pymodified@@ -104,9 +104,7 @@ def resolve( | ||
| 104 | 104 | raise ValueError(f"world_size must be >= 1, got {world_size}") |
| 105 | 105 | if world_size > 1: |
| 106 | 106 | check_multi_gpu_refusals(caps, world_size) |
| 107 | - check_refusals( | |
| 108 | - training, caps, base_params, force=force, num_adapters=num_adapters | |
| 109 | - ) | |
| 107 | + check_refusals(training, caps, base_params, force=force, num_adapters=num_adapters) | |
| 110 | 108 | |
| 111 | 109 | use_qlora = _should_qlora(training, caps) |
| 112 | 110 | precision = _pick_precision(caps, override=training.precision) |
src/dlm/hardware/refusals.pymodified@@ -83,15 +83,11 @@ def check_refusals( | ||
| 83 | 83 | per_adapter_gb = max(0.1, base_params * avg_lora_r / (1e9 * 64)) |
| 84 | 84 | activations_gb = base_params * 2.0 / 1e9 * 0.25 |
| 85 | 85 | qlora_adapter_count = _qlora_adapter_count(training, num_adapters) |
| 86 | - est_peak = ( | |
| 87 | - base_gb + per_adapter_gb * qlora_adapter_count + activations_gb | |
| 88 | - ) | |
| 86 | + est_peak = base_gb + per_adapter_gb * qlora_adapter_count + activations_gb | |
| 89 | 87 | budget = caps.vram_gb * 0.85 |
| 90 | 88 | if est_peak > budget: |
| 91 | 89 | offenders = _qlora_adapter_names(training) |
| 92 | - offender_note = ( | |
| 93 | - f" (offending adapters: {sorted(offenders)})" if offenders else "" | |
| 94 | - ) | |
| 90 | + offender_note = f" (offending adapters: {sorted(offenders)})" if offenders else "" | |
| 95 | 91 | raise ResolutionError( |
| 96 | 92 | "Multi-adapter QLoRA would exceed VRAM " |
| 97 | 93 | f"(~{est_peak:.1f} GB estimated vs {budget:.1f} GB budget " |
@@ -130,8 +126,7 @@ def check_multi_gpu_refusals(caps: Capabilities, world_size: int) -> None: | ||
| 130 | 126 | ) |
| 131 | 127 | if caps.backend == Backend.CPU: |
| 132 | 128 | raise ResolutionError( |
| 133 | - "Multi-GPU training on CPU is not supported. " | |
| 134 | - "Drop `--gpus` or run single-process.", | |
| 129 | + "Multi-GPU training on CPU is not supported. Drop `--gpus` or run single-process.", | |
| 135 | 130 | ) |
| 136 | 131 | if caps.backend == Backend.ROCM: |
| 137 | 132 | raise ResolutionError( |
@@ -181,9 +176,7 @@ def _avg_lora_r(training: TrainingConfig) -> float: | ||
| 181 | 176 | """Average LoRA rank across declared adapters (fallback: flat lora_r).""" |
| 182 | 177 | if training.adapters is None or not training.adapters: |
| 183 | 178 | return float(training.lora_r) |
| 184 | - return sum(a.lora_r for a in training.adapters.values()) / len( | |
| 185 | - training.adapters | |
| 186 | - ) | |
| 179 | + return sum(a.lora_r for a in training.adapters.values()) / len(training.adapters) | |
| 187 | 180 | |
| 188 | 181 | |
| 189 | 182 | def _qlora_adapter_count(training: TrainingConfig, fallback: int) -> int: |
src/dlm/inference/backends/mlx_backend.pymodified@@ -54,8 +54,7 @@ def stage_mlx_adapter_dir(peft_adapter_dir: Path, dst_dir: Path) -> Path: | ||
| 54 | 54 | src_config = peft_adapter_dir / _ADAPTER_CONFIG_FILENAME |
| 55 | 55 | if not src_config.exists(): |
| 56 | 56 | raise MlxConversionError( |
| 57 | - f"{peft_adapter_dir} is not a PEFT adapter dir " | |
| 58 | - f"({_ADAPTER_CONFIG_FILENAME} is missing)" | |
| 57 | + f"{peft_adapter_dir} is not a PEFT adapter dir ({_ADAPTER_CONFIG_FILENAME} is missing)" | |
| 59 | 58 | ) |
| 60 | 59 | if not (peft_adapter_dir / "adapter_model.safetensors").exists(): |
| 61 | 60 | raise MlxConversionError( |
src/dlm/inference/loader.pymodified@@ -105,9 +105,7 @@ def _torch_dtype_for(precision: str) -> Any: | ||
| 105 | 105 | return lookup.get(precision, torch.float16) |
| 106 | 106 | |
| 107 | 107 | |
| 108 | -def resolve_adapter_path( | |
| 109 | - store: StorePath, *, adapter_name: str | None | |
| 110 | -) -> Path: | |
| 108 | +def resolve_adapter_path(store: StorePath, *, adapter_name: str | None) -> Path: | |
| 111 | 109 | """Return the on-disk adapter version dir for inference. |
| 112 | 110 | |
| 113 | 111 | Single entry point for both the flat (unnamed) and named-adapter |
src/dlm/metrics/queries.pymodified@@ -132,9 +132,7 @@ def evals_for_run(store_root: Path, run_id: int, *, since_step: int = 0) -> list | ||
| 132 | 132 | return [EvalRow(*row) for row in rows] |
| 133 | 133 | |
| 134 | 134 | |
| 135 | -def tokenization_for_run( | |
| 136 | - store_root: Path, run_id: int | |
| 137 | -) -> TokenizationRow | None: | |
| 135 | +def tokenization_for_run(store_root: Path, run_id: int) -> TokenizationRow | None: | |
| 138 | 136 | """The tokenization row for `run_id`, or None when absent. |
| 139 | 137 | |
| 140 | 138 | Returns None when the table is empty for this run (i.e. the run |
src/dlm/metrics/sinks/wandb.pymodified@@ -45,8 +45,7 @@ class WandbSink: | ||
| 45 | 45 | import wandb # type: ignore[import-not-found] |
| 46 | 46 | except ImportError as exc: # pragma: no cover - depends on install |
| 47 | 47 | raise RuntimeError( |
| 48 | - "W&B sink requires `wandb`; " | |
| 49 | - "run `uv sync --extra observability` to install it." | |
| 48 | + "W&B sink requires `wandb`; run `uv sync --extra observability` to install it." | |
| 50 | 49 | ) from exc |
| 51 | 50 | |
| 52 | 51 | wandb_dir = store_root / "wandb" |
src/dlm/store/inspect.pymodified@@ -202,7 +202,7 @@ def _max_version(versions_dir: Path) -> int: | ||
| 202 | 202 | if not name.startswith(_VERSION_DIR_PREFIX): |
| 203 | 203 | continue |
| 204 | 204 | try: |
| 205 | - n = int(name[len(_VERSION_DIR_PREFIX):]) | |
| 205 | + n = int(name[len(_VERSION_DIR_PREFIX) :]) | |
| 206 | 206 | except ValueError: |
| 207 | 207 | continue |
| 208 | 208 | highest = max(highest, n) |
src/dlm/store/paths.pymodified@@ -52,8 +52,7 @@ _ADAPTER_NAME_RE: Final[re.Pattern[str]] = re.compile(r"^[a-z][a-z0-9_]{0,31}$") | ||
| 52 | 52 | def _validate_adapter_name(name: str) -> None: |
| 53 | 53 | if not _ADAPTER_NAME_RE.fullmatch(name): |
| 54 | 54 | raise ValueError( |
| 55 | - f"adapter name {name!r} is not valid " | |
| 56 | - f"(must match {_ADAPTER_NAME_RE.pattern})" | |
| 55 | + f"adapter name {name!r} is not valid (must match {_ADAPTER_NAME_RE.pattern})" | |
| 57 | 56 | ) |
| 58 | 57 | |
| 59 | 58 | |
@@ -235,9 +234,7 @@ class StorePath: | ||
| 235 | 234 | ) from exc |
| 236 | 235 | from dlm.io.atomic import write_text as _atomic_write_text |
| 237 | 236 | |
| 238 | - _atomic_write_text( | |
| 239 | - self.adapter_current_pointer_for(name), f"{relative}\n" | |
| 240 | - ) | |
| 237 | + _atomic_write_text(self.adapter_current_pointer_for(name), f"{relative}\n") | |
| 241 | 238 | |
| 242 | 239 | def export_quant_dir(self, quant: str) -> Path: |
| 243 | 240 | """Return `exports/<quant>/` (does NOT create it).""" |
src/dlm/templates/init.pymodified@@ -95,8 +95,7 @@ def apply_template( | ||
| 95 | 95 | ) from exc |
| 96 | 96 | if is_gated(spec) and not accept_license: |
| 97 | 97 | raise TemplateApplyError( |
| 98 | - f"template {name!r} uses gated base {spec.key!r}; " | |
| 99 | - "pass accept_license=True" | |
| 98 | + f"template {name!r} uses gated base {spec.key!r}; pass accept_license=True" | |
| 100 | 99 | ) |
| 101 | 100 | |
| 102 | 101 | parsed = parse_text(template.dlm_text) |
src/dlm/train/checkpoint_commit.pymodified@@ -43,9 +43,7 @@ _LOG = logging.getLogger(__name__) | ||
| 43 | 43 | _VERSION_PREFIX = "v" |
| 44 | 44 | |
| 45 | 45 | |
| 46 | -def allocate_next_version( | |
| 47 | - store: StorePath, *, adapter_name: str | None = None | |
| 48 | -) -> Path: | |
| 46 | +def allocate_next_version(store: StorePath, *, adapter_name: str | None = None) -> Path: | |
| 49 | 47 | """Return the next empty `adapter/[<name>/]versions/vNNNN/` path. |
| 50 | 48 | |
| 51 | 49 | Creates the directory (and any missing parents). When `adapter_name` |
@@ -99,8 +97,7 @@ def commit_version( | ||
| 99 | 97 | ) |
| 100 | 98 | except OSError: |
| 101 | 99 | _LOG.exception( |
| 102 | - "non-finite adapter weights + rejected-dir rename failed; " | |
| 103 | - "leaving %s in place", | |
| 100 | + "non-finite adapter weights + rejected-dir rename failed; leaving %s in place", | |
| 104 | 101 | pending, |
| 105 | 102 | ) |
| 106 | 103 | raise |
@@ -156,9 +153,7 @@ def fsync_dir(path: Path) -> None: | ||
| 156 | 153 | os.close(fd) |
| 157 | 154 | |
| 158 | 155 | |
| 159 | -def list_pending_versions( | |
| 160 | - store: StorePath, *, adapter_name: str | None = None | |
| 161 | -) -> list[Path]: | |
| 156 | +def list_pending_versions(store: StorePath, *, adapter_name: str | None = None) -> list[Path]: | |
| 162 | 157 | """Return version dirs that exist on disk but aren't the current pointer. |
| 163 | 158 | |
| 164 | 159 | Used by the trainer's startup routine to detect crash-before-flip |
@@ -182,13 +177,9 @@ def list_pending_versions( | ||
| 182 | 177 | return [version_for(n) for n in sorted(existing) if n != current_n] |
| 183 | 178 | |
| 184 | 179 | |
| 185 | -def _existing_versions( | |
| 186 | - store: StorePath, *, adapter_name: str | None = None | |
| 187 | -) -> list[int]: | |
| 180 | +def _existing_versions(store: StorePath, *, adapter_name: str | None = None) -> list[int]: | |
| 188 | 181 | base = ( |
| 189 | - store.adapter_versions | |
| 190 | - if adapter_name is None | |
| 191 | - else store.adapter_versions_for(adapter_name) | |
| 182 | + store.adapter_versions if adapter_name is None else store.adapter_versions_for(adapter_name) | |
| 192 | 183 | ) |
| 193 | 184 | if not base.is_dir(): |
| 194 | 185 | return [] |
src/dlm/train/cpt/embed_warmup.pymodified@@ -127,7 +127,11 @@ class EmbedWarmupCallback: # pragma: no cover - exercised by slow integration | ||
| 127 | 127 | self._active: bool = False |
| 128 | 128 | |
| 129 | 129 | def on_train_begin( |
| 130 | - self, args: Any, state: Any, control: Any, **kwargs: Any # noqa: ARG002 | |
| 130 | + self, | |
| 131 | + args: Any, | |
| 132 | + state: Any, | |
| 133 | + control: Any, | |
| 134 | + **kwargs: Any, # noqa: ARG002 | |
| 131 | 135 | ) -> None: |
| 132 | 136 | if self.n_steps <= 0: |
| 133 | 137 | return |
@@ -145,12 +149,20 @@ class EmbedWarmupCallback: # pragma: no cover - exercised by slow integration | ||
| 145 | 149 | self._active = False |
| 146 | 150 | |
| 147 | 151 | def on_step_end( |
| 148 | - self, args: Any, state: Any, control: Any, **kwargs: Any # noqa: ARG002 | |
| 152 | + self, | |
| 153 | + args: Any, | |
| 154 | + state: Any, | |
| 155 | + control: Any, | |
| 156 | + **kwargs: Any, # noqa: ARG002 | |
| 149 | 157 | ) -> None: |
| 150 | 158 | if self._active and state.global_step >= self.n_steps: |
| 151 | 159 | self._restore() |
| 152 | 160 | |
| 153 | 161 | def on_train_end( |
| 154 | - self, args: Any, state: Any, control: Any, **kwargs: Any # noqa: ARG002 | |
| 162 | + self, | |
| 163 | + args: Any, | |
| 164 | + state: Any, | |
| 165 | + control: Any, | |
| 166 | + **kwargs: Any, # noqa: ARG002 | |
| 155 | 167 | ) -> None: |
| 156 | 168 | self._restore() |
src/dlm/train/cpt/schedule.pymodified@@ -50,13 +50,9 @@ def cosine_with_floor_lr( | ||
| 50 | 50 | if warmup_steps < 0: |
| 51 | 51 | raise ValueError(f"warmup_steps must be non-negative, got {warmup_steps}") |
| 52 | 52 | if warmup_steps >= total_steps: |
| 53 | - raise ValueError( | |
| 54 | - f"warmup_steps ({warmup_steps}) must be < total_steps ({total_steps})" | |
| 55 | - ) | |
| 53 | + raise ValueError(f"warmup_steps ({warmup_steps}) must be < total_steps ({total_steps})") | |
| 56 | 54 | if not 0.0 <= floor_ratio <= 1.0: |
| 57 | - raise ValueError( | |
| 58 | - f"floor_ratio must be in [0.0, 1.0], got {floor_ratio}" | |
| 59 | - ) | |
| 55 | + raise ValueError(f"floor_ratio must be in [0.0, 1.0], got {floor_ratio}") | |
| 60 | 56 | if step < 0: |
| 61 | 57 | raise ValueError(f"step must be non-negative, got {step}") |
| 62 | 58 | |
src/dlm/train/cpt/vocab_gap.pymodified@@ -85,8 +85,7 @@ def compute_vocab_gap( | ||
| 85 | 85 | """ |
| 86 | 86 | if len(token_ids) != len(decoded_tokens): |
| 87 | 87 | raise ValueError( |
| 88 | - f"token_ids/decoded_tokens length mismatch: " | |
| 89 | - f"{len(token_ids)} vs {len(decoded_tokens)}" | |
| 88 | + f"token_ids/decoded_tokens length mismatch: {len(token_ids)} vs {len(decoded_tokens)}" | |
| 90 | 89 | ) |
| 91 | 90 | if top_n < 0: |
| 92 | 91 | raise ValueError(f"top_n must be non-negative, got {top_n}") |
@@ -95,11 +94,7 @@ def compute_vocab_gap( | ||
| 95 | 94 | total_words = _count_words(text) |
| 96 | 95 | tpw = total_tokens / total_words if total_words else 0.0 |
| 97 | 96 | |
| 98 | - unk_hits = ( | |
| 99 | - sum(1 for tid in token_ids if tid == unk_token_id) | |
| 100 | - if unk_token_id is not None | |
| 101 | - else 0 | |
| 102 | - ) | |
| 97 | + unk_hits = sum(1 for tid in token_ids if tid == unk_token_id) if unk_token_id is not None else 0 | |
| 103 | 98 | |
| 104 | 99 | counts: Counter[str] = Counter(decoded_tokens) |
| 105 | 100 | top_tokens = counts.most_common(top_n) |
@@ -113,7 +108,9 @@ def compute_vocab_gap( | ||
| 113 | 108 | ) |
| 114 | 109 | |
| 115 | 110 | |
| 116 | -def report(text: str, tokenizer: Any, *, top_n: int = 10) -> VocabGapReport: # pragma: no cover - network/heavy | |
| 111 | +def report( | |
| 112 | + text: str, tokenizer: Any, *, top_n: int = 10 | |
| 113 | +) -> VocabGapReport: # pragma: no cover - network/heavy | |
| 117 | 114 | """Run the base tokenizer over `text` and compute the fit report. |
| 118 | 115 | |
| 119 | 116 | Heavy-import shell around `compute_vocab_gap` — covered by the slow |
@@ -147,12 +144,8 @@ def render_report(r: VocabGapReport) -> str: | ||
| 147 | 144 | f" <unk> hits : {r.unk_hits}", |
| 148 | 145 | ] |
| 149 | 146 | if r.has_unk: |
| 150 | - lines.append( | |
| 151 | - " WARNING: non-zero <unk> count — tokenizer has rare-character" | |
| 152 | - ) | |
| 153 | - lines.append( | |
| 154 | - " holes for this domain. Consider a different base model." | |
| 155 | - ) | |
| 147 | + lines.append(" WARNING: non-zero <unk> count — tokenizer has rare-character") | |
| 148 | + lines.append(" holes for this domain. Consider a different base model.") | |
| 156 | 149 | if r.top_tokens: |
| 157 | 150 | lines.append(" top tokens:") |
| 158 | 151 | width = max(len(t) for t, _ in r.top_tokens) |
src/dlm/train/distributed/rank_env.pymodified@@ -31,9 +31,7 @@ def detect_world_size() -> int: | ||
| 31 | 31 | try: |
| 32 | 32 | value = int(raw) |
| 33 | 33 | except ValueError as exc: |
| 34 | - raise ValueError( | |
| 35 | - f"WORLD_SIZE env var is not an integer: {raw!r}" | |
| 36 | - ) from exc | |
| 34 | + raise ValueError(f"WORLD_SIZE env var is not an integer: {raw!r}") from exc | |
| 37 | 35 | if value < 1: |
| 38 | 36 | return 1 |
| 39 | 37 | return value |
@@ -53,9 +51,7 @@ def detect_rank() -> int: | ||
| 53 | 51 | try: |
| 54 | 52 | value = int(raw) |
| 55 | 53 | except ValueError as exc: |
| 56 | - raise ValueError( | |
| 57 | - f"{key} env var is not an integer: {raw!r}" | |
| 58 | - ) from exc | |
| 54 | + raise ValueError(f"{key} env var is not an integer: {raw!r}") from exc | |
| 59 | 55 | if value < 0: |
| 60 | 56 | return 0 |
| 61 | 57 | return value |
src/dlm/train/multi_adapter/router.pymodified@@ -107,7 +107,6 @@ def sections_for(parsed: ParsedDlm, adapter_name: str) -> list[Section]: | ||
| 107 | 107 | plan = build_plan(parsed) |
| 108 | 108 | if adapter_name not in plan.by_adapter: |
| 109 | 109 | raise UnknownAdapterError( |
| 110 | - f"adapter {adapter_name!r} not declared " | |
| 111 | - f"(declared: {sorted(plan.by_adapter)})" | |
| 110 | + f"adapter {adapter_name!r} not declared (declared: {sorted(plan.by_adapter)})" | |
| 112 | 111 | ) |
| 113 | 112 | return plan.by_adapter[adapter_name] |
src/dlm/train/preference/dpo_phase.pymodified@@ -351,9 +351,7 @@ def _build_real_dpo_trainer( # pragma: no cover | ||
| 351 | 351 | # Policy: base + the SFT-trained adapter as trainable. |
| 352 | 352 | base_model = load_base_model(spec, plan) |
| 353 | 353 | adapter_dir = store.adapter_version(reference_adapter_version) |
| 354 | - policy_model = PeftModel.from_pretrained( | |
| 355 | - base_model, str(adapter_dir), is_trainable=True | |
| 356 | - ) | |
| 354 | + policy_model = PeftModel.from_pretrained(base_model, str(adapter_dir), is_trainable=True) | |
| 357 | 355 | |
| 358 | 356 | # Reference: frozen per preference.reference mode. We reload a |
| 359 | 357 | # clean base for the reference rather than sharing `base_model` so |
@@ -376,9 +374,7 @@ def _build_real_dpo_trainer( # pragma: no cover | ||
| 376 | 374 | doc_ds = build_dpo_dataset(list(parsed.sections)) |
| 377 | 375 | rng = _random.Random(seed + reference_adapter_version) |
| 378 | 376 | now = datetime.now(UTC).replace(tzinfo=None, microsecond=0) |
| 379 | - replay_rows = replay.sample_preference_rows( | |
| 380 | - k=max(8, 2 * len(doc_ds)), now=now, rng=rng | |
| 381 | - ) | |
| 377 | + replay_rows = replay.sample_preference_rows(k=max(8, 2 * len(doc_ds)), now=now, rng=rng) | |
| 382 | 378 | if replay_rows: |
| 383 | 379 | replay_ds = Dataset.from_list(replay_rows) |
| 384 | 380 | train_ds = concatenate_datasets([doc_ds, replay_ds]) |
src/dlm/train/preference/dpo_trainer.pymodified@@ -119,9 +119,7 @@ def load_reference_model( # pragma: no cover | ||
| 119 | 119 | try: |
| 120 | 120 | ref = PeftModel.from_pretrained(model, str(adapter_path), is_trainable=False) |
| 121 | 121 | except Exception as exc: |
| 122 | - raise DpoReferenceLoadError( | |
| 123 | - adapter_path=str(adapter_path), cause=str(exc) | |
| 124 | - ) from exc | |
| 122 | + raise DpoReferenceLoadError(adapter_path=str(adapter_path), cause=str(exc)) from exc | |
| 125 | 123 | _freeze(ref) |
| 126 | 124 | return ref |
| 127 | 125 | |
src/dlm/train/preference/errors.pymodified@@ -36,8 +36,6 @@ class DpoReferenceLoadError(DpoPhaseError): | ||
| 36 | 36 | adapter-version path that couldn't be opened.""" |
| 37 | 37 | |
| 38 | 38 | def __init__(self, *, adapter_path: str, cause: str) -> None: |
| 39 | - super().__init__( | |
| 40 | - f"could not load DPO reference model from {adapter_path}: {cause}" | |
| 41 | - ) | |
| 39 | + super().__init__(f"could not load DPO reference model from {adapter_path}: {cause}") | |
| 42 | 40 | self.adapter_path = adapter_path |
| 43 | 41 | self.cause = cause |
src/dlm/train/preference/orpo_phase.pymodified@@ -42,7 +42,6 @@ from dlm.train.trainer import ( | ||
| 42 | 42 | ) |
| 43 | 43 | |
| 44 | 44 | if TYPE_CHECKING: |
| 45 | - | |
| 46 | 45 | from dlm.base_models import BaseModelSpec |
| 47 | 46 | from dlm.doc.parser import ParsedDlm |
| 48 | 47 | from dlm.hardware.capabilities import Capabilities |
@@ -294,9 +293,7 @@ def _build_real_orpo_trainer( # pragma: no cover | ||
| 294 | 293 | |
| 295 | 294 | base_model = load_base_model(spec, plan) |
| 296 | 295 | adapter_dir = store.adapter_version(reference_adapter_version) |
| 297 | - policy_model = PeftModel.from_pretrained( | |
| 298 | - base_model, str(adapter_dir), is_trainable=True | |
| 299 | - ) | |
| 296 | + policy_model = PeftModel.from_pretrained(base_model, str(adapter_dir), is_trainable=True) | |
| 300 | 297 | |
| 301 | 298 | tok_bringup = prepare_tokenizer(spec.hf_id, spec.revision) |
| 302 | 299 | |
@@ -305,9 +302,7 @@ def _build_real_orpo_trainer( # pragma: no cover | ||
| 305 | 302 | doc_ds = build_dpo_dataset(list(parsed.sections)) |
| 306 | 303 | rng = _random.Random(seed + reference_adapter_version) |
| 307 | 304 | now = datetime.now(UTC).replace(tzinfo=None, microsecond=0) |
| 308 | - replay_rows = replay.sample_preference_rows( | |
| 309 | - k=max(8, 2 * len(doc_ds)), now=now, rng=rng | |
| 310 | - ) | |
| 305 | + replay_rows = replay.sample_preference_rows(k=max(8, 2 * len(doc_ds)), now=now, rng=rng) | |
| 311 | 306 | if replay_rows: |
| 312 | 307 | replay_ds = Dataset.from_list(replay_rows) |
| 313 | 308 | train_ds = concatenate_datasets([doc_ds, replay_ds]) |
src/dlm/train/preference/phase_orchestrator.pymodified@@ -116,9 +116,7 @@ def run_phases( | ||
| 116 | 116 | explicitly request DPO, skip with a warning instead of raising. |
| 117 | 117 | """ |
| 118 | 118 | sections = list(parsed.sections) |
| 119 | - pref_cfg = resolve_preference_enabled( | |
| 120 | - parsed.frontmatter.training.preference, sections | |
| 121 | - ) | |
| 119 | + pref_cfg = resolve_preference_enabled(parsed.frontmatter.training.preference, sections) | |
| 122 | 120 | results: list[PhaseResult] = [] |
| 123 | 121 | |
| 124 | 122 | sft_fn = sft_runner or _real_sft_runner() |
@@ -137,9 +135,7 @@ def run_phases( | ||
| 137 | 135 | sft_result = sft_fn(store, parsed, spec, plan, **sft_kwargs) |
| 138 | 136 | results.append(PhaseResult(phase="sft", result=sft_result)) |
| 139 | 137 | |
| 140 | - should_run_pref = phase == "preference" or ( | |
| 141 | - phase == "all" and pref_cfg.enabled | |
| 142 | - ) | |
| 138 | + should_run_pref = phase == "preference" or (phase == "all" and pref_cfg.enabled) | |
| 143 | 139 | if should_run_pref: |
| 144 | 140 | if not has_preference_content(sections): |
| 145 | 141 | if phase == "preference": |
src/dlm/train/trainer.pymodified@@ -1257,9 +1257,7 @@ def _expand_directives( | ||
| 1257 | 1257 | if parsed.frontmatter.training.sources is None: |
| 1258 | 1258 | return parsed, () |
| 1259 | 1259 | |
| 1260 | - base_path = ( | |
| 1261 | - parsed.source_path.parent if parsed.source_path is not None else Path.cwd() | |
| 1262 | - ) | |
| 1260 | + base_path = parsed.source_path.parent if parsed.source_path is not None else Path.cwd() | |
| 1263 | 1261 | result = expand_sources(parsed, base_path=base_path) |
| 1264 | 1262 | if not result.sections: |
| 1265 | 1263 | return parsed, result.provenance |
tests/integration/directives/test_auto_scaffold_cycle.pymodified@@ -96,9 +96,7 @@ def test_auto_scaffold_train_resume_cycle( | ||
| 96 | 96 | ), |
| 97 | 97 | ) |
| 98 | 98 | |
| 99 | - run1 = run_training( | |
| 100 | - store, parsed, spec, plan, mode="fresh", seed=42, max_steps=6 | |
| 101 | - ) | |
| 99 | + run1 = run_training(store, parsed, spec, plan, mode="fresh", seed=42, max_steps=6) | |
| 102 | 100 | assert run1.adapter_version == 1 |
| 103 | 101 | |
| 104 | 102 | # --- Second invocation: reuse scaffolded .dlm --------------------- |
@@ -118,9 +116,7 @@ def test_auto_scaffold_train_resume_cycle( | ||
| 118 | 116 | |
| 119 | 117 | # Train again — should produce adapter v0002 in the same store. |
| 120 | 118 | parsed2 = parse_file(result2.dlm_path) |
| 121 | - run2 = run_training( | |
| 122 | - store, parsed2, spec, plan, mode="fresh", seed=42, max_steps=6 | |
| 123 | - ) | |
| 119 | + run2 = run_training(store, parsed2, spec, plan, mode="fresh", seed=42, max_steps=6) | |
| 124 | 120 | assert run2.adapter_version == 2 |
| 125 | 121 | |
| 126 | 122 | manifest = load_manifest(store.manifest) |
tests/integration/directives/test_dlm_dir_descent.pymodified@@ -23,31 +23,31 @@ _VALID_ULID = "01HZ4X7TGZM3J1A2B3C4D5E6F7" | ||
| 23 | 23 | def _build_tree(root: Path) -> None: |
| 24 | 24 | """Build a repo fixture: |
| 25 | 25 | |
| 26 | - root/ | |
| 26 | + root/ | |
| 27 | + .dlm/ | |
| 28 | + training.yaml include: ['src/**/*.py', 'docs/**/*.md'] | |
| 29 | + exclude: ['**/test_*.py'] | |
| 30 | + metadata: {language: python} | |
| 31 | + ignore *.log | |
| 32 | + src/ | |
| 33 | + main.py | |
| 34 | + test_main.py | |
| 35 | + vendor/ | |
| 27 | 36 | .dlm/ |
| 28 | - training.yaml include: ['src/**/*.py', 'docs/**/*.md'] | |
| 29 | - exclude: ['**/test_*.py'] | |
| 30 | - metadata: {language: python} | |
| 31 | - ignore *.log | |
| 32 | - src/ | |
| 33 | - main.py | |
| 34 | - test_main.py | |
| 35 | - vendor/ | |
| 36 | - .dlm/ | |
| 37 | - training.yaml exclude_defaults: false | |
| 38 | - metadata: {vendor: true_yes} | |
| 39 | - .git_shim/ (bare dir w/ file to prove defaults off) | |
| 40 | - HEAD | |
| 41 | - dep.py | |
| 42 | - docs/ | |
| 43 | - guide.md | |
| 44 | - .dlm/ | |
| 45 | - ignore !draft.md (re-include what parent excluded? N/A) | |
| 46 | - draft.md | |
| 47 | - debug.log | |
| 48 | - .env.local | |
| 49 | - build/ | |
| 50 | - output.py | |
| 37 | + training.yaml exclude_defaults: false | |
| 38 | + metadata: {vendor: true_yes} | |
| 39 | + .git_shim/ (bare dir w/ file to prove defaults off) | |
| 40 | + HEAD | |
| 41 | + dep.py | |
| 42 | + docs/ | |
| 43 | + guide.md | |
| 44 | + .dlm/ | |
| 45 | + ignore !draft.md (re-include what parent excluded? N/A) | |
| 46 | + draft.md | |
| 47 | + debug.log | |
| 48 | + .env.local | |
| 49 | + build/ | |
| 50 | + output.py | |
| 51 | 51 | """ |
| 52 | 52 | (root / ".dlm").mkdir() |
| 53 | 53 | (root / ".dlm" / "training.yaml").write_text( |
@@ -66,9 +66,7 @@ def _build_tree(root: Path) -> None: | ||
| 66 | 66 | (root / "src" / "vendor").mkdir() |
| 67 | 67 | (root / "src" / "vendor" / ".dlm").mkdir() |
| 68 | 68 | (root / "src" / "vendor" / ".dlm" / "training.yaml").write_text( |
| 69 | - "dlm_training_version: 1\n" | |
| 70 | - "exclude_defaults: false\n" | |
| 71 | - "metadata:\n vendor: true_yes\n", | |
| 69 | + "dlm_training_version: 1\nexclude_defaults: false\nmetadata:\n vendor: true_yes\n", | |
| 72 | 70 | encoding="utf-8", |
| 73 | 71 | ) |
| 74 | 72 | (root / "src" / "vendor" / "dep.py").write_text("def dep(): pass\n") |
tests/integration/directives/test_full_cycle.pymodified@@ -60,8 +60,7 @@ def test_directive_tree_trains_and_summarizes( | ||
| 60 | 60 | tree = home / "src" |
| 61 | 61 | tree.mkdir() |
| 62 | 62 | (tree / "a.py").write_text( |
| 63 | - "def add(x, y):\n return x + y\n\n" | |
| 64 | - "def sub(x, y):\n return x - y\n", | |
| 63 | + "def add(x, y):\n return x + y\n\ndef sub(x, y):\n return x - y\n", | |
| 65 | 64 | encoding="utf-8", |
| 66 | 65 | ) |
| 67 | 66 | (tree / "b.py").write_text( |
tests/integration/metrics/test_full_cycle.pymodified@@ -33,9 +33,7 @@ def test_trained_store_has_metrics_rows( # pragma: no cover - slow path | ||
| 33 | 33 | runs = recent_runs(trained_store.store.root, limit=10) |
| 34 | 34 | assert runs, "trainer.run() did not record any runs" |
| 35 | 35 | latest = runs[0] |
| 36 | - assert latest.status in ("ok", "running"), ( | |
| 37 | - f"expected 'ok' or 'running', got {latest.status!r}" | |
| 38 | - ) | |
| 36 | + assert latest.status in ("ok", "running"), f"expected 'ok' or 'running', got {latest.status!r}" | |
| 39 | 37 | |
| 40 | 38 | steps = steps_for_run(trained_store.store.root, latest.run_id) |
| 41 | 39 | # The tiny-model fixture runs at least one step. |
tests/integration/train/multi_adapter/test_two_adapters.pymodified@@ -148,12 +148,12 @@ def test_two_adapters_each_get_their_own_version_history( | ||
| 148 | 148 | import json |
| 149 | 149 | |
| 150 | 150 | k_cfg = json.loads( |
| 151 | - (store.adapter_version_for("knowledge", 1) / "adapter_config.json") | |
| 152 | - .read_text(encoding="utf-8") | |
| 151 | + (store.adapter_version_for("knowledge", 1) / "adapter_config.json").read_text( | |
| 152 | + encoding="utf-8" | |
| 153 | + ) | |
| 153 | 154 | ) |
| 154 | 155 | t_cfg = json.loads( |
| 155 | - (store.adapter_version_for("tone", 1) / "adapter_config.json") | |
| 156 | - .read_text(encoding="utf-8") | |
| 156 | + (store.adapter_version_for("tone", 1) / "adapter_config.json").read_text(encoding="utf-8") | |
| 157 | 157 | ) |
| 158 | 158 | assert k_cfg["r"] == 8, f"knowledge lora_r: {k_cfg['r']}" |
| 159 | 159 | assert t_cfg["r"] == 4, f"tone lora_r: {t_cfg['r']}" |
tests/integration/train/multi_adapter/test_weighted_merge.pymodified@@ -79,9 +79,7 @@ def _train_two_adapters( | ||
| 79 | 79 | make_dlm( |
| 80 | 80 | sections=[prose(_PROSE)], |
| 81 | 81 | base_model="smollm2-135m", |
| 82 | - training_overrides={ | |
| 83 | - "adapters": {"knowledge": {}, "tone": {}} | |
| 84 | - }, | |
| 82 | + training_overrides={"adapters": {"knowledge": {}, "tone": {}}}, | |
| 85 | 83 | ), |
| 86 | 84 | encoding="utf-8", |
| 87 | 85 | ) |
@@ -134,9 +132,7 @@ def test_weighted_merge_saves_tokenizer_files( | ||
| 134 | 132 | |
| 135 | 133 | spec = resolve_base_model(parsed.frontmatter.base_model, accept_license=True) |
| 136 | 134 | cached = download_spec(spec, local_files_only=True) |
| 137 | - base_model = AutoModelForCausalLM.from_pretrained( | |
| 138 | - str(cached.path), revision=spec.revision | |
| 139 | - ) | |
| 135 | + base_model = AutoModelForCausalLM.from_pretrained(str(cached.path), revision=spec.revision) | |
| 140 | 136 | |
| 141 | 137 | entries = [ |
| 142 | 138 | MixEntry(name="knowledge", weight=1.0), |
@@ -194,9 +190,7 @@ def test_weighted_merge_passes_preflight_tokenizer_vocab( | ||
| 194 | 190 | |
| 195 | 191 | spec = resolve_base_model(parsed.frontmatter.base_model, accept_license=True) |
| 196 | 192 | cached = download_spec(spec, local_files_only=True) |
| 197 | - base_model = AutoModelForCausalLM.from_pretrained( | |
| 198 | - str(cached.path), revision=spec.revision | |
| 199 | - ) | |
| 193 | + base_model = AutoModelForCausalLM.from_pretrained(str(cached.path), revision=spec.revision) | |
| 200 | 194 | |
| 201 | 195 | entries = [ |
| 202 | 196 | MixEntry(name="knowledge", weight=0.7), |
tests/integration/train/preference/test_dpo_tinymodel.pymodified@@ -67,9 +67,17 @@ def _five_terse_preference_triples() -> str: | ||
| 67 | 67 | the rejected one — the direction DPO should push completions.""" |
| 68 | 68 | pairs = [ |
| 69 | 69 | ("What is 2 + 2?", "4.", "The sum of two and two is four, a basic arithmetic fact."), |
| 70 | - ("What color is grass?", "Green.", "Grass is typically a vibrant shade of green most of the year."), | |
| 70 | + ( | |
| 71 | + "What color is grass?", | |
| 72 | + "Green.", | |
| 73 | + "Grass is typically a vibrant shade of green most of the year.", | |
| 74 | + ), | |
| 71 | 75 | ("Is water wet?", "Yes.", "Water is generally considered wet in most everyday contexts."), |
| 72 | - ("Do birds fly?", "Most do.", "The majority of bird species can indeed fly, though a few cannot."), | |
| 76 | + ( | |
| 77 | + "Do birds fly?", | |
| 78 | + "Most do.", | |
| 79 | + "The majority of bird species can indeed fly, though a few cannot.", | |
| 80 | + ), | |
| 73 | 81 | ("What's 10 - 3?", "7.", "Ten minus three equals seven in standard arithmetic."), |
| 74 | 82 | ] |
| 75 | 83 | parts: list[str] = [] |
tests/integration/train/preference/test_orpo_tinymodel.pymodified@@ -77,9 +77,17 @@ def test_orpo_phase_writes_second_adapter_version(trained_store) -> None: # typ | ||
| 77 | 77 | def _five_terse_preference_triples() -> str: |
| 78 | 78 | pairs = [ |
| 79 | 79 | ("What is 2 + 2?", "4.", "The sum of two and two is four, a basic arithmetic fact."), |
| 80 | - ("What color is grass?", "Green.", "Grass is typically a vibrant shade of green most of the year."), | |
| 80 | + ( | |
| 81 | + "What color is grass?", | |
| 82 | + "Green.", | |
| 83 | + "Grass is typically a vibrant shade of green most of the year.", | |
| 84 | + ), | |
| 81 | 85 | ("Is water wet?", "Yes.", "Water is generally considered wet in most everyday contexts."), |
| 82 | - ("Do birds fly?", "Most do.", "The majority of bird species can indeed fly, though a few cannot."), | |
| 86 | + ( | |
| 87 | + "Do birds fly?", | |
| 88 | + "Most do.", | |
| 89 | + "The majority of bird species can indeed fly, though a few cannot.", | |
| 90 | + ), | |
| 83 | 91 | ("What's 10 - 3?", "7.", "Ten minus three equals seven in standard arithmetic."), |
| 84 | 92 | ] |
| 85 | 93 | parts: list[str] = [] |
tests/unit/cli/test_prompt_adapter_flag.pymodified@@ -48,9 +48,7 @@ def _scaffold_multi_doc(tmp_path: Path) -> Path: | ||
| 48 | 48 | |
| 49 | 49 | |
| 50 | 50 | class TestFlatDocRejectsAdapter: |
| 51 | - def test_single_adapter_doc_with_adapter_flag_exits_2( | |
| 52 | - self, tmp_path: Path | |
| 53 | - ) -> None: | |
| 51 | + def test_single_adapter_doc_with_adapter_flag_exits_2(self, tmp_path: Path) -> None: | |
| 54 | 52 | doc = _scaffold_flat_doc(tmp_path) |
| 55 | 53 | runner = CliRunner() |
| 56 | 54 | result = runner.invoke( |
tests/unit/cli/test_serve_guard.pymodified@@ -18,12 +18,7 @@ from dlm.cli.app import app | ||
| 18 | 18 | |
| 19 | 19 | def _write_minimal_dlm(path: Path, dlm_id: str = "01KPQ9M3" + "0" * 18) -> None: |
| 20 | 20 | path.write_text( |
| 21 | - "---\n" | |
| 22 | - f"dlm_id: {dlm_id}\n" | |
| 23 | - "dlm_version: 6\n" | |
| 24 | - "base_model: smollm2-135m\n" | |
| 25 | - "---\n" | |
| 26 | - "body\n", | |
| 21 | + f"---\ndlm_id: {dlm_id}\ndlm_version: 6\nbase_model: smollm2-135m\n---\nbody\n", | |
| 27 | 22 | encoding="utf-8", |
| 28 | 23 | ) |
| 29 | 24 | |
@@ -40,8 +35,10 @@ class TestServeUntrainedGuard: | ||
| 40 | 35 | result = runner.invoke( |
| 41 | 36 | app, |
| 42 | 37 | [ |
| 43 | - "--home", str(tmp_path / "home"), | |
| 44 | - "serve", str(doc), | |
| 38 | + "--home", | |
| 39 | + str(tmp_path / "home"), | |
| 40 | + "serve", | |
| 41 | + str(doc), | |
| 45 | 42 | ], |
| 46 | 43 | ) |
| 47 | 44 | assert result.exit_code == 1, result.output |
tests/unit/cli/test_train_scaffold_cli.pymodified@@ -32,9 +32,7 @@ def _captured() -> dict[str, Any]: | ||
| 32 | 32 | return {} |
| 33 | 33 | |
| 34 | 34 | |
| 35 | -def _install_capturing_fake( | |
| 36 | - monkeypatch: pytest.MonkeyPatch, captured: dict[str, Any] | |
| 37 | -) -> None: | |
| 35 | +def _install_capturing_fake(monkeypatch: pytest.MonkeyPatch, captured: dict[str, Any]) -> None: | |
| 38 | 36 | """Replace `run_phases` with a stub that records call args and |
| 39 | 37 | returns `[]` (triggering the CLI's "no-op: nothing to train" path |
| 40 | 38 | with exit code 0). The scaffold + manifest + expand_sources pipeline |
@@ -80,10 +78,14 @@ class TestDlmTrainDirScaffold: | ||
| 80 | 78 | result = runner.invoke( |
| 81 | 79 | app, |
| 82 | 80 | [ |
| 83 | - "--home", str(tmp_path / "home"), | |
| 84 | - "train", str(corpus), | |
| 85 | - "--base", "smollm2-135m", | |
| 86 | - "--include", "**/*.md", | |
| 81 | + "--home", | |
| 82 | + str(tmp_path / "home"), | |
| 83 | + "train", | |
| 84 | + str(corpus), | |
| 85 | + "--base", | |
| 86 | + "smollm2-135m", | |
| 87 | + "--include", | |
| 88 | + "**/*.md", | |
| 87 | 89 | ], |
| 88 | 90 | ) |
| 89 | 91 | |
@@ -105,12 +107,8 @@ class TestDlmTrainDirScaffold: | ||
| 105 | 107 | always 0.""" |
| 106 | 108 | corpus = tmp_path / "corpus" |
| 107 | 109 | corpus.mkdir() |
| 108 | - (corpus / "alpha.md").write_text( | |
| 109 | - "# Alpha\nalpha-unique-token\n", encoding="utf-8" | |
| 110 | - ) | |
| 111 | - (corpus / "beta.md").write_text( | |
| 112 | - "# Beta\nbeta-unique-token\n", encoding="utf-8" | |
| 113 | - ) | |
| 110 | + (corpus / "alpha.md").write_text("# Alpha\nalpha-unique-token\n", encoding="utf-8") | |
| 111 | + (corpus / "beta.md").write_text("# Beta\nbeta-unique-token\n", encoding="utf-8") | |
| 114 | 112 | |
| 115 | 113 | captured = _captured() |
| 116 | 114 | _install_capturing_fake(monkeypatch, captured) |
@@ -120,10 +118,14 @@ class TestDlmTrainDirScaffold: | ||
| 120 | 118 | result = runner.invoke( |
| 121 | 119 | app, |
| 122 | 120 | [ |
| 123 | - "--home", str(tmp_path / "home"), | |
| 124 | - "train", str(corpus), | |
| 125 | - "--base", "smollm2-135m", | |
| 126 | - "--include", "**/*.md", | |
| 121 | + "--home", | |
| 122 | + str(tmp_path / "home"), | |
| 123 | + "train", | |
| 124 | + str(corpus), | |
| 125 | + "--base", | |
| 126 | + "smollm2-135m", | |
| 127 | + "--include", | |
| 128 | + "**/*.md", | |
| 127 | 129 | ], |
| 128 | 130 | ) |
| 129 | 131 | |
@@ -146,12 +148,8 @@ class TestDlmTrainDirScaffold: | ||
| 146 | 148 | ) |
| 147 | 149 | combined = _section_texts(expanded.sections) |
| 148 | 150 | rendered = "\n".join(f" {s.content[:80]!r}" for s in expanded.sections) |
| 149 | - assert "alpha-unique-token" in combined, ( | |
| 150 | - "B2: alpha.md not ingested. got:\n" + rendered | |
| 151 | - ) | |
| 152 | - assert "beta-unique-token" in combined, ( | |
| 153 | - "B2: beta.md not ingested. got:\n" + rendered | |
| 154 | - ) | |
| 151 | + assert "alpha-unique-token" in combined, "B2: alpha.md not ingested. got:\n" + rendered | |
| 152 | + assert "beta-unique-token" in combined, "B2: beta.md not ingested. got:\n" + rendered | |
| 155 | 153 | assert expanded.provenance[0].file_count == 2 |
| 156 | 154 | assert expanded.provenance[0].total_bytes > 0 |
| 157 | 155 | |
@@ -173,10 +171,14 @@ class TestDlmTrainDirScaffold: | ||
| 173 | 171 | result = runner.invoke( |
| 174 | 172 | app, |
| 175 | 173 | [ |
| 176 | - "--home", str(tmp_path / "home"), | |
| 177 | - "train", str(corpus), | |
| 178 | - "--base", "smollm2-135m", | |
| 179 | - "--include", "**/*.md", | |
| 174 | + "--home", | |
| 175 | + str(tmp_path / "home"), | |
| 176 | + "train", | |
| 177 | + str(corpus), | |
| 178 | + "--base", | |
| 179 | + "smollm2-135m", | |
| 180 | + "--include", | |
| 181 | + "**/*.md", | |
| 180 | 182 | ], |
| 181 | 183 | ) |
| 182 | 184 | |
@@ -217,10 +219,14 @@ class TestDlmTrainDirScaffold: | ||
| 217 | 219 | r1 = runner.invoke( |
| 218 | 220 | app, |
| 219 | 221 | [ |
| 220 | - "--home", str(tmp_path / "home"), | |
| 221 | - "train", str(corpus), | |
| 222 | - "--base", "smollm2-135m", | |
| 223 | - "--include", "**/*.md", | |
| 222 | + "--home", | |
| 223 | + str(tmp_path / "home"), | |
| 224 | + "train", | |
| 225 | + str(corpus), | |
| 226 | + "--base", | |
| 227 | + "smollm2-135m", | |
| 228 | + "--include", | |
| 229 | + "**/*.md", | |
| 224 | 230 | ], |
| 225 | 231 | ) |
| 226 | 232 | assert r1.exit_code == 0, r1.output |
@@ -235,13 +241,14 @@ class TestDlmTrainDirScaffold: | ||
| 235 | 241 | r2 = runner.invoke( |
| 236 | 242 | app, |
| 237 | 243 | [ |
| 238 | - "--home", str(tmp_path / "home"), | |
| 239 | - "train", str(corpus), | |
| 244 | + "--home", | |
| 245 | + str(tmp_path / "home"), | |
| 246 | + "train", | |
| 247 | + str(corpus), | |
| 240 | 248 | ], |
| 241 | 249 | ) |
| 242 | 250 | |
| 243 | 251 | assert r2.exit_code == 0, r2.output |
| 244 | 252 | assert manifest_path.stat().st_mtime_ns == first_mtime, ( |
| 245 | - "manifest was rewritten on the resume path; " | |
| 246 | - "training history could be lost" | |
| 253 | + "manifest was rewritten on the resume path; training history could be lost" | |
| 247 | 254 | ) |
tests/unit/directives/test_cache.pymodified@@ -98,9 +98,7 @@ class TestInvalidation: | ||
| 98 | 98 | cache.put(key_a, _tokens(4)) |
| 99 | 99 | assert cache.get(key_b) is None |
| 100 | 100 | |
| 101 | - def test_missing_file_recovers( | |
| 102 | - self, tmp_path: Path, caplog: pytest.LogCaptureFixture | |
| 103 | - ) -> None: | |
| 101 | + def test_missing_file_recovers(self, tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None: | |
| 104 | 102 | """If the on-disk entry vanishes under us, get() should treat |
| 105 | 103 | it as a miss and clean up the stale manifest row.""" |
| 106 | 104 | import logging |
tests/unit/directives/test_defaults.pymodified@@ -93,6 +93,4 @@ def test_default_excludes_catch_known_traps(path: str) -> None: | ||
| 93 | 93 | ], |
| 94 | 94 | ) |
| 95 | 95 | def test_default_excludes_leave_source_alone(path: str) -> None: |
| 96 | - assert not _matches_any_default(path), ( | |
| 97 | - f"DEFAULT_EXCLUDES wrongly caught: {path}" | |
| 98 | - ) | |
| 96 | + assert not _matches_any_default(path), f"DEFAULT_EXCLUDES wrongly caught: {path}" | |
tests/unit/directives/test_discovery.pymodified@@ -79,9 +79,7 @@ def test_schema_violation_logs_and_continues( | ||
| 79 | 79 | tmp_path: Path, caplog: pytest.LogCaptureFixture |
| 80 | 80 | ) -> None: |
| 81 | 81 | (tmp_path / ".dlm").mkdir() |
| 82 | - (tmp_path / ".dlm" / "training.yaml").write_text( | |
| 83 | - "dlm_training_version: 1\nunknown_key: bad\n" | |
| 84 | - ) | |
| 82 | + (tmp_path / ".dlm" / "training.yaml").write_text("dlm_training_version: 1\nunknown_key: bad\n") | |
| 85 | 83 | caplog.set_level(logging.WARNING, logger="dlm.directives.discovery") |
| 86 | 84 | configs = discover_configs(tmp_path) |
| 87 | 85 | assert configs[0].config is None |
@@ -101,9 +99,7 @@ def test_training_yaml_non_mapping_top_level( | ||
| 101 | 99 | |
| 102 | 100 | def test_both_files_coexist(tmp_path: Path) -> None: |
| 103 | 101 | (tmp_path / ".dlm").mkdir() |
| 104 | - (tmp_path / ".dlm" / "training.yaml").write_text( | |
| 105 | - "dlm_training_version: 1\nexclude: ['a']\n" | |
| 106 | - ) | |
| 102 | + (tmp_path / ".dlm" / "training.yaml").write_text("dlm_training_version: 1\nexclude: ['a']\n") | |
| 107 | 103 | (tmp_path / ".dlm" / "ignore").write_text("*.tmp\n") |
| 108 | 104 | (c,) = discover_configs(tmp_path) |
| 109 | 105 | assert c.config is not None |
tests/unit/directives/test_expand.pymodified@@ -81,12 +81,7 @@ def test_max_files_truncates_deterministically(tmp_path: Path) -> None: | ||
| 81 | 81 | src.mkdir() |
| 82 | 82 | for i in range(5): |
| 83 | 83 | (src / f"{i}.py").write_text(f"# {i}\n") |
| 84 | - body = ( | |
| 85 | - " sources:\n" | |
| 86 | - " - path: src\n" | |
| 87 | - " include: ['**/*.py']\n" | |
| 88 | - " max_files: 2\n" | |
| 89 | - ) | |
| 84 | + body = " sources:\n - path: src\n include: ['**/*.py']\n max_files: 2\n" | |
| 90 | 85 | parsed, _ = _make_parsed(body, tmp_path) |
| 91 | 86 | result = expand_sources(parsed, base_path=tmp_path) # type: ignore[arg-type] |
| 92 | 87 | # Sorted: 0.py, 1.py land; 2/3/4 get dropped |
@@ -101,12 +96,7 @@ def test_max_bytes_per_file_skips_oversize(tmp_path: Path) -> None: | ||
| 101 | 96 | src.mkdir() |
| 102 | 97 | (src / "small.py").write_text("x\n") # 2 bytes |
| 103 | 98 | (src / "big.py").write_text("x" * 100) |
| 104 | - body = ( | |
| 105 | - " sources:\n" | |
| 106 | - " - path: src\n" | |
| 107 | - " include: ['**/*.py']\n" | |
| 108 | - " max_bytes_per_file: 10\n" | |
| 109 | - ) | |
| 99 | + body = " sources:\n - path: src\n include: ['**/*.py']\n max_bytes_per_file: 10\n" | |
| 110 | 100 | parsed, _ = _make_parsed(body, tmp_path) |
| 111 | 101 | result = expand_sources(parsed, base_path=tmp_path) # type: ignore[arg-type] |
| 112 | 102 | assert len(result.sections) == 1 |
@@ -157,11 +147,7 @@ def test_strict_policy_refuses_external_path(tmp_path: Path) -> None: | ||
| 157 | 147 | outside.mkdir(exist_ok=True) |
| 158 | 148 | try: |
| 159 | 149 | (outside / "a.py").write_text("x") |
| 160 | - body = ( | |
| 161 | - " sources_policy: strict\n" | |
| 162 | - " sources:\n" | |
| 163 | - f" - path: {outside}\n" | |
| 164 | - ) | |
| 150 | + body = f" sources_policy: strict\n sources:\n - path: {outside}\n" | |
| 165 | 151 | parsed, _ = _make_parsed(body, tmp_path) |
| 166 | 152 | with pytest.raises(DirectivePolicyError): |
| 167 | 153 | expand_sources(parsed, base_path=tmp_path) # type: ignore[arg-type] |
@@ -176,7 +162,7 @@ def test_permissive_policy_allows_external_path(tmp_path: Path) -> None: | ||
| 176 | 162 | outside.mkdir(exist_ok=True) |
| 177 | 163 | try: |
| 178 | 164 | (outside / "a.py").write_text("ok\n") |
| 179 | - body = " sources:\n" f" - path: {outside}\n include: ['**/*.py']\n" | |
| 165 | + body = f" sources:\n - path: {outside}\n include: ['**/*.py']\n" | |
| 180 | 166 | parsed, _ = _make_parsed(body, tmp_path) |
| 181 | 167 | result = expand_sources(parsed, base_path=tmp_path) # type: ignore[arg-type] |
| 182 | 168 | assert len(result.sections) == 1 |
tests/unit/directives/test_merge.pymodified@@ -99,18 +99,24 @@ def test_training_yaml_exclude_blocks_file(tmp_path: Path) -> None: | ||
| 99 | 99 | ) |
| 100 | 100 | directive = _directive(tmp_path, include=("**/*.py",)) |
| 101 | 101 | configs = discover_configs(tmp_path) |
| 102 | - assert effective_config_for( | |
| 103 | - tmp_path / "src" / "main.py", | |
| 104 | - source_root=tmp_path, | |
| 105 | - discovered=configs, | |
| 106 | - parent_directive=directive, | |
| 107 | - ).included is True | |
| 108 | - assert effective_config_for( | |
| 109 | - tmp_path / "src" / "test_main.py", | |
| 110 | - source_root=tmp_path, | |
| 111 | - discovered=configs, | |
| 112 | - parent_directive=directive, | |
| 113 | - ).included is False | |
| 102 | + assert ( | |
| 103 | + effective_config_for( | |
| 104 | + tmp_path / "src" / "main.py", | |
| 105 | + source_root=tmp_path, | |
| 106 | + discovered=configs, | |
| 107 | + parent_directive=directive, | |
| 108 | + ).included | |
| 109 | + is True | |
| 110 | + ) | |
| 111 | + assert ( | |
| 112 | + effective_config_for( | |
| 113 | + tmp_path / "src" / "test_main.py", | |
| 114 | + source_root=tmp_path, | |
| 115 | + discovered=configs, | |
| 116 | + parent_directive=directive, | |
| 117 | + ).included | |
| 118 | + is False | |
| 119 | + ) | |
| 114 | 120 | |
| 115 | 121 | |
| 116 | 122 | # ---- .dlm/ignore negation -------------------------------------------------- |
@@ -136,18 +142,24 @@ def test_ignore_negation_re_includes_file(tmp_path: Path) -> None: | ||
| 136 | 142 | _write(tmp_path / ".dlm" / "ignore", "*.log\n!special.log\n") |
| 137 | 143 | directive = _directive(tmp_path) |
| 138 | 144 | configs = discover_configs(tmp_path) |
| 139 | - assert effective_config_for( | |
| 140 | - tmp_path / "debug.log", | |
| 141 | - source_root=tmp_path, | |
| 142 | - discovered=configs, | |
| 143 | - parent_directive=directive, | |
| 144 | - ).included is False | |
| 145 | - assert effective_config_for( | |
| 146 | - tmp_path / "special.log", | |
| 147 | - source_root=tmp_path, | |
| 148 | - discovered=configs, | |
| 149 | - parent_directive=directive, | |
| 150 | - ).included is True | |
| 145 | + assert ( | |
| 146 | + effective_config_for( | |
| 147 | + tmp_path / "debug.log", | |
| 148 | + source_root=tmp_path, | |
| 149 | + discovered=configs, | |
| 150 | + parent_directive=directive, | |
| 151 | + ).included | |
| 152 | + is False | |
| 153 | + ) | |
| 154 | + assert ( | |
| 155 | + effective_config_for( | |
| 156 | + tmp_path / "special.log", | |
| 157 | + source_root=tmp_path, | |
| 158 | + discovered=configs, | |
| 159 | + parent_directive=directive, | |
| 160 | + ).included | |
| 161 | + is True | |
| 162 | + ) | |
| 151 | 163 | |
| 152 | 164 | |
| 153 | 165 | def test_deeper_ignore_negation_unblocks_parent_exclude(tmp_path: Path) -> None: |
@@ -176,18 +188,24 @@ def test_default_excludes_apply_by_default(tmp_path: Path) -> None: | ||
| 176 | 188 | _write(tmp_path / "src" / "main.py", "x") |
| 177 | 189 | directive = _directive(tmp_path) |
| 178 | 190 | configs = discover_configs(tmp_path) |
| 179 | - assert effective_config_for( | |
| 180 | - tmp_path / ".git" / "HEAD", | |
| 181 | - source_root=tmp_path, | |
| 182 | - discovered=configs, | |
| 183 | - parent_directive=directive, | |
| 184 | - ).included is False | |
| 185 | - assert effective_config_for( | |
| 186 | - tmp_path / "src" / "main.py", | |
| 187 | - source_root=tmp_path, | |
| 188 | - discovered=configs, | |
| 189 | - parent_directive=directive, | |
| 190 | - ).included is True | |
| 191 | + assert ( | |
| 192 | + effective_config_for( | |
| 193 | + tmp_path / ".git" / "HEAD", | |
| 194 | + source_root=tmp_path, | |
| 195 | + discovered=configs, | |
| 196 | + parent_directive=directive, | |
| 197 | + ).included | |
| 198 | + is False | |
| 199 | + ) | |
| 200 | + assert ( | |
| 201 | + effective_config_for( | |
| 202 | + tmp_path / "src" / "main.py", | |
| 203 | + source_root=tmp_path, | |
| 204 | + discovered=configs, | |
| 205 | + parent_directive=directive, | |
| 206 | + ).included | |
| 207 | + is True | |
| 208 | + ) | |
| 191 | 209 | |
| 192 | 210 | |
| 193 | 211 | def test_exclude_defaults_false_disables_default_set(tmp_path: Path) -> None: |
@@ -230,9 +248,9 @@ def test_metadata_shallow_to_deep_merge(tmp_path: Path) -> None: | ||
| 230 | 248 | parent_directive=directive, |
| 231 | 249 | ) |
| 232 | 250 | assert eff.tags == { |
| 233 | - "language": "python", # from shallower | |
| 251 | + "language": "python", # from shallower | |
| 234 | 252 | "domain": "vendor_override", # deeper overrides shallower |
| 235 | - "source": "third_party", # from deeper only | |
| 253 | + "source": "third_party", # from deeper only | |
| 236 | 254 | } |
| 237 | 255 | |
| 238 | 256 | |
tests/unit/directives/test_safety.pymodified@@ -130,11 +130,7 @@ def test_enumerate_is_deterministic(tmp_path: Path) -> None: | ||
| 130 | 130 | def test_enumerate_exclude_wins(tmp_path: Path) -> None: |
| 131 | 131 | (tmp_path / "keep.py").write_text("x") |
| 132 | 132 | (tmp_path / "skip.py").write_text("x") |
| 133 | - got = list( | |
| 134 | - enumerate_matching_files( | |
| 135 | - tmp_path, include=("**/*.py",), exclude=("skip.py",) | |
| 136 | - ) | |
| 137 | - ) | |
| 133 | + got = list(enumerate_matching_files(tmp_path, include=("**/*.py",), exclude=("skip.py",))) | |
| 138 | 134 | assert [p.name for p in got] == ["keep.py"] |
| 139 | 135 | |
| 140 | 136 | |
tests/unit/doc/test_fence_adapter_suffix.pymodified@@ -43,9 +43,7 @@ class TestParseFenceSuffix: | ||
| 43 | 43 | assert instr[0].adapter == "tone" |
| 44 | 44 | |
| 45 | 45 | def test_preference_fence_adapter(self) -> None: |
| 46 | - parsed = _parse( | |
| 47 | - "::preference#knowledge::\n### Prompt\nq\n### Chosen\nc\n### Rejected\nr\n" | |
| 48 | - ) | |
| 46 | + parsed = _parse("::preference#knowledge::\n### Prompt\nq\n### Chosen\nc\n### Rejected\nr\n") | |
| 49 | 47 | pref = [s for s in parsed.sections if s.type == SectionType.PREFERENCE] |
| 50 | 48 | assert pref |
| 51 | 49 | assert pref[0].adapter == "knowledge" |
@@ -97,9 +95,7 @@ class TestSectionIdentityUnchanged: | ||
| 97 | 95 | """Routing is structural, not content — same content with and |
| 98 | 96 | without a `#adapter` suffix must produce the same section_id |
| 99 | 97 | so replay snapshots don't duplicate rows on routing edits.""" |
| 100 | - s_plain = Section( | |
| 101 | - type=SectionType.INSTRUCTION, content="### Q\nhi\n### A\nbye" | |
| 102 | - ) | |
| 98 | + s_plain = Section(type=SectionType.INSTRUCTION, content="### Q\nhi\n### A\nbye") | |
| 103 | 99 | s_routed = Section( |
| 104 | 100 | type=SectionType.INSTRUCTION, |
| 105 | 101 | content="### Q\nhi\n### A\nbye", |
tests/unit/doc/test_migration_v1_to_v2.pymodified@@ -87,9 +87,7 @@ class TestPartialDpoBlock: | ||
| 87 | 87 | |
| 88 | 88 | class TestReferenceRename: |
| 89 | 89 | def test_pre_dpo_adapter_becomes_pre_adapter(self) -> None: |
| 90 | - raw: dict[str, Any] = { | |
| 91 | - "training": {"dpo": {"reference": "pre_dpo_adapter"}} | |
| 92 | - } | |
| 90 | + raw: dict[str, Any] = {"training": {"dpo": {"reference": "pre_dpo_adapter"}}} | |
| 93 | 91 | out = migrate(raw) |
| 94 | 92 | assert out["training"]["preference"]["reference"] == "pre_adapter" |
| 95 | 93 | |
tests/unit/doc/test_round_trip_v4_adapters.pymodified@@ -51,9 +51,7 @@ def test_round_trip_v4_multi_adapter_doc_is_idempotent() -> None: | ||
| 51 | 51 | |
| 52 | 52 | once = serialize(parse_text(original)) |
| 53 | 53 | twice = serialize(parse_text(once)) |
| 54 | - assert once == twice, ( | |
| 55 | - "v4 adapters doc not idempotent under serialize round-trip" | |
| 56 | - ) | |
| 54 | + assert once == twice, "v4 adapters doc not idempotent under serialize round-trip" | |
| 57 | 55 | |
| 58 | 56 | |
| 59 | 57 | def test_round_trip_preserves_fence_suffixes() -> None: |
tests/unit/doc/test_schema.pymodified@@ -192,9 +192,7 @@ class TestTrainingConfigPreferenceSubfield: | ||
| 192 | 192 | |
| 193 | 193 | def test_rejects_unknown_field_inside_preference(self) -> None: |
| 194 | 194 | with pytest.raises(ValidationError): |
| 195 | - TrainingConfig.model_validate( | |
| 196 | - {"preference": {"enabled": True, "rubbish": 1}} | |
| 197 | - ) | |
| 195 | + TrainingConfig.model_validate({"preference": {"enabled": True, "rubbish": 1}}) | |
| 198 | 196 | |
| 199 | 197 | |
| 200 | 198 | class TestCptConfig: |
@@ -233,17 +231,13 @@ class TestTrainingConfigCptSubfield: | ||
| 233 | 231 | assert t.cpt.embed_warmup_steps == 0 |
| 234 | 232 | |
| 235 | 233 | def test_accepts_nested_dict_for_cpt(self) -> None: |
| 236 | - t = TrainingConfig.model_validate( | |
| 237 | - {"cpt": {"schedule": "dapt", "embed_warmup_steps": 200}} | |
| 238 | - ) | |
| 234 | + t = TrainingConfig.model_validate({"cpt": {"schedule": "dapt", "embed_warmup_steps": 200}}) | |
| 239 | 235 | assert t.cpt.schedule == "dapt" |
| 240 | 236 | assert t.cpt.embed_warmup_steps == 200 |
| 241 | 237 | |
| 242 | 238 | def test_rejects_unknown_field_inside_cpt(self) -> None: |
| 243 | 239 | with pytest.raises(ValidationError): |
| 244 | - TrainingConfig.model_validate( | |
| 245 | - {"cpt": {"schedule": "dapt", "rubbish": 1}} | |
| 246 | - ) | |
| 240 | + TrainingConfig.model_validate({"cpt": {"schedule": "dapt", "rubbish": 1}}) | |
| 247 | 241 | |
| 248 | 242 | |
| 249 | 243 | class TestAdapterConfig: |
@@ -317,15 +311,11 @@ class TestNamedAdapters: | ||
| 317 | 311 | |
| 318 | 312 | def test_flat_lora_r_with_block_rejected(self) -> None: |
| 319 | 313 | with pytest.raises(ValidationError, match="flat per-adapter fields"): |
| 320 | - TrainingConfig.model_validate( | |
| 321 | - {"lora_r": 32, "adapters": {"knowledge": {}}} | |
| 322 | - ) | |
| 314 | + TrainingConfig.model_validate({"lora_r": 32, "adapters": {"knowledge": {}}}) | |
| 323 | 315 | |
| 324 | 316 | def test_flat_learning_rate_with_block_rejected(self) -> None: |
| 325 | 317 | with pytest.raises(ValidationError, match="flat per-adapter fields"): |
| 326 | - TrainingConfig.model_validate( | |
| 327 | - {"learning_rate": 1e-3, "adapters": {"tone": {}}} | |
| 328 | - ) | |
| 318 | + TrainingConfig.model_validate({"learning_rate": 1e-3, "adapters": {"tone": {}}}) | |
| 329 | 319 | |
| 330 | 320 | def test_top_level_shared_knobs_allowed_alongside_block(self) -> None: |
| 331 | 321 | # seed, num_epochs, sequence_len, etc. are explicitly shared |
tests/unit/eval/test_summary.pymodified@@ -130,9 +130,7 @@ class TestMixedModeFields: | ||
| 130 | 130 | |
| 131 | 131 | class TestSplitLossByMode: |
| 132 | 132 | def test_mixed_rows(self) -> None: |
| 133 | - out = split_loss_by_mode( | |
| 134 | - [(1.0, "cpt"), (2.0, "cpt"), (0.5, "sft"), (1.5, "sft")] | |
| 135 | - ) | |
| 133 | + out = split_loss_by_mode([(1.0, "cpt"), (2.0, "cpt"), (0.5, "sft"), (1.5, "sft")]) | |
| 136 | 134 | assert out == LossByMode(cpt=1.5, sft=1.0) |
| 137 | 135 | |
| 138 | 136 | def test_all_cpt(self) -> None: |
@@ -151,9 +149,7 @@ class TestSplitLossByMode: | ||
| 151 | 149 | assert out.sft is None |
| 152 | 150 | |
| 153 | 151 | def test_unknown_modes_ignored(self) -> None: |
| 154 | - out = split_loss_by_mode( | |
| 155 | - [(1.0, "cpt"), (2.0, "preference"), (3.0, "other")] | |
| 156 | - ) | |
| 152 | + out = split_loss_by_mode([(1.0, "cpt"), (2.0, "preference"), (3.0, "other")]) | |
| 157 | 153 | assert out.cpt == pytest.approx(1.0) |
| 158 | 154 | assert out.sft is None |
| 159 | 155 | |
tests/unit/export/test_vendoring.pymodified@@ -73,9 +73,7 @@ class TestLlamaQuantizeBin: | ||
| 73 | 73 | assert path.is_file() |
| 74 | 74 | assert path.name == "llama-quantize" |
| 75 | 75 | |
| 76 | - def test_missing_binary_raises( | |
| 77 | - self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch | |
| 78 | - ) -> None: | |
| 76 | + def test_missing_binary_raises(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: | |
| 79 | 77 | # Clear PATH so the `shutil.which` fallback can't find a |
| 80 | 78 | # brew-installed llama-quantize on the developer's machine. |
| 81 | 79 | monkeypatch.setenv("PATH", str(tmp_path / "empty")) |
tests/unit/export/test_weighted_merge.pymodified@@ -128,9 +128,7 @@ class TestResolveFirstSourcePath: | ||
| 128 | 128 | with pytest.raises(InvalidMixSpecError, match="empty mix"): |
| 129 | 129 | resolve_first_source_path(store, []) |
| 130 | 130 | |
| 131 | - def test_missing_adapter_raises_export_error( | |
| 132 | - self, tmp_path | |
| 133 | - ) -> None: # type: ignore[no-untyped-def] | |
| 131 | + def test_missing_adapter_raises_export_error(self, tmp_path) -> None: # type: ignore[no-untyped-def] | |
| 134 | 132 | from dlm.export.errors import ExportError |
| 135 | 133 | from dlm.export.weighted_merge import resolve_first_source_path |
| 136 | 134 | from dlm.store.paths import StorePath |
tests/unit/hardware/test_capabilities.pymodified@@ -84,9 +84,7 @@ class TestMlxAvailability: | ||
| 84 | 84 | caps = probe() |
| 85 | 85 | assert caps.has_mlx is True |
| 86 | 86 | |
| 87 | - def test_mps_reports_no_mlx_when_mlx_lm_missing( | |
| 88 | - self, monkeypatch: pytest.MonkeyPatch | |
| 89 | - ) -> None: | |
| 87 | + def test_mps_reports_no_mlx_when_mlx_lm_missing(self, monkeypatch: pytest.MonkeyPatch) -> None: | |
| 90 | 88 | from dlm.hardware import capabilities as caps_mod |
| 91 | 89 | |
| 92 | 90 | real_avail = caps_mod._module_available |
tests/unit/hardware/test_f28_multi_adapter_qlora.pymodified@@ -14,17 +14,13 @@ from tests.fixtures.hardware_mocks import force_cuda | ||
| 14 | 14 | |
| 15 | 15 | def _qlora_multi_doc(num: int) -> TrainingConfig: |
| 16 | 16 | """TrainingConfig with `num` QLoRA adapters declared.""" |
| 17 | - adapters = { | |
| 18 | - f"a{i}": AdapterConfig(adapter="qlora") for i in range(num) | |
| 19 | - } | |
| 17 | + adapters = {f"a{i}": AdapterConfig(adapter="qlora") for i in range(num)} | |
| 20 | 18 | return TrainingConfig.model_validate({"adapters": adapters}) |
| 21 | 19 | |
| 22 | 20 | |
| 23 | 21 | def _qlora_multi_doc_with_rank(num: int, lora_r: int) -> TrainingConfig: |
| 24 | 22 | """Multi-adapter doc with `num` QLoRA adapters at the given lora_r.""" |
| 25 | - adapters = { | |
| 26 | - f"a{i}": AdapterConfig(adapter="qlora", lora_r=lora_r) for i in range(num) | |
| 27 | - } | |
| 23 | + adapters = {f"a{i}": AdapterConfig(adapter="qlora", lora_r=lora_r) for i in range(num)} | |
| 28 | 24 | return TrainingConfig.model_validate({"adapters": adapters}) |
| 29 | 25 | |
| 30 | 26 | |
@@ -85,9 +81,7 @@ class TestF28MultiAdapterQLoraRefusal: | ||
| 85 | 81 | adapters = {"a0": AdapterConfig(), "a1": AdapterConfig()} |
| 86 | 82 | lora_multi = TrainingConfig.model_validate({"adapters": adapters}) |
| 87 | 83 | # LoRA bypasses QLoRA refusals entirely. |
| 88 | - check_refusals( | |
| 89 | - lora_multi, caps, base_params=1_500_000_000, num_adapters=2 | |
| 90 | - ) | |
| 84 | + check_refusals(lora_multi, caps, base_params=1_500_000_000, num_adapters=2) | |
| 91 | 85 | |
| 92 | 86 | def test_small_base_low_rank_multi_qlora_passes(self) -> None: |
| 93 | 87 | """The old formula falsely refused small-base multi-QLoRA. |
@@ -136,9 +130,7 @@ class TestEffectiveAdapter: | ||
| 136 | 130 | } |
| 137 | 131 | ) |
| 138 | 132 | with pytest.raises(ResolutionError, match="Multi-adapter QLoRA"): |
| 139 | - check_refusals( | |
| 140 | - mixed, caps, base_params=7_000_000_000, num_adapters=3 | |
| 141 | - ) | |
| 133 | + check_refusals(mixed, caps, base_params=7_000_000_000, num_adapters=3) | |
| 142 | 134 | |
| 143 | 135 | def test_mixed_adapter_error_names_only_qlora_offenders(self) -> None: |
| 144 | 136 | with force_cuda(vram_gb=12.0): |
@@ -153,9 +145,7 @@ class TestEffectiveAdapter: | ||
| 153 | 145 | } |
| 154 | 146 | ) |
| 155 | 147 | with pytest.raises(ResolutionError) as exc_info: |
| 156 | - check_refusals( | |
| 157 | - mixed, caps, base_params=7_000_000_000, num_adapters=3 | |
| 158 | - ) | |
| 148 | + check_refusals(mixed, caps, base_params=7_000_000_000, num_adapters=3) | |
| 159 | 149 | message = str(exc_info.value) |
| 160 | 150 | assert "qlora_one" in message |
| 161 | 151 | assert "lora_a" not in message |
tests/unit/hardware/test_plan.pymodified@@ -45,9 +45,7 @@ class TestPrecisionPicker: | ||
| 45 | 45 | with force_mps(): |
| 46 | 46 | caps = probe() |
| 47 | 47 | with caplog.at_level(logging.WARNING, logger="dlm.hardware.plan"): # type: ignore[attr-defined] |
| 48 | - plan = resolve( | |
| 49 | - _cfg(precision="fp16"), caps, base_params=8_000_000_000, seq_len=2048 | |
| 50 | - ) | |
| 48 | + plan = resolve(_cfg(precision="fp16"), caps, base_params=8_000_000_000, seq_len=2048) | |
| 51 | 49 | assert plan.precision == "fp16" |
| 52 | 50 | # The caller must see the risk explicitly — silent fp16 on MPS |
| 53 | 51 | # is what caused the original bug. |
@@ -63,9 +61,7 @@ class TestPrecisionPicker: | ||
| 63 | 61 | with force_mps(): |
| 64 | 62 | caps = probe() |
| 65 | 63 | with caplog.at_level(logging.WARNING, logger="dlm.hardware.plan"): # type: ignore[attr-defined] |
| 66 | - plan = resolve( | |
| 67 | - _cfg(precision="bf16"), caps, base_params=1_500_000_000, seq_len=2048 | |
| 68 | - ) | |
| 64 | + plan = resolve(_cfg(precision="bf16"), caps, base_params=1_500_000_000, seq_len=2048) | |
| 69 | 65 | assert plan.precision == "bf16" |
| 70 | 66 | assert caplog.records == [] # type: ignore[attr-defined] |
| 71 | 67 | |
@@ -73,9 +69,7 @@ class TestPrecisionPicker: | ||
| 73 | 69 | # CUDA default is bf16 (Ampere+) — override to fp32 honored. |
| 74 | 70 | with force_cuda(sm=(8, 0)): |
| 75 | 71 | caps = probe() |
| 76 | - plan = resolve( | |
| 77 | - _cfg(precision="fp32"), caps, base_params=1_500_000_000, seq_len=2048 | |
| 78 | - ) | |
| 72 | + plan = resolve(_cfg(precision="fp32"), caps, base_params=1_500_000_000, seq_len=2048) | |
| 79 | 73 | assert plan.precision == "fp32" |
| 80 | 74 | |
| 81 | 75 | |
@@ -204,9 +198,7 @@ class TestDpoPhaseAdjustments: | ||
| 204 | 198 | # at 1, not round to 0. |
| 205 | 199 | with force_cuda(sm=(8, 9), vram_gb=4.0): |
| 206 | 200 | caps = probe() |
| 207 | - dpo = resolve( | |
| 208 | - _cfg(), caps, base_params=1_500_000_000, seq_len=2048, phase="dpo" | |
| 209 | - ) | |
| 201 | + dpo = resolve(_cfg(), caps, base_params=1_500_000_000, seq_len=2048, phase="dpo") | |
| 210 | 202 | assert dpo.micro_batch_size >= 1 |
| 211 | 203 | |
| 212 | 204 | def test_dpo_peak_vram_exceeds_sft(self) -> None: |
@@ -228,12 +220,8 @@ class TestDpoPhaseAdjustments: | ||
| 228 | 220 | def test_dpo_reason_mentions_phase(self) -> None: |
| 229 | 221 | with force_cuda(sm=(8, 9), vram_gb=24.0): |
| 230 | 222 | caps = probe() |
| 231 | - dpo = resolve( | |
| 232 | - _cfg(), caps, base_params=1_500_000_000, seq_len=2048, phase="dpo" | |
| 233 | - ) | |
| 234 | - sft = resolve( | |
| 235 | - _cfg(), caps, base_params=1_500_000_000, seq_len=2048, phase="sft" | |
| 236 | - ) | |
| 223 | + dpo = resolve(_cfg(), caps, base_params=1_500_000_000, seq_len=2048, phase="dpo") | |
| 224 | + sft = resolve(_cfg(), caps, base_params=1_500_000_000, seq_len=2048, phase="sft") | |
| 237 | 225 | assert "phase=dpo" in dpo.reason |
| 238 | 226 | assert "phase=dpo" not in sft.reason |
| 239 | 227 | |
tests/unit/inference/test_resolve_adapter_path.pymodified@@ -49,13 +49,9 @@ class TestNamedLayout: | ||
| 49 | 49 | v1 = s.adapter_version_for("knowledge", 1) |
| 50 | 50 | v1.mkdir(parents=True) |
| 51 | 51 | s.set_current_adapter_for("knowledge", v1) |
| 52 | - assert ( | |
| 53 | - resolve_adapter_path(s, adapter_name="knowledge") == v1.resolve() | |
| 54 | - ) | |
| 52 | + assert resolve_adapter_path(s, adapter_name="knowledge") == v1.resolve() | |
| 55 | 53 | |
| 56 | - def test_missing_named_pointer_mentions_adapter_name( | |
| 57 | - self, tmp_path: Path | |
| 58 | - ) -> None: | |
| 54 | + def test_missing_named_pointer_mentions_adapter_name(self, tmp_path: Path) -> None: | |
| 59 | 55 | s = _store(tmp_path) |
| 60 | 56 | s.ensure_adapter_layout("knowledge") |
| 61 | 57 | with pytest.raises(AdapterNotFoundError, match="'knowledge'"): |
tests/unit/lock/test_mismatch_policy.pymodified@@ -49,9 +49,9 @@ class TestAccelerateUninstall: | ||
| 49 | 49 | current = _lock(pinned_versions={"torch": "2.5.1"}) |
| 50 | 50 | msgs = [m for _s, m in classify_mismatches(prior, current)] |
| 51 | 51 | # `_rule_minor_peers` fires on peer disappear with "no longer pinned". |
| 52 | - assert any( | |
| 53 | - "accelerate" in m and "no longer pinned" in m for m in msgs | |
| 54 | - ), f"expected accelerate-removal warning, got {msgs!r}" | |
| 52 | + assert any("accelerate" in m and "no longer pinned" in m for m in msgs), ( | |
| 53 | + f"expected accelerate-removal warning, got {msgs!r}" | |
| 54 | + ) | |
| 55 | 55 | |
| 56 | 56 | |
| 57 | 57 | class TestWorldSize: |
tests/unit/metrics/test_db_schema.pymodified@@ -20,10 +20,7 @@ class TestConnect: | ||
| 20 | 20 | def test_creates_schema(self, tmp_path: Path) -> None: |
| 21 | 21 | with connect(tmp_path) as conn: |
| 22 | 22 | tables = { |
| 23 | - row[0] | |
| 24 | - for row in conn.execute( | |
| 25 | - "SELECT name FROM sqlite_master WHERE type='table'" | |
| 26 | - ) | |
| 23 | + row[0] for row in conn.execute("SELECT name FROM sqlite_master WHERE type='table'") | |
| 27 | 24 | } |
| 28 | 25 | assert tables == {"runs", "steps", "evals", "exports", "tokenization"} |
| 29 | 26 | |
@@ -65,14 +62,10 @@ class TestEnsureSchema: | ||
| 65 | 62 | conn.execute( |
| 66 | 63 | "INSERT INTO runs (run_id, started_at, status) VALUES (1, 'now', 'running')" |
| 67 | 64 | ) |
| 68 | - conn.execute( | |
| 69 | - "INSERT INTO steps (run_id, step, loss, at) VALUES (1, 1, 0.5, 'now')" | |
| 70 | - ) | |
| 65 | + conn.execute("INSERT INTO steps (run_id, step, loss, at) VALUES (1, 1, 0.5, 'now')") | |
| 71 | 66 | # Duplicate (1, 1) should violate PK unless we upsert. |
| 72 | 67 | try: |
| 73 | - conn.execute( | |
| 74 | - "INSERT INTO steps (run_id, step, loss, at) VALUES (1, 1, 0.4, 'now')" | |
| 75 | - ) | |
| 68 | + conn.execute("INSERT INTO steps (run_id, step, loss, at) VALUES (1, 1, 0.4, 'now')") | |
| 76 | 69 | raise AssertionError("duplicate PK accepted") |
| 77 | 70 | except sqlite3.IntegrityError: |
| 78 | 71 | pass |
tests/unit/metrics/test_queries.pymodified@@ -22,9 +22,7 @@ def _seed(store_root: Path) -> None: | ||
| 22 | 22 | """Populate a DB with three runs and a handful of steps/evals.""" |
| 23 | 23 | rec = MetricsRecorder(store_root) |
| 24 | 24 | for run_id in (1, 2, 3): |
| 25 | - rec.record_run_start( | |
| 26 | - RunStart(run_id=run_id, adapter_version=run_id, phase="sft", seed=42) | |
| 27 | - ) | |
| 25 | + rec.record_run_start(RunStart(run_id=run_id, adapter_version=run_id, phase="sft", seed=42)) | |
| 28 | 26 | for step in (10, 20, 30): |
| 29 | 27 | rec.record_step(StepEvent(run_id=run_id, step=step, loss=2.0 - 0.1 * step)) |
| 30 | 28 | rec.record_eval(EvalEvent(run_id=run_id, step=30, val_loss=1.5)) |
tests/unit/replay/test_store.pymodified@@ -90,29 +90,21 @@ class TestSampleRows: | ||
| 90 | 90 | assert raw_sid in rows[0]["_dlm_section_id"] |
| 91 | 91 | |
| 92 | 92 | |
| 93 | -_PREF_BODY_A = ( | |
| 94 | - "### Prompt\nqA\n### Chosen\ncA\n### Rejected\nrA" | |
| 95 | -) | |
| 96 | -_PREF_BODY_B = ( | |
| 97 | - "### Prompt\nqB\n### Chosen\ncB\n### Rejected\nrB" | |
| 98 | -) | |
| 93 | +_PREF_BODY_A = "### Prompt\nqA\n### Chosen\ncA\n### Rejected\nrA" | |
| 94 | +_PREF_BODY_B = "### Prompt\nqB\n### Chosen\ncB\n### Rejected\nrB" | |
| 99 | 95 | |
| 100 | 96 | |
| 101 | 97 | class TestSamplePreferenceRows: |
| 102 | 98 | def test_empty_corpus_returns_empty(self, tmp_path: Path) -> None: |
| 103 | 99 | s = _store(tmp_path) |
| 104 | - rows = s.sample_preference_rows( | |
| 105 | - k=5, now=datetime(2026, 4, 1), rng=random.Random(0) | |
| 106 | - ) | |
| 100 | + rows = s.sample_preference_rows(k=5, now=datetime(2026, 4, 1), rng=random.Random(0)) | |
| 107 | 101 | assert rows == [] |
| 108 | 102 | |
| 109 | 103 | def test_corpus_with_no_preferences_returns_empty(self, tmp_path: Path) -> None: |
| 110 | 104 | s = _store(tmp_path) |
| 111 | 105 | s.append(_snap("a" * 16, "prose", "hello", added=datetime(2026, 1, 1))) |
| 112 | 106 | s.append(_snap("b" * 16, "instruction", "### Q\nq\n### A\na", added=datetime(2026, 1, 1))) |
| 113 | - rows = s.sample_preference_rows( | |
| 114 | - k=2, now=datetime(2026, 4, 1), rng=random.Random(0) | |
| 115 | - ) | |
| 107 | + rows = s.sample_preference_rows(k=2, now=datetime(2026, 4, 1), rng=random.Random(0)) | |
| 116 | 108 | assert rows == [] |
| 117 | 109 | |
| 118 | 110 | def test_filters_to_preferences_only(self, tmp_path: Path) -> None: |
@@ -120,9 +112,7 @@ class TestSamplePreferenceRows: | ||
| 120 | 112 | s.append(_snap("a" * 16, "prose", "prose body", added=datetime(2026, 1, 1))) |
| 121 | 113 | s.append(_snap("b" * 16, "preference", _PREF_BODY_A, added=datetime(2026, 1, 1))) |
| 122 | 114 | s.append(_snap("c" * 16, "preference", _PREF_BODY_B, added=datetime(2026, 1, 2))) |
| 123 | - rows = s.sample_preference_rows( | |
| 124 | - k=10, now=datetime(2026, 4, 1), rng=random.Random(0) | |
| 125 | - ) | |
| 115 | + rows = s.sample_preference_rows(k=10, now=datetime(2026, 4, 1), rng=random.Random(0)) | |
| 126 | 116 | assert len(rows) == 2 |
| 127 | 117 | assert {r["prompt"] for r in rows} == {"qA", "qB"} |
| 128 | 118 | assert all("chosen" in r and "rejected" in r for r in rows) |
@@ -133,15 +123,11 @@ class TestSamplePreferenceRows: | ||
| 133 | 123 | sid = f"{i:016x}" |
| 134 | 124 | body = f"### Prompt\nq{i}\n### Chosen\nc{i}\n### Rejected\nr{i}" |
| 135 | 125 | s.append(_snap(sid, "preference", body, added=datetime(2026, 1, 1))) |
| 136 | - rows = s.sample_preference_rows( | |
| 137 | - k=2, now=datetime(2026, 4, 1), rng=random.Random(0) | |
| 138 | - ) | |
| 126 | + rows = s.sample_preference_rows(k=2, now=datetime(2026, 4, 1), rng=random.Random(0)) | |
| 139 | 127 | assert len(rows) == 2 |
| 140 | 128 | |
| 141 | 129 | def test_replay_sid_prefix_applied(self, tmp_path: Path) -> None: |
| 142 | 130 | s = _store(tmp_path) |
| 143 | 131 | s.append(_snap("a" * 16, "preference", _PREF_BODY_A, added=datetime(2026, 1, 1))) |
| 144 | - rows = s.sample_preference_rows( | |
| 145 | - k=1, now=datetime(2026, 4, 1), rng=random.Random(0) | |
| 146 | - ) | |
| 132 | + rows = s.sample_preference_rows(k=1, now=datetime(2026, 4, 1), rng=random.Random(0)) | |
| 147 | 133 | assert rows[0]["_dlm_section_id"].startswith("replay:") |
tests/unit/store/test_inspect_named_adapters.pymodified@@ -72,9 +72,7 @@ class TestMultiAdapterDiscovery: | ||
| 72 | 72 | |
| 73 | 73 | inspection = inspect_store(store) |
| 74 | 74 | assert inspection.named_adapters == [ |
| 75 | - NamedAdapterState( | |
| 76 | - name="knowledge", has_current=False, latest_version=1 | |
| 77 | - ) | |
| 75 | + NamedAdapterState(name="knowledge", has_current=False, latest_version=1) | |
| 78 | 76 | ] |
| 79 | 77 | |
| 80 | 78 | def test_empty_adapter_dir_without_versions_skipped(self, tmp_path: Path) -> None: |
tests/unit/store/test_paths_named_adapters.pymodified@@ -26,17 +26,11 @@ class TestNamedAdapterPaths: | ||
| 26 | 26 | |
| 27 | 27 | def test_adapter_version_for_name_pads_four_digits(self, tmp_path: Path) -> None: |
| 28 | 28 | s = _store(tmp_path) |
| 29 | - assert ( | |
| 30 | - s.adapter_version_for("tone", 7) | |
| 31 | - == s.adapter / "tone" / "versions" / "v0007" | |
| 32 | - ) | |
| 29 | + assert s.adapter_version_for("tone", 7) == s.adapter / "tone" / "versions" / "v0007" | |
| 33 | 30 | |
| 34 | 31 | def test_pointer_path_for_name(self, tmp_path: Path) -> None: |
| 35 | 32 | s = _store(tmp_path) |
| 36 | - assert ( | |
| 37 | - s.adapter_current_pointer_for("knowledge") | |
| 38 | - == s.adapter / "knowledge" / "current.txt" | |
| 39 | - ) | |
| 33 | + assert s.adapter_current_pointer_for("knowledge") == s.adapter / "knowledge" / "current.txt" | |
| 40 | 34 | |
| 41 | 35 | |
| 42 | 36 | class TestNamedAdapterValidation: |
tests/unit/templates/test_registry.pymodified@@ -25,8 +25,7 @@ def test_bundled_templates_dir_exists() -> None: | ||
| 25 | 25 | def test_list_bundled_returns_eight_templates() -> None: |
| 26 | 26 | templates = list_bundled() |
| 27 | 27 | assert len(templates) >= 8, ( |
| 28 | - f"expected at least 8 templates, got {len(templates)}: " | |
| 29 | - f"{[t.name for t in templates]}" | |
| 28 | + f"expected at least 8 templates, got {len(templates)}: {[t.name for t in templates]}" | |
| 30 | 29 | ) |
| 31 | 30 | names = {t.name for t in templates} |
| 32 | 31 | required = { |
@@ -69,17 +68,13 @@ def test_registry_drops_template_missing_sidecar(tmp_path: Path) -> None: | ||
| 69 | 68 | |
| 70 | 69 | |
| 71 | 70 | def test_registry_drops_template_with_malformed_meta(tmp_path: Path) -> None: |
| 72 | - (tmp_path / "broken.dlm").write_text( | |
| 73 | - "---\ndlm_id: 01AAAA\nbase_model: foo\n---\n# body\n" | |
| 74 | - ) | |
| 71 | + (tmp_path / "broken.dlm").write_text("---\ndlm_id: 01AAAA\nbase_model: foo\n---\n# body\n") | |
| 75 | 72 | (tmp_path / "broken.meta.yaml").write_text("not: a: valid: yaml: mapping\n") |
| 76 | 73 | assert list_bundled(gallery_dir=tmp_path) == [] |
| 77 | 74 | |
| 78 | 75 | |
| 79 | 76 | def test_load_template_with_mismatched_name_raises(tmp_path: Path) -> None: |
| 80 | - (tmp_path / "fine.dlm").write_text( | |
| 81 | - "---\ndlm_id: 01AAAA\nbase_model: foo\n---\n# body\n" | |
| 82 | - ) | |
| 77 | + (tmp_path / "fine.dlm").write_text("---\ndlm_id: 01AAAA\nbase_model: foo\n---\n# body\n") | |
| 83 | 78 | # meta.name doesn't match the filename stem. |
| 84 | 79 | (tmp_path / "fine.meta.yaml").write_text( |
| 85 | 80 | "name: different\ntitle: X\nrecommended_base: qwen2.5-1.5b\nsummary: hi\n" |
tests/unit/test_templates_parse.pymodified@@ -23,8 +23,7 @@ def test_templates_dir_is_populated() -> None: | ||
| 23 | 23 | # Guard against a silent deletion of the bundled gallery. |
| 24 | 24 | paths = _template_paths() |
| 25 | 25 | assert len(paths) >= 8, ( |
| 26 | - f"expected at least 8 gallery templates under {bundled_templates_dir()}, " | |
| 27 | - f"got {len(paths)}" | |
| 26 | + f"expected at least 8 gallery templates under {bundled_templates_dir()}, got {len(paths)}" | |
| 28 | 27 | ) |
| 29 | 28 | |
| 30 | 29 | |
tests/unit/train/cpt/test_embed_warmup.pymodified@@ -23,9 +23,7 @@ class _FakeParam: | ||
| 23 | 23 | self.requires_grad = requires_grad |
| 24 | 24 | |
| 25 | 25 | |
| 26 | -def _model( | |
| 27 | - *, embed_frozen: bool = True, head_frozen: bool = True, tied: bool = False | |
| 28 | -) -> Any: | |
| 26 | +def _model(*, embed_frozen: bool = True, head_frozen: bool = True, tied: bool = False) -> Any: | |
| 29 | 27 | embed_param = _FakeParam(requires_grad=not embed_frozen) |
| 30 | 28 | head_param = embed_param if tied else _FakeParam(requires_grad=not head_frozen) |
| 31 | 29 | embed_module = SimpleNamespace(weight=embed_param) |
@@ -77,9 +75,9 @@ class TestExtendModulesToSave: | ||
| 77 | 75 | assert extend_modules_to_save_for_embed_warmup(None, embed_warmup_steps=0) is None |
| 78 | 76 | |
| 79 | 77 | def test_zero_warmup_passes_through_list(self) -> None: |
| 80 | - assert extend_modules_to_save_for_embed_warmup( | |
| 81 | - ["embed_tokens"], embed_warmup_steps=0 | |
| 82 | - ) == ["embed_tokens"] | |
| 78 | + assert extend_modules_to_save_for_embed_warmup(["embed_tokens"], embed_warmup_steps=0) == [ | |
| 79 | + "embed_tokens" | |
| 80 | + ] | |
| 83 | 81 | |
| 84 | 82 | def test_warmup_on_with_no_existing(self) -> None: |
| 85 | 83 | out = extend_modules_to_save_for_embed_warmup(None, embed_warmup_steps=50) |
tests/unit/train/cpt/test_schedule.pymodified@@ -19,29 +19,20 @@ class TestWarmupRamp: | ||
| 19 | 19 | assert cosine_with_floor_lr(0, total_steps=100, warmup_steps=20) == 0.0 |
| 20 | 20 | |
| 21 | 21 | def test_half_warmup_is_half(self) -> None: |
| 22 | - assert cosine_with_floor_lr( | |
| 23 | - 10, total_steps=100, warmup_steps=20 | |
| 24 | - ) == pytest.approx(0.5) | |
| 22 | + assert cosine_with_floor_lr(10, total_steps=100, warmup_steps=20) == pytest.approx(0.5) | |
| 25 | 23 | |
| 26 | 24 | def test_warmup_end_is_peak(self) -> None: |
| 27 | 25 | # step == warmup_steps: first step of decay phase; cosine is 1.0 |
| 28 | 26 | # at decay_progress=0, so we're at peak. |
| 29 | - assert cosine_with_floor_lr( | |
| 30 | - 20, total_steps=100, warmup_steps=20 | |
| 31 | - ) == pytest.approx(1.0) | |
| 27 | + assert cosine_with_floor_lr(20, total_steps=100, warmup_steps=20) == pytest.approx(1.0) | |
| 32 | 28 | |
| 33 | 29 | def test_zero_warmup_jumps_to_peak(self) -> None: |
| 34 | - assert cosine_with_floor_lr( | |
| 35 | - 0, total_steps=100, warmup_steps=0 | |
| 36 | - ) == pytest.approx(1.0) | |
| 30 | + assert cosine_with_floor_lr(0, total_steps=100, warmup_steps=0) == pytest.approx(1.0) | |
| 37 | 31 | |
| 38 | 32 | |
| 39 | 33 | class TestCosineDecay: |
| 40 | 34 | def test_monotone_decrease_through_decay(self) -> None: |
| 41 | - lrs = [ | |
| 42 | - cosine_with_floor_lr(s, total_steps=100, warmup_steps=20) | |
| 43 | - for s in range(20, 100, 5) | |
| 44 | - ] | |
| 35 | + lrs = [cosine_with_floor_lr(s, total_steps=100, warmup_steps=20) for s in range(20, 100, 5)] | |
| 45 | 36 | for a, b in zip(lrs, lrs[1:], strict=False): |
| 46 | 37 | assert a > b |
| 47 | 38 | |
@@ -66,9 +57,7 @@ class TestCosineDecay: | ||
| 66 | 57 | # Midpoint of cosine decay (decay_progress=0.5) gives cos(pi/2)=0, |
| 67 | 58 | # so cosine multiplier = 0.5 → LR = floor + (1-floor)*0.5 |
| 68 | 59 | floor = 0.1 |
| 69 | - mid = cosine_with_floor_lr( | |
| 70 | - 60, total_steps=100, warmup_steps=20, floor_ratio=floor | |
| 71 | - ) | |
| 60 | + mid = cosine_with_floor_lr(60, total_steps=100, warmup_steps=20, floor_ratio=floor) | |
| 72 | 61 | expected = floor + (1.0 - floor) * 0.5 |
| 73 | 62 | assert mid == pytest.approx(expected) |
| 74 | 63 | |
@@ -97,9 +86,7 @@ class TestInputValidation: | ||
| 97 | 86 | @pytest.mark.parametrize("bad", [-0.01, 1.01, 2.0]) |
| 98 | 87 | def test_floor_ratio_out_of_range(self, bad: float) -> None: |
| 99 | 88 | with pytest.raises(ValueError, match="floor_ratio must be in"): |
| 100 | - cosine_with_floor_lr( | |
| 101 | - 0, total_steps=100, warmup_steps=10, floor_ratio=bad | |
| 102 | - ) | |
| 89 | + cosine_with_floor_lr(0, total_steps=100, warmup_steps=10, floor_ratio=bad) | |
| 103 | 90 | |
| 104 | 91 | |
| 105 | 92 | class TestDefaultConstants: |
@@ -130,11 +117,7 @@ class TestContinuityAcrossWarmup: | ||
| 130 | 117 | # At warmup_steps, cosine is exactly 1. They should differ by |
| 131 | 118 | # ~1/warmup_steps (the ramp's last sub-peak increment). |
| 132 | 119 | warmup = 50 |
| 133 | - last_ramp = cosine_with_floor_lr( | |
| 134 | - warmup - 1, total_steps=200, warmup_steps=warmup | |
| 135 | - ) | |
| 136 | - first_decay = cosine_with_floor_lr( | |
| 137 | - warmup, total_steps=200, warmup_steps=warmup | |
| 138 | - ) | |
| 120 | + last_ramp = cosine_with_floor_lr(warmup - 1, total_steps=200, warmup_steps=warmup) | |
| 121 | + first_decay = cosine_with_floor_lr(warmup, total_steps=200, warmup_steps=warmup) | |
| 139 | 122 | assert first_decay == pytest.approx(1.0) |
| 140 | 123 | assert math.isclose(first_decay - last_ramp, 1 / warmup, abs_tol=1e-9) |
tests/unit/train/cpt/test_vocab_gap.pymodified@@ -17,9 +17,7 @@ from dlm.train.cpt.vocab_gap import ( | ||
| 17 | 17 | |
| 18 | 18 | class TestComputeVocabGap: |
| 19 | 19 | def test_empty_inputs(self) -> None: |
| 20 | - r = compute_vocab_gap( | |
| 21 | - [], text="", unk_token_id=None, decoded_tokens=[] | |
| 22 | - ) | |
| 20 | + r = compute_vocab_gap([], text="", unk_token_id=None, decoded_tokens=[]) | |
| 23 | 21 | assert r.total_tokens == 0 |
| 24 | 22 | assert r.total_words == 0 |
| 25 | 23 | assert r.tokens_per_word == 0.0 |
@@ -113,9 +111,7 @@ class TestComputeVocabGapValidation: | ||
| 113 | 111 | |
| 114 | 112 | def test_negative_top_n_rejected(self) -> None: |
| 115 | 113 | with pytest.raises(ValueError, match="top_n must be non-negative"): |
| 116 | - compute_vocab_gap( | |
| 117 | - [], text="", unk_token_id=None, decoded_tokens=[], top_n=-1 | |
| 118 | - ) | |
| 114 | + compute_vocab_gap([], text="", unk_token_id=None, decoded_tokens=[], top_n=-1) | |
| 119 | 115 | |
| 120 | 116 | |
| 121 | 117 | class TestRenderReport: |
tests/unit/train/distributed/test_rank_env.pymodified@@ -40,9 +40,7 @@ class TestDetectRank: | ||
| 40 | 40 | monkeypatch.delenv("LOCAL_RANK", raising=False) |
| 41 | 41 | assert detect_rank() == 0 |
| 42 | 42 | |
| 43 | - def test_rank_takes_precedence_over_local_rank( | |
| 44 | - self, monkeypatch: pytest.MonkeyPatch | |
| 45 | - ) -> None: | |
| 43 | + def test_rank_takes_precedence_over_local_rank(self, monkeypatch: pytest.MonkeyPatch) -> None: | |
| 46 | 44 | monkeypatch.setenv("RANK", "3") |
| 47 | 45 | monkeypatch.setenv("LOCAL_RANK", "1") |
| 48 | 46 | assert detect_rank() == 3 |
tests/unit/train/multi_adapter/test_orchestrator.pymodified@@ -92,18 +92,14 @@ def _single_adapter_parsed(dlm_id: str) -> ParsedDlm: | ||
| 92 | 92 | base_model="smollm2-135m", |
| 93 | 93 | training=TrainingConfig(seed=42), |
| 94 | 94 | ), |
| 95 | - sections=( | |
| 96 | - Section(type=SectionType.PROSE, content="Single-adapter prose."), | |
| 97 | - ), | |
| 95 | + sections=(Section(type=SectionType.PROSE, content="Single-adapter prose."),), | |
| 98 | 96 | ) |
| 99 | 97 | |
| 100 | 98 | |
| 101 | 99 | def _seed_store(tmp_path: Path, dlm_id: str) -> Any: |
| 102 | 100 | store = for_dlm(dlm_id, home=tmp_path) |
| 103 | 101 | store.ensure_layout() |
| 104 | - save_manifest( | |
| 105 | - store.manifest, Manifest(dlm_id=dlm_id, base_model="smollm2-135m") | |
| 106 | - ) | |
| 102 | + save_manifest(store.manifest, Manifest(dlm_id=dlm_id, base_model="smollm2-135m")) | |
| 107 | 103 | return store |
| 108 | 104 | |
| 109 | 105 | |
@@ -220,9 +216,7 @@ class TestMultiAdapterOrchestration: | ||
| 220 | 216 | # Flat field stays at 0 (untouched) for multi-adapter stores. |
| 221 | 217 | assert manifest.adapter_version == 0 |
| 222 | 218 | |
| 223 | - def test_training_run_summaries_carry_adapter_name( | |
| 224 | - self, tmp_path: Path | |
| 225 | - ) -> None: | |
| 219 | + def test_training_run_summaries_carry_adapter_name(self, tmp_path: Path) -> None: | |
| 226 | 220 | """Audit-07 M1: each TrainingRunSummary is tagged with the name.""" |
| 227 | 221 | dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FB" |
| 228 | 222 | store = _seed_store(tmp_path, dlm_id) |
tests/unit/train/multi_adapter/test_router.pymodified@@ -60,9 +60,7 @@ class TestProseFansOut: | ||
| 60 | 60 | prose_in_knowledge = [ |
| 61 | 61 | s for s in plan.by_adapter["knowledge"] if s.type is SectionType.PROSE |
| 62 | 62 | ] |
| 63 | - prose_in_tone = [ | |
| 64 | - s for s in plan.by_adapter["tone"] if s.type is SectionType.PROSE | |
| 65 | - ] | |
| 63 | + prose_in_tone = [s for s in plan.by_adapter["tone"] if s.type is SectionType.PROSE] | |
| 66 | 64 | assert len(prose_in_knowledge) == 1 |
| 67 | 65 | assert len(prose_in_tone) == 1 |
| 68 | 66 | |
@@ -80,38 +78,23 @@ class TestInstructionRouting: | ||
| 80 | 78 | parsed = parse_text(_doc(body, multi_adapter=True)) |
| 81 | 79 | plan = build_plan(parsed) |
| 82 | 80 | # First-declared is "knowledge". |
| 83 | - assert any( | |
| 84 | - s.type is SectionType.INSTRUCTION | |
| 85 | - for s in plan.by_adapter["knowledge"] | |
| 86 | - ) | |
| 87 | - assert not any( | |
| 88 | - s.type is SectionType.INSTRUCTION for s in plan.by_adapter["tone"] | |
| 89 | - ) | |
| 81 | + assert any(s.type is SectionType.INSTRUCTION for s in plan.by_adapter["knowledge"]) | |
| 82 | + assert not any(s.type is SectionType.INSTRUCTION for s in plan.by_adapter["tone"]) | |
| 90 | 83 | |
| 91 | 84 | def test_tagged_instruction_goes_to_named_adapter(self) -> None: |
| 92 | 85 | body = "::instruction#tone::\n### Q\nhi\n### A\nbye\n" |
| 93 | 86 | parsed = parse_text(_doc(body, multi_adapter=True)) |
| 94 | 87 | plan = build_plan(parsed) |
| 95 | - assert not any( | |
| 96 | - s.type is SectionType.INSTRUCTION | |
| 97 | - for s in plan.by_adapter["knowledge"] | |
| 98 | - ) | |
| 99 | - assert any( | |
| 100 | - s.type is SectionType.INSTRUCTION for s in plan.by_adapter["tone"] | |
| 101 | - ) | |
| 88 | + assert not any(s.type is SectionType.INSTRUCTION for s in plan.by_adapter["knowledge"]) | |
| 89 | + assert any(s.type is SectionType.INSTRUCTION for s in plan.by_adapter["tone"]) | |
| 102 | 90 | |
| 103 | 91 | |
| 104 | 92 | class TestPreferenceRouting: |
| 105 | 93 | def test_tagged_preference_goes_to_named_adapter(self) -> None: |
| 106 | - body = ( | |
| 107 | - "::preference#tone::\n" | |
| 108 | - "### Prompt\nq\n### Chosen\nc\n### Rejected\nr\n" | |
| 109 | - ) | |
| 94 | + body = "::preference#tone::\n### Prompt\nq\n### Chosen\nc\n### Rejected\nr\n" | |
| 110 | 95 | parsed = parse_text(_doc(body, multi_adapter=True)) |
| 111 | 96 | plan = build_plan(parsed) |
| 112 | - assert any( | |
| 113 | - s.type is SectionType.PREFERENCE for s in plan.by_adapter["tone"] | |
| 114 | - ) | |
| 97 | + assert any(s.type is SectionType.PREFERENCE for s in plan.by_adapter["tone"]) | |
| 115 | 98 | |
| 116 | 99 | |
| 117 | 100 | class TestUnknownAdapter: |
@@ -129,10 +112,7 @@ class TestUnknownAdapter: | ||
| 129 | 112 | |
| 130 | 113 | class TestSingleAdapterDoc: |
| 131 | 114 | def test_single_adapter_doc_routes_all_to_default(self) -> None: |
| 132 | - body = ( | |
| 133 | - "# Prose\n\nShared.\n\n" | |
| 134 | - "::instruction::\n### Q\nh\n### A\nb\n" | |
| 135 | - ) | |
| 115 | + body = "# Prose\n\nShared.\n\n::instruction::\n### Q\nh\n### A\nb\n" | |
| 136 | 116 | parsed = parse_text(_doc(body, multi_adapter=False)) |
| 137 | 117 | plan = build_plan(parsed) |
| 138 | 118 | assert set(plan.by_adapter) == {"default"} |
@@ -149,10 +129,7 @@ class TestSingleAdapterDoc: | ||
| 149 | 129 | |
| 150 | 130 | class TestSectionsForShortcut: |
| 151 | 131 | def test_returns_same_as_plan_entry(self) -> None: |
| 152 | - body = ( | |
| 153 | - "shared prose\n\n" | |
| 154 | - "::instruction#tone::\n### Q\nh\n### A\nb\n" | |
| 155 | - ) | |
| 132 | + body = "shared prose\n\n::instruction#tone::\n### Q\nh\n### A\nb\n" | |
| 156 | 133 | parsed = parse_text(_doc(body, multi_adapter=True)) |
| 157 | 134 | plan = build_plan(parsed) |
| 158 | 135 | assert sections_for(parsed, "tone") == plan.by_adapter["tone"] |
tests/unit/train/preference/test_determinism_plumbing.pymodified@@ -90,9 +90,7 @@ def _seed_store(tmp_path: Path, dlm_id: str) -> Any: | ||
| 90 | 90 | |
| 91 | 91 | |
| 92 | 92 | class TestDpoSeedsRngBeforeTraining: |
| 93 | - def test_explicit_seed_flows_through_to_seed_everything( | |
| 94 | - self, tmp_path: Path | |
| 95 | - ) -> None: | |
| 93 | + def test_explicit_seed_flows_through_to_seed_everything(self, tmp_path: Path) -> None: | |
| 96 | 94 | from dlm.train.preference.dpo_phase import run |
| 97 | 95 | |
| 98 | 96 | store = _seed_store(tmp_path, "01KDPOSEED" + "0" * 16) |
@@ -100,9 +98,7 @@ class TestDpoSeedsRngBeforeTraining: | ||
| 100 | 98 | |
| 101 | 99 | with patch( |
| 102 | 100 | "dlm.train.preference.dpo_phase.seed_everything", |
| 103 | - wraps=__import__( | |
| 104 | - "dlm.train.determinism", fromlist=["seed_everything"] | |
| 105 | - ).seed_everything, | |
| 101 | + wraps=__import__("dlm.train.determinism", fromlist=["seed_everything"]).seed_everything, | |
| 106 | 102 | ) as spy: |
| 107 | 103 | run( |
| 108 | 104 | store, |
@@ -123,9 +119,7 @@ class TestDpoSeedsRngBeforeTraining: | ||
| 123 | 119 | |
| 124 | 120 | with patch( |
| 125 | 121 | "dlm.train.preference.dpo_phase.seed_everything", |
| 126 | - wraps=__import__( | |
| 127 | - "dlm.train.determinism", fromlist=["seed_everything"] | |
| 128 | - ).seed_everything, | |
| 122 | + wraps=__import__("dlm.train.determinism", fromlist=["seed_everything"]).seed_everything, | |
| 129 | 123 | ) as spy: |
| 130 | 124 | run( |
| 131 | 125 | store, |
@@ -148,9 +142,7 @@ class TestOrpoSeedsRngBeforeTraining: | ||
| 148 | 142 | |
| 149 | 143 | with patch( |
| 150 | 144 | "dlm.train.preference.orpo_phase.seed_everything", |
| 151 | - wraps=__import__( | |
| 152 | - "dlm.train.determinism", fromlist=["seed_everything"] | |
| 153 | - ).seed_everything, | |
| 145 | + wraps=__import__("dlm.train.determinism", fromlist=["seed_everything"]).seed_everything, | |
| 154 | 146 | ) as spy: |
| 155 | 147 | run( |
| 156 | 148 | store, |
tests/unit/train/preference/test_dpo_dataset.pymodified@@ -12,9 +12,7 @@ from dlm.train.preference.dpo_dataset import ( | ||
| 12 | 12 | ) |
| 13 | 13 | |
| 14 | 14 | _PREF_BODY_ONE = ( |
| 15 | - "### Prompt\nWhat time is it?\n" | |
| 16 | - "### Chosen\nIt is 3 PM.\n" | |
| 17 | - "### Rejected\nTime is an illusion.\n" | |
| 15 | + "### Prompt\nWhat time is it?\n### Chosen\nIt is 3 PM.\n### Rejected\nTime is an illusion.\n" | |
| 18 | 16 | ) |
| 19 | 17 | |
| 20 | 18 | _PREF_BODY_TWO = ( |
tests/unit/train/preference/test_dpo_phase.pymodified@@ -23,20 +23,14 @@ from dlm.train.state_sidecar import STATE_FILENAME, STATE_SHA_FILENAME | ||
| 23 | 23 | |
| 24 | 24 | |
| 25 | 25 | def _parsed_with_preferences() -> ParsedDlm: |
| 26 | - pref_body = ( | |
| 27 | - "### Prompt\nq?\n### Chosen\nc.\n### Rejected\nr.\n" | |
| 28 | - ) | |
| 26 | + pref_body = "### Prompt\nq?\n### Chosen\nc.\n### Rejected\nr.\n" | |
| 29 | 27 | return ParsedDlm( |
| 30 | 28 | frontmatter=DlmFrontmatter( |
| 31 | 29 | dlm_id="01KABCD" + "0" * 19, |
| 32 | 30 | base_model="smollm2-135m", |
| 33 | - training=TrainingConfig( | |
| 34 | - seed=42, preference=PreferenceConfig(enabled=True) | |
| 35 | - ), | |
| 36 | - ), | |
| 37 | - sections=( | |
| 38 | - Section(type=SectionType.PREFERENCE, content=pref_body), | |
| 31 | + training=TrainingConfig(seed=42, preference=PreferenceConfig(enabled=True)), | |
| 39 | 32 | ), |
| 33 | + sections=(Section(type=SectionType.PREFERENCE, content=pref_body),), | |
| 40 | 34 | ) |
| 41 | 35 | |
| 42 | 36 | |
tests/unit/train/preference/test_dpo_trainer.pymodified@@ -42,13 +42,9 @@ class TestCoreFields: | ||
| 42 | 42 | cfg = PreferenceConfig( |
| 43 | 43 | enabled=True, |
| 44 | 44 | loss_type="ipo", |
| 45 | - hyperparams=PreferenceHyperparams( | |
| 46 | - beta=0.2, learning_rate=3e-6, num_epochs=2 | |
| 47 | - ), | |
| 48 | - ) | |
| 49 | - kwargs = build_dpo_config_kwargs( | |
| 50 | - cfg, _plan(), output_dir=tmp_path, max_length=1024, seed=7 | |
| 45 | + hyperparams=PreferenceHyperparams(beta=0.2, learning_rate=3e-6, num_epochs=2), | |
| 51 | 46 | ) |
| 47 | + kwargs = build_dpo_config_kwargs(cfg, _plan(), output_dir=tmp_path, max_length=1024, seed=7) | |
| 52 | 48 | assert kwargs["output_dir"] == str(tmp_path) |
| 53 | 49 | assert kwargs["learning_rate"] == 3e-6 |
| 54 | 50 | assert kwargs["num_train_epochs"] == 2 |
tests/unit/train/preference/test_method_registry.pymodified@@ -46,6 +46,7 @@ class TestRegisterCanReplace: | ||
| 46 | 46 | def test_register_overrides_existing(self) -> None: |
| 47 | 47 | saved = resolve("dpo") |
| 48 | 48 | try: |
| 49 | + | |
| 49 | 50 | def _stub(*args: object, **kwargs: object) -> str: # type: ignore[return-value] |
| 50 | 51 | return "stub" |
| 51 | 52 | |
tests/unit/train/preference/test_orpo_phase.pymodified@@ -34,9 +34,7 @@ def _parsed_with_preferences() -> ParsedDlm: | ||
| 34 | 34 | preference=PreferenceConfig(enabled=True, method="orpo"), |
| 35 | 35 | ), |
| 36 | 36 | ), |
| 37 | - sections=( | |
| 38 | - Section(type=SectionType.PREFERENCE, content=pref_body), | |
| 39 | - ), | |
| 37 | + sections=(Section(type=SectionType.PREFERENCE, content=pref_body),), | |
| 40 | 38 | ) |
| 41 | 39 | |
| 42 | 40 | |
@@ -85,9 +83,7 @@ def _seed_prior_sft(store, dlm_id: str = "01ORPOTEST") -> None: # type: ignore[ | ||
| 85 | 83 | store.ensure_layout() |
| 86 | 84 | save_manifest( |
| 87 | 85 | store.manifest, |
| 88 | - Manifest( | |
| 89 | - dlm_id=dlm_id, base_model="smollm2-135m", adapter_version=1 | |
| 90 | - ), | |
| 86 | + Manifest(dlm_id=dlm_id, base_model="smollm2-135m", adapter_version=1), | |
| 91 | 87 | ) |
| 92 | 88 | v0001 = store.adapter_version(1) |
| 93 | 89 | v0001.mkdir(parents=True, exist_ok=True) |
@@ -172,9 +168,7 @@ class TestRunHappyPath: | ||
| 172 | 168 | |
| 173 | 169 | |
| 174 | 170 | class TestRunSteps: |
| 175 | - def test_factory_receives_reference_adapter_version( | |
| 176 | - self, tmp_path: Path | |
| 177 | - ) -> None: | |
| 171 | + def test_factory_receives_reference_adapter_version(self, tmp_path: Path) -> None: | |
| 178 | 172 | captured: dict[str, Any] = {} |
| 179 | 173 | |
| 180 | 174 | def _capturing_factory(**kwargs: Any) -> MagicMock: |
tests/unit/train/preference/test_orpo_trainer.pymodified@@ -36,9 +36,7 @@ class TestCoreMapping: | ||
| 36 | 36 | cfg = PreferenceConfig( |
| 37 | 37 | enabled=True, |
| 38 | 38 | method="orpo", |
| 39 | - hyperparams=PreferenceHyperparams( | |
| 40 | - alpha=0.15, learning_rate=3e-6, num_epochs=2 | |
| 41 | - ), | |
| 39 | + hyperparams=PreferenceHyperparams(alpha=0.15, learning_rate=3e-6, num_epochs=2), | |
| 42 | 40 | ) |
| 43 | 41 | kwargs = build_orpo_config_kwargs( |
| 44 | 42 | cfg, _plan(), output_dir=tmp_path, max_length=1024, seed=7 |
tests/unit/train/preference/test_phase_orchestrator.pymodified@@ -45,9 +45,7 @@ def _instruction() -> Section: | ||
| 45 | 45 | def _pref() -> Section: |
| 46 | 46 | return Section( |
| 47 | 47 | type=SectionType.PREFERENCE, |
| 48 | - content=( | |
| 49 | - "### Prompt\nq\n### Chosen\nc\n### Rejected\nr\n" | |
| 50 | - ), | |
| 48 | + content=("### Prompt\nq\n### Chosen\nc\n### Rejected\nr\n"), | |
| 51 | 49 | start_line=1, |
| 52 | 50 | ) |
| 53 | 51 | |
@@ -88,10 +86,7 @@ def _parsed( | ||
| 88 | 86 | who wrote `training.preference.enabled: true/false` in their |
| 89 | 87 | frontmatter. |
| 90 | 88 | """ |
| 91 | - pref = ( | |
| 92 | - PreferenceConfig() if dpo_enabled is None | |
| 93 | - else PreferenceConfig(enabled=dpo_enabled) | |
| 94 | - ) | |
| 89 | + pref = PreferenceConfig() if dpo_enabled is None else PreferenceConfig(enabled=dpo_enabled) | |
| 95 | 90 | return _FakeParsed( |
| 96 | 91 | sections=tuple(sections), |
| 97 | 92 | frontmatter=_FakeFrontmatter(training=_FakeTraining(preference=pref)), |
tests/unit/train/test_resolve_adapter_hparams.pymodified@@ -23,9 +23,7 @@ def _parsed(training: TrainingConfig) -> ParsedDlm: | ||
| 23 | 23 | |
| 24 | 24 | class TestFlatConfig: |
| 25 | 25 | def test_returns_flat_fields_when_adapter_name_is_none(self) -> None: |
| 26 | - training = TrainingConfig( | |
| 27 | - lora_r=16, lora_alpha=32, lora_dropout=0.1, learning_rate=1e-3 | |
| 28 | - ) | |
| 26 | + training = TrainingConfig(lora_r=16, lora_alpha=32, lora_dropout=0.1, learning_rate=1e-3) | |
| 29 | 27 | r, alpha, dropout, lr = _resolve_adapter_hparams(_parsed(training), None) |
| 30 | 28 | assert (r, alpha, dropout) == (16, 32, pytest.approx(0.1)) |
| 31 | 29 | assert lr == pytest.approx(1e-3) |
@@ -52,32 +50,22 @@ class TestMultiAdapterConfig: | ||
| 52 | 50 | } |
| 53 | 51 | } |
| 54 | 52 | ) |
| 55 | - k_r, k_alpha, _k_drop, _k_lr = _resolve_adapter_hparams( | |
| 56 | - _parsed(training), "knowledge" | |
| 57 | - ) | |
| 58 | - t_r, t_alpha, t_drop, t_lr = _resolve_adapter_hparams( | |
| 59 | - _parsed(training), "tone" | |
| 60 | - ) | |
| 53 | + k_r, k_alpha, _k_drop, _k_lr = _resolve_adapter_hparams(_parsed(training), "knowledge") | |
| 54 | + t_r, t_alpha, t_drop, t_lr = _resolve_adapter_hparams(_parsed(training), "tone") | |
| 61 | 55 | assert (k_r, k_alpha) == (8, 16) |
| 62 | 56 | assert (t_r, t_alpha) == (4, 8) |
| 63 | 57 | assert t_drop == pytest.approx(0.2) |
| 64 | 58 | assert t_lr == pytest.approx(1e-4) |
| 65 | 59 | |
| 66 | 60 | def test_unknown_adapter_name_falls_back_to_flat(self) -> None: |
| 67 | - training = TrainingConfig.model_validate( | |
| 68 | - {"adapters": {"knowledge": {"lora_r": 8}}} | |
| 69 | - ) | |
| 61 | + training = TrainingConfig.model_validate({"adapters": {"knowledge": {"lora_r": 8}}}) | |
| 70 | 62 | # ghost isn't declared; we fall back to defaults rather than crash. |
| 71 | 63 | r, _, _, _ = _resolve_adapter_hparams(_parsed(training), "ghost") |
| 72 | 64 | assert r == 8 # flat default |
| 73 | 65 | |
| 74 | 66 | def test_per_adapter_defaults_when_not_overridden(self) -> None: |
| 75 | - training = TrainingConfig.model_validate( | |
| 76 | - {"adapters": {"default_one": {}}} | |
| 77 | - ) | |
| 78 | - r, alpha, dropout, lr = _resolve_adapter_hparams( | |
| 79 | - _parsed(training), "default_one" | |
| 80 | - ) | |
| 67 | + training = TrainingConfig.model_validate({"adapters": {"default_one": {}}}) | |
| 68 | + r, alpha, dropout, lr = _resolve_adapter_hparams(_parsed(training), "default_one") | |
| 81 | 69 | # AdapterConfig() defaults: r=8, alpha=16, dropout=0.05, lr=2e-4 |
| 82 | 70 | assert (r, alpha) == (8, 16) |
| 83 | 71 | assert dropout == pytest.approx(0.05) |
tests/unit/watch/test_status.pymodified@@ -16,23 +16,17 @@ class TestRenderStatus: | ||
| 16 | 16 | |
| 17 | 17 | def test_after_cycle_shows_loss_and_steps(self) -> None: |
| 18 | 18 | status = WatchStatus(doc_path="mydoc.dlm", sections=12) |
| 19 | - status.mark_cycle_done( | |
| 20 | - train_loss=1.2, val_loss=1.35, steps=50, coalesced=1 | |
| 21 | - ) | |
| 19 | + status.mark_cycle_done(train_loss=1.2, val_loss=1.35, steps=50, coalesced=1) | |
| 22 | 20 | line = render_status(status) |
| 23 | 21 | assert "val loss: 1.35" in line |
| 24 | 22 | assert "steps: 50" in line |
| 25 | 23 | |
| 26 | 24 | def test_coalesced_only_shown_when_plural(self) -> None: |
| 27 | 25 | status = WatchStatus(doc_path="d") |
| 28 | - status.mark_cycle_done( | |
| 29 | - train_loss=None, val_loss=None, steps=10, coalesced=1 | |
| 30 | - ) | |
| 26 | + status.mark_cycle_done(train_loss=None, val_loss=None, steps=10, coalesced=1) | |
| 31 | 27 | assert "coalesced" not in render_status(status) |
| 32 | 28 | |
| 33 | - status.mark_cycle_done( | |
| 34 | - train_loss=None, val_loss=None, steps=10, coalesced=5 | |
| 35 | - ) | |
| 29 | + status.mark_cycle_done(train_loss=None, val_loss=None, steps=10, coalesced=5) | |
| 36 | 30 | assert "coalesced: 5" in render_status(status) |
| 37 | 31 | |
| 38 | 32 | def test_age_buckets(self) -> None: |