fork of indigo with slightly nicer lexgen
at main 6.3 kB view raw
1package visual 2 3import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "io" 9 "log/slog" 10 "mime/multipart" 11 "net/http" 12 "time" 13 14 lexutil "github.com/bluesky-social/indigo/lex/util" 15 "github.com/bluesky-social/indigo/util" 16 17 "github.com/carlmjohnson/versioninfo" 18) 19 20type HiveAIClient struct { 21 Client http.Client 22 ApiToken string 23 24 PreScreenClient *PreScreenClient 25} 26 27// schema: https://docs.thehive.ai/reference/classification 28type HiveAIResp struct { 29 Status []HiveAIResp_Status `json:"status"` 30} 31 32type HiveAIResp_Status struct { 33 Response HiveAIResp_Response `json:"response"` 34} 35 36type HiveAIResp_Response struct { 37 Output []HiveAIResp_Out `json:"output"` 38} 39 40type HiveAIResp_Out struct { 41 Time float64 `json:"time"` 42 Classes []HiveAIResp_Class `json:"classes"` 43} 44 45type HiveAIResp_Class struct { 46 Class string `json:"class"` 47 Score float64 `json:"score"` 48} 49 50func NewHiveAIClient(token string) HiveAIClient { 51 return HiveAIClient{ 52 Client: *util.RobustHTTPClient(), 53 ApiToken: token, 54 } 55} 56 57// Simple direct mappings from individual classes to to labels 58// 59// hive gore and violence: https://docs.thehive.ai/docs/class-descriptions-violence-gore 60func summarizeSimpleLabels(cl []HiveAIResp_Class) []string { 61 var labels []string 62 63 for _, cls := range cl { 64 if cls.Class == "very_bloody" && cls.Score >= 0.90 { 65 labels = append(labels, "graphic-media") 66 } 67 if cls.Class == "human_corpse" && cls.Score >= 0.90 { 68 labels = append(labels, "graphic-media") 69 } 70 if cls.Class == "hanging" && cls.Score >= 0.90 { 71 labels = append(labels, "graphic-media") 72 } 73 if cls.Class == "yes_self_harm" && cls.Score >= 0.96 { 74 labels = append(labels, "self-harm") 75 } 76 } 77 return labels 78} 79 80// Matches only one (or none) of: porn, sexual, nudity 81// 82// porn: sexual and nudity. including both explicit activity or full-frontal and suggestive/intent 83// sexual: sexually suggestive, not explicit; may include some forms of nudity 84// nudity: non-sexual nudity (eg, artistic, possibly some photographic) 85// 86// hive docs/definitions: https://docs.thehive.ai/docs/sexual-content 87func summarizeSexualLabels(cl []HiveAIResp_Class) string { 88 89 scores := make(map[string]float64) 90 for _, cls := range cl { 91 scores[cls.Class] = cls.Score 92 } 93 94 threshold := 0.90 95 96 // if this is furry art content, then require very high confidence when flagging for any sexual reason 97 // note that this is a custom model, not always returned in generic Hive responses 98 if furryScore, ok := scores["furry-yes_furry"]; ok && furryScore > 0.95 { 99 threshold = 0.99 100 } 101 102 // first check if porn... 103 for _, pornClass := range []string{"yes_sexual_activity", "animal_genitalia_and_human", "yes_realistic_nsfw"} { 104 if scores[pornClass] >= threshold { 105 return "porn" 106 } 107 } 108 if scores["general_nsfw"] >= threshold { 109 // special case for some anime examples 110 if scores["animated_animal_genitalia"] >= 0.5 { 111 return "porn" 112 } 113 114 // special case for some pornographic/explicit classic drawings 115 if scores["yes_undressed"] >= threshold && scores["yes_sexual_activity"] >= threshold { 116 return "porn" 117 } 118 } 119 120 // then check for sexual suggestive (which may include nudity)... 121 for _, sexualClass := range []string{"yes_sexual_intent", "yes_sex_toy"} { 122 if scores[sexualClass] >= threshold { 123 return "sexual" 124 } 125 } 126 if scores["yes_undressed"] >= threshold { 127 // special case for bondage examples 128 if scores["yes_sex_toy"] > 0.75 { 129 return "sexual" 130 } 131 } 132 133 // then non-sexual nudity... 134 for _, nudityClass := range []string{"yes_male_nudity", "yes_female_nudity", "yes_undressed"} { 135 if scores[nudityClass] >= threshold { 136 return "nudity" 137 } 138 } 139 140 // then finally flag remaining "underwear" images in to sexually suggestive 141 // (after non-sexual content already labeled above) 142 for _, underwearClass := range []string{"yes_male_underwear", "yes_female_underwear"} { 143 // TODO: experimenting with higher threshhold during traffic spike 144 //if scores[underwearClass] >= threshold { 145 if scores[underwearClass] >= 0.98 { 146 return "sexual" 147 } 148 } 149 150 return "" 151} 152 153func (resp *HiveAIResp) SummarizeLabels() []string { 154 var labels []string 155 156 for _, status := range resp.Status { 157 for _, out := range status.Response.Output { 158 simple := summarizeSimpleLabels(out.Classes) 159 if len(simple) > 0 { 160 labels = append(labels, simple...) 161 } 162 163 sexual := summarizeSexualLabels(out.Classes) 164 if sexual != "" { 165 labels = append(labels, sexual) 166 } 167 } 168 } 169 170 return labels 171} 172 173func (hal *HiveAIClient) LabelBlob(ctx context.Context, blob lexutil.LexBlob, blobBytes []byte) ([]string, error) { 174 175 slog.Debug("sending blob to Hive AI", "cid", blob.Ref.String(), "mimetype", blob.MimeType, "size", len(blobBytes)) 176 177 // generic HTTP form file upload, then parse the response JSON 178 body := &bytes.Buffer{} 179 writer := multipart.NewWriter(body) 180 part, err := writer.CreateFormFile("media", blob.Ref.String()) 181 if err != nil { 182 return nil, err 183 } 184 _, err = part.Write(blobBytes) 185 if err != nil { 186 return nil, err 187 } 188 err = writer.Close() 189 if err != nil { 190 return nil, err 191 } 192 193 req, err := http.NewRequest("POST", "https://api.thehive.ai/api/v2/task/sync", body) 194 if err != nil { 195 return nil, err 196 } 197 198 start := time.Now() 199 defer func() { 200 duration := time.Since(start) 201 hiveAPIDuration.Observe(duration.Seconds()) 202 }() 203 204 req.Header.Set("Authorization", fmt.Sprintf("Token %s", hal.ApiToken)) 205 req.Header.Add("Content-Type", writer.FormDataContentType()) 206 req.Header.Set("Accept", "application/json") 207 req.Header.Set("User-Agent", "indigo-automod/"+versioninfo.Short()) 208 209 req = req.WithContext(ctx) 210 res, err := hal.Client.Do(req) 211 if err != nil { 212 return nil, fmt.Errorf("HiveAI request failed: %v", err) 213 } 214 defer res.Body.Close() 215 216 hiveAPICount.WithLabelValues(fmt.Sprint(res.StatusCode)).Inc() 217 if res.StatusCode != 200 { 218 return nil, fmt.Errorf("HiveAI request failed statusCode=%d", res.StatusCode) 219 } 220 221 respBytes, err := io.ReadAll(res.Body) 222 if err != nil { 223 return nil, fmt.Errorf("failed to read HiveAI resp body: %v", err) 224 } 225 226 var respObj HiveAIResp 227 if err := json.Unmarshal(respBytes, &respObj); err != nil { 228 return nil, fmt.Errorf("failed to parse HiveAI resp JSON: %v", err) 229 } 230 slog.Info("hive-ai-response", "cid", blob.Ref.String(), "obj", respObj) 231 return respObj.SummarizeLabels(), nil 232}