Live video on the AT Protocol

Merge pull request #151 from bluwumeaway/feature/ratelimit-middleware

add rate limiting middleware

authored by Eli Mallon and committed by GitHub 595fe0a7 a2a57148

+109 -1
+86
pkg/api/api.go
··· 14 "os" 15 "slices" 16 "strings" 17 "time" 18 19 "github.com/NYTimes/gziphandler" ··· 21 "github.com/julienschmidt/httprouter" 22 "github.com/rs/cors" 23 sloghttp "github.com/samber/slog-http" 24 25 "stream.place/streamplace/js/app" 26 "stream.place/streamplace/pkg/atproto" ··· 54 Bus *bus.Bus 55 ATSync *atproto.ATProtoSynchronizer 56 Director *director.Director 57 } 58 59 func MakeStreamplaceAPI(cli *config.CLI, mod model.Model, signer *eip712.EIP712Signer, noter notifications.FirebaseNotifier, mm *media.MediaManager, ms media.MediaSigner, bus *bus.Bus, atsync *atproto.ATProtoSynchronizer, d *director.Director) (*StreamplaceAPI, error) { ··· 72 Bus: bus, 73 ATSync: atsync, 74 Director: d, 75 } 76 a.Mimes, err = updater.GetMimes() 77 if err != nil { ··· 214 handler := sloghttp.Recovery(router) 215 handler = cors.AllowAll().Handler(handler) 216 handler = sloghttp.New(slog.Default())(handler) 217 218 return handler, nil 219 } ··· 669 } 670 } 671 672 func (a *StreamplaceAPI) ServeHTTP(ctx context.Context) error { 673 handler, err := a.Handler(ctx) 674 if err != nil { ··· 733 w.WriteHeader(200) 734 } 735 }
··· 14 "os" 15 "slices" 16 "strings" 17 + "sync" 18 "time" 19 20 "github.com/NYTimes/gziphandler" ··· 22 "github.com/julienschmidt/httprouter" 23 "github.com/rs/cors" 24 sloghttp "github.com/samber/slog-http" 25 + "golang.org/x/time/rate" 26 27 "stream.place/streamplace/js/app" 28 "stream.place/streamplace/pkg/atproto" ··· 56 Bus *bus.Bus 57 ATSync *atproto.ATProtoSynchronizer 58 Director *director.Director 59 + 60 + connTracker *WebsocketTracker 61 + 62 + limiters map[string]*rate.Limiter 63 + limitersMu sync.Mutex 64 + } 65 + 66 + type WebsocketTracker struct { 67 + connections map[string]int 68 + maxConnsPerIP int 69 + mu sync.RWMutex 70 } 71 72 func MakeStreamplaceAPI(cli *config.CLI, mod model.Model, signer *eip712.EIP712Signer, noter notifications.FirebaseNotifier, mm *media.MediaManager, ms media.MediaSigner, bus *bus.Bus, atsync *atproto.ATProtoSynchronizer, d *director.Director) (*StreamplaceAPI, error) { ··· 85 Bus: bus, 86 ATSync: atsync, 87 Director: d, 88 + connTracker: NewWebsocketTracker(5), 89 + limiters: make(map[string]*rate.Limiter), 90 } 91 a.Mimes, err = updater.GetMimes() 92 if err != nil { ··· 229 handler := sloghttp.Recovery(router) 230 handler = cors.AllowAll().Handler(handler) 231 handler = sloghttp.New(slog.Default())(handler) 232 + handler = a.RateLimitMiddleware(ctx)(handler) 233 234 return handler, nil 235 } ··· 685 } 686 } 687 688 + func (a *StreamplaceAPI) RateLimitMiddleware(ctx context.Context) func(http.Handler) http.Handler { 689 + return func(next http.Handler) http.Handler { 690 + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 691 + ip, _, err := net.SplitHostPort(req.RemoteAddr) 692 + if err != nil { 693 + ip = req.RemoteAddr 694 + } 695 + 696 + limiter := a.getLimiter(ip) 697 + 698 + if !limiter.Allow() { 699 + log.Warn(ctx, "rate limit exceeded", "ip", ip, "path", req.URL.Path) 700 + apierrors.WriteHTTPTooManyRequests(w, "rate limit exceeded") 701 + return 702 + } 703 + 704 + next.ServeHTTP(w, req) 705 + }) 706 + } 707 + } 708 + 709 func (a *StreamplaceAPI) ServeHTTP(ctx context.Context) error { 710 handler, err := a.Handler(ctx) 711 if err != nil { ··· 770 w.WriteHeader(200) 771 } 772 } 773 + 774 + func (a *StreamplaceAPI) getLimiter(ip string) *rate.Limiter { 775 + a.limitersMu.Lock() 776 + defer a.limitersMu.Unlock() 777 + 778 + limiter, exists := a.limiters[ip] 779 + if !exists { 780 + // 5 actions per second with a burst of 3 781 + limiter = rate.NewLimiter(rate.Limit(10.0), 8) 782 + a.limiters[ip] = limiter 783 + } 784 + 785 + return limiter 786 + } 787 + 788 + func NewWebsocketTracker(maxConns int) *WebsocketTracker { 789 + return &WebsocketTracker{ 790 + connections: make(map[string]int), 791 + maxConnsPerIP: maxConns, 792 + } 793 + } 794 + 795 + func (t *WebsocketTracker) AddConnection(ip string) bool { 796 + t.mu.Lock() 797 + defer t.mu.Unlock() 798 + 799 + count := t.connections[ip] 800 + 801 + if count >= t.maxConnsPerIP { 802 + return false 803 + } 804 + 805 + t.connections[ip] = count + 1 806 + return true 807 + } 808 + 809 + func (t *WebsocketTracker) RemoveConnection(ip string) { 810 + t.mu.Lock() 811 + defer t.mu.Unlock() 812 + 813 + count := t.connections[ip] 814 + if count > 0 { 815 + t.connections[ip] = count - 1 816 + } 817 + 818 + if t.connections[ip] == 0 { 819 + delete(t.connections, ip) 820 + } 821 + }
+19 -1
pkg/api/websocket.go
··· 3 import ( 4 "context" 5 "encoding/json" 6 "net/http" 7 "time" 8 9 "github.com/google/uuid" 10 "github.com/gorilla/websocket" 11 "github.com/julienschmidt/httprouter" 12 apierrors "stream.place/streamplace/pkg/errors" 13 "stream.place/streamplace/pkg/log" 14 "stream.place/streamplace/pkg/renditions" ··· 27 func (a *StreamplaceAPI) HandleWebsocket(ctx context.Context) httprouter.Handle { 28 ctx = log.WithLogValues(ctx, "func", "HandleWebsocket") 29 return func(w http.ResponseWriter, req *http.Request, params httprouter.Params) { 30 uu, _ := uuid.NewV7() 31 - ctx = log.WithLogValues(ctx, "uuid", uu.String(), "remoteAddr", req.RemoteAddr, "url", req.URL.String()) 32 log.Log(ctx, "websocket opened") 33 spmetrics.WebsocketsOpen.Inc() 34 defer spmetrics.WebsocketsOpen.Dec() ··· 50 ctx, cancel := context.WithCancel(ctx) 51 defer cancel() 52 defer conn.Close() 53 initialBurst := make(chan any, 200) 54 err = conn.SetReadDeadline(time.Now().Add(30 * time.Second)) 55 if err != nil {
··· 3 import ( 4 "context" 5 "encoding/json" 6 + "net" 7 "net/http" 8 "time" 9 10 "github.com/google/uuid" 11 "github.com/gorilla/websocket" 12 "github.com/julienschmidt/httprouter" 13 + 14 apierrors "stream.place/streamplace/pkg/errors" 15 "stream.place/streamplace/pkg/log" 16 "stream.place/streamplace/pkg/renditions" ··· 29 func (a *StreamplaceAPI) HandleWebsocket(ctx context.Context) httprouter.Handle { 30 ctx = log.WithLogValues(ctx, "func", "HandleWebsocket") 31 return func(w http.ResponseWriter, req *http.Request, params httprouter.Params) { 32 + ip, _, err := net.SplitHostPort(req.RemoteAddr) 33 + if err != nil { 34 + ip = req.RemoteAddr 35 + } 36 + 37 + if !a.connTracker.AddConnection(ip) { 38 + log.Warn(ctx, "rate limit exceeded", "ip", ip, "path", req.URL.Path) 39 + apierrors.WriteHTTPTooManyRequests(w, "rate limit exceeded") 40 + return 41 + } 42 + 43 + defer a.connTracker.RemoveConnection(ip) 44 + 45 uu, _ := uuid.NewV7() 46 + connID := uu.String() 47 + 48 + ctx = log.WithLogValues(ctx, "uuid", connID, "remoteAddr", req.RemoteAddr, "url", req.URL.String()) 49 log.Log(ctx, "websocket opened") 50 spmetrics.WebsocketsOpen.Inc() 51 defer spmetrics.WebsocketsOpen.Dec() ··· 67 ctx, cancel := context.WithCancel(ctx) 68 defer cancel() 69 defer conn.Close() 70 + 71 initialBurst := make(chan any, 200) 72 err = conn.SetReadDeadline(time.Now().Add(30 * time.Second)) 73 if err != nil {
+4
pkg/errors/errors.go
··· 61 func WriteHTTPNotImplemented(w http.ResponseWriter, msg string, err error) APIError { 62 return writeHttpError(w, msg, http.StatusNotImplemented, err) 63 }
··· 61 func WriteHTTPNotImplemented(w http.ResponseWriter, msg string, err error) APIError { 62 return writeHttpError(w, msg, http.StatusNotImplemented, err) 63 } 64 + 65 + func WriteHTTPTooManyRequests(w http.ResponseWriter, msg string) APIError { 66 + return writeHttpError(w, msg, http.StatusTooManyRequests, nil) 67 + }