fork of indigo with slightly nicer lexgen

add rigging for hive to pre-screen images before sending to hive (#728)

for now, just going to log if the pre-screen fails at its job

authored by Whyrusleeping and committed by GitHub 05d42107 ce365b16

Changed files
+171
automod
cmd
+2
automod/visual/hiveai_client.go
··· 20 20 type HiveAIClient struct { 21 21 Client http.Client 22 22 ApiToken string 23 + 24 + PreScreenClient *PreScreenClient 23 25 } 24 26 25 27 // schema: https://docs.thehive.ai/reference/classification
+18
automod/visual/hiveai_rule.go
··· 13 13 return nil 14 14 } 15 15 16 + var psclabel string 17 + if hal.PreScreenClient != nil { 18 + val, err := hal.PreScreenClient.PreScreenImage(c.Ctx, data) 19 + if err != nil { 20 + c.Logger.Info("prescreen-request-error", "err", err) 21 + } else { 22 + psclabel = val 23 + } 24 + } 25 + 16 26 labels, err := hal.LabelBlob(c.Ctx, blob, data) 17 27 if err != nil { 18 28 return err 29 + } 30 + 31 + if psclabel == "sfw" { 32 + if len(labels) > 0 { 33 + c.Logger.Info("prescreen-safe-failure", "uri", c.RecordOp.ATURI()) 34 + } else { 35 + c.Logger.Info("prescreen-safe-success", "uri", c.RecordOp.ATURI()) 36 + } 19 37 } 20 38 21 39 for _, l := range labels {
+130
automod/visual/prescreen.go
··· 1 + package visual 2 + 3 + import ( 4 + "bytes" 5 + "context" 6 + "encoding/json" 7 + "fmt" 8 + "mime/multipart" 9 + "net/http" 10 + "sync" 11 + "time" 12 + ) 13 + 14 + const failureThresh = 10 15 + 16 + type PreScreenClient struct { 17 + Host string 18 + Token string 19 + 20 + breakerEOL time.Time 21 + breakerLk sync.Mutex 22 + failures int 23 + 24 + c *http.Client 25 + } 26 + 27 + func NewPreScreenClient(host, token string) *PreScreenClient { 28 + c := &http.Client{ 29 + Timeout: time.Second * 5, 30 + } 31 + 32 + return &PreScreenClient{ 33 + Host: host, 34 + Token: token, 35 + c: c, 36 + } 37 + } 38 + 39 + func (c *PreScreenClient) available() bool { 40 + c.breakerLk.Lock() 41 + defer c.breakerLk.Unlock() 42 + if c.breakerEOL.IsZero() { 43 + return true 44 + } 45 + 46 + if time.Now().After(c.breakerEOL) { 47 + c.breakerEOL = time.Time{} 48 + return true 49 + } 50 + 51 + return false 52 + } 53 + 54 + func (c *PreScreenClient) recordCallResult(success bool) { 55 + c.breakerLk.Lock() 56 + defer c.breakerLk.Unlock() 57 + if !c.breakerEOL.IsZero() { 58 + return 59 + } 60 + 61 + if success { 62 + c.failures = 0 63 + } else { 64 + c.failures++ 65 + if c.failures > failureThresh { 66 + c.breakerEOL = time.Now().Add(time.Minute) 67 + c.failures = 0 68 + } 69 + } 70 + } 71 + 72 + func (c *PreScreenClient) PreScreenImage(ctx context.Context, blob []byte) (string, error) { 73 + if !c.available() { 74 + return "", fmt.Errorf("pre-screening temporarily unavailable") 75 + } 76 + 77 + res, err := c.checkImage(ctx, blob) 78 + if err != nil { 79 + c.recordCallResult(false) 80 + return "", err 81 + } 82 + 83 + c.recordCallResult(true) 84 + return res, nil 85 + } 86 + 87 + type PreScreenResult struct { 88 + Result string `json:"result"` 89 + } 90 + 91 + func (c *PreScreenClient) checkImage(ctx context.Context, data []byte) (string, error) { 92 + url := c.Host + "/predict" 93 + 94 + body := new(bytes.Buffer) 95 + writer := multipart.NewWriter(body) 96 + 97 + part, err := writer.CreateFormFile("files", "image") 98 + if err != nil { 99 + return "", err 100 + } 101 + 102 + part.Write(data) 103 + 104 + if err := writer.Close(); err != nil { 105 + return "", err 106 + } 107 + 108 + req, err := http.NewRequest("POST", url, body) 109 + if err != nil { 110 + return "", err 111 + } 112 + 113 + req = req.WithContext(ctx) 114 + 115 + req.Header.Set("Content-Type", writer.FormDataContentType()) 116 + req.Header.Set("Authorization", "Bearer "+c.Token) 117 + 118 + resp, err := c.c.Do(req) 119 + if err != nil { 120 + return "", err 121 + } 122 + defer resp.Body.Close() 123 + 124 + var out PreScreenResult 125 + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { 126 + return "", err 127 + } 128 + 129 + return out.Result, nil 130 + }
+14
cmd/hepa/main.go
··· 138 138 Usage: "force a fixed number of parallel firehose workers. default (or 0) for auto-scaling; 200 works for a large instance", 139 139 EnvVars: []string{"HEPA_FIREHOSE_PARALLELISM"}, 140 140 }, 141 + &cli.StringFlag{ 142 + Name: "prescreen-host", 143 + Usage: "hostname of prescreen server", 144 + EnvVars: []string{"HEPA_PRESCREEN_HOST"}, 145 + }, 146 + &cli.StringFlag{ 147 + Name: "prescreen-token", 148 + Usage: "secret token for prescreen server", 149 + EnvVars: []string{"HEPA_PRESCREEN_TOKEN"}, 150 + }, 141 151 } 142 152 143 153 app.Commands = []*cli.Command{ ··· 242 252 RatelimitBypass: cctx.String("ratelimit-bypass"), 243 253 RulesetName: cctx.String("ruleset"), 244 254 FirehoseParallelism: cctx.Int("firehose-parallelism"), 255 + PreScreenHost: cctx.String("prescreen-host"), 256 + PreScreenToken: cctx.String("prescreen-token"), 245 257 }, 246 258 ) 247 259 if err != nil { ··· 316 328 RatelimitBypass: cctx.String("ratelimit-bypass"), 317 329 RulesetName: cctx.String("ruleset"), 318 330 FirehoseParallelism: cctx.Int("firehose-parallelism"), 331 + PreScreenHost: cctx.String("prescreen-host"), 332 + PreScreenToken: cctx.String("prescreen-token"), 319 333 }, 320 334 ) 321 335 }
+7
cmd/hepa/server.go
··· 61 61 RulesetName string 62 62 RatelimitBypass string 63 63 FirehoseParallelism int 64 + PreScreenHost string 65 + PreScreenToken string 64 66 } 65 67 66 68 func NewServer(dir identity.Directory, config Config) (*Server, error) { ··· 169 171 logger.Info("configuring Hive AI image labeler") 170 172 hc := visual.NewHiveAIClient(config.HiveAPIToken) 171 173 extraBlobRules = append(extraBlobRules, hc.HiveLabelBlobRule) 174 + 175 + if config.PreScreenHost != "" { 176 + psc := visual.NewPreScreenClient(config.PreScreenHost, config.PreScreenToken) 177 + hc.PreScreenClient = psc 178 + } 172 179 } 173 180 174 181 if config.AbyssHost != "" && config.AbyssPassword != "" {