tenseleyflow/sway / ef438ac

Browse files

tests/test_cluster_kl_prove_value: same sklearn-free kmeans stub

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
ef438aca4712769de13b2126843dc1e8a6fabae0
Parents
5c9e091
Tree
064fa89

1 changed file

StatusFile+-
M tests/unit/test_cluster_kl_prove_value.py 11 0
tests/unit/test_cluster_kl_prove_value.pymodified
@@ -57,6 +57,13 @@ def _stub_embedder(text_to_vec: dict[str, np.ndarray]): # type: ignore[no-untyp
5757
     return _encode
5858
 
5959
 
60
+def _argmax_kmeans(embeddings: np.ndarray, *, k: int, seed: int) -> np.ndarray:
61
+    """sklearn-free stub — cluster by argmax of the one-hot test embeddings."""
62
+    del seed
63
+    labels = np.argmax(embeddings, axis=1).astype(np.int64)
64
+    return labels % k
65
+
66
+
6067
 @pytest.fixture
6168
 def monkeyed_embed(monkeypatch: pytest.MonkeyPatch) -> dict[str, np.ndarray]:
6269
     table: dict[str, np.ndarray] = {}
@@ -64,6 +71,10 @@ def monkeyed_embed(monkeypatch: pytest.MonkeyPatch) -> dict[str, np.ndarray]:
6471
         "dlm_sway.probes.cluster_kl._load_embedder",
6572
         lambda _model_id: _stub_embedder(table),  # type: ignore[arg-type]
6673
     )
74
+    monkeypatch.setattr(
75
+        "dlm_sway.probes.cluster_kl._kmeans_cluster",
76
+        _argmax_kmeans,
77
+    )
6778
     return table
6879
 
6980