Go · 9894 bytes Raw Blame History
1 package llm
2
3 import (
4 "context"
5 "fmt"
6 "strings"
7 "time"
8
9 "parrot/internal/config"
10 )
11
12 type LLMManager struct {
13 config *config.Config
14 apiClient *APIClient
15 ollamaClient *OllamaClient
16 cache *ResponseCache
17 }
18
19 type Backend string
20
21 const (
22 BackendAPI Backend = "api"
23 BackendLocal Backend = "local"
24 BackendFallback Backend = "fallback"
25 )
26
27 // getLocalTimeout returns the appropriate timeout based on generation mode
28 func getLocalTimeout(cfg *config.Config) time.Duration {
29 if cfg.General.GenerationMode == "spicy" {
30 return 5 * time.Second // Patient timeout for quality mode
31 }
32 return 3 * time.Second // Snappy timeout (raw Ollama ~1.4s, needs headroom)
33 }
34
35 func NewLLMManager(cfg *config.Config) *LLMManager {
36 manager := &LLMManager{
37 config: cfg,
38 cache: GetResponseCache(),
39 }
40
41 // Initialize API client if enabled
42 if cfg.API.Enabled && cfg.API.APIKey != "" {
43 manager.apiClient = NewAPIClient(
44 cfg.API.Endpoint,
45 cfg.API.APIKey,
46 cfg.API.Model,
47 cfg.API.Timeout,
48 )
49 }
50
51 // Initialize Ollama client if enabled
52 if cfg.Local.Enabled {
53 manager.ollamaClient = NewOllamaClient(
54 cfg.Local.Endpoint,
55 cfg.Local.Model,
56 )
57
58 // Set generation mode (snappy = fast, spicy = quality)
59 manager.ollamaClient.SetMode(cfg.General.GenerationMode)
60
61 // Warm up the model in the background for better performance
62 if manager.ollamaClient.IsAvailable() {
63 go func() {
64 if err := manager.ollamaClient.WarmupModel(); err != nil && cfg.General.Debug {
65 fmt.Printf("🔥 Model warmup failed: %v\n", err)
66 } else if cfg.General.Debug {
67 fmt.Printf("🔥 Model warmed up successfully\n")
68 }
69 }()
70 }
71 }
72
73 return manager
74 }
75
76 func (m *LLMManager) Generate(ctx context.Context, prompt string, commandType string) (string, Backend) {
77 // If fallback mode is enabled, skip LLM backends
78 if m.config.General.FallbackMode {
79 return m.generateFallback(commandType, "", ""), BackendFallback
80 }
81
82 // Try backends in priority order: API -> Local -> Fallback
83
84 // 1. Try API first (if available)
85 if m.apiClient != nil && m.config.API.Enabled {
86 if m.config.General.Debug {
87 fmt.Printf("🔍 Trying API backend...\n")
88 }
89
90 // Create timeout context for API calls
91 timeoutDuration := time.Duration(m.config.API.Timeout) * time.Second
92 apiCtx, cancel := context.WithTimeout(ctx, timeoutDuration)
93 defer cancel()
94
95 response, err := m.apiClient.Generate(apiCtx, prompt)
96 if err == nil && response != "" {
97 response = m.cleanResponse(response)
98 if m.config.General.Debug {
99 fmt.Printf("✅ API backend succeeded\n")
100 }
101 return response, BackendAPI
102 }
103
104 if m.config.General.Debug {
105 fmt.Printf("❌ API backend failed: %v\n", err)
106 }
107 }
108
109 // 2. Try local Ollama (if available)
110 if m.ollamaClient != nil && m.config.Local.Enabled {
111 if m.config.General.Debug {
112 fmt.Printf("🔍 Trying local backend (%s mode)...\n", m.config.General.GenerationMode)
113 }
114
115 // Create timeout context based on generation mode
116 timeoutDuration := getLocalTimeout(m.config)
117 localCtx, cancel := context.WithTimeout(ctx, timeoutDuration)
118 defer cancel()
119
120 response, err := m.ollamaClient.Generate(localCtx, prompt)
121 if m.config.General.Debug {
122 fmt.Printf("🐛 Raw Ollama response: '%s', error: %v\n", response, err)
123 }
124 if err == nil && response != "" {
125 response = m.cleanResponse(response)
126 if m.config.General.Debug {
127 fmt.Printf("✅ Local backend succeeded with: '%s'\n", response)
128 }
129 return response, BackendLocal
130 }
131
132 if m.config.General.Debug {
133 fmt.Printf("❌ Local backend failed: %v\n", err)
134 }
135 }
136
137 // 3. Fallback to hardcoded responses
138 if m.config.General.Debug {
139 fmt.Printf("🔄 Using fallback backend\n")
140 }
141 return m.generateFallback(commandType, "", ""), BackendFallback
142 }
143
144 // GenerateWithContext generates a response with full context for intelligent fallbacks
145 func (m *LLMManager) GenerateWithContext(ctx context.Context, prompt string, commandType string, fullCommand string, exitCode string) (string, Backend) {
146 // If fallback mode is enabled, skip LLM backends
147 if m.config.General.FallbackMode {
148 return m.generateFallback(commandType, fullCommand, exitCode), BackendFallback
149 }
150
151 // Check cache first for repeated failures
152 if m.cache != nil {
153 if cached, found := m.cache.Get(fullCommand, commandType, exitCode, m.config.General.GenerationMode); found {
154 if m.config.General.Debug {
155 fmt.Printf("⚡ Cache hit!\n")
156 }
157 return cached, BackendLocal // Treat cache as local backend
158 }
159 }
160
161 // Try backends in priority order: API -> Local -> Fallback
162
163 // 1. Try API first (if available)
164 if m.apiClient != nil && m.config.API.Enabled {
165 if m.config.General.Debug {
166 fmt.Printf("🔍 Trying API backend...\n")
167 }
168
169 // Create timeout context for API calls
170 timeoutDuration := time.Duration(m.config.API.Timeout) * time.Second
171 apiCtx, cancel := context.WithTimeout(ctx, timeoutDuration)
172 defer cancel()
173
174 response, err := m.apiClient.Generate(apiCtx, prompt)
175 if err == nil && response != "" {
176 response = m.cleanResponse(response)
177 if m.config.General.Debug {
178 fmt.Printf("✅ API backend succeeded\n")
179 }
180 // Cache successful response
181 if m.cache != nil {
182 m.cache.Set(fullCommand, commandType, exitCode, m.config.General.GenerationMode, response)
183 }
184 return response, BackendAPI
185 }
186
187 if m.config.General.Debug {
188 fmt.Printf("❌ API backend failed: %v\n", err)
189 }
190 }
191
192 // 2. Try local Ollama (if available)
193 if m.ollamaClient != nil && m.config.Local.Enabled {
194 if m.config.General.Debug {
195 fmt.Printf("🔍 Trying local backend (%s mode)...\n", m.config.General.GenerationMode)
196 }
197
198 // Create timeout context based on generation mode
199 timeoutDuration := getLocalTimeout(m.config)
200 localCtx, cancel := context.WithTimeout(ctx, timeoutDuration)
201 defer cancel()
202
203 response, err := m.ollamaClient.Generate(localCtx, prompt)
204 if m.config.General.Debug {
205 fmt.Printf("🐛 Raw Ollama response: '%s', error: %v\n", response, err)
206 }
207 if err == nil && response != "" {
208 response = m.cleanResponse(response)
209 if m.config.General.Debug {
210 fmt.Printf("✅ Local backend succeeded with: '%s'\n", response)
211 }
212 // Cache successful response
213 if m.cache != nil {
214 m.cache.Set(fullCommand, commandType, exitCode, m.config.General.GenerationMode, response)
215 }
216 return response, BackendLocal
217 }
218
219 if m.config.General.Debug {
220 fmt.Printf("❌ Local backend failed: %v\n", err)
221 }
222 }
223
224 // 3. Fallback to smart context-aware responses
225 if m.config.General.Debug {
226 fmt.Printf("🔄 Using smart fallback backend\n")
227 }
228 return m.generateFallback(commandType, fullCommand, exitCode), BackendFallback
229 }
230
231 func (m *LLMManager) cleanResponse(response string) string {
232 // Clean up the response
233 response = strings.TrimSpace(response)
234
235 // Split at newlines and only keep the first meaningful part
236 lines := strings.Split(response, "\n")
237 if len(lines) > 1 {
238 // Keep only the first line, discard any commentary after newlines
239 response = strings.TrimSpace(lines[0])
240 }
241
242 // Remove common prefixes from LLMs
243 prefixes := []string{
244 "Response:",
245 "Parrot says:",
246 "🦜",
247 }
248
249 for _, prefix := range prefixes {
250 if strings.HasPrefix(response, prefix) {
251 response = strings.TrimSpace(response[len(prefix):])
252 }
253 }
254
255 // Remove character count annotations like "(97 characters)"
256 if idx := strings.Index(response, " ("); idx != -1 {
257 remaining := response[idx:]
258 if strings.Contains(remaining, "character") && strings.Contains(remaining, ")") {
259 response = strings.TrimSpace(response[:idx])
260 }
261 }
262
263 // Remove "Note:" annotations and similar commentary
264 if idx := strings.Index(response, "Note:"); idx != -1 {
265 response = strings.TrimSpace(response[:idx])
266 }
267 if idx := strings.Index(response, " *"); idx != -1 {
268 // Remove asterisk annotations like "* This is a note"
269 remaining := response[idx:]
270 if strings.HasPrefix(strings.TrimSpace(remaining), "* ") {
271 response = strings.TrimSpace(response[:idx])
272 }
273 }
274
275 // Remove quotes if the entire response is quoted
276 if len(response) >= 2 && response[0] == '"' && response[len(response)-1] == '"' {
277 response = response[1 : len(response)-1]
278 }
279
280 // Ensure response isn't too long (keep it snappy)
281 if len(response) > 150 {
282 // Try to cut at sentence boundary
283 if idx := strings.LastIndex(response[:150], "."); idx > 50 {
284 response = response[:idx+1]
285 } else {
286 response = response[:147] + "..."
287 }
288 }
289
290 return strings.TrimSpace(response)
291 }
292
293 func (m *LLMManager) generateFallback(commandType string, fullCommand string, exitCode string) string {
294 // Use smart context-aware fallback when we have context
295 if fullCommand != "" || exitCode != "" {
296 ctx := ParseCommandContext(fullCommand, commandType, exitCode)
297 insult := GenerateSmartFallback(ctx)
298 if insult != "" {
299 return insult
300 }
301 }
302
303 // Fall back to expanded database if no smart match
304 return GetExpandedFallback(commandType, fullCommand)
305 }
306
307 func (m *LLMManager) GetStatus() map[string]interface{} {
308 status := map[string]interface{}{
309 "fallback_mode": m.config.General.FallbackMode,
310 "debug": m.config.General.Debug,
311 "personality": m.config.General.Personality,
312 }
313
314 // Check API status
315 if m.apiClient != nil && m.config.API.Enabled {
316 status["api_enabled"] = true
317 status["api_provider"] = m.config.API.Provider
318 status["api_model"] = m.config.API.Model
319 status["api_available"] = m.apiClient.IsAvailable()
320 } else {
321 status["api_enabled"] = false
322 status["api_available"] = false
323 }
324
325 // Check local status
326 if m.ollamaClient != nil && m.config.Local.Enabled {
327 status["local_enabled"] = true
328 status["local_provider"] = m.config.Local.Provider
329 status["local_model"] = m.config.Local.Model
330 status["local_available"] = m.ollamaClient.IsAvailable()
331 } else {
332 status["local_enabled"] = false
333 status["local_available"] = false
334 }
335
336 return status
337 }