1package search
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io/ioutil"
9 "log/slog"
10 "strings"
11
12 "github.com/bluesky-social/indigo/atproto/identity"
13 "github.com/bluesky-social/indigo/atproto/syntax"
14
15 es "github.com/opensearch-project/opensearch-go/v2"
16 "go.opentelemetry.io/otel/attribute"
17)
18
19type EsSearchHit struct {
20 Index string `json:"_index"`
21 ID string `json:"_id"`
22 Score float64 `json:"_score"`
23 Source json.RawMessage `json:"_source"`
24}
25
26type EsSearchHits struct {
27 Total struct { // not used
28 Value int
29 Relation string
30 } `json:"total"`
31 MaxScore float64 `json:"max_score"`
32 Hits []EsSearchHit `json:"hits"`
33}
34
35type EsSearchResponse struct {
36 Took int `json:"took"`
37 TimedOut bool `json:"timed_out"`
38 Hits EsSearchHits `json:"hits"`
39}
40
41type UserResult struct {
42 Did string `json:"did"`
43 Handle string `json:"handle"`
44}
45
46type PostSearchResult struct {
47 Tid string `json:"tid"`
48 Cid string `json:"cid"`
49 User UserResult `json:"user"`
50 Post any `json:"post"`
51}
52
53type PostSearchParams struct {
54 Query string `json:"q"`
55 Sort string `json:"sort"`
56 Author *syntax.DID `json:"author"`
57 Since *syntax.Datetime `json:"since"`
58 Until *syntax.Datetime `json:"until"`
59 Mentions *syntax.DID `json:"mentions"`
60 Lang *syntax.Language `json:"lang"`
61 Domain string `json:"domain"`
62 URL string `json:"url"`
63 Tags []string `json:"tag"`
64 Viewer *syntax.DID `json:"viewer"`
65 Offset int `json:"offset"`
66 Size int `json:"size"`
67}
68
69type ActorSearchParams struct {
70 Query string `json:"q"`
71 Typeahead bool `json:"typeahead"`
72 Follows []syntax.DID `json:"follows"`
73 Viewer *syntax.DID `json:"viewer"`
74 Offset int `json:"offset"`
75 Size int `json:"size"`
76}
77
78// Merges params from another param object in to this one. Intended to meld parsed query with HTTP query params, so not all functionality is supported, and priority is with the "current" object
79func (p *PostSearchParams) Update(other *PostSearchParams) {
80 p.Query = other.Query
81 if p.Author == nil {
82 p.Author = other.Author
83 }
84 if p.Since == nil {
85 p.Since = other.Since
86 }
87 if p.Until == nil {
88 p.Until = other.Until
89 }
90 if p.Mentions == nil {
91 p.Mentions = other.Mentions
92 }
93 if p.Lang == nil {
94 p.Lang = other.Lang
95 }
96 if p.Domain == "" {
97 p.Domain = other.Domain
98 }
99 if p.URL == "" {
100 p.URL = other.URL
101 }
102 if len(p.Tags) == 0 {
103 p.Tags = other.Tags
104 }
105}
106
107// Filters turns search params in to actual elasticsearch/opensearch filter DSL
108func (p *PostSearchParams) Filters() []map[string]interface{} {
109 var filters []map[string]interface{}
110
111 if p.Author != nil {
112 filters = append(filters, map[string]interface{}{
113 "term": map[string]interface{}{"did": map[string]interface{}{
114 "value": p.Author.String(),
115 "case_insensitive": true,
116 }},
117 })
118 }
119
120 if p.Mentions != nil {
121 filters = append(filters, map[string]interface{}{
122 "term": map[string]interface{}{"mention_did": map[string]interface{}{
123 "value": p.Mentions.String(),
124 "case_insensitive": true,
125 }},
126 })
127 }
128
129 if p.Lang != nil {
130 // TODO: extracting just the 2-char code would be good
131 filters = append(filters, map[string]interface{}{
132 "term": map[string]interface{}{"lang_code_iso2": map[string]interface{}{
133 "value": p.Lang.String(),
134 "case_insensitive": true,
135 }},
136 })
137 }
138
139 if p.Since != nil {
140 filters = append(filters, map[string]interface{}{
141 "range": map[string]interface{}{
142 "created_at": map[string]interface{}{
143 "gte": p.Since.String(),
144 },
145 },
146 })
147 }
148
149 if p.Until != nil {
150 filters = append(filters, map[string]interface{}{
151 "range": map[string]interface{}{
152 "created_at": map[string]interface{}{
153 "lt": p.Until.String(),
154 },
155 },
156 })
157 }
158
159 if p.URL != "" {
160 filters = append(filters, map[string]interface{}{
161 "term": map[string]interface{}{"url": map[string]interface{}{
162 "value": NormalizeLossyURL(p.URL),
163 "case_insensitive": true,
164 }},
165 })
166 }
167
168 if p.Domain != "" {
169 filters = append(filters, map[string]interface{}{
170 "term": map[string]interface{}{"domain": map[string]interface{}{
171 "value": p.Domain,
172 "case_insensitive": true,
173 }},
174 })
175 }
176
177 for _, tag := range p.Tags {
178 filters = append(filters, map[string]interface{}{
179 "term": map[string]interface{}{
180 "tag": map[string]interface{}{
181 "value": tag,
182 "case_insensitive": true,
183 },
184 },
185 })
186 }
187
188 return filters
189}
190
191// Filters turns search params in to actual elasticsearch/opensearch filter DSL
192func (p *ActorSearchParams) Filters() []map[string]interface{} {
193 var filters []map[string]interface{}
194
195 if p.Follows != nil && len(p.Follows) > 0 {
196 follows := make([]string, len(p.Follows))
197 for i, did := range p.Follows {
198 follows[i] = did.String()
199 }
200 filters = append(filters, map[string]interface{}{
201 "terms": map[string]interface{}{
202 "did": follows,
203 },
204 })
205 }
206
207 return filters
208}
209
210func checkParams(offset, size int) error {
211 if offset+size > 10000 || size > 250 || offset > 10000 || offset < 0 || size < 0 {
212 return fmt.Errorf("disallowed size/offset parameters")
213 }
214 return nil
215}
216
217func DoSearchPosts(ctx context.Context, dir identity.Directory, escli *es.Client, index string, params *PostSearchParams) (*EsSearchResponse, error) {
218 ctx, span := tracer.Start(ctx, "DoSearchPosts")
219 defer span.End()
220
221 if err := checkParams(params.Offset, params.Size); err != nil {
222 return nil, err
223 }
224 queryStringParams := ParsePostQuery(ctx, dir, params.Query, params.Viewer)
225 params.Update(&queryStringParams)
226 idx := "everything"
227 if containsJapanese(params.Query) {
228 idx = "everything_ja"
229 }
230 basic := map[string]interface{}{
231 "simple_query_string": map[string]interface{}{
232 "query": params.Query,
233 "fields": []string{idx},
234 "flags": "AND|NOT|OR|PHRASE|PRECEDENCE|WHITESPACE",
235 "default_operator": "and",
236 "lenient": true,
237 "analyze_wildcard": false,
238 },
239 }
240 filters := params.Filters()
241 // filter out future posts (TODO: temporary hack)
242 now := syntax.DatetimeNow()
243 filters = append(filters, map[string]interface{}{
244 "range": map[string]interface{}{
245 "created_at": map[string]interface{}{
246 "lte": now,
247 },
248 },
249 })
250 query := map[string]interface{}{
251 "query": map[string]interface{}{
252 "bool": map[string]interface{}{
253 "must": basic,
254 "filter": filters,
255 },
256 },
257 "sort": map[string]any{
258 "created_at": map[string]any{
259 "order": "desc",
260 },
261 },
262 "size": params.Size,
263 "from": params.Offset,
264 }
265
266 return doSearch(ctx, escli, index, query)
267}
268
269func DoSearchProfiles(ctx context.Context, dir identity.Directory, escli *es.Client, index string, params *ActorSearchParams) (*EsSearchResponse, error) {
270 ctx, span := tracer.Start(ctx, "DoSearchProfiles")
271 defer span.End()
272
273 if err := checkParams(params.Offset, params.Size); err != nil {
274 return nil, err
275 }
276
277 filters := params.Filters()
278
279 fulltext := map[string]interface{}{
280 "simple_query_string": map[string]interface{}{
281 "query": params.Query,
282 "fields": []string{"everything"},
283 "flags": "AND|NOT|OR|PHRASE|PRECEDENCE|WHITESPACE",
284 "default_operator": "and",
285 "lenient": true,
286 "analyze_wildcard": false,
287 },
288 }
289 primary := fulltext
290
291 // if the query string is just a single token (after parsing out filter
292 // syntax), then have the primary query be an "OR" of the basic fulltext
293 // query and the typeahead query
294 if len(strings.Split(params.Query, " ")) == 1 {
295 typeahead := map[string]interface{}{
296 "multi_match": map[string]interface{}{
297 "query": params.Query,
298 "type": "bool_prefix",
299 "operator": "and",
300 "fields": []string{
301 "typeahead",
302 "typeahead._2gram",
303 "typeahead._3gram",
304 },
305 },
306 }
307 primary = map[string]interface{}{
308 "bool": map[string]interface{}{
309 "should": []interface{}{
310 fulltext,
311 typeahead,
312 },
313 },
314 }
315 }
316
317 query := map[string]interface{}{
318 "query": map[string]interface{}{
319 "bool": map[string]interface{}{
320 "must": primary,
321 "should": []interface{}{
322 map[string]interface{}{"term": map[string]interface{}{"has_avatar": true}},
323 map[string]interface{}{"term": map[string]interface{}{"has_banner": true}},
324 },
325 "minimum_should_match": 0,
326 "boost": 0.5,
327 },
328 },
329 "size": params.Size,
330 "from": params.Offset,
331 }
332
333 if len(filters) > 0 {
334 query["query"].(map[string]interface{})["bool"].(map[string]interface{})["filter"] = filters
335 }
336
337 return doSearch(ctx, escli, index, query)
338}
339
340func DoSearchProfilesTypeahead(ctx context.Context, escli *es.Client, index string, params *ActorSearchParams) (*EsSearchResponse, error) {
341 ctx, span := tracer.Start(ctx, "DoSearchProfilesTypeahead")
342 defer span.End()
343
344 if err := checkParams(0, params.Size); err != nil {
345 return nil, err
346 }
347
348 filters := params.Filters()
349
350 query := map[string]interface{}{
351 "query": map[string]interface{}{
352 "bool": map[string]interface{}{
353 "must": map[string]interface{}{
354 "multi_match": map[string]interface{}{
355 "query": params.Query,
356 "type": "bool_prefix",
357 "operator": "and",
358 "fields": []string{
359 "typeahead",
360 "typeahead._2gram",
361 "typeahead._3gram",
362 },
363 },
364 },
365 },
366 },
367 "size": params.Size,
368 "from": params.Offset,
369 }
370
371 if len(filters) > 0 {
372 query["query"].(map[string]interface{})["bool"].(map[string]interface{})["filter"] = filters
373 }
374
375 return doSearch(ctx, escli, index, query)
376}
377
378// helper to do a full-featured Lucene query parser (query_string) search, with all possible facets. Not safe to expose publicly.
379func DoSearchGeneric(ctx context.Context, escli *es.Client, index, q string) (*EsSearchResponse, error) {
380 ctx, span := tracer.Start(ctx, "DoSearchGeneric")
381 defer span.End()
382
383 query := map[string]interface{}{
384 "query": map[string]interface{}{
385 "query_string": map[string]interface{}{
386 "query": q,
387 "default_operator": "and",
388 "analyze_wildcard": true,
389 "allow_leading_wildcard": false,
390 "lenient": true,
391 "default_field": "everything",
392 },
393 },
394 }
395
396 return doSearch(ctx, escli, index, query)
397}
398
399func doSearch(ctx context.Context, escli *es.Client, index string, query interface{}) (*EsSearchResponse, error) {
400 ctx, span := tracer.Start(ctx, "doSearch")
401 defer span.End()
402
403 span.SetAttributes(attribute.String("index", index), attribute.String("query", fmt.Sprintf("%+v", query)))
404
405 b, err := json.Marshal(query)
406 if err != nil {
407 return nil, fmt.Errorf("failed to serialize query: %w", err)
408 }
409 slog.Info("sending query", "index", index, "query", string(b))
410
411 // Perform the search request.
412 res, err := escli.Search(
413 escli.Search.WithContext(ctx),
414 escli.Search.WithIndex(index),
415 escli.Search.WithBody(bytes.NewBuffer(b)),
416 )
417 if err != nil {
418 return nil, fmt.Errorf("search query error: %w", err)
419 }
420 defer res.Body.Close()
421 if res.IsError() {
422 raw, err := ioutil.ReadAll(res.Body)
423 if nil == err {
424 slog.Warn("search query error", "resp", string(raw), "status_code", res.StatusCode)
425 }
426 return nil, fmt.Errorf("search query error, code=%d", res.StatusCode)
427 }
428
429 var out EsSearchResponse
430 if err := json.NewDecoder(res.Body).Decode(&out); err != nil {
431 return nil, fmt.Errorf("decoding search response: %w", err)
432 }
433
434 return &out, nil
435}