Python · 5795 bytes Raw Blame History
1 """Turn-phase tracking for Loader runtime execution."""
2
3 from __future__ import annotations
4
5 from collections.abc import Awaitable, Callable
6 from dataclasses import dataclass
7 from enum import StrEnum
8
9 from .context import RuntimeContext
10 from .events import AgentEvent
11 from .tracing import RuntimeTracer
12
13 EventSink = Callable[[AgentEvent], Awaitable[None]]
14
15
16 class TurnPhase(StrEnum):
17 """Named phases of one runtime turn."""
18
19 PREPARE = "prepare"
20 ASSISTANT = "assistant"
21 REPAIR = "repair"
22 TOOLS = "tools"
23 CRITIQUE = "critique"
24 COMPLETION = "completion"
25 FINALIZE = "finalize"
26
27
28 class TurnTransitionKind(StrEnum):
29 """Classification for why one turn-state transition occurred."""
30
31 NORMAL = "normal"
32 RETRY = "retry"
33 REROUTE = "reroute"
34 RECOVERY = "recovery"
35 TERMINAL = "terminal"
36
37
38 @dataclass(slots=True)
39 class TurnTransition:
40 """One validated turn-state transition."""
41
42 from_phase: str | None
43 to_phase: str
44 reason_code: str
45 reason_summary: str
46 kind: TurnTransitionKind
47
48 @property
49 def summary(self) -> str:
50 source = self.from_phase or "start"
51 return (
52 f"{source} -> {self.to_phase} "
53 f"[{self.kind.value}] {self.reason_summary}"
54 )
55
56
57 class TurnStateMachine:
58 """Validate allowed turn-state transitions."""
59
60 _ALLOWED_TRANSITIONS: dict[str | None, set[str]] = {
61 None: {TurnPhase.PREPARE.value},
62 TurnPhase.PREPARE.value: {
63 TurnPhase.ASSISTANT.value,
64 TurnPhase.FINALIZE.value,
65 },
66 TurnPhase.ASSISTANT.value: {
67 TurnPhase.REPAIR.value,
68 TurnPhase.TOOLS.value,
69 TurnPhase.CRITIQUE.value,
70 TurnPhase.COMPLETION.value,
71 TurnPhase.FINALIZE.value,
72 },
73 TurnPhase.REPAIR.value: {
74 TurnPhase.ASSISTANT.value,
75 TurnPhase.TOOLS.value,
76 TurnPhase.COMPLETION.value,
77 TurnPhase.FINALIZE.value,
78 },
79 TurnPhase.TOOLS.value: {
80 TurnPhase.ASSISTANT.value,
81 TurnPhase.CRITIQUE.value,
82 TurnPhase.COMPLETION.value,
83 TurnPhase.FINALIZE.value,
84 },
85 TurnPhase.CRITIQUE.value: {
86 TurnPhase.ASSISTANT.value,
87 TurnPhase.COMPLETION.value,
88 TurnPhase.FINALIZE.value,
89 },
90 TurnPhase.COMPLETION.value: {
91 TurnPhase.ASSISTANT.value,
92 TurnPhase.FINALIZE.value,
93 },
94 TurnPhase.FINALIZE.value: set(),
95 }
96
97 def __init__(self) -> None:
98 self.current_phase: str | None = None
99 self.last_transition: TurnTransition | None = None
100
101 def transition(
102 self,
103 phase: TurnPhase,
104 *,
105 reason_code: str,
106 reason_summary: str,
107 kind: TurnTransitionKind = TurnTransitionKind.NORMAL,
108 ) -> TurnTransition | None:
109 """Validate and record a transition to the target phase."""
110
111 if phase.value == self.current_phase:
112 return None
113
114 allowed = self._ALLOWED_TRANSITIONS.get(self.current_phase, set())
115 if phase.value not in allowed:
116 raise ValueError(
117 "Invalid turn-state transition: "
118 f"{self.current_phase or 'start'} -> {phase.value}"
119 )
120
121 transition = TurnTransition(
122 from_phase=self.current_phase,
123 to_phase=phase.value,
124 reason_code=reason_code,
125 reason_summary=reason_summary,
126 kind=kind,
127 )
128 self.current_phase = phase.value
129 self.last_transition = transition
130 return transition
131
132 def clear(self) -> None:
133 """Reset the active phase after a turn completes."""
134
135 self.current_phase = None
136
137
138 class TurnPhaseTracker:
139 """Persist and emit turn-phase transitions."""
140
141 def __init__(self, context: RuntimeContext, tracer: RuntimeTracer) -> None:
142 self.context = context
143 self.tracer = tracer
144 self.state_machine = TurnStateMachine()
145
146 async def enter(
147 self,
148 phase: TurnPhase,
149 emit: EventSink,
150 *,
151 detail: str | None = None,
152 reason_code: str | None = None,
153 kind: TurnTransitionKind = TurnTransitionKind.NORMAL,
154 ) -> None:
155 """Move the runtime into a named phase and emit the transition."""
156
157 summary = detail or f"Phase: {phase.value}"
158 transition = self.state_machine.transition(
159 phase,
160 reason_code=reason_code or phase.value,
161 reason_summary=summary,
162 kind=kind,
163 )
164 if transition is None:
165 return
166
167 self.context.session.update_runtime_state(
168 active_turn_phase=phase.value,
169 last_turn_transition_summary=transition.summary,
170 last_turn_transition_kind=transition.kind.value,
171 last_turn_transition_reason_code=transition.reason_code,
172 )
173 self.tracer.record(
174 "turn.phase_changed",
175 phase=phase.value,
176 detail=summary,
177 from_phase=transition.from_phase,
178 transition_kind=transition.kind.value,
179 reason_code=transition.reason_code,
180 )
181 await emit(
182 AgentEvent(
183 type="turn_phase",
184 content=transition.summary,
185 turn_phase=phase.value,
186 transition_kind=transition.kind.value,
187 transition_summary=transition.summary,
188 transition_reason_code=transition.reason_code,
189 )
190 )
191
192 def clear(self) -> None:
193 """Clear the persisted active phase when the turn finishes."""
194
195 self.state_machine.clear()
196 self.context.session.update_runtime_state(active_turn_phase=None)