Margin is an open annotation layer for the internet. Powered by the AT Protocol. margin.at
extension web atproto comments
at main 220 lines 4.9 kB view raw
1package embeddings 2 3import ( 4 "bytes" 5 "encoding/json" 6 "fmt" 7 "io" 8 "net/http" 9 "os" 10 "strings" 11 "sync" 12 "time" 13 14 "margin.at/internal/logger" 15) 16 17const ( 18 Model = "text-embedding-3-small" 19 Dimensions = 1536 20 MaxTokens = 8191 21 MaxInputChars = 8000 22 BatchSize = 64 23 openAIEndpoint = "https://api.openai.com/v1/embeddings" 24) 25 26type Client struct { 27 apiKey string 28 httpClient *http.Client 29 mu sync.Mutex 30} 31 32type embeddingRequest struct { 33 Model string `json:"model"` 34 Input []string `json:"input"` 35 Dimensions int `json:"dimensions,omitempty"` 36} 37 38type embeddingResponse struct { 39 Data []embeddingData `json:"data"` 40 Usage struct { 41 TotalTokens int `json:"total_tokens"` 42 } `json:"usage"` 43 Error *struct { 44 Message string `json:"message"` 45 } `json:"error,omitempty"` 46} 47 48type embeddingData struct { 49 Index int `json:"index"` 50 Embedding []float32 `json:"embedding"` 51} 52 53func NewClient() *Client { 54 apiKey := os.Getenv("OPENAI_API_KEY") 55 if apiKey == "" { 56 logger.Info("OPENAI_API_KEY not set — embedding generation will be disabled") 57 } 58 return &Client{ 59 apiKey: apiKey, 60 httpClient: &http.Client{ 61 Timeout: 30 * time.Second, 62 }, 63 } 64} 65 66func (c *Client) IsEnabled() bool { 67 return c.apiKey != "" 68} 69 70func (c *Client) Embed(text string) ([]float32, error) { 71 results, err := c.EmbedBatch([]string{text}) 72 if err != nil { 73 return nil, err 74 } 75 if len(results) == 0 { 76 return nil, fmt.Errorf("empty embedding response") 77 } 78 return results[0], nil 79} 80 81func (c *Client) EmbedBatch(texts []string) ([][]float32, error) { 82 if !c.IsEnabled() { 83 return nil, fmt.Errorf("OpenAI API key not configured") 84 } 85 86 truncated := make([]string, len(texts)) 87 for i, t := range texts { 88 t = truncateText(t, MaxInputChars) 89 if strings.TrimSpace(t) == "" { 90 t = " " 91 } 92 truncated[i] = t 93 } 94 95 results := make([][]float32, len(texts)) 96 97 for start := 0; start < len(truncated); start += BatchSize { 98 end := start + BatchSize 99 if end > len(truncated) { 100 end = len(truncated) 101 } 102 batch := truncated[start:end] 103 104 embeddings, err := c.callAPI(batch) 105 if err != nil { 106 return nil, fmt.Errorf("embedding batch %d-%d failed: %w", start, end, err) 107 } 108 109 for _, emb := range embeddings { 110 idx := start + emb.Index 111 if idx < len(results) { 112 results[idx] = emb.Embedding 113 } 114 } 115 } 116 117 return results, nil 118} 119 120func (c *Client) callAPI(inputs []string) ([]embeddingData, error) { 121 reqBody := embeddingRequest{ 122 Model: Model, 123 Input: inputs, 124 } 125 126 body, err := json.Marshal(reqBody) 127 if err != nil { 128 return nil, fmt.Errorf("marshal request: %w", err) 129 } 130 131 req, err := http.NewRequest("POST", openAIEndpoint, bytes.NewReader(body)) 132 if err != nil { 133 return nil, fmt.Errorf("create request: %w", err) 134 } 135 req.Header.Set("Content-Type", "application/json") 136 req.Header.Set("Authorization", "Bearer "+c.apiKey) 137 138 resp, err := c.httpClient.Do(req) 139 if err != nil { 140 return nil, fmt.Errorf("API request: %w", err) 141 } 142 defer resp.Body.Close() 143 144 respBody, err := io.ReadAll(resp.Body) 145 if err != nil { 146 return nil, fmt.Errorf("read response: %w", err) 147 } 148 149 if resp.StatusCode != http.StatusOK { 150 return nil, fmt.Errorf("API returned %d: %s", resp.StatusCode, string(respBody)) 151 } 152 153 var result embeddingResponse 154 if err := json.Unmarshal(respBody, &result); err != nil { 155 return nil, fmt.Errorf("unmarshal response: %w", err) 156 } 157 158 if result.Error != nil { 159 return nil, fmt.Errorf("API error: %s", result.Error.Message) 160 } 161 162 return result.Data, nil 163} 164 165func truncateText(text string, maxChars int) string { 166 if len(text) <= maxChars { 167 return text 168 } 169 return text[:maxChars] 170} 171 172func BuildAnnotationText(bodyValue, selectorJSON, targetTitle, tagsJSON *string) string { 173 var parts []string 174 175 if selectorJSON != nil && *selectorJSON != "" { 176 var selector struct { 177 Exact string `json:"exact"` 178 Prefix string `json:"prefix"` 179 Suffix string `json:"suffix"` 180 } 181 if err := json.Unmarshal([]byte(*selectorJSON), &selector); err == nil && selector.Exact != "" { 182 parts = append(parts, selector.Exact) 183 } 184 } 185 186 if bodyValue != nil && *bodyValue != "" { 187 parts = append(parts, *bodyValue) 188 } 189 190 if targetTitle != nil && *targetTitle != "" { 191 parts = append(parts, *targetTitle) 192 } 193 194 if tagsJSON != nil && *tagsJSON != "" { 195 var tags []string 196 if err := json.Unmarshal([]byte(*tagsJSON), &tags); err == nil && len(tags) > 0 { 197 parts = append(parts, strings.Join(tags, ", ")) 198 } 199 } 200 201 return strings.Join(parts, " | ") 202} 203 204func BuildDocumentText(title, description, textContent string, tags []string) string { 205 var parts []string 206 207 parts = append(parts, title) 208 209 if len(tags) > 0 { 210 parts = append(parts, strings.Join(tags, ", ")) 211 } 212 213 if textContent != "" { 214 parts = append(parts, textContent) 215 } else if description != "" { 216 parts = append(parts, description) 217 } 218 219 return strings.Join(parts, " | ") 220}