Live video on the AT Protocol
1package api
2
3import (
4 "context"
5 "encoding/json"
6 "net"
7 "net/http"
8 "time"
9
10 "github.com/google/uuid"
11 "github.com/gorilla/websocket"
12 "github.com/julienschmidt/httprouter"
13
14 apierrors "stream.place/streamplace/pkg/errors"
15 "stream.place/streamplace/pkg/log"
16 "stream.place/streamplace/pkg/renditions"
17 "stream.place/streamplace/pkg/spmetrics"
18 "stream.place/streamplace/pkg/streamplace"
19)
20
21// todo: does this mean a whole message has to fit within the buffer?
22var upgrader = websocket.Upgrader{
23 ReadBufferSize: 1024,
24 WriteBufferSize: 1024,
25}
26
27var pingPeriod = 5 * time.Second
28
29func (a *StreamplaceAPI) HandleWebsocket(ctx context.Context) httprouter.Handle {
30 ctx = log.WithLogValues(ctx, "func", "HandleWebsocket")
31 return func(w http.ResponseWriter, req *http.Request, params httprouter.Params) {
32 ip, _, err := net.SplitHostPort(req.RemoteAddr)
33 if err != nil {
34 ip = req.RemoteAddr
35 }
36
37 if a.CLI.RateLimitWebsocket > 0 {
38 if !a.connTracker.AddConnection(ip) {
39 log.Warn(ctx, "rate limit exceeded", "ip", ip, "path", req.URL.Path)
40 apierrors.WriteHTTPTooManyRequests(w, "rate limit exceeded")
41 return
42 }
43
44 defer a.connTracker.RemoveConnection(ip)
45 }
46
47 uu, _ := uuid.NewV7()
48 connID := uu.String()
49
50 ctx = log.WithLogValues(ctx, "uuid", connID, "remoteAddr", req.RemoteAddr, "url", req.URL.String())
51 log.Log(ctx, "websocket opened")
52 spmetrics.WebsocketsOpen.Inc()
53 defer spmetrics.WebsocketsOpen.Dec()
54 user := params.ByName("repoDID")
55 if user == "" {
56 apierrors.WriteHTTPBadRequest(w, "user required", nil)
57 return
58 }
59 repoDID, err := a.NormalizeUser(ctx, user)
60 if err != nil {
61 apierrors.WriteHTTPNotFound(w, "user not found", err)
62 return
63 }
64 conn, err := upgrader.Upgrade(w, req, nil)
65 if err != nil {
66 apierrors.WriteHTTPInternalServerError(w, "could not upgrade to websocket", err)
67 return
68 }
69 ctx, cancel := context.WithCancel(ctx)
70 defer cancel()
71 defer conn.Close()
72
73 initialBurst := make(chan any, 200)
74 err = conn.SetReadDeadline(time.Now().Add(30 * time.Second))
75 if err != nil {
76 log.Error(ctx, "could not set read deadline", "error", err)
77 return
78 }
79
80 pongCh := make(chan struct{})
81
82 go func() {
83 for {
84 select {
85 case <-ctx.Done():
86 return
87 case <-pongCh:
88 err := conn.SetReadDeadline(time.Now().Add(30 * time.Second))
89 if err != nil {
90 log.Error(ctx, "could not set read deadline", "error", err)
91 return
92 }
93 case <-time.After(30 * time.Second):
94 log.Log(ctx, "websocket timeout, closing connection")
95 // timeout!
96 conn.Close()
97 cancel()
98 return
99 }
100 }
101 }()
102
103 conn.SetPongHandler(func(appData string) error {
104 log.Debug(ctx, "received pong", "appData", appData)
105 pongCh <- struct{}{}
106 return nil
107 })
108 go func() {
109
110 ch := a.Bus.Subscribe(repoDID)
111 defer a.Bus.Unsubscribe(repoDID, ch)
112 // Create a ticker that fires every 3 seconds
113 ticker := time.NewTicker(3 * time.Second)
114 pingTicker := time.NewTicker(pingPeriod)
115 defer ticker.Stop()
116 defer pingTicker.Stop()
117
118 send := func(msg any) {
119 bs, err := json.Marshal(msg)
120 if err != nil {
121 log.Error(ctx, "could not marshal message", "error", err)
122 return
123 }
124 log.Debug(ctx, "sending message", "message", string(bs))
125 err = conn.WriteMessage(websocket.TextMessage, bs)
126 if err != nil {
127 log.Error(ctx, "could not write message", "error", err)
128 return
129 }
130 }
131
132 for {
133 select {
134 case msg := <-ch:
135 send(msg)
136 case msg := <-initialBurst:
137 send(msg)
138 case <-ticker.C:
139 count := spmetrics.GetViewCount(repoDID)
140 bs, err := json.Marshal(streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"})
141 if err != nil {
142 log.Error(ctx, "could not marshal view count", "error", err)
143 continue
144 }
145 err = conn.WriteMessage(websocket.TextMessage, bs)
146 if err != nil {
147 log.Error(ctx, "could not write ping message", "error", err)
148 return
149 }
150 case <-pingTicker.C:
151 err := conn.WriteMessage(websocket.PingMessage, []byte{})
152 if err != nil {
153 log.Error(ctx, "could not write ping message", "error", err)
154 return
155 }
156 case <-ctx.Done():
157 log.Debug(ctx, "context done, stopping websocket sender")
158 return
159 }
160 }
161 }()
162
163 go func() {
164 seg, err := a.Model.LatestSegmentForUser(repoDID)
165 if err != nil {
166 log.Error(ctx, "could not get replies", "error", err)
167 return
168 }
169 spSeg, err := seg.ToStreamplaceSegment()
170 if err != nil {
171 log.Error(ctx, "could not convert segment to streamplace segment", "error", err)
172 return
173 }
174 initialBurst <- spSeg
175 if a.CLI.LivepeerGatewayURL != "" {
176 renditions, err := renditions.GenerateRenditions(spSeg)
177 if err != nil {
178 log.Error(ctx, "could not generate renditions", "error", err)
179 return
180 }
181 outRs := streamplace.Defs_Renditions{
182 LexiconTypeID: "place.stream.defs#renditions",
183 }
184 outRs.Renditions = []*streamplace.Defs_Rendition{}
185 for _, r := range renditions {
186 outRs.Renditions = append(outRs.Renditions, &streamplace.Defs_Rendition{
187 LexiconTypeID: "place.stream.defs#rendition",
188 Name: r.Name,
189 })
190 }
191 initialBurst <- outRs
192 }
193 }()
194
195 go func() {
196 ls, err := a.Model.GetLatestLivestreamForRepo(repoDID)
197 if err != nil {
198 log.Error(ctx, "could not get latest livestream", "error", err)
199 return
200 }
201 lsv, err := ls.ToLivestreamView()
202 if err != nil {
203 log.Error(ctx, "could not marshal livestream", "error", err)
204 return
205 }
206 initialBurst <- lsv
207 }()
208
209 go func() {
210 count := spmetrics.GetViewCount(repoDID)
211 initialBurst <- streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"}
212 }()
213
214 go func() {
215 messages, err := a.Model.MostRecentChatMessages(repoDID)
216 if err != nil {
217 log.Error(ctx, "could not get chat messages", "error", err)
218 return
219 }
220 for _, message := range messages {
221 initialBurst <- message
222 }
223 }()
224
225 for {
226 messageType, message, err := conn.ReadMessage()
227 if err != nil {
228 log.Error(ctx, "error reading message", "error", err)
229 break
230 }
231 log.Log(ctx, "received message", "messageType", messageType, "message", string(message))
232 }
233 }
234}