tenseleyflow/documentlanguagemodel / 5ee947d

Browse files

Add token usage tracking to OpenAI and Anthropic teachers

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
5ee947d5d21792013fc7e6843286b6dd344834a7
Parents
a016dea
Tree
2712546

3 changed files

StatusFile+-
M src/dlm/synth/__init__.py 2 0
M src/dlm/synth/teachers.py 59 0
M tests/unit/synth/test_teachers.py 61 0
src/dlm/synth/__init__.pymodified
@@ -59,6 +59,7 @@ from dlm.synth.teachers import (
5959
     SynthTeacher,
6060
     TeacherKind,
6161
     TeacherRef,
62
+    TeacherUsage,
6263
     VllmServerTeacher,
6364
     build_teacher,
6465
     parse_teacher_ref,
@@ -98,6 +99,7 @@ __all__ = [
9899
     "TeacherKind",
99100
     "TeacherInvocationError",
100101
     "TeacherRef",
102
+    "TeacherUsage",
101103
     "TeacherUnavailableError",
102104
     "VllmServerTeacher",
103105
     "apply_plan",
src/dlm/synth/teachers.pymodified
@@ -2,6 +2,7 @@
22
 
33
 import importlib
44
 import json
5
+import logging
56
 import os
67
 import urllib.error
78
 import urllib.request
@@ -17,6 +18,8 @@ from dlm.synth.errors import (
1718
     TeacherUnavailableError,
1819
 )
1920
 
21
+_log = logging.getLogger(__name__)
22
+
2023
 TeacherKind = Literal["self", "hf", "openai", "anthropic", "vllm-server"]
2124
 
2225
 _DEFAULT_MAX_NEW_TOKENS = 512
@@ -25,6 +28,32 @@ _OPENAI_API_KEY_ENV = "OPENAI_API_KEY"
2528
 _ANTHROPIC_API_KEY_ENV = "ANTHROPIC_API_KEY"
2629
 
2730
 
31
+@dataclass
32
+class TeacherUsage:
33
+    """Accumulated token usage from API-backed teachers."""
34
+
35
+    prompt_tokens: int = 0
36
+    completion_tokens: int = 0
37
+    requests: int = 0
38
+
39
+    @property
40
+    def total_tokens(self) -> int:
41
+        return self.prompt_tokens + self.completion_tokens
42
+
43
+    def log_summary(self, teacher_name: str) -> None:
44
+        if self.requests == 0:
45
+            return
46
+        _log.info(
47
+            "teacher %s usage: %d requests, %d prompt tokens, "
48
+            "%d completion tokens, %d total tokens",
49
+            teacher_name,
50
+            self.requests,
51
+            self.prompt_tokens,
52
+            self.completion_tokens,
53
+            self.total_tokens,
54
+        )
55
+
56
+
2857
 @dataclass(frozen=True)
2958
 class TeacherRef:
3059
     """Parsed `--teacher` selector from the CLI."""
@@ -185,6 +214,7 @@ class OpenAiTeacher:
185214
     client_factory: OpenAiClientFactory | None = field(default=None, repr=False, compare=False)
186215
     api_key_env: str = field(default=_OPENAI_API_KEY_ENV, repr=False, compare=False)
187216
     name: str = field(init=False)
217
+    usage: TeacherUsage = field(default_factory=TeacherUsage, init=False, repr=False, compare=False)
188218
     _client: Any = field(default=None, init=False, repr=False, compare=False)
189219
 
190220
     def __post_init__(self) -> None:
@@ -222,6 +252,7 @@ class OpenAiTeacher:
222252
             response = client.chat.completions.create(**payload)
223253
         except Exception as exc:
224254
             raise TeacherInvocationError(f"{self.name} request failed: {exc}") from exc
255
+        _accumulate_openai_usage(self.usage, response)
225256
         return _require_non_empty_teacher_output(
226257
             _extract_openai_message_text(response),
227258
             teacher=self.name,
@@ -247,6 +278,7 @@ class AnthropicTeacher:
247278
     client_factory: AnthropicClientFactory | None = field(default=None, repr=False, compare=False)
248279
     api_key_env: str = field(default=_ANTHROPIC_API_KEY_ENV, repr=False, compare=False)
249280
     name: str = field(init=False)
281
+    usage: TeacherUsage = field(default_factory=TeacherUsage, init=False, repr=False, compare=False)
250282
     _client: Any = field(default=None, init=False, repr=False, compare=False)
251283
 
252284
     def __post_init__(self) -> None:
@@ -281,6 +313,7 @@ class AnthropicTeacher:
281313
             response = client.messages.create(**payload)
282314
         except Exception as exc:
283315
             raise TeacherInvocationError(f"{self.name} request failed: {exc}") from exc
316
+        _accumulate_anthropic_usage(self.usage, response)
284317
         return _require_non_empty_teacher_output(
285318
             _extract_anthropic_text(response),
286319
             teacher=self.name,
@@ -618,6 +651,32 @@ def _obj_get(obj: object, name: str) -> object:
618651
     return getattr(obj, name, None)
619652
 
620653
 
654
+def _accumulate_openai_usage(usage: TeacherUsage, response: Any) -> None:
655
+    usage.requests += 1
656
+    u = _obj_get(response, "usage")
657
+    if u is None:
658
+        return
659
+    pt = _obj_get(u, "prompt_tokens")
660
+    ct = _obj_get(u, "completion_tokens")
661
+    if isinstance(pt, int):
662
+        usage.prompt_tokens += pt
663
+    if isinstance(ct, int):
664
+        usage.completion_tokens += ct
665
+
666
+
667
+def _accumulate_anthropic_usage(usage: TeacherUsage, response: Any) -> None:
668
+    usage.requests += 1
669
+    u = _obj_get(response, "usage")
670
+    if u is None:
671
+        return
672
+    pt = _obj_get(u, "input_tokens")
673
+    ct = _obj_get(u, "output_tokens")
674
+    if isinstance(pt, int):
675
+        usage.prompt_tokens += pt
676
+    if isinstance(ct, int):
677
+        usage.completion_tokens += ct
678
+
679
+
621680
 def _normalize_openai_compat_base_url(url: str) -> str:
622681
     stripped = url.rstrip("/")
623682
     if stripped.endswith("/v1/chat/completions"):
tests/unit/synth/test_teachers.pymodified
@@ -21,6 +21,7 @@ from dlm.synth import (
2121
     SelfTeacher,
2222
     TeacherInvocationError,
2323
     TeacherUnavailableError,
24
+    TeacherUsage,
2425
     VllmServerTeacher,
2526
     build_teacher,
2627
     parse_teacher_ref,
@@ -228,6 +229,24 @@ class TestOpenAiTeacher:
228229
         assert payloads[0]["seed"] == 5
229230
         assert factories == ["secret"]
230231
 
232
+    def test_openai_teacher_accumulates_usage(self, monkeypatch: pytest.MonkeyPatch) -> None:
233
+        monkeypatch.setenv("OPENAI_API_KEY", "secret")
234
+
235
+        def _create(**_kwargs: Any) -> Any:
236
+            return SimpleNamespace(
237
+                choices=[SimpleNamespace(message=SimpleNamespace(content="ok"))],
238
+                usage=SimpleNamespace(prompt_tokens=10, completion_tokens=5),
239
+            )
240
+
241
+        client = SimpleNamespace(chat=SimpleNamespace(completions=SimpleNamespace(create=_create)))
242
+        teacher = OpenAiTeacher("gpt-4o-mini", client_factory=lambda _k: client)
243
+        teacher.generate("sys", "usr")
244
+        teacher.generate("sys", "usr")
245
+        assert teacher.usage.requests == 2
246
+        assert teacher.usage.prompt_tokens == 20
247
+        assert teacher.usage.completion_tokens == 10
248
+        assert teacher.usage.total_tokens == 30
249
+
231250
     def test_openai_teacher_wraps_request_failures(self, monkeypatch: pytest.MonkeyPatch) -> None:
232251
         monkeypatch.setenv("OPENAI_API_KEY", "secret")
233252
 
@@ -291,6 +310,28 @@ class TestAnthropicTeacher:
291310
         assert captured["payload"]["model"] == "claude-3-5-haiku-latest"
292311
         assert factories == ["secret"]
293312
 
313
+    def test_anthropic_teacher_accumulates_usage(self, monkeypatch: pytest.MonkeyPatch) -> None:
314
+        monkeypatch.setenv("ANTHROPIC_API_KEY", "secret")
315
+
316
+        class _Messages:
317
+            @staticmethod
318
+            def create(**_kwargs: Any) -> Any:
319
+                return SimpleNamespace(
320
+                    content=[SimpleNamespace(type="text", text="ok")],
321
+                    usage=SimpleNamespace(input_tokens=8, output_tokens=3),
322
+                )
323
+
324
+        teacher = AnthropicTeacher(
325
+            "claude-3-5-haiku-latest",
326
+            client_factory=lambda _k: SimpleNamespace(messages=_Messages()),
327
+        )
328
+        teacher.generate("sys", "usr")
329
+        teacher.generate("sys", "usr")
330
+        assert teacher.usage.requests == 2
331
+        assert teacher.usage.prompt_tokens == 16
332
+        assert teacher.usage.completion_tokens == 6
333
+        assert teacher.usage.total_tokens == 22
334
+
294335
     def test_anthropic_teacher_wraps_request_failures(
295336
         self, monkeypatch: pytest.MonkeyPatch
296337
     ) -> None:
@@ -374,6 +415,26 @@ class TestVllmServerTeacher:
374415
         assert completion_calls[0][3:] == (29, 0.4, 0.75, 9, 30.0)
375416
 
376417
 
418
+class TestTeacherUsage:
419
+    def test_total_tokens(self) -> None:
420
+        u = TeacherUsage(prompt_tokens=10, completion_tokens=5, requests=1)
421
+        assert u.total_tokens == 15
422
+
423
+    def test_log_summary_skips_zero_requests(self, caplog: pytest.LogCaptureFixture) -> None:
424
+        u = TeacherUsage()
425
+        u.log_summary("test")
426
+        assert "test" not in caplog.text
427
+
428
+    def test_log_summary_emits_on_nonzero(self, caplog: pytest.LogCaptureFixture) -> None:
429
+        import logging
430
+
431
+        u = TeacherUsage(prompt_tokens=100, completion_tokens=50, requests=3)
432
+        with caplog.at_level(logging.INFO, logger="dlm.synth.teachers"):
433
+            u.log_summary("openai:gpt-4o")
434
+        assert "openai:gpt-4o" in caplog.text
435
+        assert "150 total tokens" in caplog.text
436
+
437
+
377438
 class TestTeacherHelpers:
378439
     def test_flatten_teacher_prompt_handles_partial_inputs(self) -> None:
379440
         assert teachers_mod._flatten_teacher_prompt("system", "user").startswith("System:\n")