Live video on the AT Protocol
1package api
2
3import (
4 "context"
5 "encoding/json"
6 "net"
7 "net/http"
8 "time"
9
10 "github.com/google/uuid"
11 "github.com/gorilla/websocket"
12 "github.com/julienschmidt/httprouter"
13
14 apierrors "stream.place/streamplace/pkg/errors"
15 "stream.place/streamplace/pkg/log"
16 "stream.place/streamplace/pkg/renditions"
17 "stream.place/streamplace/pkg/spmetrics"
18 "stream.place/streamplace/pkg/streamplace"
19)
20
21// todo: does this mean a whole message has to fit within the buffer?
22var upgrader = websocket.Upgrader{
23 ReadBufferSize: 1024,
24 WriteBufferSize: 1024,
25}
26
27var pingPeriod = 5 * time.Second
28
29func (a *StreamplaceAPI) HandleWebsocket(ctx context.Context) httprouter.Handle {
30 ctx = log.WithLogValues(ctx, "func", "HandleWebsocket")
31 return func(w http.ResponseWriter, req *http.Request, params httprouter.Params) {
32 ip, _, err := net.SplitHostPort(req.RemoteAddr)
33 if err != nil {
34 ip = req.RemoteAddr
35 }
36
37 if !a.connTracker.AddConnection(ip) {
38 log.Warn(ctx, "rate limit exceeded", "ip", ip, "path", req.URL.Path)
39 apierrors.WriteHTTPTooManyRequests(w, "rate limit exceeded")
40 return
41 }
42
43 defer a.connTracker.RemoveConnection(ip)
44
45 uu, _ := uuid.NewV7()
46 connID := uu.String()
47
48 ctx = log.WithLogValues(ctx, "uuid", connID, "remoteAddr", req.RemoteAddr, "url", req.URL.String())
49 log.Log(ctx, "websocket opened")
50 spmetrics.WebsocketsOpen.Inc()
51 defer spmetrics.WebsocketsOpen.Dec()
52 user := params.ByName("repoDID")
53 if user == "" {
54 apierrors.WriteHTTPBadRequest(w, "user required", nil)
55 return
56 }
57 repoDID, err := a.NormalizeUser(ctx, user)
58 if err != nil {
59 apierrors.WriteHTTPNotFound(w, "user not found", err)
60 return
61 }
62 conn, err := upgrader.Upgrade(w, req, nil)
63 if err != nil {
64 apierrors.WriteHTTPInternalServerError(w, "could not upgrade to websocket", err)
65 return
66 }
67 ctx, cancel := context.WithCancel(ctx)
68 defer cancel()
69 defer conn.Close()
70
71 initialBurst := make(chan any, 200)
72 err = conn.SetReadDeadline(time.Now().Add(30 * time.Second))
73 if err != nil {
74 log.Error(ctx, "could not set read deadline", "error", err)
75 return
76 }
77
78 pongCh := make(chan struct{})
79
80 go func() {
81 for {
82 select {
83 case <-ctx.Done():
84 return
85 case <-pongCh:
86 err := conn.SetReadDeadline(time.Now().Add(30 * time.Second))
87 if err != nil {
88 log.Error(ctx, "could not set read deadline", "error", err)
89 return
90 }
91 case <-time.After(30 * time.Second):
92 log.Log(ctx, "websocket timeout, closing connection")
93 // timeout!
94 conn.Close()
95 cancel()
96 return
97 }
98 }
99 }()
100
101 conn.SetPongHandler(func(appData string) error {
102 log.Debug(ctx, "received pong", "appData", appData)
103 pongCh <- struct{}{}
104 return nil
105 })
106 go func() {
107
108 ch := a.Bus.Subscribe(repoDID)
109 defer a.Bus.Unsubscribe(repoDID, ch)
110 // Create a ticker that fires every 3 seconds
111 ticker := time.NewTicker(3 * time.Second)
112 pingTicker := time.NewTicker(pingPeriod)
113 defer ticker.Stop()
114 defer pingTicker.Stop()
115
116 send := func(msg any) {
117 bs, err := json.Marshal(msg)
118 if err != nil {
119 log.Error(ctx, "could not marshal message", "error", err)
120 return
121 }
122 log.Debug(ctx, "sending message", "message", string(bs))
123 err = conn.WriteMessage(websocket.TextMessage, bs)
124 if err != nil {
125 log.Error(ctx, "could not write message", "error", err)
126 return
127 }
128 }
129
130 for {
131 select {
132 case msg := <-ch:
133 send(msg)
134 case msg := <-initialBurst:
135 send(msg)
136 case <-ticker.C:
137 count := spmetrics.GetViewCount(repoDID)
138 bs, err := json.Marshal(streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"})
139 if err != nil {
140 log.Error(ctx, "could not marshal view count", "error", err)
141 continue
142 }
143 err = conn.WriteMessage(websocket.TextMessage, bs)
144 if err != nil {
145 log.Error(ctx, "could not write ping message", "error", err)
146 return
147 }
148 case <-pingTicker.C:
149 err := conn.WriteMessage(websocket.PingMessage, []byte{})
150 if err != nil {
151 log.Error(ctx, "could not write ping message", "error", err)
152 return
153 }
154 case <-ctx.Done():
155 log.Debug(ctx, "context done, stopping websocket sender")
156 return
157 }
158 }
159 }()
160
161 go func() {
162 seg, err := a.Model.LatestSegmentForUser(repoDID)
163 if err != nil {
164 log.Error(ctx, "could not get replies", "error", err)
165 return
166 }
167 spSeg, err := seg.ToStreamplaceSegment()
168 if err != nil {
169 log.Error(ctx, "could not convert segment to streamplace segment", "error", err)
170 return
171 }
172 initialBurst <- spSeg
173 if a.CLI.LivepeerGatewayURL != "" {
174 renditions, err := renditions.GenerateRenditions(spSeg)
175 if err != nil {
176 log.Error(ctx, "could not generate renditions", "error", err)
177 return
178 }
179 outRs := streamplace.Defs_Renditions{
180 LexiconTypeID: "place.stream.defs#renditions",
181 }
182 outRs.Renditions = []*streamplace.Defs_Rendition{}
183 for _, r := range renditions {
184 outRs.Renditions = append(outRs.Renditions, &streamplace.Defs_Rendition{
185 LexiconTypeID: "place.stream.defs#rendition",
186 Name: r.Name,
187 })
188 }
189 initialBurst <- outRs
190 }
191 }()
192
193 go func() {
194 ls, err := a.Model.GetLatestLivestreamForRepo(repoDID)
195 if err != nil {
196 log.Error(ctx, "could not get latest livestream", "error", err)
197 return
198 }
199 lsv, err := ls.ToLivestreamView()
200 if err != nil {
201 log.Error(ctx, "could not marshal livestream", "error", err)
202 return
203 }
204 initialBurst <- lsv
205 }()
206
207 go func() {
208 count := spmetrics.GetViewCount(repoDID)
209 initialBurst <- streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"}
210 }()
211
212 go func() {
213 messages, err := a.Model.MostRecentChatMessages(repoDID)
214 if err != nil {
215 log.Error(ctx, "could not get chat messages", "error", err)
216 return
217 }
218 for _, message := range messages {
219 initialBurst <- message
220 }
221 }()
222
223 for {
224 messageType, message, err := conn.ReadMessage()
225 if err != nil {
226 log.Error(ctx, "error reading message", "error", err)
227 break
228 }
229 log.Log(ctx, "received message", "messageType", messageType, "message", string(message))
230 }
231 }
232}