| 1 |
"""Unit tests for embed_warmup. |
| 2 |
|
| 3 |
Uses fake model objects — we're testing requires_grad flipping and |
| 4 |
the modules_to_save extension logic, not PyTorch. |
| 5 |
""" |
| 6 |
|
| 7 |
from __future__ import annotations |
| 8 |
|
| 9 |
from types import SimpleNamespace |
| 10 |
from typing import Any |
| 11 |
|
| 12 |
import pytest |
| 13 |
|
| 14 |
from dlm.train.cpt.embed_warmup import ( |
| 15 |
EmbedWarmupCallback, |
| 16 |
extend_modules_to_save_for_embed_warmup, |
| 17 |
unfreeze_embeddings_for, |
| 18 |
) |
| 19 |
|
| 20 |
|
| 21 |
class _FakeParam: |
| 22 |
def __init__(self, requires_grad: bool = False) -> None: |
| 23 |
self.requires_grad = requires_grad |
| 24 |
|
| 25 |
|
| 26 |
def _model(*, embed_frozen: bool = True, head_frozen: bool = True, tied: bool = False) -> Any: |
| 27 |
embed_param = _FakeParam(requires_grad=not embed_frozen) |
| 28 |
head_param = embed_param if tied else _FakeParam(requires_grad=not head_frozen) |
| 29 |
embed_module = SimpleNamespace(weight=embed_param) |
| 30 |
head_module = SimpleNamespace(weight=head_param) |
| 31 |
return SimpleNamespace( |
| 32 |
get_input_embeddings=lambda: embed_module, |
| 33 |
get_output_embeddings=lambda: head_module, |
| 34 |
) |
| 35 |
|
| 36 |
|
| 37 |
class TestUnfreezeContextManager: |
| 38 |
def test_missing_embedding_modules_yield_empty_list(self) -> None: |
| 39 |
model = SimpleNamespace( |
| 40 |
get_input_embeddings=lambda: None, |
| 41 |
get_output_embeddings=lambda: None, |
| 42 |
) |
| 43 |
with unfreeze_embeddings_for(model) as weights: |
| 44 |
assert weights == [] |
| 45 |
|
| 46 |
def test_modules_without_weight_are_skipped(self) -> None: |
| 47 |
model = SimpleNamespace( |
| 48 |
get_input_embeddings=lambda: SimpleNamespace(weight=None), |
| 49 |
get_output_embeddings=lambda: SimpleNamespace(weight=None), |
| 50 |
) |
| 51 |
with unfreeze_embeddings_for(model) as weights: |
| 52 |
assert weights == [] |
| 53 |
|
| 54 |
def test_unfreezes_both_embeddings(self) -> None: |
| 55 |
model = _model(embed_frozen=True, head_frozen=True) |
| 56 |
with unfreeze_embeddings_for(model) as weights: |
| 57 |
assert all(w.requires_grad for w in weights) |
| 58 |
assert len(weights) == 2 |
| 59 |
|
| 60 |
def test_restores_originals_on_exit(self) -> None: |
| 61 |
model = _model(embed_frozen=True, head_frozen=True) |
| 62 |
with unfreeze_embeddings_for(model): |
| 63 |
pass |
| 64 |
assert model.get_input_embeddings().weight.requires_grad is False |
| 65 |
assert model.get_output_embeddings().weight.requires_grad is False |
| 66 |
|
| 67 |
def test_preserves_per_weight_original_state(self) -> None: |
| 68 |
model = _model(embed_frozen=True, head_frozen=False) |
| 69 |
with unfreeze_embeddings_for(model): |
| 70 |
pass |
| 71 |
assert model.get_input_embeddings().weight.requires_grad is False |
| 72 |
assert model.get_output_embeddings().weight.requires_grad is True |
| 73 |
|
| 74 |
def test_restores_on_exception(self) -> None: |
| 75 |
model = _model(embed_frozen=True, head_frozen=True) |
| 76 |
with pytest.raises(RuntimeError, match="boom"): |
| 77 |
with unfreeze_embeddings_for(model): |
| 78 |
raise RuntimeError("boom") |
| 79 |
assert model.get_input_embeddings().weight.requires_grad is False |
| 80 |
assert model.get_output_embeddings().weight.requires_grad is False |
| 81 |
|
| 82 |
def test_tied_weights_deduplicated(self) -> None: |
| 83 |
model = _model(tied=True) |
| 84 |
with unfreeze_embeddings_for(model) as weights: |
| 85 |
assert len(weights) == 1 # same object once |
| 86 |
assert weights[0].requires_grad is True |
| 87 |
|
| 88 |
|
| 89 |
class TestExtendModulesToSave: |
| 90 |
def test_zero_warmup_passes_through_none(self) -> None: |
| 91 |
assert extend_modules_to_save_for_embed_warmup(None, embed_warmup_steps=0) is None |
| 92 |
|
| 93 |
def test_zero_warmup_passes_through_list(self) -> None: |
| 94 |
assert extend_modules_to_save_for_embed_warmup(["embed_tokens"], embed_warmup_steps=0) == [ |
| 95 |
"embed_tokens" |
| 96 |
] |
| 97 |
|
| 98 |
def test_warmup_on_with_no_existing(self) -> None: |
| 99 |
out = extend_modules_to_save_for_embed_warmup(None, embed_warmup_steps=50) |
| 100 |
assert out == ["embed_tokens", "lm_head"] |
| 101 |
|
| 102 |
def test_warmup_on_with_existing_tokenizer_grew(self) -> None: |
| 103 |
# tokenizer_grew already added both — result is stable (no duplicates). |
| 104 |
out = extend_modules_to_save_for_embed_warmup( |
| 105 |
["embed_tokens", "lm_head"], embed_warmup_steps=50 |
| 106 |
) |
| 107 |
assert out == ["embed_tokens", "lm_head"] |
| 108 |
|
| 109 |
def test_warmup_on_preserves_order_of_existing(self) -> None: |
| 110 |
out = extend_modules_to_save_for_embed_warmup( |
| 111 |
["my_module", "embed_tokens"], embed_warmup_steps=1 |
| 112 |
) |
| 113 |
assert out == ["my_module", "embed_tokens", "lm_head"] |
| 114 |
|
| 115 |
|
| 116 |
class TestEmbedWarmupCallback: |
| 117 |
def test_rejects_negative_n_steps(self) -> None: |
| 118 |
with pytest.raises(ValueError, match="n_steps must be non-negative"): |
| 119 |
EmbedWarmupCallback(_model(), n_steps=-1) |
| 120 |
|
| 121 |
def test_on_train_begin_unfreezes_when_n_positive(self) -> None: |
| 122 |
model = _model(embed_frozen=True, head_frozen=True) |
| 123 |
cb = EmbedWarmupCallback(model, n_steps=10) |
| 124 |
cb.on_train_begin(args=None, state=None, control=None) |
| 125 |
assert model.get_input_embeddings().weight.requires_grad is True |
| 126 |
assert model.get_output_embeddings().weight.requires_grad is True |
| 127 |
|
| 128 |
def test_on_train_begin_is_noop_when_n_zero(self) -> None: |
| 129 |
model = _model(embed_frozen=True, head_frozen=True) |
| 130 |
cb = EmbedWarmupCallback(model, n_steps=0) |
| 131 |
cb.on_train_begin(args=None, state=None, control=None) |
| 132 |
assert model.get_input_embeddings().weight.requires_grad is False |
| 133 |
|
| 134 |
def test_step_end_at_budget_restores(self) -> None: |
| 135 |
model = _model(embed_frozen=True, head_frozen=True) |
| 136 |
cb = EmbedWarmupCallback(model, n_steps=5) |
| 137 |
cb.on_train_begin(args=None, state=None, control=None) |
| 138 |
# Steps before budget: still unfrozen. |
| 139 |
cb.on_step_end(args=None, state=SimpleNamespace(global_step=3), control=None) |
| 140 |
assert model.get_input_embeddings().weight.requires_grad is True |
| 141 |
# Step == budget: refreeze. |
| 142 |
cb.on_step_end(args=None, state=SimpleNamespace(global_step=5), control=None) |
| 143 |
assert model.get_input_embeddings().weight.requires_grad is False |
| 144 |
|
| 145 |
def test_train_end_restores_if_still_active(self) -> None: |
| 146 |
model = _model(embed_frozen=True, head_frozen=True) |
| 147 |
cb = EmbedWarmupCallback(model, n_steps=1000) # will never fire |
| 148 |
cb.on_train_begin(args=None, state=None, control=None) |
| 149 |
cb.on_train_end(args=None, state=None, control=None) |
| 150 |
assert model.get_input_embeddings().weight.requires_grad is False |
| 151 |
|
| 152 |
def test_restore_is_idempotent(self) -> None: |
| 153 |
model = _model(embed_frozen=True, head_frozen=True) |
| 154 |
cb = EmbedWarmupCallback(model, n_steps=5) |
| 155 |
cb.on_train_begin(args=None, state=None, control=None) |
| 156 |
cb.on_step_end(args=None, state=SimpleNamespace(global_step=10), control=None) |
| 157 |
# Second restore via on_train_end: does not double-flip. |
| 158 |
cb.on_train_end(args=None, state=None, control=None) |
| 159 |
assert model.get_input_embeddings().weight.requires_grad is False |