Live video on the AT Protocol
1package api
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "net"
8 "net/http"
9 "time"
10
11 bsky "github.com/bluesky-social/indigo/api/bsky"
12 "github.com/google/uuid"
13 "github.com/gorilla/websocket"
14 "github.com/julienschmidt/httprouter"
15
16 "stream.place/streamplace/pkg/atproto"
17 apierrors "stream.place/streamplace/pkg/errors"
18 "stream.place/streamplace/pkg/log"
19 "stream.place/streamplace/pkg/renditions"
20 "stream.place/streamplace/pkg/spmetrics"
21 "stream.place/streamplace/pkg/streamplace"
22)
23
24// todo: does this mean a whole message has to fit within the buffer?
25var upgrader = websocket.Upgrader{
26 ReadBufferSize: 1024,
27 WriteBufferSize: 1024,
28 CheckOrigin: func(r *http.Request) bool {
29 return true
30 },
31}
32
33var pingPeriod = 5 * time.Second
34
35func (a *StreamplaceAPI) HandleWebsocket(ctx context.Context) httprouter.Handle {
36 ctx = log.WithLogValues(ctx, "func", "HandleWebsocket")
37 return func(w http.ResponseWriter, req *http.Request, params httprouter.Params) {
38 ip, _, err := net.SplitHostPort(req.RemoteAddr)
39 if err != nil {
40 ip = req.RemoteAddr
41 }
42
43 if a.CLI.RateLimitWebsocket > 0 {
44 if !a.connTracker.AddConnection(ip) {
45 log.Warn(ctx, "rate limit exceeded", "ip", ip, "path", req.URL.Path)
46 apierrors.WriteHTTPTooManyRequests(w, "rate limit exceeded")
47 return
48 }
49
50 defer a.connTracker.RemoveConnection(ip)
51 }
52
53 uu, _ := uuid.NewV7()
54 connID := uu.String()
55
56 ctx = log.WithLogValues(ctx, "uuid", connID, "remoteAddr", req.RemoteAddr, "url", req.URL.String())
57 log.Log(ctx, "websocket opened")
58 spmetrics.WebsocketsOpen.Inc()
59 defer spmetrics.WebsocketsOpen.Dec()
60 user := params.ByName("repoDID")
61 if user == "" {
62 apierrors.WriteHTTPBadRequest(w, "user required", nil)
63 return
64 }
65 repoDID, err := a.NormalizeUser(ctx, user)
66 if err != nil {
67 apierrors.WriteHTTPNotFound(w, "user not found", err)
68 return
69 }
70 conn, err := upgrader.Upgrade(w, req, nil)
71 if err != nil {
72 apierrors.WriteHTTPInternalServerError(w, "could not upgrade to websocket", err)
73 return
74 }
75 ctx, cancel := context.WithCancel(ctx)
76 defer cancel()
77 defer conn.Close()
78
79 initialBurst := make(chan any, 200)
80 err = conn.SetReadDeadline(time.Now().Add(30 * time.Second))
81 if err != nil {
82 log.Error(ctx, "could not set read deadline", "error", err)
83 return
84 }
85
86 pongCh := make(chan struct{})
87
88 go func() {
89 for {
90 select {
91 case <-ctx.Done():
92 return
93 case <-pongCh:
94 err := conn.SetReadDeadline(time.Now().Add(30 * time.Second))
95 if err != nil {
96 log.Error(ctx, "could not set read deadline", "error", err)
97 return
98 }
99 case <-time.After(30 * time.Second):
100 log.Log(ctx, "websocket timeout, closing connection")
101 // timeout!
102 conn.Close()
103 cancel()
104 return
105 }
106 }
107 }()
108
109 conn.SetPongHandler(func(appData string) error {
110 log.Debug(ctx, "received pong", "appData", appData)
111 pongCh <- struct{}{}
112 return nil
113 })
114 go func() {
115
116 ch := a.Bus.Subscribe(repoDID)
117 defer a.Bus.Unsubscribe(repoDID, ch)
118 // Create a ticker that fires every 3 seconds
119 ticker := time.NewTicker(3 * time.Second)
120 pingTicker := time.NewTicker(pingPeriod)
121 defer ticker.Stop()
122 defer pingTicker.Stop()
123
124 send := func(msg any) {
125 bs, err := json.Marshal(msg)
126 if err != nil {
127 log.Error(ctx, "could not marshal message", "error", err)
128 return
129 }
130 log.Debug(ctx, "sending message", "message", string(bs))
131 err = conn.WriteMessage(websocket.TextMessage, bs)
132 if err != nil {
133 log.Error(ctx, "could not write message", "error", err)
134 return
135 }
136 }
137
138 for {
139 select {
140 case msg := <-ch:
141 send(msg)
142 case msg := <-initialBurst:
143 send(msg)
144 case <-ticker.C:
145 count := a.Bus.GetViewerCount(repoDID)
146 bs, err := json.Marshal(streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"})
147 if err != nil {
148 log.Error(ctx, "could not marshal view count", "error", err)
149 continue
150 }
151 err = conn.WriteMessage(websocket.TextMessage, bs)
152 if err != nil {
153 log.Error(ctx, "could not write ping message", "error", err)
154 return
155 }
156 case <-pingTicker.C:
157 err := conn.WriteMessage(websocket.PingMessage, []byte{})
158 if err != nil {
159 log.Error(ctx, "could not write ping message", "error", err)
160 return
161 }
162 case <-ctx.Done():
163 log.Debug(ctx, "context done, stopping websocket sender")
164 return
165 }
166 }
167 }()
168
169 go func() {
170 profile, err := a.Model.GetRepo(repoDID)
171 if err != nil {
172 log.Error(ctx, "could not get profile", "error", err)
173 return
174 }
175 if profile != nil {
176 p := map[string]any{
177 "$type": "app.bsky.actor.defs#profileViewBasic",
178 "did": repoDID,
179 "handle": profile.Handle,
180 }
181 initialBurst <- p
182 }
183 }()
184
185 go func() {
186 seg, err := a.LocalDB.LatestSegmentForUser(repoDID)
187 if err != nil {
188 log.Error(ctx, "could not get replies", "error", err)
189 return
190 }
191 spSeg, err := seg.ToStreamplaceSegment()
192 if err != nil {
193 log.Error(ctx, "could not convert segment to streamplace segment", "error", err)
194 return
195 }
196 initialBurst <- spSeg
197 outRs := streamplace.Defs_Renditions{
198 LexiconTypeID: "place.stream.defs#renditions",
199 Renditions: []*streamplace.Defs_Rendition{},
200 }
201 if a.CLI.LivepeerGatewayURL != "" {
202 videoRenditions, err := renditions.GenerateRenditions(spSeg)
203 if err != nil {
204 log.Error(ctx, "could not generate renditions", "error", err)
205 return
206 }
207 for _, r := range videoRenditions {
208 outRs.Renditions = append(outRs.Renditions, &streamplace.Defs_Rendition{
209 LexiconTypeID: "place.stream.defs#rendition",
210 Name: r.Name,
211 })
212 }
213 }
214 outRs.Renditions = append(outRs.Renditions, &streamplace.Defs_Rendition{
215 LexiconTypeID: "place.stream.defs#rendition",
216 Name: renditions.AudioRendition.Name,
217 })
218 initialBurst <- outRs
219 }()
220
221 go func() {
222 ls, err := a.Model.GetLatestLivestreamForRepo(repoDID)
223 if err != nil {
224 log.Error(ctx, "could not get latest livestream", "error", err)
225 return
226 }
227 if ls == nil {
228 log.Error(ctx, "no livestream found", "repoDID", repoDID)
229 return
230 }
231 lsv, err := ls.ToLivestreamView()
232 if err != nil {
233 log.Error(ctx, "could not marshal livestream", "error", err)
234 return
235 }
236 initialBurst <- lsv
237 }()
238
239 go func() {
240 count := a.Bus.GetViewerCount(repoDID)
241 initialBurst <- streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"}
242 }()
243
244 go func() {
245 messages, err := a.Model.MostRecentChatMessages(repoDID)
246 if err != nil {
247 log.Error(ctx, "could not get chat messages", "error", err)
248 return
249 }
250
251 // Add mod badges to messages
252 issuerDID := fmt.Sprintf("did:web:%s", a.CLI.BroadcasterHost)
253 for _, message := range messages {
254 err := atproto.AddModBadgeIfApplicable(ctx, message, repoDID, issuerDID, a.Model)
255 if err != nil {
256 log.Error(ctx, "failed to add mod badge to message", "error", err)
257 }
258 initialBurst <- message
259 }
260 }()
261
262 go func() {
263 teleports, err := a.Model.GetActiveTeleportsToRepo(repoDID)
264 if err != nil {
265 log.Error(ctx, "could not get active teleports", "error", err)
266 return
267 }
268 // just send the latest one if it started <3m ago
269 if len(teleports) > 0 && teleports[0].StartsAt.After(time.Now().Add(-3*time.Minute)) {
270 tp := teleports[0]
271 if tp.Repo == nil {
272 log.Error(ctx, "teleportee repo is nil", "uri", tp.URI)
273 return
274 }
275 viewerCount := a.Bus.GetViewerCount(tp.RepoDID)
276 arrivalMsg := streamplace.Livestream_TeleportArrival{
277 LexiconTypeID: "place.stream.livestream#teleportArrival",
278 TeleportUri: tp.URI,
279 Source: &bsky.ActorDefs_ProfileViewBasic{
280 Did: tp.RepoDID,
281 Handle: tp.Repo.Handle,
282 },
283 ViewerCount: int64(viewerCount),
284 StartsAt: tp.StartsAt.Format(time.RFC3339),
285 }
286
287 // get the source chat profile
288 chatProfile, err := a.Model.GetChatProfile(ctx, tp.RepoDID)
289 if err == nil && chatProfile != nil {
290 spcp, err := chatProfile.ToStreamplaceChatProfile()
291 if err == nil {
292 arrivalMsg.ChatProfile = spcp
293 }
294 }
295
296 initialBurst <- arrivalMsg
297 }
298 }()
299
300 for {
301 messageType, message, err := conn.ReadMessage()
302 if err != nil {
303 log.Error(ctx, "error reading message", "error", err)
304 break
305 }
306 log.Log(ctx, "received message", "messageType", messageType, "message", string(message))
307 }
308 }
309}