Live video on the AT Protocol
at natb/docs-table-wrapping 234 lines 6.3 kB view raw
1package api 2 3import ( 4 "context" 5 "encoding/json" 6 "net" 7 "net/http" 8 "time" 9 10 "github.com/google/uuid" 11 "github.com/gorilla/websocket" 12 "github.com/julienschmidt/httprouter" 13 14 apierrors "stream.place/streamplace/pkg/errors" 15 "stream.place/streamplace/pkg/log" 16 "stream.place/streamplace/pkg/renditions" 17 "stream.place/streamplace/pkg/spmetrics" 18 "stream.place/streamplace/pkg/streamplace" 19) 20 21// todo: does this mean a whole message has to fit within the buffer? 22var upgrader = websocket.Upgrader{ 23 ReadBufferSize: 1024, 24 WriteBufferSize: 1024, 25} 26 27var pingPeriod = 5 * time.Second 28 29func (a *StreamplaceAPI) HandleWebsocket(ctx context.Context) httprouter.Handle { 30 ctx = log.WithLogValues(ctx, "func", "HandleWebsocket") 31 return func(w http.ResponseWriter, req *http.Request, params httprouter.Params) { 32 ip, _, err := net.SplitHostPort(req.RemoteAddr) 33 if err != nil { 34 ip = req.RemoteAddr 35 } 36 37 if a.CLI.RateLimitWebsocket > 0 { 38 if !a.connTracker.AddConnection(ip) { 39 log.Warn(ctx, "rate limit exceeded", "ip", ip, "path", req.URL.Path) 40 apierrors.WriteHTTPTooManyRequests(w, "rate limit exceeded") 41 return 42 } 43 44 defer a.connTracker.RemoveConnection(ip) 45 } 46 47 uu, _ := uuid.NewV7() 48 connID := uu.String() 49 50 ctx = log.WithLogValues(ctx, "uuid", connID, "remoteAddr", req.RemoteAddr, "url", req.URL.String()) 51 log.Log(ctx, "websocket opened") 52 spmetrics.WebsocketsOpen.Inc() 53 defer spmetrics.WebsocketsOpen.Dec() 54 user := params.ByName("repoDID") 55 if user == "" { 56 apierrors.WriteHTTPBadRequest(w, "user required", nil) 57 return 58 } 59 repoDID, err := a.NormalizeUser(ctx, user) 60 if err != nil { 61 apierrors.WriteHTTPNotFound(w, "user not found", err) 62 return 63 } 64 conn, err := upgrader.Upgrade(w, req, nil) 65 if err != nil { 66 apierrors.WriteHTTPInternalServerError(w, "could not upgrade to websocket", err) 67 return 68 } 69 ctx, cancel := context.WithCancel(ctx) 70 defer cancel() 71 defer conn.Close() 72 73 initialBurst := make(chan any, 200) 74 err = conn.SetReadDeadline(time.Now().Add(30 * time.Second)) 75 if err != nil { 76 log.Error(ctx, "could not set read deadline", "error", err) 77 return 78 } 79 80 pongCh := make(chan struct{}) 81 82 go func() { 83 for { 84 select { 85 case <-ctx.Done(): 86 return 87 case <-pongCh: 88 err := conn.SetReadDeadline(time.Now().Add(30 * time.Second)) 89 if err != nil { 90 log.Error(ctx, "could not set read deadline", "error", err) 91 return 92 } 93 case <-time.After(30 * time.Second): 94 log.Log(ctx, "websocket timeout, closing connection") 95 // timeout! 96 conn.Close() 97 cancel() 98 return 99 } 100 } 101 }() 102 103 conn.SetPongHandler(func(appData string) error { 104 log.Debug(ctx, "received pong", "appData", appData) 105 pongCh <- struct{}{} 106 return nil 107 }) 108 go func() { 109 110 ch := a.Bus.Subscribe(repoDID) 111 defer a.Bus.Unsubscribe(repoDID, ch) 112 // Create a ticker that fires every 3 seconds 113 ticker := time.NewTicker(3 * time.Second) 114 pingTicker := time.NewTicker(pingPeriod) 115 defer ticker.Stop() 116 defer pingTicker.Stop() 117 118 send := func(msg any) { 119 bs, err := json.Marshal(msg) 120 if err != nil { 121 log.Error(ctx, "could not marshal message", "error", err) 122 return 123 } 124 log.Debug(ctx, "sending message", "message", string(bs)) 125 err = conn.WriteMessage(websocket.TextMessage, bs) 126 if err != nil { 127 log.Error(ctx, "could not write message", "error", err) 128 return 129 } 130 } 131 132 for { 133 select { 134 case msg := <-ch: 135 send(msg) 136 case msg := <-initialBurst: 137 send(msg) 138 case <-ticker.C: 139 count := spmetrics.GetViewCount(repoDID) 140 bs, err := json.Marshal(streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"}) 141 if err != nil { 142 log.Error(ctx, "could not marshal view count", "error", err) 143 continue 144 } 145 err = conn.WriteMessage(websocket.TextMessage, bs) 146 if err != nil { 147 log.Error(ctx, "could not write ping message", "error", err) 148 return 149 } 150 case <-pingTicker.C: 151 err := conn.WriteMessage(websocket.PingMessage, []byte{}) 152 if err != nil { 153 log.Error(ctx, "could not write ping message", "error", err) 154 return 155 } 156 case <-ctx.Done(): 157 log.Debug(ctx, "context done, stopping websocket sender") 158 return 159 } 160 } 161 }() 162 163 go func() { 164 seg, err := a.Model.LatestSegmentForUser(repoDID) 165 if err != nil { 166 log.Error(ctx, "could not get replies", "error", err) 167 return 168 } 169 spSeg, err := seg.ToStreamplaceSegment() 170 if err != nil { 171 log.Error(ctx, "could not convert segment to streamplace segment", "error", err) 172 return 173 } 174 initialBurst <- spSeg 175 if a.CLI.LivepeerGatewayURL != "" { 176 renditions, err := renditions.GenerateRenditions(spSeg) 177 if err != nil { 178 log.Error(ctx, "could not generate renditions", "error", err) 179 return 180 } 181 outRs := streamplace.Defs_Renditions{ 182 LexiconTypeID: "place.stream.defs#renditions", 183 } 184 outRs.Renditions = []*streamplace.Defs_Rendition{} 185 for _, r := range renditions { 186 outRs.Renditions = append(outRs.Renditions, &streamplace.Defs_Rendition{ 187 LexiconTypeID: "place.stream.defs#rendition", 188 Name: r.Name, 189 }) 190 } 191 initialBurst <- outRs 192 } 193 }() 194 195 go func() { 196 ls, err := a.Model.GetLatestLivestreamForRepo(repoDID) 197 if err != nil { 198 log.Error(ctx, "could not get latest livestream", "error", err) 199 return 200 } 201 lsv, err := ls.ToLivestreamView() 202 if err != nil { 203 log.Error(ctx, "could not marshal livestream", "error", err) 204 return 205 } 206 initialBurst <- lsv 207 }() 208 209 go func() { 210 count := spmetrics.GetViewCount(repoDID) 211 initialBurst <- streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"} 212 }() 213 214 go func() { 215 messages, err := a.Model.MostRecentChatMessages(repoDID) 216 if err != nil { 217 log.Error(ctx, "could not get chat messages", "error", err) 218 return 219 } 220 for _, message := range messages { 221 initialBurst <- message 222 } 223 }() 224 225 for { 226 messageType, message, err := conn.ReadMessage() 227 if err != nil { 228 log.Error(ctx, "error reading message", "error", err) 229 break 230 } 231 log.Log(ctx, "received message", "messageType", messageType, "message", string(message)) 232 } 233 } 234}