Live video on the AT Protocol
1package api
2
3import (
4 "context"
5 "encoding/json"
6 "net"
7 "net/http"
8 "time"
9
10 bsky "github.com/bluesky-social/indigo/api/bsky"
11 "github.com/google/uuid"
12 "github.com/gorilla/websocket"
13 "github.com/julienschmidt/httprouter"
14
15 "stream.place/streamplace/pkg/bus"
16 apierrors "stream.place/streamplace/pkg/errors"
17 "stream.place/streamplace/pkg/log"
18 "stream.place/streamplace/pkg/renditions"
19 "stream.place/streamplace/pkg/spmetrics"
20 "stream.place/streamplace/pkg/streamplace"
21)
22
23// todo: does this mean a whole message has to fit within the buffer?
24var upgrader = websocket.Upgrader{
25 ReadBufferSize: 1024,
26 WriteBufferSize: 1024,
27 CheckOrigin: func(r *http.Request) bool {
28 return true
29 },
30}
31
32var pingPeriod = 5 * time.Second
33
34func (a *StreamplaceAPI) HandleWebsocket(ctx context.Context) httprouter.Handle {
35 ctx = log.WithLogValues(ctx, "func", "HandleWebsocket")
36 return func(w http.ResponseWriter, req *http.Request, params httprouter.Params) {
37 ip, _, err := net.SplitHostPort(req.RemoteAddr)
38 if err != nil {
39 ip = req.RemoteAddr
40 }
41
42 if a.CLI.RateLimitWebsocket > 0 {
43 if !a.connTracker.AddConnection(ip) {
44 log.Warn(ctx, "rate limit exceeded", "ip", ip, "path", req.URL.Path)
45 apierrors.WriteHTTPTooManyRequests(w, "rate limit exceeded")
46 return
47 }
48
49 defer a.connTracker.RemoveConnection(ip)
50 }
51
52 uu, _ := uuid.NewV7()
53 connID := uu.String()
54
55 ctx = log.WithLogValues(ctx, "uuid", connID, "remoteAddr", req.RemoteAddr, "url", req.URL.String())
56 log.Log(ctx, "websocket opened")
57 spmetrics.WebsocketsOpen.Inc()
58 defer spmetrics.WebsocketsOpen.Dec()
59 user := params.ByName("repoDID")
60 if user == "" {
61 apierrors.WriteHTTPBadRequest(w, "user required", nil)
62 return
63 }
64 repoDID, err := a.NormalizeUser(ctx, user)
65 if err != nil {
66 apierrors.WriteHTTPNotFound(w, "user not found", err)
67 return
68 }
69 viewerDID := req.URL.Query().Get("viewer")
70 if viewerDID != "" {
71 if _, err := a.ATSync.SyncBlueskyRepoCached(ctx, viewerDID, a.Model); err != nil {
72 log.Error(ctx, "could not sync viewer repo", "error", err, "viewerDID", viewerDID)
73 }
74 }
75 conn, err := upgrader.Upgrade(w, req, nil)
76 if err != nil {
77 apierrors.WriteHTTPInternalServerError(w, "could not upgrade to websocket", err)
78 return
79 }
80 ctx, cancel := context.WithCancel(ctx)
81 defer cancel()
82 defer conn.Close()
83
84 initialBurst := make(chan any, 200)
85 err = conn.SetReadDeadline(time.Now().Add(30 * time.Second))
86 if err != nil {
87 log.Error(ctx, "could not set read deadline", "error", err)
88 return
89 }
90
91 pongCh := make(chan struct{})
92
93 go func() {
94 for {
95 select {
96 case <-ctx.Done():
97 return
98 case <-pongCh:
99 err := conn.SetReadDeadline(time.Now().Add(30 * time.Second))
100 if err != nil {
101 log.Error(ctx, "could not set read deadline", "error", err)
102 return
103 }
104 case <-time.After(30 * time.Second):
105 log.Log(ctx, "websocket timeout, closing connection")
106 // timeout!
107 conn.Close()
108 cancel()
109 return
110 }
111 }
112 }()
113
114 conn.SetPongHandler(func(appData string) error {
115 log.Debug(ctx, "received pong", "appData", appData)
116 pongCh <- struct{}{}
117 return nil
118 })
119 go func() {
120 blockedDIDs := make(map[string]bool)
121 if viewerDID != "" {
122 dids, err := a.Model.GetBlockedDIDs(ctx, viewerDID)
123 if err != nil {
124 log.Error(ctx, "could not get blocked DIDs", "error", err)
125 }
126 for _, did := range dids {
127 blockedDIDs[did] = true
128 }
129 }
130
131 ch := a.Bus.Subscribe(repoDID)
132 defer a.Bus.Unsubscribe(repoDID, ch)
133
134 var viewerCh <-chan bus.Message
135 if viewerDID != "" {
136 viewerCh = a.Bus.Subscribe(viewerDID)
137 defer a.Bus.Unsubscribe(viewerDID, viewerCh)
138 }
139
140 ticker := time.NewTicker(3 * time.Second)
141 pingTicker := time.NewTicker(pingPeriod)
142 defer ticker.Stop()
143 defer pingTicker.Stop()
144
145 send := func(msg any) {
146 bs, err := json.Marshal(msg)
147 if err != nil {
148 log.Error(ctx, "could not marshal message", "error", err)
149 return
150 }
151 log.Debug(ctx, "sending message", "message", string(bs))
152 err = conn.WriteMessage(websocket.TextMessage, bs)
153 if err != nil {
154 log.Error(ctx, "could not write message", "error", err)
155 return
156 }
157 }
158
159 isBlockedChat := func(msg any) bool {
160 if len(blockedDIDs) == 0 {
161 return false
162 }
163 if chatMsg, ok := msg.(*streamplace.ChatDefs_MessageView); ok {
164 return chatMsg.Author != nil && blockedDIDs[chatMsg.Author.Did]
165 }
166 return false
167 }
168
169 for {
170 select {
171 case msg := <-ch:
172 if isBlockedChat(msg) {
173 continue
174 }
175 send(msg)
176 case msg := <-viewerCh:
177 if blockView, ok := msg.(*streamplace.Defs_BlockView); ok {
178 if blockView.Blocker != nil && blockView.Blocker.Did == viewerDID && blockView.Record != nil {
179 blockedDIDs[blockView.Record.Subject] = true
180 send(blockView)
181 }
182 }
183 case msg := <-initialBurst:
184 send(msg)
185 case <-ticker.C:
186 count := a.Bus.GetViewerCount(repoDID)
187 bs, err := json.Marshal(streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"})
188 if err != nil {
189 log.Error(ctx, "could not marshal view count", "error", err)
190 continue
191 }
192 err = conn.WriteMessage(websocket.TextMessage, bs)
193 if err != nil {
194 log.Error(ctx, "could not write ping message", "error", err)
195 return
196 }
197 case <-pingTicker.C:
198 err := conn.WriteMessage(websocket.PingMessage, []byte{})
199 if err != nil {
200 log.Error(ctx, "could not write ping message", "error", err)
201 return
202 }
203 case <-ctx.Done():
204 log.Debug(ctx, "context done, stopping websocket sender")
205 return
206 }
207 }
208 }()
209
210 go func() {
211 profile, err := a.Model.GetRepo(repoDID)
212 if err != nil {
213 log.Error(ctx, "could not get profile", "error", err)
214 return
215 }
216 if profile != nil {
217 p := map[string]any{
218 "$type": "app.bsky.actor.defs#profileViewBasic",
219 "did": repoDID,
220 "handle": profile.Handle,
221 }
222 initialBurst <- p
223 }
224 }()
225
226 go func() {
227 seg, err := a.LocalDB.LatestSegmentForUser(repoDID)
228 if err != nil {
229 log.Error(ctx, "could not get replies", "error", err)
230 return
231 }
232 spSeg, err := seg.ToStreamplaceSegment()
233 if err != nil {
234 log.Error(ctx, "could not convert segment to streamplace segment", "error", err)
235 return
236 }
237 initialBurst <- spSeg
238 if a.CLI.LivepeerGatewayURL != "" {
239 renditions, err := renditions.GenerateRenditions(spSeg)
240 if err != nil {
241 log.Error(ctx, "could not generate renditions", "error", err)
242 return
243 }
244 outRs := streamplace.Defs_Renditions{
245 LexiconTypeID: "place.stream.defs#renditions",
246 }
247 outRs.Renditions = []*streamplace.Defs_Rendition{}
248 for _, r := range renditions {
249 outRs.Renditions = append(outRs.Renditions, &streamplace.Defs_Rendition{
250 LexiconTypeID: "place.stream.defs#rendition",
251 Name: r.Name,
252 })
253 }
254 initialBurst <- outRs
255 }
256 }()
257
258 go func() {
259 ls, err := a.Model.GetLatestLivestreamForRepo(repoDID)
260 if err != nil {
261 log.Error(ctx, "could not get latest livestream", "error", err)
262 return
263 }
264 lsv, err := ls.ToLivestreamView()
265 if err != nil {
266 log.Error(ctx, "could not marshal livestream", "error", err)
267 return
268 }
269 initialBurst <- lsv
270 }()
271
272 go func() {
273 count := a.Bus.GetViewerCount(repoDID)
274 initialBurst <- streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"}
275 }()
276
277 go func() {
278 messages, err := a.Model.MostRecentChatMessagesForViewer(repoDID, viewerDID)
279 if err != nil {
280 log.Error(ctx, "could not get chat messages", "error", err)
281 return
282 }
283 for _, message := range messages {
284 initialBurst <- message
285 }
286 }()
287
288 go func() {
289 teleports, err := a.Model.GetActiveTeleportsToRepo(repoDID)
290 if err != nil {
291 log.Error(ctx, "could not get active teleports", "error", err)
292 return
293 }
294 // just send the latest one if it started <3m ago
295 if len(teleports) > 0 && teleports[0].StartsAt.After(time.Now().Add(-3*time.Minute)) {
296 tp := teleports[0]
297 if tp.Repo == nil {
298 log.Error(ctx, "teleportee repo is nil", "uri", tp.URI)
299 }
300 viewerCount := a.Bus.GetViewerCount(tp.RepoDID)
301 arrivalMsg := streamplace.Livestream_TeleportArrival{
302 LexiconTypeID: "place.stream.livestream#teleportArrival",
303 TeleportUri: tp.URI,
304 Source: &bsky.ActorDefs_ProfileViewBasic{
305 Did: tp.RepoDID,
306 Handle: tp.Repo.Handle,
307 },
308 ViewerCount: int64(viewerCount),
309 StartsAt: tp.StartsAt.Format(time.RFC3339),
310 }
311
312 // get the source chat profile
313 chatProfile, err := a.Model.GetChatProfile(ctx, tp.RepoDID)
314 if err == nil && chatProfile != nil {
315 spcp, err := chatProfile.ToStreamplaceChatProfile()
316 if err == nil {
317 arrivalMsg.ChatProfile = spcp
318 }
319 }
320
321 initialBurst <- arrivalMsg
322 }
323 }()
324
325 for {
326 messageType, message, err := conn.ReadMessage()
327 if err != nil {
328 log.Error(ctx, "error reading message", "error", err)
329 break
330 }
331 log.Log(ctx, "received message", "messageType", messageType, "message", string(message))
332 }
333 }
334}