Margin is an open annotation layer for the internet. Powered by the AT Protocol.
margin.at
extension
web
atproto
comments
1package embeddings
2
3import (
4 "bytes"
5 "encoding/json"
6 "fmt"
7 "io"
8 "net/http"
9 "os"
10 "strings"
11 "sync"
12 "time"
13
14 "margin.at/internal/logger"
15)
16
17const (
18 Model = "text-embedding-3-small"
19 Dimensions = 1536
20 MaxTokens = 8191
21 MaxInputChars = 8000
22 BatchSize = 64
23 openAIEndpoint = "https://api.openai.com/v1/embeddings"
24)
25
26type Client struct {
27 apiKey string
28 httpClient *http.Client
29 mu sync.Mutex
30}
31
32type embeddingRequest struct {
33 Model string `json:"model"`
34 Input []string `json:"input"`
35 Dimensions int `json:"dimensions,omitempty"`
36}
37
38type embeddingResponse struct {
39 Data []embeddingData `json:"data"`
40 Usage struct {
41 TotalTokens int `json:"total_tokens"`
42 } `json:"usage"`
43 Error *struct {
44 Message string `json:"message"`
45 } `json:"error,omitempty"`
46}
47
48type embeddingData struct {
49 Index int `json:"index"`
50 Embedding []float32 `json:"embedding"`
51}
52
53func NewClient() *Client {
54 apiKey := os.Getenv("OPENAI_API_KEY")
55 if apiKey == "" {
56 logger.Info("OPENAI_API_KEY not set — embedding generation will be disabled")
57 }
58 return &Client{
59 apiKey: apiKey,
60 httpClient: &http.Client{
61 Timeout: 30 * time.Second,
62 },
63 }
64}
65
66func (c *Client) IsEnabled() bool {
67 return c.apiKey != ""
68}
69
70func (c *Client) Embed(text string) ([]float32, error) {
71 results, err := c.EmbedBatch([]string{text})
72 if err != nil {
73 return nil, err
74 }
75 if len(results) == 0 {
76 return nil, fmt.Errorf("empty embedding response")
77 }
78 return results[0], nil
79}
80
81func (c *Client) EmbedBatch(texts []string) ([][]float32, error) {
82 if !c.IsEnabled() {
83 return nil, fmt.Errorf("OpenAI API key not configured")
84 }
85
86 truncated := make([]string, len(texts))
87 for i, t := range texts {
88 t = truncateText(t, MaxInputChars)
89 if strings.TrimSpace(t) == "" {
90 t = " "
91 }
92 truncated[i] = t
93 }
94
95 results := make([][]float32, len(texts))
96
97 for start := 0; start < len(truncated); start += BatchSize {
98 end := start + BatchSize
99 if end > len(truncated) {
100 end = len(truncated)
101 }
102 batch := truncated[start:end]
103
104 embeddings, err := c.callAPI(batch)
105 if err != nil {
106 return nil, fmt.Errorf("embedding batch %d-%d failed: %w", start, end, err)
107 }
108
109 for _, emb := range embeddings {
110 idx := start + emb.Index
111 if idx < len(results) {
112 results[idx] = emb.Embedding
113 }
114 }
115 }
116
117 return results, nil
118}
119
120func (c *Client) callAPI(inputs []string) ([]embeddingData, error) {
121 reqBody := embeddingRequest{
122 Model: Model,
123 Input: inputs,
124 }
125
126 body, err := json.Marshal(reqBody)
127 if err != nil {
128 return nil, fmt.Errorf("marshal request: %w", err)
129 }
130
131 req, err := http.NewRequest("POST", openAIEndpoint, bytes.NewReader(body))
132 if err != nil {
133 return nil, fmt.Errorf("create request: %w", err)
134 }
135 req.Header.Set("Content-Type", "application/json")
136 req.Header.Set("Authorization", "Bearer "+c.apiKey)
137
138 resp, err := c.httpClient.Do(req)
139 if err != nil {
140 return nil, fmt.Errorf("API request: %w", err)
141 }
142 defer resp.Body.Close()
143
144 respBody, err := io.ReadAll(resp.Body)
145 if err != nil {
146 return nil, fmt.Errorf("read response: %w", err)
147 }
148
149 if resp.StatusCode != http.StatusOK {
150 return nil, fmt.Errorf("API returned %d: %s", resp.StatusCode, string(respBody))
151 }
152
153 var result embeddingResponse
154 if err := json.Unmarshal(respBody, &result); err != nil {
155 return nil, fmt.Errorf("unmarshal response: %w", err)
156 }
157
158 if result.Error != nil {
159 return nil, fmt.Errorf("API error: %s", result.Error.Message)
160 }
161
162 return result.Data, nil
163}
164
165func truncateText(text string, maxChars int) string {
166 if len(text) <= maxChars {
167 return text
168 }
169 return text[:maxChars]
170}
171
172func BuildAnnotationText(bodyValue, selectorJSON, targetTitle, tagsJSON *string) string {
173 var parts []string
174
175 if selectorJSON != nil && *selectorJSON != "" {
176 var selector struct {
177 Exact string `json:"exact"`
178 Prefix string `json:"prefix"`
179 Suffix string `json:"suffix"`
180 }
181 if err := json.Unmarshal([]byte(*selectorJSON), &selector); err == nil && selector.Exact != "" {
182 parts = append(parts, selector.Exact)
183 }
184 }
185
186 if bodyValue != nil && *bodyValue != "" {
187 parts = append(parts, *bodyValue)
188 }
189
190 if targetTitle != nil && *targetTitle != "" {
191 parts = append(parts, *targetTitle)
192 }
193
194 if tagsJSON != nil && *tagsJSON != "" {
195 var tags []string
196 if err := json.Unmarshal([]byte(*tagsJSON), &tags); err == nil && len(tags) > 0 {
197 parts = append(parts, strings.Join(tags, ", "))
198 }
199 }
200
201 return strings.Join(parts, " | ")
202}
203
204func BuildDocumentText(title, description, textContent string, tags []string) string {
205 var parts []string
206
207 parts = append(parts, title)
208
209 if len(tags) > 0 {
210 parts = append(parts, strings.Join(tags, ", "))
211 }
212
213 if textContent != "" {
214 parts = append(parts, textContent)
215 } else if description != "" {
216 parts = append(parts, description)
217 }
218
219 return strings.Join(parts, " | ")
220}