Live video on the AT Protocol
at eli/routing-cleanup 309 lines 8.6 kB view raw
1package api 2 3import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "net" 8 "net/http" 9 "time" 10 11 bsky "github.com/bluesky-social/indigo/api/bsky" 12 "github.com/google/uuid" 13 "github.com/gorilla/websocket" 14 "github.com/julienschmidt/httprouter" 15 16 "stream.place/streamplace/pkg/atproto" 17 apierrors "stream.place/streamplace/pkg/errors" 18 "stream.place/streamplace/pkg/log" 19 "stream.place/streamplace/pkg/renditions" 20 "stream.place/streamplace/pkg/spmetrics" 21 "stream.place/streamplace/pkg/streamplace" 22) 23 24// todo: does this mean a whole message has to fit within the buffer? 25var upgrader = websocket.Upgrader{ 26 ReadBufferSize: 1024, 27 WriteBufferSize: 1024, 28 CheckOrigin: func(r *http.Request) bool { 29 return true 30 }, 31} 32 33var pingPeriod = 5 * time.Second 34 35func (a *StreamplaceAPI) HandleWebsocket(ctx context.Context) httprouter.Handle { 36 ctx = log.WithLogValues(ctx, "func", "HandleWebsocket") 37 return func(w http.ResponseWriter, req *http.Request, params httprouter.Params) { 38 ip, _, err := net.SplitHostPort(req.RemoteAddr) 39 if err != nil { 40 ip = req.RemoteAddr 41 } 42 43 if a.CLI.RateLimitWebsocket > 0 { 44 if !a.connTracker.AddConnection(ip) { 45 log.Warn(ctx, "rate limit exceeded", "ip", ip, "path", req.URL.Path) 46 apierrors.WriteHTTPTooManyRequests(w, "rate limit exceeded") 47 return 48 } 49 50 defer a.connTracker.RemoveConnection(ip) 51 } 52 53 uu, _ := uuid.NewV7() 54 connID := uu.String() 55 56 ctx = log.WithLogValues(ctx, "uuid", connID, "remoteAddr", req.RemoteAddr, "url", req.URL.String()) 57 log.Log(ctx, "websocket opened") 58 spmetrics.WebsocketsOpen.Inc() 59 defer spmetrics.WebsocketsOpen.Dec() 60 user := params.ByName("repoDID") 61 if user == "" { 62 apierrors.WriteHTTPBadRequest(w, "user required", nil) 63 return 64 } 65 repoDID, err := a.NormalizeUser(ctx, user) 66 if err != nil { 67 apierrors.WriteHTTPNotFound(w, "user not found", err) 68 return 69 } 70 conn, err := upgrader.Upgrade(w, req, nil) 71 if err != nil { 72 apierrors.WriteHTTPInternalServerError(w, "could not upgrade to websocket", err) 73 return 74 } 75 ctx, cancel := context.WithCancel(ctx) 76 defer cancel() 77 defer conn.Close() 78 79 initialBurst := make(chan any, 200) 80 err = conn.SetReadDeadline(time.Now().Add(30 * time.Second)) 81 if err != nil { 82 log.Error(ctx, "could not set read deadline", "error", err) 83 return 84 } 85 86 pongCh := make(chan struct{}) 87 88 go func() { 89 for { 90 select { 91 case <-ctx.Done(): 92 return 93 case <-pongCh: 94 err := conn.SetReadDeadline(time.Now().Add(30 * time.Second)) 95 if err != nil { 96 log.Error(ctx, "could not set read deadline", "error", err) 97 return 98 } 99 case <-time.After(30 * time.Second): 100 log.Log(ctx, "websocket timeout, closing connection") 101 // timeout! 102 conn.Close() 103 cancel() 104 return 105 } 106 } 107 }() 108 109 conn.SetPongHandler(func(appData string) error { 110 log.Debug(ctx, "received pong", "appData", appData) 111 pongCh <- struct{}{} 112 return nil 113 }) 114 go func() { 115 116 ch := a.Bus.Subscribe(repoDID) 117 defer a.Bus.Unsubscribe(repoDID, ch) 118 // Create a ticker that fires every 3 seconds 119 ticker := time.NewTicker(3 * time.Second) 120 pingTicker := time.NewTicker(pingPeriod) 121 defer ticker.Stop() 122 defer pingTicker.Stop() 123 124 send := func(msg any) { 125 bs, err := json.Marshal(msg) 126 if err != nil { 127 log.Error(ctx, "could not marshal message", "error", err) 128 return 129 } 130 log.Debug(ctx, "sending message", "message", string(bs)) 131 err = conn.WriteMessage(websocket.TextMessage, bs) 132 if err != nil { 133 log.Error(ctx, "could not write message", "error", err) 134 return 135 } 136 } 137 138 for { 139 select { 140 case msg := <-ch: 141 send(msg) 142 case msg := <-initialBurst: 143 send(msg) 144 case <-ticker.C: 145 count := a.Bus.GetViewerCount(repoDID) 146 bs, err := json.Marshal(streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"}) 147 if err != nil { 148 log.Error(ctx, "could not marshal view count", "error", err) 149 continue 150 } 151 err = conn.WriteMessage(websocket.TextMessage, bs) 152 if err != nil { 153 log.Error(ctx, "could not write ping message", "error", err) 154 return 155 } 156 case <-pingTicker.C: 157 err := conn.WriteMessage(websocket.PingMessage, []byte{}) 158 if err != nil { 159 log.Error(ctx, "could not write ping message", "error", err) 160 return 161 } 162 case <-ctx.Done(): 163 log.Debug(ctx, "context done, stopping websocket sender") 164 return 165 } 166 } 167 }() 168 169 go func() { 170 profile, err := a.Model.GetRepo(repoDID) 171 if err != nil { 172 log.Error(ctx, "could not get profile", "error", err) 173 return 174 } 175 if profile != nil { 176 p := map[string]any{ 177 "$type": "app.bsky.actor.defs#profileViewBasic", 178 "did": repoDID, 179 "handle": profile.Handle, 180 } 181 initialBurst <- p 182 } 183 }() 184 185 go func() { 186 seg, err := a.LocalDB.LatestSegmentForUser(repoDID) 187 if err != nil { 188 log.Error(ctx, "could not get replies", "error", err) 189 return 190 } 191 spSeg, err := seg.ToStreamplaceSegment() 192 if err != nil { 193 log.Error(ctx, "could not convert segment to streamplace segment", "error", err) 194 return 195 } 196 initialBurst <- spSeg 197 outRs := streamplace.Defs_Renditions{ 198 LexiconTypeID: "place.stream.defs#renditions", 199 Renditions: []*streamplace.Defs_Rendition{}, 200 } 201 if a.CLI.LivepeerGatewayURL != "" { 202 videoRenditions, err := renditions.GenerateRenditions(spSeg) 203 if err != nil { 204 log.Error(ctx, "could not generate renditions", "error", err) 205 return 206 } 207 for _, r := range videoRenditions { 208 outRs.Renditions = append(outRs.Renditions, &streamplace.Defs_Rendition{ 209 LexiconTypeID: "place.stream.defs#rendition", 210 Name: r.Name, 211 }) 212 } 213 } 214 outRs.Renditions = append(outRs.Renditions, &streamplace.Defs_Rendition{ 215 LexiconTypeID: "place.stream.defs#rendition", 216 Name: renditions.AudioRendition.Name, 217 }) 218 initialBurst <- outRs 219 }() 220 221 go func() { 222 ls, err := a.Model.GetLatestLivestreamForRepo(repoDID) 223 if err != nil { 224 log.Error(ctx, "could not get latest livestream", "error", err) 225 return 226 } 227 if ls == nil { 228 log.Error(ctx, "no livestream found", "repoDID", repoDID) 229 return 230 } 231 lsv, err := ls.ToLivestreamView() 232 if err != nil { 233 log.Error(ctx, "could not marshal livestream", "error", err) 234 return 235 } 236 initialBurst <- lsv 237 }() 238 239 go func() { 240 count := a.Bus.GetViewerCount(repoDID) 241 initialBurst <- streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"} 242 }() 243 244 go func() { 245 messages, err := a.Model.MostRecentChatMessages(repoDID) 246 if err != nil { 247 log.Error(ctx, "could not get chat messages", "error", err) 248 return 249 } 250 251 // Add mod badges to messages 252 issuerDID := fmt.Sprintf("did:web:%s", a.CLI.BroadcasterHost) 253 for _, message := range messages { 254 err := atproto.AddModBadgeIfApplicable(ctx, message, repoDID, issuerDID, a.Model) 255 if err != nil { 256 log.Error(ctx, "failed to add mod badge to message", "error", err) 257 } 258 initialBurst <- message 259 } 260 }() 261 262 go func() { 263 teleports, err := a.Model.GetActiveTeleportsToRepo(repoDID) 264 if err != nil { 265 log.Error(ctx, "could not get active teleports", "error", err) 266 return 267 } 268 // just send the latest one if it started <3m ago 269 if len(teleports) > 0 && teleports[0].StartsAt.After(time.Now().Add(-3*time.Minute)) { 270 tp := teleports[0] 271 if tp.Repo == nil { 272 log.Error(ctx, "teleportee repo is nil", "uri", tp.URI) 273 return 274 } 275 viewerCount := a.Bus.GetViewerCount(tp.RepoDID) 276 arrivalMsg := streamplace.Livestream_TeleportArrival{ 277 LexiconTypeID: "place.stream.livestream#teleportArrival", 278 TeleportUri: tp.URI, 279 Source: &bsky.ActorDefs_ProfileViewBasic{ 280 Did: tp.RepoDID, 281 Handle: tp.Repo.Handle, 282 }, 283 ViewerCount: int64(viewerCount), 284 StartsAt: tp.StartsAt.Format(time.RFC3339), 285 } 286 287 // get the source chat profile 288 chatProfile, err := a.Model.GetChatProfile(ctx, tp.RepoDID) 289 if err == nil && chatProfile != nil { 290 spcp, err := chatProfile.ToStreamplaceChatProfile() 291 if err == nil { 292 arrivalMsg.ChatProfile = spcp 293 } 294 } 295 296 initialBurst <- arrivalMsg 297 } 298 }() 299 300 for { 301 messageType, message, err := conn.ReadMessage() 302 if err != nil { 303 log.Error(ctx, "error reading message", "error", err) 304 break 305 } 306 log.Log(ctx, "received message", "messageType", messageType, "message", string(message)) 307 } 308 } 309}