Live video on the AT Protocol
1package api
2
3import (
4 "context"
5 "encoding/json"
6 "net"
7 "net/http"
8 "time"
9
10 bsky "github.com/bluesky-social/indigo/api/bsky"
11 "github.com/google/uuid"
12 "github.com/gorilla/websocket"
13 "github.com/julienschmidt/httprouter"
14
15 apierrors "stream.place/streamplace/pkg/errors"
16 "stream.place/streamplace/pkg/log"
17 "stream.place/streamplace/pkg/renditions"
18 "stream.place/streamplace/pkg/spmetrics"
19 "stream.place/streamplace/pkg/streamplace"
20)
21
22// todo: does this mean a whole message has to fit within the buffer?
23var upgrader = websocket.Upgrader{
24 ReadBufferSize: 1024,
25 WriteBufferSize: 1024,
26 CheckOrigin: func(r *http.Request) bool {
27 return true
28 },
29}
30
31var pingPeriod = 5 * time.Second
32
33func (a *StreamplaceAPI) HandleWebsocket(ctx context.Context) httprouter.Handle {
34 ctx = log.WithLogValues(ctx, "func", "HandleWebsocket")
35 return func(w http.ResponseWriter, req *http.Request, params httprouter.Params) {
36 ip, _, err := net.SplitHostPort(req.RemoteAddr)
37 if err != nil {
38 ip = req.RemoteAddr
39 }
40
41 if a.CLI.RateLimitWebsocket > 0 {
42 if !a.connTracker.AddConnection(ip) {
43 log.Warn(ctx, "rate limit exceeded", "ip", ip, "path", req.URL.Path)
44 apierrors.WriteHTTPTooManyRequests(w, "rate limit exceeded")
45 return
46 }
47
48 defer a.connTracker.RemoveConnection(ip)
49 }
50
51 uu, _ := uuid.NewV7()
52 connID := uu.String()
53
54 ctx = log.WithLogValues(ctx, "uuid", connID, "remoteAddr", req.RemoteAddr, "url", req.URL.String())
55 log.Log(ctx, "websocket opened")
56 spmetrics.WebsocketsOpen.Inc()
57 defer spmetrics.WebsocketsOpen.Dec()
58 user := params.ByName("repoDID")
59 if user == "" {
60 apierrors.WriteHTTPBadRequest(w, "user required", nil)
61 return
62 }
63 repoDID, err := a.NormalizeUser(ctx, user)
64 if err != nil {
65 apierrors.WriteHTTPNotFound(w, "user not found", err)
66 return
67 }
68 conn, err := upgrader.Upgrade(w, req, nil)
69 if err != nil {
70 apierrors.WriteHTTPInternalServerError(w, "could not upgrade to websocket", err)
71 return
72 }
73 ctx, cancel := context.WithCancel(ctx)
74 defer cancel()
75 defer conn.Close()
76
77 initialBurst := make(chan any, 200)
78 err = conn.SetReadDeadline(time.Now().Add(30 * time.Second))
79 if err != nil {
80 log.Error(ctx, "could not set read deadline", "error", err)
81 return
82 }
83
84 pongCh := make(chan struct{})
85
86 go func() {
87 for {
88 select {
89 case <-ctx.Done():
90 return
91 case <-pongCh:
92 err := conn.SetReadDeadline(time.Now().Add(30 * time.Second))
93 if err != nil {
94 log.Error(ctx, "could not set read deadline", "error", err)
95 return
96 }
97 case <-time.After(30 * time.Second):
98 log.Log(ctx, "websocket timeout, closing connection")
99 // timeout!
100 conn.Close()
101 cancel()
102 return
103 }
104 }
105 }()
106
107 conn.SetPongHandler(func(appData string) error {
108 log.Debug(ctx, "received pong", "appData", appData)
109 pongCh <- struct{}{}
110 return nil
111 })
112 go func() {
113
114 ch := a.Bus.Subscribe(repoDID)
115 defer a.Bus.Unsubscribe(repoDID, ch)
116 // Create a ticker that fires every 3 seconds
117 ticker := time.NewTicker(3 * time.Second)
118 pingTicker := time.NewTicker(pingPeriod)
119 defer ticker.Stop()
120 defer pingTicker.Stop()
121
122 send := func(msg any) {
123 bs, err := json.Marshal(msg)
124 if err != nil {
125 log.Error(ctx, "could not marshal message", "error", err)
126 return
127 }
128 log.Debug(ctx, "sending message", "message", string(bs))
129 err = conn.WriteMessage(websocket.TextMessage, bs)
130 if err != nil {
131 log.Error(ctx, "could not write message", "error", err)
132 return
133 }
134 }
135
136 for {
137 select {
138 case msg := <-ch:
139 send(msg)
140 case msg := <-initialBurst:
141 send(msg)
142 case <-ticker.C:
143 count := a.Bus.GetViewerCount(repoDID)
144 bs, err := json.Marshal(streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"})
145 if err != nil {
146 log.Error(ctx, "could not marshal view count", "error", err)
147 continue
148 }
149 err = conn.WriteMessage(websocket.TextMessage, bs)
150 if err != nil {
151 log.Error(ctx, "could not write ping message", "error", err)
152 return
153 }
154 case <-pingTicker.C:
155 err := conn.WriteMessage(websocket.PingMessage, []byte{})
156 if err != nil {
157 log.Error(ctx, "could not write ping message", "error", err)
158 return
159 }
160 case <-ctx.Done():
161 log.Debug(ctx, "context done, stopping websocket sender")
162 return
163 }
164 }
165 }()
166
167 go func() {
168 profile, err := a.Model.GetRepo(repoDID)
169 if err != nil {
170 log.Error(ctx, "could not get profile", "error", err)
171 return
172 }
173 if profile != nil {
174 p := map[string]any{
175 "$type": "app.bsky.actor.defs#profileViewBasic",
176 "did": repoDID,
177 "handle": profile.Handle,
178 }
179 initialBurst <- p
180 }
181 }()
182
183 go func() {
184 seg, err := a.LocalDB.LatestSegmentForUser(repoDID)
185 if err != nil {
186 log.Error(ctx, "could not get replies", "error", err)
187 return
188 }
189 spSeg, err := seg.ToStreamplaceSegment()
190 if err != nil {
191 log.Error(ctx, "could not convert segment to streamplace segment", "error", err)
192 return
193 }
194 initialBurst <- spSeg
195 if a.CLI.LivepeerGatewayURL != "" {
196 renditions, err := renditions.GenerateRenditions(spSeg)
197 if err != nil {
198 log.Error(ctx, "could not generate renditions", "error", err)
199 return
200 }
201 outRs := streamplace.Defs_Renditions{
202 LexiconTypeID: "place.stream.defs#renditions",
203 }
204 outRs.Renditions = []*streamplace.Defs_Rendition{}
205 for _, r := range renditions {
206 outRs.Renditions = append(outRs.Renditions, &streamplace.Defs_Rendition{
207 LexiconTypeID: "place.stream.defs#rendition",
208 Name: r.Name,
209 })
210 }
211 initialBurst <- outRs
212 }
213 }()
214
215 go func() {
216 ls, err := a.Model.GetLatestLivestreamForRepo(repoDID)
217 if err != nil {
218 log.Error(ctx, "could not get latest livestream", "error", err)
219 return
220 }
221 lsv, err := ls.ToLivestreamView()
222 if err != nil {
223 log.Error(ctx, "could not marshal livestream", "error", err)
224 return
225 }
226 initialBurst <- lsv
227 }()
228
229 go func() {
230 count := a.Bus.GetViewerCount(repoDID)
231 initialBurst <- streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"}
232 }()
233
234 go func() {
235 messages, err := a.Model.MostRecentChatMessages(repoDID)
236 if err != nil {
237 log.Error(ctx, "could not get chat messages", "error", err)
238 return
239 }
240 for _, message := range messages {
241 initialBurst <- message
242 }
243 }()
244
245 go func() {
246 teleports, err := a.Model.GetActiveTeleportsToRepo(repoDID)
247 if err != nil {
248 log.Error(ctx, "could not get active teleports", "error", err)
249 return
250 }
251 // just send the latest one if it started <3m ago
252 if len(teleports) > 0 && teleports[0].StartsAt.After(time.Now().Add(-3*time.Minute)) {
253 tp := teleports[0]
254 if tp.Repo == nil {
255 log.Error(ctx, "teleportee repo is nil", "uri", tp.URI)
256 }
257 viewerCount := a.Bus.GetViewerCount(tp.RepoDID)
258 arrivalMsg := streamplace.Livestream_TeleportArrival{
259 LexiconTypeID: "place.stream.livestream#teleportArrival",
260 TeleportUri: tp.URI,
261 Source: &bsky.ActorDefs_ProfileViewBasic{
262 Did: tp.RepoDID,
263 Handle: tp.Repo.Handle,
264 },
265 ViewerCount: int64(viewerCount),
266 StartsAt: tp.StartsAt.Format(time.RFC3339),
267 }
268
269 // get the source chat profile
270 chatProfile, err := a.Model.GetChatProfile(ctx, tp.RepoDID)
271 if err == nil && chatProfile != nil {
272 spcp, err := chatProfile.ToStreamplaceChatProfile()
273 if err == nil {
274 arrivalMsg.ChatProfile = spcp
275 }
276 }
277
278 initialBurst <- arrivalMsg
279 }
280 }()
281
282 for {
283 messageType, message, err := conn.ReadMessage()
284 if err != nil {
285 log.Error(ctx, "error reading message", "error", err)
286 break
287 }
288 log.Log(ctx, "received message", "messageType", messageType, "message", string(message))
289 }
290 }
291}