Live video on the AT Protocol
at issue-784 334 lines 9.2 kB view raw
1package api 2 3import ( 4 "context" 5 "encoding/json" 6 "net" 7 "net/http" 8 "time" 9 10 bsky "github.com/bluesky-social/indigo/api/bsky" 11 "github.com/google/uuid" 12 "github.com/gorilla/websocket" 13 "github.com/julienschmidt/httprouter" 14 15 "stream.place/streamplace/pkg/bus" 16 apierrors "stream.place/streamplace/pkg/errors" 17 "stream.place/streamplace/pkg/log" 18 "stream.place/streamplace/pkg/renditions" 19 "stream.place/streamplace/pkg/spmetrics" 20 "stream.place/streamplace/pkg/streamplace" 21) 22 23// todo: does this mean a whole message has to fit within the buffer? 24var upgrader = websocket.Upgrader{ 25 ReadBufferSize: 1024, 26 WriteBufferSize: 1024, 27 CheckOrigin: func(r *http.Request) bool { 28 return true 29 }, 30} 31 32var pingPeriod = 5 * time.Second 33 34func (a *StreamplaceAPI) HandleWebsocket(ctx context.Context) httprouter.Handle { 35 ctx = log.WithLogValues(ctx, "func", "HandleWebsocket") 36 return func(w http.ResponseWriter, req *http.Request, params httprouter.Params) { 37 ip, _, err := net.SplitHostPort(req.RemoteAddr) 38 if err != nil { 39 ip = req.RemoteAddr 40 } 41 42 if a.CLI.RateLimitWebsocket > 0 { 43 if !a.connTracker.AddConnection(ip) { 44 log.Warn(ctx, "rate limit exceeded", "ip", ip, "path", req.URL.Path) 45 apierrors.WriteHTTPTooManyRequests(w, "rate limit exceeded") 46 return 47 } 48 49 defer a.connTracker.RemoveConnection(ip) 50 } 51 52 uu, _ := uuid.NewV7() 53 connID := uu.String() 54 55 ctx = log.WithLogValues(ctx, "uuid", connID, "remoteAddr", req.RemoteAddr, "url", req.URL.String()) 56 log.Log(ctx, "websocket opened") 57 spmetrics.WebsocketsOpen.Inc() 58 defer spmetrics.WebsocketsOpen.Dec() 59 user := params.ByName("repoDID") 60 if user == "" { 61 apierrors.WriteHTTPBadRequest(w, "user required", nil) 62 return 63 } 64 repoDID, err := a.NormalizeUser(ctx, user) 65 if err != nil { 66 apierrors.WriteHTTPNotFound(w, "user not found", err) 67 return 68 } 69 viewerDID := req.URL.Query().Get("viewer") 70 if viewerDID != "" { 71 if _, err := a.ATSync.SyncBlueskyRepoCached(ctx, viewerDID, a.Model); err != nil { 72 log.Error(ctx, "could not sync viewer repo", "error", err, "viewerDID", viewerDID) 73 } 74 } 75 conn, err := upgrader.Upgrade(w, req, nil) 76 if err != nil { 77 apierrors.WriteHTTPInternalServerError(w, "could not upgrade to websocket", err) 78 return 79 } 80 ctx, cancel := context.WithCancel(ctx) 81 defer cancel() 82 defer conn.Close() 83 84 initialBurst := make(chan any, 200) 85 err = conn.SetReadDeadline(time.Now().Add(30 * time.Second)) 86 if err != nil { 87 log.Error(ctx, "could not set read deadline", "error", err) 88 return 89 } 90 91 pongCh := make(chan struct{}) 92 93 go func() { 94 for { 95 select { 96 case <-ctx.Done(): 97 return 98 case <-pongCh: 99 err := conn.SetReadDeadline(time.Now().Add(30 * time.Second)) 100 if err != nil { 101 log.Error(ctx, "could not set read deadline", "error", err) 102 return 103 } 104 case <-time.After(30 * time.Second): 105 log.Log(ctx, "websocket timeout, closing connection") 106 // timeout! 107 conn.Close() 108 cancel() 109 return 110 } 111 } 112 }() 113 114 conn.SetPongHandler(func(appData string) error { 115 log.Debug(ctx, "received pong", "appData", appData) 116 pongCh <- struct{}{} 117 return nil 118 }) 119 go func() { 120 blockedDIDs := make(map[string]bool) 121 if viewerDID != "" { 122 dids, err := a.Model.GetBlockedDIDs(ctx, viewerDID) 123 if err != nil { 124 log.Error(ctx, "could not get blocked DIDs", "error", err) 125 } 126 for _, did := range dids { 127 blockedDIDs[did] = true 128 } 129 } 130 131 ch := a.Bus.Subscribe(repoDID) 132 defer a.Bus.Unsubscribe(repoDID, ch) 133 134 var viewerCh <-chan bus.Message 135 if viewerDID != "" { 136 viewerCh = a.Bus.Subscribe(viewerDID) 137 defer a.Bus.Unsubscribe(viewerDID, viewerCh) 138 } 139 140 ticker := time.NewTicker(3 * time.Second) 141 pingTicker := time.NewTicker(pingPeriod) 142 defer ticker.Stop() 143 defer pingTicker.Stop() 144 145 send := func(msg any) { 146 bs, err := json.Marshal(msg) 147 if err != nil { 148 log.Error(ctx, "could not marshal message", "error", err) 149 return 150 } 151 log.Debug(ctx, "sending message", "message", string(bs)) 152 err = conn.WriteMessage(websocket.TextMessage, bs) 153 if err != nil { 154 log.Error(ctx, "could not write message", "error", err) 155 return 156 } 157 } 158 159 isBlockedChat := func(msg any) bool { 160 if len(blockedDIDs) == 0 { 161 return false 162 } 163 if chatMsg, ok := msg.(*streamplace.ChatDefs_MessageView); ok { 164 return chatMsg.Author != nil && blockedDIDs[chatMsg.Author.Did] 165 } 166 return false 167 } 168 169 for { 170 select { 171 case msg := <-ch: 172 if isBlockedChat(msg) { 173 continue 174 } 175 send(msg) 176 case msg := <-viewerCh: 177 if blockView, ok := msg.(*streamplace.Defs_BlockView); ok { 178 if blockView.Blocker != nil && blockView.Blocker.Did == viewerDID && blockView.Record != nil { 179 blockedDIDs[blockView.Record.Subject] = true 180 send(blockView) 181 } 182 } 183 case msg := <-initialBurst: 184 send(msg) 185 case <-ticker.C: 186 count := a.Bus.GetViewerCount(repoDID) 187 bs, err := json.Marshal(streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"}) 188 if err != nil { 189 log.Error(ctx, "could not marshal view count", "error", err) 190 continue 191 } 192 err = conn.WriteMessage(websocket.TextMessage, bs) 193 if err != nil { 194 log.Error(ctx, "could not write ping message", "error", err) 195 return 196 } 197 case <-pingTicker.C: 198 err := conn.WriteMessage(websocket.PingMessage, []byte{}) 199 if err != nil { 200 log.Error(ctx, "could not write ping message", "error", err) 201 return 202 } 203 case <-ctx.Done(): 204 log.Debug(ctx, "context done, stopping websocket sender") 205 return 206 } 207 } 208 }() 209 210 go func() { 211 profile, err := a.Model.GetRepo(repoDID) 212 if err != nil { 213 log.Error(ctx, "could not get profile", "error", err) 214 return 215 } 216 if profile != nil { 217 p := map[string]any{ 218 "$type": "app.bsky.actor.defs#profileViewBasic", 219 "did": repoDID, 220 "handle": profile.Handle, 221 } 222 initialBurst <- p 223 } 224 }() 225 226 go func() { 227 seg, err := a.LocalDB.LatestSegmentForUser(repoDID) 228 if err != nil { 229 log.Error(ctx, "could not get replies", "error", err) 230 return 231 } 232 spSeg, err := seg.ToStreamplaceSegment() 233 if err != nil { 234 log.Error(ctx, "could not convert segment to streamplace segment", "error", err) 235 return 236 } 237 initialBurst <- spSeg 238 if a.CLI.LivepeerGatewayURL != "" { 239 renditions, err := renditions.GenerateRenditions(spSeg) 240 if err != nil { 241 log.Error(ctx, "could not generate renditions", "error", err) 242 return 243 } 244 outRs := streamplace.Defs_Renditions{ 245 LexiconTypeID: "place.stream.defs#renditions", 246 } 247 outRs.Renditions = []*streamplace.Defs_Rendition{} 248 for _, r := range renditions { 249 outRs.Renditions = append(outRs.Renditions, &streamplace.Defs_Rendition{ 250 LexiconTypeID: "place.stream.defs#rendition", 251 Name: r.Name, 252 }) 253 } 254 initialBurst <- outRs 255 } 256 }() 257 258 go func() { 259 ls, err := a.Model.GetLatestLivestreamForRepo(repoDID) 260 if err != nil { 261 log.Error(ctx, "could not get latest livestream", "error", err) 262 return 263 } 264 lsv, err := ls.ToLivestreamView() 265 if err != nil { 266 log.Error(ctx, "could not marshal livestream", "error", err) 267 return 268 } 269 initialBurst <- lsv 270 }() 271 272 go func() { 273 count := a.Bus.GetViewerCount(repoDID) 274 initialBurst <- streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"} 275 }() 276 277 go func() { 278 messages, err := a.Model.MostRecentChatMessagesForViewer(repoDID, viewerDID) 279 if err != nil { 280 log.Error(ctx, "could not get chat messages", "error", err) 281 return 282 } 283 for _, message := range messages { 284 initialBurst <- message 285 } 286 }() 287 288 go func() { 289 teleports, err := a.Model.GetActiveTeleportsToRepo(repoDID) 290 if err != nil { 291 log.Error(ctx, "could not get active teleports", "error", err) 292 return 293 } 294 // just send the latest one if it started <3m ago 295 if len(teleports) > 0 && teleports[0].StartsAt.After(time.Now().Add(-3*time.Minute)) { 296 tp := teleports[0] 297 if tp.Repo == nil { 298 log.Error(ctx, "teleportee repo is nil", "uri", tp.URI) 299 } 300 viewerCount := a.Bus.GetViewerCount(tp.RepoDID) 301 arrivalMsg := streamplace.Livestream_TeleportArrival{ 302 LexiconTypeID: "place.stream.livestream#teleportArrival", 303 TeleportUri: tp.URI, 304 Source: &bsky.ActorDefs_ProfileViewBasic{ 305 Did: tp.RepoDID, 306 Handle: tp.Repo.Handle, 307 }, 308 ViewerCount: int64(viewerCount), 309 StartsAt: tp.StartsAt.Format(time.RFC3339), 310 } 311 312 // get the source chat profile 313 chatProfile, err := a.Model.GetChatProfile(ctx, tp.RepoDID) 314 if err == nil && chatProfile != nil { 315 spcp, err := chatProfile.ToStreamplaceChatProfile() 316 if err == nil { 317 arrivalMsg.ChatProfile = spcp 318 } 319 } 320 321 initialBurst <- arrivalMsg 322 } 323 }() 324 325 for { 326 messageType, message, err := conn.ReadMessage() 327 if err != nil { 328 log.Error(ctx, "error reading message", "error", err) 329 break 330 } 331 log.Log(ctx, "received message", "messageType", messageType, "message", string(message)) 332 } 333 } 334}