1package visual
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "log/slog"
10 "mime/multipart"
11 "net/http"
12 "time"
13
14 lexutil "github.com/bluesky-social/indigo/lex/util"
15 "github.com/bluesky-social/indigo/util"
16
17 "github.com/carlmjohnson/versioninfo"
18)
19
20type HiveAIClient struct {
21 Client http.Client
22 ApiToken string
23
24 PreScreenClient *PreScreenClient
25}
26
27// schema: https://docs.thehive.ai/reference/classification
28type HiveAIResp struct {
29 Status []HiveAIResp_Status `json:"status"`
30}
31
32type HiveAIResp_Status struct {
33 Response HiveAIResp_Response `json:"response"`
34}
35
36type HiveAIResp_Response struct {
37 Output []HiveAIResp_Out `json:"output"`
38}
39
40type HiveAIResp_Out struct {
41 Time float64 `json:"time"`
42 Classes []HiveAIResp_Class `json:"classes"`
43}
44
45type HiveAIResp_Class struct {
46 Class string `json:"class"`
47 Score float64 `json:"score"`
48}
49
50func NewHiveAIClient(token string) HiveAIClient {
51 return HiveAIClient{
52 Client: *util.RobustHTTPClient(),
53 ApiToken: token,
54 }
55}
56
57// Simple direct mappings from individual classes to to labels
58//
59// hive gore and violence: https://docs.thehive.ai/docs/class-descriptions-violence-gore
60func summarizeSimpleLabels(cl []HiveAIResp_Class) []string {
61 var labels []string
62
63 for _, cls := range cl {
64 if cls.Class == "very_bloody" && cls.Score >= 0.90 {
65 labels = append(labels, "graphic-media")
66 }
67 if cls.Class == "human_corpse" && cls.Score >= 0.90 {
68 labels = append(labels, "graphic-media")
69 }
70 if cls.Class == "hanging" && cls.Score >= 0.90 {
71 labels = append(labels, "graphic-media")
72 }
73 if cls.Class == "yes_self_harm" && cls.Score >= 0.96 {
74 labels = append(labels, "self-harm")
75 }
76 }
77 return labels
78}
79
80// Matches only one (or none) of: porn, sexual, nudity
81//
82// porn: sexual and nudity. including both explicit activity or full-frontal and suggestive/intent
83// sexual: sexually suggestive, not explicit; may include some forms of nudity
84// nudity: non-sexual nudity (eg, artistic, possibly some photographic)
85//
86// hive docs/definitions: https://docs.thehive.ai/docs/sexual-content
87func summarizeSexualLabels(cl []HiveAIResp_Class) string {
88
89 scores := make(map[string]float64)
90 for _, cls := range cl {
91 scores[cls.Class] = cls.Score
92 }
93
94 threshold := 0.90
95
96 // if this is furry art content, then require very high confidence when flagging for any sexual reason
97 // note that this is a custom model, not always returned in generic Hive responses
98 if furryScore, ok := scores["furry-yes_furry"]; ok && furryScore > 0.95 {
99 threshold = 0.99
100 }
101
102 // first check if porn...
103 for _, pornClass := range []string{"yes_sexual_activity", "animal_genitalia_and_human", "yes_realistic_nsfw"} {
104 if scores[pornClass] >= threshold {
105 return "porn"
106 }
107 }
108 if scores["general_nsfw"] >= threshold {
109 // special case for some anime examples
110 if scores["animated_animal_genitalia"] >= 0.5 {
111 return "porn"
112 }
113
114 // special case for some pornographic/explicit classic drawings
115 if scores["yes_undressed"] >= threshold && scores["yes_sexual_activity"] >= threshold {
116 return "porn"
117 }
118 }
119
120 // then check for sexual suggestive (which may include nudity)...
121 for _, sexualClass := range []string{"yes_sexual_intent", "yes_sex_toy"} {
122 if scores[sexualClass] >= threshold {
123 return "sexual"
124 }
125 }
126 if scores["yes_undressed"] >= threshold {
127 // special case for bondage examples
128 if scores["yes_sex_toy"] > 0.75 {
129 return "sexual"
130 }
131 }
132
133 // then non-sexual nudity...
134 for _, nudityClass := range []string{"yes_male_nudity", "yes_female_nudity", "yes_undressed"} {
135 if scores[nudityClass] >= threshold {
136 return "nudity"
137 }
138 }
139
140 // then finally flag remaining "underwear" images in to sexually suggestive
141 // (after non-sexual content already labeled above)
142 for _, underwearClass := range []string{"yes_male_underwear", "yes_female_underwear"} {
143 // TODO: experimenting with higher threshhold during traffic spike
144 //if scores[underwearClass] >= threshold {
145 if scores[underwearClass] >= 0.98 {
146 return "sexual"
147 }
148 }
149
150 return ""
151}
152
153func (resp *HiveAIResp) SummarizeLabels() []string {
154 var labels []string
155
156 for _, status := range resp.Status {
157 for _, out := range status.Response.Output {
158 simple := summarizeSimpleLabels(out.Classes)
159 if len(simple) > 0 {
160 labels = append(labels, simple...)
161 }
162
163 sexual := summarizeSexualLabels(out.Classes)
164 if sexual != "" {
165 labels = append(labels, sexual)
166 }
167 }
168 }
169
170 return labels
171}
172
173func (hal *HiveAIClient) LabelBlob(ctx context.Context, blob lexutil.LexBlob, blobBytes []byte) ([]string, error) {
174
175 slog.Debug("sending blob to Hive AI", "cid", blob.Ref.String(), "mimetype", blob.MimeType, "size", len(blobBytes))
176
177 // generic HTTP form file upload, then parse the response JSON
178 body := &bytes.Buffer{}
179 writer := multipart.NewWriter(body)
180 part, err := writer.CreateFormFile("media", blob.Ref.String())
181 if err != nil {
182 return nil, err
183 }
184 _, err = part.Write(blobBytes)
185 if err != nil {
186 return nil, err
187 }
188 err = writer.Close()
189 if err != nil {
190 return nil, err
191 }
192
193 req, err := http.NewRequest("POST", "https://api.thehive.ai/api/v2/task/sync", body)
194 if err != nil {
195 return nil, err
196 }
197
198 start := time.Now()
199 defer func() {
200 duration := time.Since(start)
201 hiveAPIDuration.Observe(duration.Seconds())
202 }()
203
204 req.Header.Set("Authorization", fmt.Sprintf("Token %s", hal.ApiToken))
205 req.Header.Add("Content-Type", writer.FormDataContentType())
206 req.Header.Set("Accept", "application/json")
207 req.Header.Set("User-Agent", "indigo-automod/"+versioninfo.Short())
208
209 req = req.WithContext(ctx)
210 res, err := hal.Client.Do(req)
211 if err != nil {
212 return nil, fmt.Errorf("HiveAI request failed: %v", err)
213 }
214 defer res.Body.Close()
215
216 hiveAPICount.WithLabelValues(fmt.Sprint(res.StatusCode)).Inc()
217 if res.StatusCode != 200 {
218 return nil, fmt.Errorf("HiveAI request failed statusCode=%d", res.StatusCode)
219 }
220
221 respBytes, err := io.ReadAll(res.Body)
222 if err != nil {
223 return nil, fmt.Errorf("failed to read HiveAI resp body: %v", err)
224 }
225
226 var respObj HiveAIResp
227 if err := json.Unmarshal(respBytes, &respObj); err != nil {
228 return nil, fmt.Errorf("failed to parse HiveAI resp JSON: %v", err)
229 }
230 slog.Info("hive-ai-response", "cid", blob.Ref.String(), "obj", respObj)
231 return respObj.SummarizeLabels(), nil
232}