Live video on the AT Protocol
at natb/loading-overlay 253 lines 6.7 kB view raw
1package api 2 3import ( 4 "context" 5 "encoding/json" 6 "net" 7 "net/http" 8 "time" 9 10 "github.com/google/uuid" 11 "github.com/gorilla/websocket" 12 "github.com/julienschmidt/httprouter" 13 14 apierrors "stream.place/streamplace/pkg/errors" 15 "stream.place/streamplace/pkg/log" 16 "stream.place/streamplace/pkg/renditions" 17 "stream.place/streamplace/pkg/spmetrics" 18 "stream.place/streamplace/pkg/streamplace" 19) 20 21// todo: does this mean a whole message has to fit within the buffer? 22var upgrader = websocket.Upgrader{ 23 ReadBufferSize: 1024, 24 WriteBufferSize: 1024, 25 CheckOrigin: func(r *http.Request) bool { 26 return true 27 }, 28} 29 30var pingPeriod = 5 * time.Second 31 32func (a *StreamplaceAPI) HandleWebsocket(ctx context.Context) httprouter.Handle { 33 ctx = log.WithLogValues(ctx, "func", "HandleWebsocket") 34 return func(w http.ResponseWriter, req *http.Request, params httprouter.Params) { 35 ip, _, err := net.SplitHostPort(req.RemoteAddr) 36 if err != nil { 37 ip = req.RemoteAddr 38 } 39 40 if a.CLI.RateLimitWebsocket > 0 { 41 if !a.connTracker.AddConnection(ip) { 42 log.Warn(ctx, "rate limit exceeded", "ip", ip, "path", req.URL.Path) 43 apierrors.WriteHTTPTooManyRequests(w, "rate limit exceeded") 44 return 45 } 46 47 defer a.connTracker.RemoveConnection(ip) 48 } 49 50 uu, _ := uuid.NewV7() 51 connID := uu.String() 52 53 ctx = log.WithLogValues(ctx, "uuid", connID, "remoteAddr", req.RemoteAddr, "url", req.URL.String()) 54 log.Log(ctx, "websocket opened") 55 spmetrics.WebsocketsOpen.Inc() 56 defer spmetrics.WebsocketsOpen.Dec() 57 user := params.ByName("repoDID") 58 if user == "" { 59 apierrors.WriteHTTPBadRequest(w, "user required", nil) 60 return 61 } 62 repoDID, err := a.NormalizeUser(ctx, user) 63 if err != nil { 64 apierrors.WriteHTTPNotFound(w, "user not found", err) 65 return 66 } 67 conn, err := upgrader.Upgrade(w, req, nil) 68 if err != nil { 69 apierrors.WriteHTTPInternalServerError(w, "could not upgrade to websocket", err) 70 return 71 } 72 ctx, cancel := context.WithCancel(ctx) 73 defer cancel() 74 defer conn.Close() 75 76 initialBurst := make(chan any, 200) 77 err = conn.SetReadDeadline(time.Now().Add(30 * time.Second)) 78 if err != nil { 79 log.Error(ctx, "could not set read deadline", "error", err) 80 return 81 } 82 83 pongCh := make(chan struct{}) 84 85 go func() { 86 for { 87 select { 88 case <-ctx.Done(): 89 return 90 case <-pongCh: 91 err := conn.SetReadDeadline(time.Now().Add(30 * time.Second)) 92 if err != nil { 93 log.Error(ctx, "could not set read deadline", "error", err) 94 return 95 } 96 case <-time.After(30 * time.Second): 97 log.Log(ctx, "websocket timeout, closing connection") 98 // timeout! 99 conn.Close() 100 cancel() 101 return 102 } 103 } 104 }() 105 106 conn.SetPongHandler(func(appData string) error { 107 log.Debug(ctx, "received pong", "appData", appData) 108 pongCh <- struct{}{} 109 return nil 110 }) 111 go func() { 112 113 ch := a.Bus.Subscribe(repoDID) 114 defer a.Bus.Unsubscribe(repoDID, ch) 115 // Create a ticker that fires every 3 seconds 116 ticker := time.NewTicker(3 * time.Second) 117 pingTicker := time.NewTicker(pingPeriod) 118 defer ticker.Stop() 119 defer pingTicker.Stop() 120 121 send := func(msg any) { 122 bs, err := json.Marshal(msg) 123 if err != nil { 124 log.Error(ctx, "could not marshal message", "error", err) 125 return 126 } 127 log.Debug(ctx, "sending message", "message", string(bs)) 128 err = conn.WriteMessage(websocket.TextMessage, bs) 129 if err != nil { 130 log.Error(ctx, "could not write message", "error", err) 131 return 132 } 133 } 134 135 for { 136 select { 137 case msg := <-ch: 138 send(msg) 139 case msg := <-initialBurst: 140 send(msg) 141 case <-ticker.C: 142 count := spmetrics.GetViewCount(repoDID) 143 bs, err := json.Marshal(streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"}) 144 if err != nil { 145 log.Error(ctx, "could not marshal view count", "error", err) 146 continue 147 } 148 err = conn.WriteMessage(websocket.TextMessage, bs) 149 if err != nil { 150 log.Error(ctx, "could not write ping message", "error", err) 151 return 152 } 153 case <-pingTicker.C: 154 err := conn.WriteMessage(websocket.PingMessage, []byte{}) 155 if err != nil { 156 log.Error(ctx, "could not write ping message", "error", err) 157 return 158 } 159 case <-ctx.Done(): 160 log.Debug(ctx, "context done, stopping websocket sender") 161 return 162 } 163 } 164 }() 165 166 go func() { 167 profile, err := a.Model.GetRepo(repoDID) 168 if err != nil { 169 log.Error(ctx, "could not get profile", "error", err) 170 return 171 } 172 if profile != nil { 173 p := map[string]any{ 174 "$type": "app.bsky.actor.defs#profileViewBasic", 175 "did": repoDID, 176 "handle": profile.Handle, 177 } 178 initialBurst <- p 179 } 180 }() 181 182 go func() { 183 seg, err := a.Model.LatestSegmentForUser(repoDID) 184 if err != nil { 185 log.Error(ctx, "could not get replies", "error", err) 186 return 187 } 188 spSeg, err := seg.ToStreamplaceSegment() 189 if err != nil { 190 log.Error(ctx, "could not convert segment to streamplace segment", "error", err) 191 return 192 } 193 initialBurst <- spSeg 194 if a.CLI.LivepeerGatewayURL != "" { 195 renditions, err := renditions.GenerateRenditions(spSeg) 196 if err != nil { 197 log.Error(ctx, "could not generate renditions", "error", err) 198 return 199 } 200 outRs := streamplace.Defs_Renditions{ 201 LexiconTypeID: "place.stream.defs#renditions", 202 } 203 outRs.Renditions = []*streamplace.Defs_Rendition{} 204 for _, r := range renditions { 205 outRs.Renditions = append(outRs.Renditions, &streamplace.Defs_Rendition{ 206 LexiconTypeID: "place.stream.defs#rendition", 207 Name: r.Name, 208 }) 209 } 210 initialBurst <- outRs 211 } 212 }() 213 214 go func() { 215 ls, err := a.Model.GetLatestLivestreamForRepo(repoDID) 216 if err != nil { 217 log.Error(ctx, "could not get latest livestream", "error", err) 218 return 219 } 220 lsv, err := ls.ToLivestreamView() 221 if err != nil { 222 log.Error(ctx, "could not marshal livestream", "error", err) 223 return 224 } 225 initialBurst <- lsv 226 }() 227 228 go func() { 229 count := spmetrics.GetViewCount(repoDID) 230 initialBurst <- streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"} 231 }() 232 233 go func() { 234 messages, err := a.Model.MostRecentChatMessages(repoDID) 235 if err != nil { 236 log.Error(ctx, "could not get chat messages", "error", err) 237 return 238 } 239 for _, message := range messages { 240 initialBurst <- message 241 } 242 }() 243 244 for { 245 messageType, message, err := conn.ReadMessage() 246 if err != nil { 247 log.Error(ctx, "error reading message", "error", err) 248 break 249 } 250 log.Log(ctx, "received message", "messageType", messageType, "message", string(message)) 251 } 252 } 253}