Live video on the AT Protocol
at natb/urfave 291 lines 7.9 kB view raw
1package api 2 3import ( 4 "context" 5 "encoding/json" 6 "net" 7 "net/http" 8 "time" 9 10 bsky "github.com/bluesky-social/indigo/api/bsky" 11 "github.com/google/uuid" 12 "github.com/gorilla/websocket" 13 "github.com/julienschmidt/httprouter" 14 15 apierrors "stream.place/streamplace/pkg/errors" 16 "stream.place/streamplace/pkg/log" 17 "stream.place/streamplace/pkg/renditions" 18 "stream.place/streamplace/pkg/spmetrics" 19 "stream.place/streamplace/pkg/streamplace" 20) 21 22// todo: does this mean a whole message has to fit within the buffer? 23var upgrader = websocket.Upgrader{ 24 ReadBufferSize: 1024, 25 WriteBufferSize: 1024, 26 CheckOrigin: func(r *http.Request) bool { 27 return true 28 }, 29} 30 31var pingPeriod = 5 * time.Second 32 33func (a *StreamplaceAPI) HandleWebsocket(ctx context.Context) httprouter.Handle { 34 ctx = log.WithLogValues(ctx, "func", "HandleWebsocket") 35 return func(w http.ResponseWriter, req *http.Request, params httprouter.Params) { 36 ip, _, err := net.SplitHostPort(req.RemoteAddr) 37 if err != nil { 38 ip = req.RemoteAddr 39 } 40 41 if a.CLI.RateLimitWebsocket > 0 { 42 if !a.connTracker.AddConnection(ip) { 43 log.Warn(ctx, "rate limit exceeded", "ip", ip, "path", req.URL.Path) 44 apierrors.WriteHTTPTooManyRequests(w, "rate limit exceeded") 45 return 46 } 47 48 defer a.connTracker.RemoveConnection(ip) 49 } 50 51 uu, _ := uuid.NewV7() 52 connID := uu.String() 53 54 ctx = log.WithLogValues(ctx, "uuid", connID, "remoteAddr", req.RemoteAddr, "url", req.URL.String()) 55 log.Log(ctx, "websocket opened") 56 spmetrics.WebsocketsOpen.Inc() 57 defer spmetrics.WebsocketsOpen.Dec() 58 user := params.ByName("repoDID") 59 if user == "" { 60 apierrors.WriteHTTPBadRequest(w, "user required", nil) 61 return 62 } 63 repoDID, err := a.NormalizeUser(ctx, user) 64 if err != nil { 65 apierrors.WriteHTTPNotFound(w, "user not found", err) 66 return 67 } 68 conn, err := upgrader.Upgrade(w, req, nil) 69 if err != nil { 70 apierrors.WriteHTTPInternalServerError(w, "could not upgrade to websocket", err) 71 return 72 } 73 ctx, cancel := context.WithCancel(ctx) 74 defer cancel() 75 defer conn.Close() 76 77 initialBurst := make(chan any, 200) 78 err = conn.SetReadDeadline(time.Now().Add(30 * time.Second)) 79 if err != nil { 80 log.Error(ctx, "could not set read deadline", "error", err) 81 return 82 } 83 84 pongCh := make(chan struct{}) 85 86 go func() { 87 for { 88 select { 89 case <-ctx.Done(): 90 return 91 case <-pongCh: 92 err := conn.SetReadDeadline(time.Now().Add(30 * time.Second)) 93 if err != nil { 94 log.Error(ctx, "could not set read deadline", "error", err) 95 return 96 } 97 case <-time.After(30 * time.Second): 98 log.Log(ctx, "websocket timeout, closing connection") 99 // timeout! 100 conn.Close() 101 cancel() 102 return 103 } 104 } 105 }() 106 107 conn.SetPongHandler(func(appData string) error { 108 log.Debug(ctx, "received pong", "appData", appData) 109 pongCh <- struct{}{} 110 return nil 111 }) 112 go func() { 113 114 ch := a.Bus.Subscribe(repoDID) 115 defer a.Bus.Unsubscribe(repoDID, ch) 116 // Create a ticker that fires every 3 seconds 117 ticker := time.NewTicker(3 * time.Second) 118 pingTicker := time.NewTicker(pingPeriod) 119 defer ticker.Stop() 120 defer pingTicker.Stop() 121 122 send := func(msg any) { 123 bs, err := json.Marshal(msg) 124 if err != nil { 125 log.Error(ctx, "could not marshal message", "error", err) 126 return 127 } 128 log.Debug(ctx, "sending message", "message", string(bs)) 129 err = conn.WriteMessage(websocket.TextMessage, bs) 130 if err != nil { 131 log.Error(ctx, "could not write message", "error", err) 132 return 133 } 134 } 135 136 for { 137 select { 138 case msg := <-ch: 139 send(msg) 140 case msg := <-initialBurst: 141 send(msg) 142 case <-ticker.C: 143 count := a.Bus.GetViewerCount(repoDID) 144 bs, err := json.Marshal(streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"}) 145 if err != nil { 146 log.Error(ctx, "could not marshal view count", "error", err) 147 continue 148 } 149 err = conn.WriteMessage(websocket.TextMessage, bs) 150 if err != nil { 151 log.Error(ctx, "could not write ping message", "error", err) 152 return 153 } 154 case <-pingTicker.C: 155 err := conn.WriteMessage(websocket.PingMessage, []byte{}) 156 if err != nil { 157 log.Error(ctx, "could not write ping message", "error", err) 158 return 159 } 160 case <-ctx.Done(): 161 log.Debug(ctx, "context done, stopping websocket sender") 162 return 163 } 164 } 165 }() 166 167 go func() { 168 profile, err := a.Model.GetRepo(repoDID) 169 if err != nil { 170 log.Error(ctx, "could not get profile", "error", err) 171 return 172 } 173 if profile != nil { 174 p := map[string]any{ 175 "$type": "app.bsky.actor.defs#profileViewBasic", 176 "did": repoDID, 177 "handle": profile.Handle, 178 } 179 initialBurst <- p 180 } 181 }() 182 183 go func() { 184 seg, err := a.LocalDB.LatestSegmentForUser(repoDID) 185 if err != nil { 186 log.Error(ctx, "could not get replies", "error", err) 187 return 188 } 189 spSeg, err := seg.ToStreamplaceSegment() 190 if err != nil { 191 log.Error(ctx, "could not convert segment to streamplace segment", "error", err) 192 return 193 } 194 initialBurst <- spSeg 195 if a.CLI.LivepeerGatewayURL != "" { 196 renditions, err := renditions.GenerateRenditions(spSeg) 197 if err != nil { 198 log.Error(ctx, "could not generate renditions", "error", err) 199 return 200 } 201 outRs := streamplace.Defs_Renditions{ 202 LexiconTypeID: "place.stream.defs#renditions", 203 } 204 outRs.Renditions = []*streamplace.Defs_Rendition{} 205 for _, r := range renditions { 206 outRs.Renditions = append(outRs.Renditions, &streamplace.Defs_Rendition{ 207 LexiconTypeID: "place.stream.defs#rendition", 208 Name: r.Name, 209 }) 210 } 211 initialBurst <- outRs 212 } 213 }() 214 215 go func() { 216 ls, err := a.Model.GetLatestLivestreamForRepo(repoDID) 217 if err != nil { 218 log.Error(ctx, "could not get latest livestream", "error", err) 219 return 220 } 221 lsv, err := ls.ToLivestreamView() 222 if err != nil { 223 log.Error(ctx, "could not marshal livestream", "error", err) 224 return 225 } 226 initialBurst <- lsv 227 }() 228 229 go func() { 230 count := a.Bus.GetViewerCount(repoDID) 231 initialBurst <- streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"} 232 }() 233 234 go func() { 235 messages, err := a.Model.MostRecentChatMessages(repoDID) 236 if err != nil { 237 log.Error(ctx, "could not get chat messages", "error", err) 238 return 239 } 240 for _, message := range messages { 241 initialBurst <- message 242 } 243 }() 244 245 go func() { 246 teleports, err := a.Model.GetActiveTeleportsToRepo(repoDID) 247 if err != nil { 248 log.Error(ctx, "could not get active teleports", "error", err) 249 return 250 } 251 // just send the latest one if it started <3m ago 252 if len(teleports) > 0 && teleports[0].StartsAt.After(time.Now().Add(-3*time.Minute)) { 253 tp := teleports[0] 254 if tp.Repo == nil { 255 log.Error(ctx, "teleportee repo is nil", "uri", tp.URI) 256 } 257 viewerCount := a.Bus.GetViewerCount(tp.RepoDID) 258 arrivalMsg := streamplace.Livestream_TeleportArrival{ 259 LexiconTypeID: "place.stream.livestream#teleportArrival", 260 TeleportUri: tp.URI, 261 Source: &bsky.ActorDefs_ProfileViewBasic{ 262 Did: tp.RepoDID, 263 Handle: tp.Repo.Handle, 264 }, 265 ViewerCount: int64(viewerCount), 266 StartsAt: tp.StartsAt.Format(time.RFC3339), 267 } 268 269 // get the source chat profile 270 chatProfile, err := a.Model.GetChatProfile(ctx, tp.RepoDID) 271 if err == nil && chatProfile != nil { 272 spcp, err := chatProfile.ToStreamplaceChatProfile() 273 if err == nil { 274 arrivalMsg.ChatProfile = spcp 275 } 276 } 277 278 initialBurst <- arrivalMsg 279 } 280 }() 281 282 for { 283 messageType, message, err := conn.ReadMessage() 284 if err != nil { 285 log.Error(ctx, "error reading message", "error", err) 286 break 287 } 288 log.Log(ctx, "received message", "messageType", messageType, "message", string(message)) 289 } 290 } 291}