Move task classification into runtime
- SHA
bb48c80b9ad7f3e5bfdbbf60db557097ebd8e6cf- Parents
-
e36e64f - Tree
6ad44c9
bb48c80
bb48c80b9ad7f3e5bfdbbf60db557097ebd8e6cfe36e64f
6ad44c9| Status | File | + | - |
|---|---|---|---|
| M |
src/loader/agent/reasoning.py
|
5 | 108 |
| M |
src/loader/runtime/conversation.py
|
1 | 4 |
| A |
src/loader/runtime/task_classification.py
|
193 | 0 |
| A |
tests/test_task_classification.py
|
29 | 0 |
src/loader/agent/reasoning.pymodified@@ -30,114 +30,11 @@ from ..runtime.reasoning_types import ( | |||
| 30 | TaskCompletionCheck, | 30 | TaskCompletionCheck, |
| 31 | TaskDecomposition, | 31 | TaskDecomposition, |
| 32 | ) | 32 | ) |
| 33 | - | 33 | +from ..runtime.task_classification import ( |
| 34 | - | 34 | + estimate_complexity, |
| 35 | -# === Query Classification === | 35 | + get_token_budget, |
| 36 | - | 36 | + is_conversational, |
| 37 | -def is_conversational(message: str) -> bool: | 37 | +) |
| 38 | - """Detect if a message is conversational rather than a task. | ||
| 39 | - | ||
| 40 | - Returns True for greetings, casual chat, simple questions about the agent. | ||
| 41 | - These don't need tool calling - just a quick response. | ||
| 42 | - """ | ||
| 43 | - msg = message.lower().strip() | ||
| 44 | - | ||
| 45 | - # Very short messages are usually conversational | ||
| 46 | - if len(msg) < 15: | ||
| 47 | - # Greetings | ||
| 48 | - greetings = [ | ||
| 49 | - "hi", "hello", "hey", "yo", "sup", "hiya", "howdy", | ||
| 50 | - "ello", "hallo", "greetings", "good morning", "good afternoon", | ||
| 51 | - "good evening", "morning", "evening", "afternoon", | ||
| 52 | - "what's up", "whats up", "wassup", "how are you", | ||
| 53 | - "how's it going", "hows it going", | ||
| 54 | - ] | ||
| 55 | - if any(msg.startswith(g) or msg == g for g in greetings): | ||
| 56 | - return True | ||
| 57 | - | ||
| 58 | - # Questions about the agent itself | ||
| 59 | - agent_questions = [ | ||
| 60 | - "who are you", "what are you", "what can you do", | ||
| 61 | - "how do you work", "what is loader", "what's loader", | ||
| 62 | - "help", "what is this", "how does this work", | ||
| 63 | - ] | ||
| 64 | - if any(q in msg for q in agent_questions): | ||
| 65 | - return True | ||
| 66 | - | ||
| 67 | - # Casual/social messages | ||
| 68 | - casual = [ | ||
| 69 | - "thanks", "thank you", "thx", "ty", | ||
| 70 | - "cool", "nice", "great", "awesome", "ok", "okay", | ||
| 71 | - "bye", "goodbye", "see you", "later", "cya", | ||
| 72 | - "lol", "haha", "hehe", "lmao", | ||
| 73 | - "please", "sorry", "oops", | ||
| 74 | - ] | ||
| 75 | - if msg in casual or any(msg == c for c in casual): | ||
| 76 | - return True | ||
| 77 | - | ||
| 78 | - # Messages that are clearly NOT conversational (tasks) | ||
| 79 | - task_indicators = [ | ||
| 80 | - "create", "make", "build", "write", "edit", "delete", "remove", | ||
| 81 | - "run", "execute", "install", "fix", "debug", "test", "check", | ||
| 82 | - "find", "search", "show", "list", "read", "open", "close", | ||
| 83 | - "add", "update", "change", "modify", "refactor", "implement", | ||
| 84 | - "file", "folder", "directory", "code", "function", "class", | ||
| 85 | - "git", "npm", "pip", "python", "node", "bash", "command", | ||
| 86 | - ] | ||
| 87 | - if any(ind in msg for ind in task_indicators): | ||
| 88 | - return False | ||
| 89 | - | ||
| 90 | - # Short messages without task indicators are likely conversational | ||
| 91 | - if len(msg) < 30 and not any(c in msg for c in [".", "/", "\\", "`"]): | ||
| 92 | - return True | ||
| 93 | - | ||
| 94 | - return False | ||
| 95 | - | ||
| 96 | - | ||
| 97 | -def estimate_complexity(message: str) -> str: | ||
| 98 | - """Estimate query complexity for token budgeting. | ||
| 99 | - | ||
| 100 | - Returns: "trivial", "simple", "moderate", or "complex" | ||
| 101 | - """ | ||
| 102 | - msg = message.lower() | ||
| 103 | - word_count = len(message.split()) | ||
| 104 | - | ||
| 105 | - # Trivial: greetings, thanks, very short | ||
| 106 | - if is_conversational(message) or word_count < 5: | ||
| 107 | - return "trivial" | ||
| 108 | - | ||
| 109 | - # Complex indicators | ||
| 110 | - complex_indicators = [ | ||
| 111 | - "project", "application", "website", "api", "database", | ||
| 112 | - "refactor", "migrate", "upgrade", "implement", "design", | ||
| 113 | - "multiple", "several", "all", "entire", "whole", | ||
| 114 | - "and then", "after that", "also", "as well", | ||
| 115 | - ] | ||
| 116 | - complex_count = sum(1 for ind in complex_indicators if ind in msg) | ||
| 117 | - | ||
| 118 | - if complex_count >= 2 or word_count > 50: | ||
| 119 | - return "complex" | ||
| 120 | - | ||
| 121 | - # Simple indicators | ||
| 122 | - simple_indicators = [ | ||
| 123 | - "what is", "how do", "show me", "list", "read", | ||
| 124 | - "single", "one", "just", "only", "quick", | ||
| 125 | - ] | ||
| 126 | - if any(ind in msg for ind in simple_indicators) and word_count < 20: | ||
| 127 | - return "simple" | ||
| 128 | - | ||
| 129 | - return "moderate" | ||
| 130 | - | ||
| 131 | - | ||
| 132 | -def get_token_budget(complexity: str) -> tuple[int, int]: | ||
| 133 | - """Get (max_tokens, context_tokens) for a complexity level.""" | ||
| 134 | - budgets = { | ||
| 135 | - "trivial": (256, 2048), # Quick response, minimal context | ||
| 136 | - "simple": (512, 4096), # Short response, some context | ||
| 137 | - "moderate": (1024, 8192), # Normal response | ||
| 138 | - "complex": (2048, 16384), # Full response, full context | ||
| 139 | - } | ||
| 140 | - return budgets.get(complexity, (1024, 8192)) | ||
| 141 | 38 | ||
| 142 | 39 | ||
| 143 | # Prompts for reasoning stages | 40 | # Prompts for reasoning stages |
src/loader/runtime/conversation.pymodified@@ -7,10 +7,6 @@ from collections.abc import Awaitable, Callable | |||
| 7 | from pathlib import Path | 7 | from pathlib import Path |
| 8 | from typing import Any | 8 | from typing import Any |
| 9 | 9 | ||
| 10 | -from ..agent.reasoning import ( | ||
| 11 | - estimate_complexity, | ||
| 12 | - get_token_budget, | ||
| 13 | -) | ||
| 14 | from ..llm.base import Message, Role, ToolCall | 10 | from ..llm.base import Message, Role, ToolCall |
| 15 | from .assistant_turns import AssistantTurnRequester | 11 | from .assistant_turns import AssistantTurnRequester |
| 16 | from .completion_policy import CompletionPolicy | 12 | from .completion_policy import CompletionPolicy |
@@ -23,6 +19,7 @@ from .hooks import build_default_tool_hooks | |||
| 23 | from .phases import TurnPhase, TurnPhaseTracker | 19 | from .phases import TurnPhase, TurnPhaseTracker |
| 24 | from .repair import ResponseRepairer | 20 | from .repair import ResponseRepairer |
| 25 | from .rollback import RollbackPlan | 21 | from .rollback import RollbackPlan |
| 22 | +from .task_classification import estimate_complexity, get_token_budget | ||
| 26 | from .tool_batches import ToolBatchRunner | 23 | from .tool_batches import ToolBatchRunner |
| 27 | from .tracing import RuntimeTracer | 24 | from .tracing import RuntimeTracer |
| 28 | from .workflow import ( | 25 | from .workflow import ( |
src/loader/runtime/task_classification.pyadded@@ -0,0 +1,193 @@ | |||
| 1 | +"""Runtime-owned task classification helpers.""" | ||
| 2 | + | ||
| 3 | +from __future__ import annotations | ||
| 4 | + | ||
| 5 | + | ||
| 6 | +def is_conversational(message: str) -> bool: | ||
| 7 | + """Detect if a message is conversational rather than a task.""" | ||
| 8 | + | ||
| 9 | + msg = message.lower().strip() | ||
| 10 | + | ||
| 11 | + if len(msg) < 15: | ||
| 12 | + greetings = [ | ||
| 13 | + "hi", | ||
| 14 | + "hello", | ||
| 15 | + "hey", | ||
| 16 | + "yo", | ||
| 17 | + "sup", | ||
| 18 | + "hiya", | ||
| 19 | + "howdy", | ||
| 20 | + "ello", | ||
| 21 | + "hallo", | ||
| 22 | + "greetings", | ||
| 23 | + "good morning", | ||
| 24 | + "good afternoon", | ||
| 25 | + "good evening", | ||
| 26 | + "morning", | ||
| 27 | + "evening", | ||
| 28 | + "afternoon", | ||
| 29 | + "what's up", | ||
| 30 | + "whats up", | ||
| 31 | + "wassup", | ||
| 32 | + "how are you", | ||
| 33 | + "how's it going", | ||
| 34 | + "hows it going", | ||
| 35 | + ] | ||
| 36 | + if any(msg.startswith(greeting) or msg == greeting for greeting in greetings): | ||
| 37 | + return True | ||
| 38 | + | ||
| 39 | + agent_questions = [ | ||
| 40 | + "who are you", | ||
| 41 | + "what are you", | ||
| 42 | + "what can you do", | ||
| 43 | + "how do you work", | ||
| 44 | + "what is loader", | ||
| 45 | + "what's loader", | ||
| 46 | + "help", | ||
| 47 | + "what is this", | ||
| 48 | + "how does this work", | ||
| 49 | + ] | ||
| 50 | + if any(question in msg for question in agent_questions): | ||
| 51 | + return True | ||
| 52 | + | ||
| 53 | + casual = [ | ||
| 54 | + "thanks", | ||
| 55 | + "thank you", | ||
| 56 | + "thx", | ||
| 57 | + "ty", | ||
| 58 | + "cool", | ||
| 59 | + "nice", | ||
| 60 | + "great", | ||
| 61 | + "awesome", | ||
| 62 | + "ok", | ||
| 63 | + "okay", | ||
| 64 | + "bye", | ||
| 65 | + "goodbye", | ||
| 66 | + "see you", | ||
| 67 | + "later", | ||
| 68 | + "cya", | ||
| 69 | + "lol", | ||
| 70 | + "haha", | ||
| 71 | + "hehe", | ||
| 72 | + "lmao", | ||
| 73 | + "please", | ||
| 74 | + "sorry", | ||
| 75 | + "oops", | ||
| 76 | + ] | ||
| 77 | + if msg in casual or any(msg == item for item in casual): | ||
| 78 | + return True | ||
| 79 | + | ||
| 80 | + task_indicators = [ | ||
| 81 | + "create", | ||
| 82 | + "make", | ||
| 83 | + "build", | ||
| 84 | + "write", | ||
| 85 | + "edit", | ||
| 86 | + "delete", | ||
| 87 | + "remove", | ||
| 88 | + "run", | ||
| 89 | + "execute", | ||
| 90 | + "install", | ||
| 91 | + "fix", | ||
| 92 | + "debug", | ||
| 93 | + "test", | ||
| 94 | + "check", | ||
| 95 | + "find", | ||
| 96 | + "search", | ||
| 97 | + "show", | ||
| 98 | + "list", | ||
| 99 | + "read", | ||
| 100 | + "open", | ||
| 101 | + "close", | ||
| 102 | + "add", | ||
| 103 | + "update", | ||
| 104 | + "change", | ||
| 105 | + "modify", | ||
| 106 | + "refactor", | ||
| 107 | + "implement", | ||
| 108 | + "file", | ||
| 109 | + "folder", | ||
| 110 | + "directory", | ||
| 111 | + "code", | ||
| 112 | + "function", | ||
| 113 | + "class", | ||
| 114 | + "git", | ||
| 115 | + "npm", | ||
| 116 | + "pip", | ||
| 117 | + "python", | ||
| 118 | + "node", | ||
| 119 | + "bash", | ||
| 120 | + "command", | ||
| 121 | + ] | ||
| 122 | + if any(indicator in msg for indicator in task_indicators): | ||
| 123 | + return False | ||
| 124 | + | ||
| 125 | + if len(msg) < 30 and not any(char in msg for char in [".", "/", "\\", "`"]): | ||
| 126 | + return True | ||
| 127 | + | ||
| 128 | + return False | ||
| 129 | + | ||
| 130 | + | ||
| 131 | +def estimate_complexity(message: str) -> str: | ||
| 132 | + """Estimate query complexity for token budgeting.""" | ||
| 133 | + | ||
| 134 | + msg = message.lower() | ||
| 135 | + word_count = len(message.split()) | ||
| 136 | + | ||
| 137 | + if is_conversational(message) or word_count < 5: | ||
| 138 | + return "trivial" | ||
| 139 | + | ||
| 140 | + complex_indicators = [ | ||
| 141 | + "project", | ||
| 142 | + "application", | ||
| 143 | + "website", | ||
| 144 | + "api", | ||
| 145 | + "database", | ||
| 146 | + "refactor", | ||
| 147 | + "migrate", | ||
| 148 | + "upgrade", | ||
| 149 | + "implement", | ||
| 150 | + "design", | ||
| 151 | + "multiple", | ||
| 152 | + "several", | ||
| 153 | + "all", | ||
| 154 | + "entire", | ||
| 155 | + "whole", | ||
| 156 | + "and then", | ||
| 157 | + "after that", | ||
| 158 | + "also", | ||
| 159 | + "as well", | ||
| 160 | + ] | ||
| 161 | + complex_count = sum(1 for indicator in complex_indicators if indicator in msg) | ||
| 162 | + | ||
| 163 | + if complex_count >= 2 or word_count > 50: | ||
| 164 | + return "complex" | ||
| 165 | + | ||
| 166 | + simple_indicators = [ | ||
| 167 | + "what is", | ||
| 168 | + "how do", | ||
| 169 | + "show me", | ||
| 170 | + "list", | ||
| 171 | + "read", | ||
| 172 | + "single", | ||
| 173 | + "one", | ||
| 174 | + "just", | ||
| 175 | + "only", | ||
| 176 | + "quick", | ||
| 177 | + ] | ||
| 178 | + if any(indicator in msg for indicator in simple_indicators) and word_count < 20: | ||
| 179 | + return "simple" | ||
| 180 | + | ||
| 181 | + return "moderate" | ||
| 182 | + | ||
| 183 | + | ||
| 184 | +def get_token_budget(complexity: str) -> tuple[int, int]: | ||
| 185 | + """Get `(max_tokens, context_tokens)` for a complexity level.""" | ||
| 186 | + | ||
| 187 | + budgets = { | ||
| 188 | + "trivial": (256, 2048), | ||
| 189 | + "simple": (512, 4096), | ||
| 190 | + "moderate": (1024, 8192), | ||
| 191 | + "complex": (2048, 16384), | ||
| 192 | + } | ||
| 193 | + return budgets.get(complexity, (1024, 8192)) | ||
tests/test_task_classification.pyadded@@ -0,0 +1,29 @@ | |||
| 1 | +"""Tests for runtime-owned task classification helpers.""" | ||
| 2 | + | ||
| 3 | +from __future__ import annotations | ||
| 4 | + | ||
| 5 | +from loader.runtime.task_classification import ( | ||
| 6 | + estimate_complexity, | ||
| 7 | + get_token_budget, | ||
| 8 | + is_conversational, | ||
| 9 | +) | ||
| 10 | + | ||
| 11 | + | ||
| 12 | +def test_is_conversational_detects_small_talk() -> None: | ||
| 13 | + assert is_conversational("hi there") is True | ||
| 14 | + assert is_conversational("thanks") is True | ||
| 15 | + assert is_conversational("what can you do?") is True | ||
| 16 | + | ||
| 17 | + | ||
| 18 | +def test_is_conversational_rejects_actionable_tasks() -> None: | ||
| 19 | + assert is_conversational("create a new README for this repo") is False | ||
| 20 | + assert is_conversational("run the tests and fix failures") is False | ||
| 21 | + | ||
| 22 | + | ||
| 23 | +def test_estimate_complexity_and_token_budget_cover_common_paths() -> None: | ||
| 24 | + assert estimate_complexity("hi") == "trivial" | ||
| 25 | + assert estimate_complexity("show me the config file") == "simple" | ||
| 26 | + assert estimate_complexity( | ||
| 27 | + "implement a website and then refactor the database layer" | ||
| 28 | + ) == "complex" | ||
| 29 | + assert get_token_budget("simple") == (512, 4096) | ||