tenseleyflow/sway / 198cd55

Browse files

ruff auto-format after S23 edits (line-length-driven, no logic)

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
198cd55595f118c6fe6ad03dece560aa8c649e6b
Parents
9f82456
Tree
423ff69

4 changed files

StatusFile+-
M src/dlm_sway/backends/hf.py 4 4
M src/dlm_sway/core/scoring.py 1 3
M src/dlm_sway/probes/cluster_kl.py 2 1
M src/dlm_sway/probes/delta_kl.py 2 1
src/dlm_sway/backends/hf.pymodified
@@ -254,9 +254,7 @@ class _HFView:
254254
         log_probs = F.log_softmax(logits.float(), dim=-1).squeeze(0)
255255
         return _topk_to_token_dist(log_probs, top_k=top_k)
256256
 
257
-    def next_token_dist_batch(
258
-        self, prompts: Sequence[str], *, top_k: int = 256
259
-    ) -> list[TokenDist]:
257
+    def next_token_dist_batch(self, prompts: Sequence[str], *, top_k: int = 256) -> list[TokenDist]:
260258
         """Batched forward via tokenizer left-padding.
261259
 
262260
         Decoder-only LMs need left-padding because the last-token
@@ -294,7 +292,9 @@ class _HFView:
294292
                     attention_mask=tokens.get("attention_mask"),
295293
                 ).logits[:, -1, :]  # (B, V) — left-pad makes "last" always the real last token
296294
             log_probs = F.log_softmax(logits.float(), dim=-1)  # (B, V)
297
-            return [_topk_to_token_dist(log_probs[row], top_k=top_k) for row in range(len(miss_prompts))]
295
+            return [
296
+                _topk_to_token_dist(log_probs[row], top_k=top_k) for row in range(len(miss_prompts))
297
+            ]
298298
 
299299
         return self._inst.cached_batch(
300300
             "next_token_dist", self.id, list(prompts), top_k, compute_misses
src/dlm_sway/core/scoring.pymodified
@@ -129,9 +129,7 @@ class ScoringBackend(Protocol):
129129
         """
130130
         ...
131131
 
132
-    def next_token_dist_batch(
133
-        self, prompts: Sequence[str], *, top_k: int = 256
134
-    ) -> list[TokenDist]:
132
+    def next_token_dist_batch(self, prompts: Sequence[str], *, top_k: int = 256) -> list[TokenDist]:
135133
         """Batched variant of :meth:`next_token_dist`.
136134
 
137135
         Returns one :class:`TokenDist` per entry in ``prompts``, in the
src/dlm_sway/probes/cluster_kl.pymodified
@@ -197,7 +197,8 @@ class ClusterKLProbe(Probe):
197197
         with ctx.backend.as_finetuned() as ft_view:
198198
             ft_dists = ft_view.next_token_dist_batch(list(spec.prompts), top_k=top_k)
199199
         divergences: list[float] = [
200
-            divergence(b, f, kind=spec.divergence) for b, f in zip(base_dists, ft_dists, strict=True)
200
+            divergence(b, f, kind=spec.divergence)
201
+            for b, f in zip(base_dists, ft_dists, strict=True)
201202
         ]
202203
 
203204
         # Aggregate per-cluster means + variances. A cluster that
src/dlm_sway/probes/delta_kl.pymodified
@@ -98,7 +98,8 @@ class DeltaKLProbe(Probe):
9898
         with ctx.backend.as_finetuned() as ft_view:
9999
             ft_dists = ft_view.next_token_dist_batch(list(spec.prompts), top_k=top_k)
100100
         divergences: list[float] = [
101
-            divergence(b, f, kind=spec.divergence) for b, f in zip(base_dists, ft_dists, strict=True)
101
+            divergence(b, f, kind=spec.divergence)
102
+            for b, f in zip(base_dists, ft_dists, strict=True)
102103
         ]
103104
 
104105
         raw_mean = statistics.fmean(divergences)