| 1 | package llm |
| 2 | |
| 3 | import ( |
| 4 | "bytes" |
| 5 | "context" |
| 6 | "encoding/json" |
| 7 | "fmt" |
| 8 | "net/http" |
| 9 | "net/url" |
| 10 | "time" |
| 11 | ) |
| 12 | |
| 13 | type OllamaClient struct { |
| 14 | BaseURL string |
| 15 | Model string |
| 16 | Mode string // "snappy" (fast) or "spicy" (quality) |
| 17 | client *http.Client |
| 18 | } |
| 19 | |
| 20 | type GenerateRequest struct { |
| 21 | Model string `json:"model"` |
| 22 | Prompt string `json:"prompt"` |
| 23 | Stream bool `json:"stream"` |
| 24 | KeepAlive string `json:"keep_alive,omitempty"` |
| 25 | Options *GenerateOptions `json:"options,omitempty"` |
| 26 | } |
| 27 | |
| 28 | // GenerateOptions controls Ollama generation behavior for speed optimization |
| 29 | type GenerateOptions struct { |
| 30 | NumPredict int `json:"num_predict,omitempty"` // Max tokens to generate (60 is plenty for insults) |
| 31 | NumCtx int `json:"num_ctx,omitempty"` // Context window size (512 is enough for small prompts) |
| 32 | Temperature float64 `json:"temperature,omitempty"` // Creativity (0.8 for variety) |
| 33 | } |
| 34 | |
| 35 | type GenerateResponse struct { |
| 36 | Response string `json:"response"` |
| 37 | Done bool `json:"done"` |
| 38 | } |
| 39 | |
| 40 | func NewOllamaClient(baseURL, model string) *OllamaClient { |
| 41 | if baseURL == "" { |
| 42 | baseURL = "http://127.0.0.1:11434" // Use IPv4 explicitly to avoid IPv6 issues |
| 43 | } |
| 44 | if model == "" { |
| 45 | model = "llama3.2:3b" |
| 46 | } |
| 47 | |
| 48 | return &OllamaClient{ |
| 49 | BaseURL: baseURL, |
| 50 | Model: model, |
| 51 | Mode: "snappy", // Default to fast mode |
| 52 | client: &http.Client{ |
| 53 | Timeout: 60 * time.Second, // Maximum timeout; actual timeout controlled by context |
| 54 | }, |
| 55 | } |
| 56 | } |
| 57 | |
| 58 | // SetMode sets the generation mode ("snappy" for speed, "spicy" for quality) |
| 59 | func (c *OllamaClient) SetMode(mode string) { |
| 60 | if mode == "spicy" || mode == "snappy" { |
| 61 | c.Mode = mode |
| 62 | } |
| 63 | } |
| 64 | |
| 65 | // getOptionsForMode returns optimized generation options based on mode |
| 66 | func (c *OllamaClient) getOptionsForMode() *GenerateOptions { |
| 67 | if c.Mode == "spicy" { |
| 68 | // Spicy mode: richer responses, more creative, willing to wait |
| 69 | return &GenerateOptions{ |
| 70 | NumPredict: 80, // Longer responses |
| 71 | NumCtx: 1024, // Rich context window |
| 72 | Temperature: 0.85, // More creative |
| 73 | } |
| 74 | } |
| 75 | // Snappy mode (default): fast and punchy |
| 76 | return &GenerateOptions{ |
| 77 | NumPredict: 40, // Short and punchy |
| 78 | NumCtx: 256, // Minimal context |
| 79 | Temperature: 0.6, // Faster convergence |
| 80 | } |
| 81 | } |
| 82 | |
| 83 | func (c *OllamaClient) Generate(ctx context.Context, prompt string) (string, error) { |
| 84 | u, err := url.JoinPath(c.BaseURL, "/api/generate") |
| 85 | if err != nil { |
| 86 | return "", fmt.Errorf("invalid base URL: %w", err) |
| 87 | } |
| 88 | |
| 89 | // Use mode-specific generation options |
| 90 | keepAlive := "5m" |
| 91 | if c.Mode == "spicy" { |
| 92 | keepAlive = "15m" // Keep model warm longer for quality mode |
| 93 | } |
| 94 | |
| 95 | req := GenerateRequest{ |
| 96 | Model: c.Model, |
| 97 | Prompt: prompt, |
| 98 | Stream: false, |
| 99 | KeepAlive: keepAlive, |
| 100 | Options: c.getOptionsForMode(), |
| 101 | } |
| 102 | |
| 103 | reqBody, err := json.Marshal(req) |
| 104 | if err != nil { |
| 105 | return "", fmt.Errorf("failed to marshal request: %w", err) |
| 106 | } |
| 107 | |
| 108 | httpReq, err := http.NewRequestWithContext(ctx, "POST", u, bytes.NewReader(reqBody)) |
| 109 | if err != nil { |
| 110 | return "", fmt.Errorf("failed to create request: %w", err) |
| 111 | } |
| 112 | httpReq.Header.Set("Content-Type", "application/json") |
| 113 | |
| 114 | resp, err := c.client.Do(httpReq) |
| 115 | if err != nil { |
| 116 | return "", fmt.Errorf("failed to send request to %s: %w", c.BaseURL, err) |
| 117 | } |
| 118 | defer resp.Body.Close() |
| 119 | |
| 120 | if resp.StatusCode != http.StatusOK { |
| 121 | // Read the actual error response body |
| 122 | bodyBytes := make([]byte, 512) // Read first 512 bytes for error message |
| 123 | n, _ := resp.Body.Read(bodyBytes) |
| 124 | bodyStr := string(bodyBytes[:n]) |
| 125 | return "", fmt.Errorf("ollama API returned status %d: %s", resp.StatusCode, bodyStr) |
| 126 | } |
| 127 | |
| 128 | var genResp GenerateResponse |
| 129 | if err := json.NewDecoder(resp.Body).Decode(&genResp); err != nil { |
| 130 | return "", fmt.Errorf("failed to decode response: %w", err) |
| 131 | } |
| 132 | |
| 133 | return genResp.Response, nil |
| 134 | } |
| 135 | |
| 136 | func (c *OllamaClient) IsAvailable() bool { |
| 137 | u, err := url.JoinPath(c.BaseURL, "/api/version") |
| 138 | if err != nil { |
| 139 | return false |
| 140 | } |
| 141 | |
| 142 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) |
| 143 | defer cancel() |
| 144 | |
| 145 | req, err := http.NewRequestWithContext(ctx, "GET", u, nil) |
| 146 | if err != nil { |
| 147 | return false |
| 148 | } |
| 149 | |
| 150 | resp, err := c.client.Do(req) |
| 151 | if err != nil { |
| 152 | return false |
| 153 | } |
| 154 | defer resp.Body.Close() |
| 155 | |
| 156 | return resp.StatusCode == http.StatusOK |
| 157 | } |
| 158 | |
| 159 | // WarmupModel preloads the model to avoid cold start delays |
| 160 | func (c *OllamaClient) WarmupModel() error { |
| 161 | u, err := url.JoinPath(c.BaseURL, "/api/generate") |
| 162 | if err != nil { |
| 163 | return fmt.Errorf("invalid base URL: %w", err) |
| 164 | } |
| 165 | |
| 166 | req := GenerateRequest{ |
| 167 | Model: c.Model, |
| 168 | Prompt: "Say OK", // Minimal prompt to load model |
| 169 | Stream: false, |
| 170 | KeepAlive: "10m", |
| 171 | Options: &GenerateOptions{ |
| 172 | NumPredict: 5, // Minimal output for warmup |
| 173 | NumCtx: 256, // Minimal context |
| 174 | }, |
| 175 | } |
| 176 | |
| 177 | reqBody, err := json.Marshal(req) |
| 178 | if err != nil { |
| 179 | return fmt.Errorf("failed to marshal request: %w", err) |
| 180 | } |
| 181 | |
| 182 | // Use a longer timeout for initial model loading |
| 183 | ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second) |
| 184 | defer cancel() |
| 185 | |
| 186 | httpReq, err := http.NewRequestWithContext(ctx, "POST", u, bytes.NewReader(reqBody)) |
| 187 | if err != nil { |
| 188 | return fmt.Errorf("failed to create request: %w", err) |
| 189 | } |
| 190 | httpReq.Header.Set("Content-Type", "application/json") |
| 191 | |
| 192 | resp, err := c.client.Do(httpReq) |
| 193 | if err != nil { |
| 194 | return fmt.Errorf("failed to warmup model: %w", err) |
| 195 | } |
| 196 | defer resp.Body.Close() |
| 197 | |
| 198 | return nil |
| 199 | } |