Live video on the AT Protocol
at natb/badges 300 lines 8.3 kB view raw
1package api 2 3import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "net" 8 "net/http" 9 "time" 10 11 bsky "github.com/bluesky-social/indigo/api/bsky" 12 "github.com/google/uuid" 13 "github.com/gorilla/websocket" 14 "github.com/julienschmidt/httprouter" 15 16 "stream.place/streamplace/pkg/atproto" 17 apierrors "stream.place/streamplace/pkg/errors" 18 "stream.place/streamplace/pkg/log" 19 "stream.place/streamplace/pkg/renditions" 20 "stream.place/streamplace/pkg/spmetrics" 21 "stream.place/streamplace/pkg/streamplace" 22) 23 24// todo: does this mean a whole message has to fit within the buffer? 25var upgrader = websocket.Upgrader{ 26 ReadBufferSize: 1024, 27 WriteBufferSize: 1024, 28 CheckOrigin: func(r *http.Request) bool { 29 return true 30 }, 31} 32 33var pingPeriod = 5 * time.Second 34 35func (a *StreamplaceAPI) HandleWebsocket(ctx context.Context) httprouter.Handle { 36 ctx = log.WithLogValues(ctx, "func", "HandleWebsocket") 37 return func(w http.ResponseWriter, req *http.Request, params httprouter.Params) { 38 ip, _, err := net.SplitHostPort(req.RemoteAddr) 39 if err != nil { 40 ip = req.RemoteAddr 41 } 42 43 if a.CLI.RateLimitWebsocket > 0 { 44 if !a.connTracker.AddConnection(ip) { 45 log.Warn(ctx, "rate limit exceeded", "ip", ip, "path", req.URL.Path) 46 apierrors.WriteHTTPTooManyRequests(w, "rate limit exceeded") 47 return 48 } 49 50 defer a.connTracker.RemoveConnection(ip) 51 } 52 53 uu, _ := uuid.NewV7() 54 connID := uu.String() 55 56 ctx = log.WithLogValues(ctx, "uuid", connID, "remoteAddr", req.RemoteAddr, "url", req.URL.String()) 57 log.Log(ctx, "websocket opened") 58 spmetrics.WebsocketsOpen.Inc() 59 defer spmetrics.WebsocketsOpen.Dec() 60 user := params.ByName("repoDID") 61 if user == "" { 62 apierrors.WriteHTTPBadRequest(w, "user required", nil) 63 return 64 } 65 repoDID, err := a.NormalizeUser(ctx, user) 66 if err != nil { 67 apierrors.WriteHTTPNotFound(w, "user not found", err) 68 return 69 } 70 conn, err := upgrader.Upgrade(w, req, nil) 71 if err != nil { 72 apierrors.WriteHTTPInternalServerError(w, "could not upgrade to websocket", err) 73 return 74 } 75 ctx, cancel := context.WithCancel(ctx) 76 defer cancel() 77 defer conn.Close() 78 79 initialBurst := make(chan any, 200) 80 err = conn.SetReadDeadline(time.Now().Add(30 * time.Second)) 81 if err != nil { 82 log.Error(ctx, "could not set read deadline", "error", err) 83 return 84 } 85 86 pongCh := make(chan struct{}) 87 88 go func() { 89 for { 90 select { 91 case <-ctx.Done(): 92 return 93 case <-pongCh: 94 err := conn.SetReadDeadline(time.Now().Add(30 * time.Second)) 95 if err != nil { 96 log.Error(ctx, "could not set read deadline", "error", err) 97 return 98 } 99 case <-time.After(30 * time.Second): 100 log.Log(ctx, "websocket timeout, closing connection") 101 // timeout! 102 conn.Close() 103 cancel() 104 return 105 } 106 } 107 }() 108 109 conn.SetPongHandler(func(appData string) error { 110 log.Debug(ctx, "received pong", "appData", appData) 111 pongCh <- struct{}{} 112 return nil 113 }) 114 go func() { 115 116 ch := a.Bus.Subscribe(repoDID) 117 defer a.Bus.Unsubscribe(repoDID, ch) 118 // Create a ticker that fires every 3 seconds 119 ticker := time.NewTicker(3 * time.Second) 120 pingTicker := time.NewTicker(pingPeriod) 121 defer ticker.Stop() 122 defer pingTicker.Stop() 123 124 send := func(msg any) { 125 bs, err := json.Marshal(msg) 126 if err != nil { 127 log.Error(ctx, "could not marshal message", "error", err) 128 return 129 } 130 log.Debug(ctx, "sending message", "message", string(bs)) 131 err = conn.WriteMessage(websocket.TextMessage, bs) 132 if err != nil { 133 log.Error(ctx, "could not write message", "error", err) 134 return 135 } 136 } 137 138 for { 139 select { 140 case msg := <-ch: 141 send(msg) 142 case msg := <-initialBurst: 143 send(msg) 144 case <-ticker.C: 145 count := a.Bus.GetViewerCount(repoDID) 146 bs, err := json.Marshal(streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"}) 147 if err != nil { 148 log.Error(ctx, "could not marshal view count", "error", err) 149 continue 150 } 151 err = conn.WriteMessage(websocket.TextMessage, bs) 152 if err != nil { 153 log.Error(ctx, "could not write ping message", "error", err) 154 return 155 } 156 case <-pingTicker.C: 157 err := conn.WriteMessage(websocket.PingMessage, []byte{}) 158 if err != nil { 159 log.Error(ctx, "could not write ping message", "error", err) 160 return 161 } 162 case <-ctx.Done(): 163 log.Debug(ctx, "context done, stopping websocket sender") 164 return 165 } 166 } 167 }() 168 169 go func() { 170 profile, err := a.Model.GetRepo(repoDID) 171 if err != nil { 172 log.Error(ctx, "could not get profile", "error", err) 173 return 174 } 175 if profile != nil { 176 p := map[string]any{ 177 "$type": "app.bsky.actor.defs#profileViewBasic", 178 "did": repoDID, 179 "handle": profile.Handle, 180 } 181 initialBurst <- p 182 } 183 }() 184 185 go func() { 186 seg, err := a.LocalDB.LatestSegmentForUser(repoDID) 187 if err != nil { 188 log.Error(ctx, "could not get replies", "error", err) 189 return 190 } 191 spSeg, err := seg.ToStreamplaceSegment() 192 if err != nil { 193 log.Error(ctx, "could not convert segment to streamplace segment", "error", err) 194 return 195 } 196 initialBurst <- spSeg 197 if a.CLI.LivepeerGatewayURL != "" { 198 renditions, err := renditions.GenerateRenditions(spSeg) 199 if err != nil { 200 log.Error(ctx, "could not generate renditions", "error", err) 201 return 202 } 203 outRs := streamplace.Defs_Renditions{ 204 LexiconTypeID: "place.stream.defs#renditions", 205 } 206 outRs.Renditions = []*streamplace.Defs_Rendition{} 207 for _, r := range renditions { 208 outRs.Renditions = append(outRs.Renditions, &streamplace.Defs_Rendition{ 209 LexiconTypeID: "place.stream.defs#rendition", 210 Name: r.Name, 211 }) 212 } 213 initialBurst <- outRs 214 } 215 }() 216 217 go func() { 218 ls, err := a.Model.GetLatestLivestreamForRepo(repoDID) 219 if err != nil { 220 log.Error(ctx, "could not get latest livestream", "error", err) 221 return 222 } 223 lsv, err := ls.ToLivestreamView() 224 if err != nil { 225 log.Error(ctx, "could not marshal livestream", "error", err) 226 return 227 } 228 initialBurst <- lsv 229 }() 230 231 go func() { 232 count := a.Bus.GetViewerCount(repoDID) 233 initialBurst <- streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"} 234 }() 235 236 go func() { 237 messages, err := a.Model.MostRecentChatMessages(repoDID) 238 if err != nil { 239 log.Error(ctx, "could not get chat messages", "error", err) 240 return 241 } 242 243 // Add mod badges to messages 244 issuerDID := fmt.Sprintf("did:web:%s", a.CLI.BroadcasterHost) 245 for _, message := range messages { 246 err := atproto.AddModBadgeIfApplicable(ctx, message, repoDID, issuerDID, a.Model) 247 if err != nil { 248 log.Error(ctx, "failed to add mod badge to message", "error", err) 249 } 250 initialBurst <- message 251 } 252 }() 253 254 go func() { 255 teleports, err := a.Model.GetActiveTeleportsToRepo(repoDID) 256 if err != nil { 257 log.Error(ctx, "could not get active teleports", "error", err) 258 return 259 } 260 // just send the latest one if it started <3m ago 261 if len(teleports) > 0 && teleports[0].StartsAt.After(time.Now().Add(-3*time.Minute)) { 262 tp := teleports[0] 263 if tp.Repo == nil { 264 log.Error(ctx, "teleportee repo is nil", "uri", tp.URI) 265 } 266 viewerCount := a.Bus.GetViewerCount(tp.RepoDID) 267 arrivalMsg := streamplace.Livestream_TeleportArrival{ 268 LexiconTypeID: "place.stream.livestream#teleportArrival", 269 TeleportUri: tp.URI, 270 Source: &bsky.ActorDefs_ProfileViewBasic{ 271 Did: tp.RepoDID, 272 Handle: tp.Repo.Handle, 273 }, 274 ViewerCount: int64(viewerCount), 275 StartsAt: tp.StartsAt.Format(time.RFC3339), 276 } 277 278 // get the source chat profile 279 chatProfile, err := a.Model.GetChatProfile(ctx, tp.RepoDID) 280 if err == nil && chatProfile != nil { 281 spcp, err := chatProfile.ToStreamplaceChatProfile() 282 if err == nil { 283 arrivalMsg.ChatProfile = spcp 284 } 285 } 286 287 initialBurst <- arrivalMsg 288 } 289 }() 290 291 for { 292 messageType, message, err := conn.ReadMessage() 293 if err != nil { 294 log.Error(ctx, "error reading message", "error", err) 295 break 296 } 297 log.Log(ctx, "received message", "messageType", messageType, "message", string(message)) 298 } 299 } 300}