+2
automod/visual/hiveai_client.go
+2
automod/visual/hiveai_client.go
+18
automod/visual/hiveai_rule.go
+18
automod/visual/hiveai_rule.go
···
13
13
return nil
14
14
}
15
15
16
+
var psclabel string
17
+
if hal.PreScreenClient != nil {
18
+
val, err := hal.PreScreenClient.PreScreenImage(c.Ctx, data)
19
+
if err != nil {
20
+
c.Logger.Info("prescreen-request-error", "err", err)
21
+
} else {
22
+
psclabel = val
23
+
}
24
+
}
25
+
16
26
labels, err := hal.LabelBlob(c.Ctx, blob, data)
17
27
if err != nil {
18
28
return err
29
+
}
30
+
31
+
if psclabel == "sfw" {
32
+
if len(labels) > 0 {
33
+
c.Logger.Info("prescreen-safe-failure", "uri", c.RecordOp.ATURI())
34
+
} else {
35
+
c.Logger.Info("prescreen-safe-success", "uri", c.RecordOp.ATURI())
36
+
}
19
37
}
20
38
21
39
for _, l := range labels {
+130
automod/visual/prescreen.go
+130
automod/visual/prescreen.go
···
1
+
package visual
2
+
3
+
import (
4
+
"bytes"
5
+
"context"
6
+
"encoding/json"
7
+
"fmt"
8
+
"mime/multipart"
9
+
"net/http"
10
+
"sync"
11
+
"time"
12
+
)
13
+
14
+
const failureThresh = 10
15
+
16
+
type PreScreenClient struct {
17
+
Host string
18
+
Token string
19
+
20
+
breakerEOL time.Time
21
+
breakerLk sync.Mutex
22
+
failures int
23
+
24
+
c *http.Client
25
+
}
26
+
27
+
func NewPreScreenClient(host, token string) *PreScreenClient {
28
+
c := &http.Client{
29
+
Timeout: time.Second * 5,
30
+
}
31
+
32
+
return &PreScreenClient{
33
+
Host: host,
34
+
Token: token,
35
+
c: c,
36
+
}
37
+
}
38
+
39
+
func (c *PreScreenClient) available() bool {
40
+
c.breakerLk.Lock()
41
+
defer c.breakerLk.Unlock()
42
+
if c.breakerEOL.IsZero() {
43
+
return true
44
+
}
45
+
46
+
if time.Now().After(c.breakerEOL) {
47
+
c.breakerEOL = time.Time{}
48
+
return true
49
+
}
50
+
51
+
return false
52
+
}
53
+
54
+
func (c *PreScreenClient) recordCallResult(success bool) {
55
+
c.breakerLk.Lock()
56
+
defer c.breakerLk.Unlock()
57
+
if !c.breakerEOL.IsZero() {
58
+
return
59
+
}
60
+
61
+
if success {
62
+
c.failures = 0
63
+
} else {
64
+
c.failures++
65
+
if c.failures > failureThresh {
66
+
c.breakerEOL = time.Now().Add(time.Minute)
67
+
c.failures = 0
68
+
}
69
+
}
70
+
}
71
+
72
+
func (c *PreScreenClient) PreScreenImage(ctx context.Context, blob []byte) (string, error) {
73
+
if !c.available() {
74
+
return "", fmt.Errorf("pre-screening temporarily unavailable")
75
+
}
76
+
77
+
res, err := c.checkImage(ctx, blob)
78
+
if err != nil {
79
+
c.recordCallResult(false)
80
+
return "", err
81
+
}
82
+
83
+
c.recordCallResult(true)
84
+
return res, nil
85
+
}
86
+
87
+
type PreScreenResult struct {
88
+
Result string `json:"result"`
89
+
}
90
+
91
+
func (c *PreScreenClient) checkImage(ctx context.Context, data []byte) (string, error) {
92
+
url := c.Host + "/predict"
93
+
94
+
body := new(bytes.Buffer)
95
+
writer := multipart.NewWriter(body)
96
+
97
+
part, err := writer.CreateFormFile("files", "image")
98
+
if err != nil {
99
+
return "", err
100
+
}
101
+
102
+
part.Write(data)
103
+
104
+
if err := writer.Close(); err != nil {
105
+
return "", err
106
+
}
107
+
108
+
req, err := http.NewRequest("POST", url, body)
109
+
if err != nil {
110
+
return "", err
111
+
}
112
+
113
+
req = req.WithContext(ctx)
114
+
115
+
req.Header.Set("Content-Type", writer.FormDataContentType())
116
+
req.Header.Set("Authorization", "Bearer "+c.Token)
117
+
118
+
resp, err := c.c.Do(req)
119
+
if err != nil {
120
+
return "", err
121
+
}
122
+
defer resp.Body.Close()
123
+
124
+
var out PreScreenResult
125
+
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
126
+
return "", err
127
+
}
128
+
129
+
return out.Result, nil
130
+
}
+14
cmd/hepa/main.go
+14
cmd/hepa/main.go
···
138
138
Usage: "force a fixed number of parallel firehose workers. default (or 0) for auto-scaling; 200 works for a large instance",
139
139
EnvVars: []string{"HEPA_FIREHOSE_PARALLELISM"},
140
140
},
141
+
&cli.StringFlag{
142
+
Name: "prescreen-host",
143
+
Usage: "hostname of prescreen server",
144
+
EnvVars: []string{"HEPA_PRESCREEN_HOST"},
145
+
},
146
+
&cli.StringFlag{
147
+
Name: "prescreen-token",
148
+
Usage: "secret token for prescreen server",
149
+
EnvVars: []string{"HEPA_PRESCREEN_TOKEN"},
150
+
},
141
151
}
142
152
143
153
app.Commands = []*cli.Command{
···
242
252
RatelimitBypass: cctx.String("ratelimit-bypass"),
243
253
RulesetName: cctx.String("ruleset"),
244
254
FirehoseParallelism: cctx.Int("firehose-parallelism"),
255
+
PreScreenHost: cctx.String("prescreen-host"),
256
+
PreScreenToken: cctx.String("prescreen-token"),
245
257
},
246
258
)
247
259
if err != nil {
···
316
328
RatelimitBypass: cctx.String("ratelimit-bypass"),
317
329
RulesetName: cctx.String("ruleset"),
318
330
FirehoseParallelism: cctx.Int("firehose-parallelism"),
331
+
PreScreenHost: cctx.String("prescreen-host"),
332
+
PreScreenToken: cctx.String("prescreen-token"),
319
333
},
320
334
)
321
335
}
+7
cmd/hepa/server.go
+7
cmd/hepa/server.go
···
61
61
RulesetName string
62
62
RatelimitBypass string
63
63
FirehoseParallelism int
64
+
PreScreenHost string
65
+
PreScreenToken string
64
66
}
65
67
66
68
func NewServer(dir identity.Directory, config Config) (*Server, error) {
···
169
171
logger.Info("configuring Hive AI image labeler")
170
172
hc := visual.NewHiveAIClient(config.HiveAPIToken)
171
173
extraBlobRules = append(extraBlobRules, hc.HiveLabelBlobRule)
174
+
175
+
if config.PreScreenHost != "" {
176
+
psc := visual.NewPreScreenClient(config.PreScreenHost, config.PreScreenToken)
177
+
hc.PreScreenClient = psc
178
+
}
172
179
}
173
180
174
181
if config.AbyssHost != "" && config.AbyssPassword != "" {