@@ -2,6 +2,7 @@ |
| 2 | 2 | |
| 3 | 3 | import importlib |
| 4 | 4 | import json |
| 5 | +import logging |
| 5 | 6 | import os |
| 6 | 7 | import urllib.error |
| 7 | 8 | import urllib.request |
@@ -17,6 +18,8 @@ from dlm.synth.errors import ( |
| 17 | 18 | TeacherUnavailableError, |
| 18 | 19 | ) |
| 19 | 20 | |
| 21 | +_log = logging.getLogger(__name__) |
| 22 | + |
| 20 | 23 | TeacherKind = Literal["self", "hf", "openai", "anthropic", "vllm-server"] |
| 21 | 24 | |
| 22 | 25 | _DEFAULT_MAX_NEW_TOKENS = 512 |
@@ -25,6 +28,32 @@ _OPENAI_API_KEY_ENV = "OPENAI_API_KEY" |
| 25 | 28 | _ANTHROPIC_API_KEY_ENV = "ANTHROPIC_API_KEY" |
| 26 | 29 | |
| 27 | 30 | |
| 31 | +@dataclass |
| 32 | +class TeacherUsage: |
| 33 | + """Accumulated token usage from API-backed teachers.""" |
| 34 | + |
| 35 | + prompt_tokens: int = 0 |
| 36 | + completion_tokens: int = 0 |
| 37 | + requests: int = 0 |
| 38 | + |
| 39 | + @property |
| 40 | + def total_tokens(self) -> int: |
| 41 | + return self.prompt_tokens + self.completion_tokens |
| 42 | + |
| 43 | + def log_summary(self, teacher_name: str) -> None: |
| 44 | + if self.requests == 0: |
| 45 | + return |
| 46 | + _log.info( |
| 47 | + "teacher %s usage: %d requests, %d prompt tokens, " |
| 48 | + "%d completion tokens, %d total tokens", |
| 49 | + teacher_name, |
| 50 | + self.requests, |
| 51 | + self.prompt_tokens, |
| 52 | + self.completion_tokens, |
| 53 | + self.total_tokens, |
| 54 | + ) |
| 55 | + |
| 56 | + |
| 28 | 57 | @dataclass(frozen=True) |
| 29 | 58 | class TeacherRef: |
| 30 | 59 | """Parsed `--teacher` selector from the CLI.""" |
@@ -185,6 +214,7 @@ class OpenAiTeacher: |
| 185 | 214 | client_factory: OpenAiClientFactory | None = field(default=None, repr=False, compare=False) |
| 186 | 215 | api_key_env: str = field(default=_OPENAI_API_KEY_ENV, repr=False, compare=False) |
| 187 | 216 | name: str = field(init=False) |
| 217 | + usage: TeacherUsage = field(default_factory=TeacherUsage, init=False, repr=False, compare=False) |
| 188 | 218 | _client: Any = field(default=None, init=False, repr=False, compare=False) |
| 189 | 219 | |
| 190 | 220 | def __post_init__(self) -> None: |
@@ -222,6 +252,7 @@ class OpenAiTeacher: |
| 222 | 252 | response = client.chat.completions.create(**payload) |
| 223 | 253 | except Exception as exc: |
| 224 | 254 | raise TeacherInvocationError(f"{self.name} request failed: {exc}") from exc |
| 255 | + _accumulate_openai_usage(self.usage, response) |
| 225 | 256 | return _require_non_empty_teacher_output( |
| 226 | 257 | _extract_openai_message_text(response), |
| 227 | 258 | teacher=self.name, |
@@ -247,6 +278,7 @@ class AnthropicTeacher: |
| 247 | 278 | client_factory: AnthropicClientFactory | None = field(default=None, repr=False, compare=False) |
| 248 | 279 | api_key_env: str = field(default=_ANTHROPIC_API_KEY_ENV, repr=False, compare=False) |
| 249 | 280 | name: str = field(init=False) |
| 281 | + usage: TeacherUsage = field(default_factory=TeacherUsage, init=False, repr=False, compare=False) |
| 250 | 282 | _client: Any = field(default=None, init=False, repr=False, compare=False) |
| 251 | 283 | |
| 252 | 284 | def __post_init__(self) -> None: |
@@ -281,6 +313,7 @@ class AnthropicTeacher: |
| 281 | 313 | response = client.messages.create(**payload) |
| 282 | 314 | except Exception as exc: |
| 283 | 315 | raise TeacherInvocationError(f"{self.name} request failed: {exc}") from exc |
| 316 | + _accumulate_anthropic_usage(self.usage, response) |
| 284 | 317 | return _require_non_empty_teacher_output( |
| 285 | 318 | _extract_anthropic_text(response), |
| 286 | 319 | teacher=self.name, |
@@ -618,6 +651,32 @@ def _obj_get(obj: object, name: str) -> object: |
| 618 | 651 | return getattr(obj, name, None) |
| 619 | 652 | |
| 620 | 653 | |
| 654 | +def _accumulate_openai_usage(usage: TeacherUsage, response: Any) -> None: |
| 655 | + usage.requests += 1 |
| 656 | + u = _obj_get(response, "usage") |
| 657 | + if u is None: |
| 658 | + return |
| 659 | + pt = _obj_get(u, "prompt_tokens") |
| 660 | + ct = _obj_get(u, "completion_tokens") |
| 661 | + if isinstance(pt, int): |
| 662 | + usage.prompt_tokens += pt |
| 663 | + if isinstance(ct, int): |
| 664 | + usage.completion_tokens += ct |
| 665 | + |
| 666 | + |
| 667 | +def _accumulate_anthropic_usage(usage: TeacherUsage, response: Any) -> None: |
| 668 | + usage.requests += 1 |
| 669 | + u = _obj_get(response, "usage") |
| 670 | + if u is None: |
| 671 | + return |
| 672 | + pt = _obj_get(u, "input_tokens") |
| 673 | + ct = _obj_get(u, "output_tokens") |
| 674 | + if isinstance(pt, int): |
| 675 | + usage.prompt_tokens += pt |
| 676 | + if isinstance(ct, int): |
| 677 | + usage.completion_tokens += ct |
| 678 | + |
| 679 | + |
| 621 | 680 | def _normalize_openai_compat_base_url(url: str) -> str: |
| 622 | 681 | stripped = url.rstrip("/") |
| 623 | 682 | if stripped.endswith("/v1/chat/completions"): |