Python · 6901 bytes Raw Blame History
1 """Unit tests for embed_warmup.
2
3 Uses fake model objects — we're testing requires_grad flipping and
4 the modules_to_save extension logic, not PyTorch.
5 """
6
7 from __future__ import annotations
8
9 from types import SimpleNamespace
10 from typing import Any
11
12 import pytest
13
14 from dlm.train.cpt.embed_warmup import (
15 EmbedWarmupCallback,
16 extend_modules_to_save_for_embed_warmup,
17 unfreeze_embeddings_for,
18 )
19
20
21 class _FakeParam:
22 def __init__(self, requires_grad: bool = False) -> None:
23 self.requires_grad = requires_grad
24
25
26 def _model(*, embed_frozen: bool = True, head_frozen: bool = True, tied: bool = False) -> Any:
27 embed_param = _FakeParam(requires_grad=not embed_frozen)
28 head_param = embed_param if tied else _FakeParam(requires_grad=not head_frozen)
29 embed_module = SimpleNamespace(weight=embed_param)
30 head_module = SimpleNamespace(weight=head_param)
31 return SimpleNamespace(
32 get_input_embeddings=lambda: embed_module,
33 get_output_embeddings=lambda: head_module,
34 )
35
36
37 class TestUnfreezeContextManager:
38 def test_missing_embedding_modules_yield_empty_list(self) -> None:
39 model = SimpleNamespace(
40 get_input_embeddings=lambda: None,
41 get_output_embeddings=lambda: None,
42 )
43 with unfreeze_embeddings_for(model) as weights:
44 assert weights == []
45
46 def test_modules_without_weight_are_skipped(self) -> None:
47 model = SimpleNamespace(
48 get_input_embeddings=lambda: SimpleNamespace(weight=None),
49 get_output_embeddings=lambda: SimpleNamespace(weight=None),
50 )
51 with unfreeze_embeddings_for(model) as weights:
52 assert weights == []
53
54 def test_unfreezes_both_embeddings(self) -> None:
55 model = _model(embed_frozen=True, head_frozen=True)
56 with unfreeze_embeddings_for(model) as weights:
57 assert all(w.requires_grad for w in weights)
58 assert len(weights) == 2
59
60 def test_restores_originals_on_exit(self) -> None:
61 model = _model(embed_frozen=True, head_frozen=True)
62 with unfreeze_embeddings_for(model):
63 pass
64 assert model.get_input_embeddings().weight.requires_grad is False
65 assert model.get_output_embeddings().weight.requires_grad is False
66
67 def test_preserves_per_weight_original_state(self) -> None:
68 model = _model(embed_frozen=True, head_frozen=False)
69 with unfreeze_embeddings_for(model):
70 pass
71 assert model.get_input_embeddings().weight.requires_grad is False
72 assert model.get_output_embeddings().weight.requires_grad is True
73
74 def test_restores_on_exception(self) -> None:
75 model = _model(embed_frozen=True, head_frozen=True)
76 with pytest.raises(RuntimeError, match="boom"):
77 with unfreeze_embeddings_for(model):
78 raise RuntimeError("boom")
79 assert model.get_input_embeddings().weight.requires_grad is False
80 assert model.get_output_embeddings().weight.requires_grad is False
81
82 def test_tied_weights_deduplicated(self) -> None:
83 model = _model(tied=True)
84 with unfreeze_embeddings_for(model) as weights:
85 assert len(weights) == 1 # same object once
86 assert weights[0].requires_grad is True
87
88
89 class TestExtendModulesToSave:
90 def test_zero_warmup_passes_through_none(self) -> None:
91 assert extend_modules_to_save_for_embed_warmup(None, embed_warmup_steps=0) is None
92
93 def test_zero_warmup_passes_through_list(self) -> None:
94 assert extend_modules_to_save_for_embed_warmup(["embed_tokens"], embed_warmup_steps=0) == [
95 "embed_tokens"
96 ]
97
98 def test_warmup_on_with_no_existing(self) -> None:
99 out = extend_modules_to_save_for_embed_warmup(None, embed_warmup_steps=50)
100 assert out == ["embed_tokens", "lm_head"]
101
102 def test_warmup_on_with_existing_tokenizer_grew(self) -> None:
103 # tokenizer_grew already added both — result is stable (no duplicates).
104 out = extend_modules_to_save_for_embed_warmup(
105 ["embed_tokens", "lm_head"], embed_warmup_steps=50
106 )
107 assert out == ["embed_tokens", "lm_head"]
108
109 def test_warmup_on_preserves_order_of_existing(self) -> None:
110 out = extend_modules_to_save_for_embed_warmup(
111 ["my_module", "embed_tokens"], embed_warmup_steps=1
112 )
113 assert out == ["my_module", "embed_tokens", "lm_head"]
114
115
116 class TestEmbedWarmupCallback:
117 def test_rejects_negative_n_steps(self) -> None:
118 with pytest.raises(ValueError, match="n_steps must be non-negative"):
119 EmbedWarmupCallback(_model(), n_steps=-1)
120
121 def test_on_train_begin_unfreezes_when_n_positive(self) -> None:
122 model = _model(embed_frozen=True, head_frozen=True)
123 cb = EmbedWarmupCallback(model, n_steps=10)
124 cb.on_train_begin(args=None, state=None, control=None)
125 assert model.get_input_embeddings().weight.requires_grad is True
126 assert model.get_output_embeddings().weight.requires_grad is True
127
128 def test_on_train_begin_is_noop_when_n_zero(self) -> None:
129 model = _model(embed_frozen=True, head_frozen=True)
130 cb = EmbedWarmupCallback(model, n_steps=0)
131 cb.on_train_begin(args=None, state=None, control=None)
132 assert model.get_input_embeddings().weight.requires_grad is False
133
134 def test_step_end_at_budget_restores(self) -> None:
135 model = _model(embed_frozen=True, head_frozen=True)
136 cb = EmbedWarmupCallback(model, n_steps=5)
137 cb.on_train_begin(args=None, state=None, control=None)
138 # Steps before budget: still unfrozen.
139 cb.on_step_end(args=None, state=SimpleNamespace(global_step=3), control=None)
140 assert model.get_input_embeddings().weight.requires_grad is True
141 # Step == budget: refreeze.
142 cb.on_step_end(args=None, state=SimpleNamespace(global_step=5), control=None)
143 assert model.get_input_embeddings().weight.requires_grad is False
144
145 def test_train_end_restores_if_still_active(self) -> None:
146 model = _model(embed_frozen=True, head_frozen=True)
147 cb = EmbedWarmupCallback(model, n_steps=1000) # will never fire
148 cb.on_train_begin(args=None, state=None, control=None)
149 cb.on_train_end(args=None, state=None, control=None)
150 assert model.get_input_embeddings().weight.requires_grad is False
151
152 def test_restore_is_idempotent(self) -> None:
153 model = _model(embed_frozen=True, head_frozen=True)
154 cb = EmbedWarmupCallback(model, n_steps=5)
155 cb.on_train_begin(args=None, state=None, control=None)
156 cb.on_step_end(args=None, state=SimpleNamespace(global_step=10), control=None)
157 # Second restore via on_train_end: does not double-flip.
158 cb.on_train_end(args=None, state=None, control=None)
159 assert model.get_input_embeddings().weight.requires_grad is False