Live video on the AT Protocol
1package cmd
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "io"
8 "net/http"
9 "strings"
10 "sync"
11 "time"
12
13 "github.com/pion/webrtc/v4"
14 "golang.org/x/sync/errgroup"
15 "stream.place/streamplace/pkg/log"
16)
17
18func WHEP(args []string) error {
19 fs := flag.NewFlagSet("whep", flag.ExitOnError)
20 count := fs.Int("count", 1, "number of concurrent streams (for load testing)")
21 duration := fs.Duration("duration", 0, "stop after this long")
22 endpoint := fs.String("endpoint", "", "endpoint to send the WHEP request to")
23 err := fs.Parse(args)
24
25 if err != nil {
26 return err
27 }
28
29 ctx := context.Background()
30 if *duration > 0 {
31 var cancel context.CancelFunc
32 ctx, cancel = context.WithTimeout(ctx, *duration)
33 defer cancel()
34 }
35
36 w := &WHEPClient{
37 Endpoint: *endpoint,
38 Count: *count,
39 }
40
41 return w.WHEP(ctx)
42}
43
44type WHEPClient struct {
45 StreamKey string
46 File string
47 Endpoint string
48 Count int
49 FreezeAfter time.Duration
50 Stats []map[string]*TrackStats
51}
52
53type WHEPConnection struct {
54 peerConnection *webrtc.PeerConnection
55 audioTrack *webrtc.TrackLocalStaticSample
56 videoTrack *webrtc.TrackLocalStaticSample
57 did string
58 Done func() <-chan struct{}
59}
60
61type TrackStats struct {
62 Total int
63 lastTotal int
64 lastUpdate time.Time
65 mu sync.Mutex
66}
67
68func (w *WHEPClient) StartWHEPConnection(ctx context.Context, stats map[string]*TrackStats) (*WHEPConnection, error) {
69
70 // Prepare the configuration
71 config := webrtc.Configuration{}
72
73 // Create a new RTCPeerConnection
74 peerConnection, err := webrtc.NewPeerConnection(config)
75 if err != nil {
76 return nil, err
77 }
78
79 // Create a ticker to print combined bitrate every 5 seconds
80 ticker := time.NewTicker(5 * time.Second)
81
82 // Start a goroutine to print combined bitrate
83 go func() {
84 for {
85 select {
86 case <-ticker.C:
87 currentTime := time.Now()
88
89 // Lock both stats to get a consistent snapshot
90 for _, s := range stats {
91 s.mu.Lock()
92 }
93
94 videoStats := stats["video"]
95 audioStats := stats["audio"]
96
97 videoElapsed := currentTime.Sub(videoStats.lastUpdate).Seconds()
98 audioElapsed := currentTime.Sub(audioStats.lastUpdate).Seconds()
99
100 videoBytes := videoStats.Total - videoStats.lastTotal
101 audioBytes := audioStats.Total - audioStats.lastTotal
102
103 videoBitrate := float64(videoBytes) * 8 / videoElapsed / 1000 // kbps
104 audioBitrate := float64(audioBytes) * 8 / audioElapsed / 1000 // kbps
105
106 log.Log(ctx, "bitrate stats",
107 "video", fmt.Sprintf("%.2f kbps (%.2f KB)", videoBitrate, float64(videoBytes)/1000),
108 "audio", fmt.Sprintf("%.2f kbps (%.2f KB)", audioBitrate, float64(audioBytes)/1000),
109 "total", fmt.Sprintf("%.2f kbps", videoBitrate+audioBitrate))
110
111 // Update last values
112 videoStats.lastTotal = videoStats.Total
113 videoStats.lastUpdate = currentTime
114 audioStats.lastTotal = audioStats.Total
115 audioStats.lastUpdate = currentTime
116
117 // Unlock stats
118 for _, s := range stats {
119 s.mu.Unlock()
120 }
121
122 case <-ctx.Done():
123 ticker.Stop()
124 return
125 }
126 }
127 }()
128
129 go func() {
130 ctx, cancel := context.WithCancel(ctx)
131 peerConnection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
132 log.Log(ctx, "track received", "track", track.ID())
133
134 // Determine track type
135 trackType := "video"
136 if track.Kind() == webrtc.RTPCodecTypeAudio {
137 trackType = "audio"
138 }
139
140 trackStat := stats[trackType]
141
142 for {
143 if ctx.Err() != nil {
144 return
145 }
146 rtp, _, err := track.ReadRTP()
147 if err != nil {
148 log.Log(ctx, "error reading RTP", "error", err)
149 cancel()
150 return
151 }
152
153 trackStat.mu.Lock()
154 trackStat.Total += len(rtp.Payload)
155 trackStat.mu.Unlock()
156 }
157 })
158 peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) {
159 log.Log(ctx, "WHEP connection State has changed", "state", connectionState.String())
160 for _, state := range failureStates {
161 if connectionState == state {
162 log.Log(ctx, "connection failed, cancelling")
163 cancel()
164 }
165 }
166 })
167
168 <-ctx.Done()
169 peerConnection.Close()
170 }()
171 peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
172 log.Log(ctx, "ICE candidate", "candidate", candidate)
173 })
174 if _, err := peerConnection.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, webrtc.RTPTransceiverInit{
175 Direction: webrtc.RTPTransceiverDirectionRecvonly,
176 }); err != nil {
177 return nil, fmt.Errorf("failed to add video transceiver: %w", err)
178 }
179 if _, err := peerConnection.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RTPTransceiverInit{
180 Direction: webrtc.RTPTransceiverDirectionRecvonly,
181 }); err != nil {
182 return nil, fmt.Errorf("failed to add audio transceiver: %w", err)
183 }
184
185 // Create an offer
186 offer, err := peerConnection.CreateOffer(nil)
187 if err != nil {
188 return nil, err
189 }
190
191 // Set the generated offer as our LocalDescription
192 err = peerConnection.SetLocalDescription(offer)
193 if err != nil {
194 return nil, err
195 }
196
197 // Wait for ICE gathering to complete
198 // gatherComplete := webrtc.GatheringCompletePromise(peerConnection)
199 // <-gatherComplete
200
201 // Create HTTP client and prepare the request
202 client := &http.Client{}
203
204 // Send the WHIP request to the server
205 req, err := http.NewRequest("POST", w.Endpoint, strings.NewReader(offer.SDP))
206 if err != nil {
207 return nil, err
208 }
209 req.Header.Set("Content-Type", "application/sdp")
210
211 // Execute the request
212 resp, err := client.Do(req)
213 if err != nil {
214 return nil, err
215 }
216 defer resp.Body.Close()
217 if resp.StatusCode != 201 {
218 return nil, fmt.Errorf("status code: %d", resp.StatusCode)
219 }
220
221 // Read and process the answer
222 answerBytes, err := io.ReadAll(resp.Body)
223 if err != nil {
224 return nil, err
225 }
226
227 // Parse the SDP answer
228 var answer webrtc.SessionDescription
229 answer.Type = webrtc.SDPTypeAnswer
230 answer.SDP = string(answerBytes)
231
232 // Apply the answer as remote description
233 err = peerConnection.SetRemoteDescription(answer)
234 if err != nil {
235 return nil, err
236 }
237
238 gatherComplete := webrtc.GatheringCompletePromise(peerConnection)
239 <-gatherComplete
240
241 conn := &WHEPConnection{
242 peerConnection: peerConnection,
243 Done: ctx.Done,
244 }
245
246 return conn, nil
247}
248
249func (w *WHEPClient) WHEP(ctx context.Context) error {
250 w.Stats = []map[string]*TrackStats{}
251 ctx, cancel := context.WithCancel(ctx)
252 defer cancel()
253
254 conns := make([]*WHEPConnection, w.Count)
255 g := &errgroup.Group{}
256 for i := 0; i < w.Count; i++ {
257 stats := map[string]*TrackStats{
258 "video": {lastUpdate: time.Now()},
259 "audio": {lastUpdate: time.Now()},
260 }
261 w.Stats = append(w.Stats, stats)
262 g.Go(func() error {
263 conn, err := w.StartWHEPConnection(ctx, stats)
264 if err != nil {
265 return err
266 }
267 conns[i] = conn
268
269 <-conn.Done()
270
271 return nil
272 })
273 }
274
275 err := g.Wait()
276 if err != nil {
277 return err
278 }
279 // if err := g.Wait(); err != nil {
280 // if err := g.Wait(); err != nil {
281 // return err
282 // }
283 // // Start a ticker to print elapsed duration every second
284 // go func() {
285 // ticker := time.NewTicker(time.Second)
286 // defer ticker.Stop()
287
288 // for {
289 // <-ticker.C
290 // for i, duration := range accumulators {
291 // trackType := "video"
292 // if i == 1 {
293 // trackType = "audio"
294 // }
295 // target := startTime.Add(time.Duration(accumulators[i]))
296 // diff := time.Since(target)
297 // log.Debug(ctx, "elapsed duration", "track", trackType, "duration", duration, "diff", diff)
298 // }
299 // }
300 // }()
301
302 // errCh := make(chan error, 1)
303
304 // for i, _ := range sinks {
305 // func(i int) {
306 // sink := sinks[i]
307 // trackType := "video"
308 // if i == 1 {
309 // trackType = "audio"
310 // }
311
312 // sink.SetCallbacks(&app.SinkCallbacks{
313 // NewSampleFunc: func(sink *app.Sink) gst.FlowReturn {
314
315 // sample := sink.PullSample()
316 // if sample == nil {
317 // return gst.FlowEOS
318 // }
319
320 // buffer := sample.GetBuffer()
321 // if buffer == nil {
322 // return gst.FlowError
323 // }
324
325 // samples := buffer.Map(gst.MapRead).Bytes()
326 // defer buffer.Unmap()
327
328 // durationPtr := buffer.Duration().AsDuration()
329 // var duration time.Duration
330 // if durationPtr == nil {
331 // errCh <- fmt.Errorf("%v duration: nil", trackType)
332 // return gst.FlowError
333 // } else {
334 // // fmt.Printf("%v duration: %v\n", trackType, *durationPtr)
335 // duration = *durationPtr
336 // }
337
338 // accumulators[i] += duration
339
340 // if w.FreezeAfter == 0 || time.Since(startTime) < w.FreezeAfter {
341 // for _, conn := range conns {
342 // if trackType == "video" {
343 // if err := conn.videoTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil {
344 // log.Log(ctx, "error writing video sample", "error", err)
345 // errCh <- err
346 // return gst.FlowError
347 // }
348 // } else {
349 // if err := conn.audioTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil {
350 // log.Log(ctx, "error writing video sample", "error", err)
351 // errCh <- err
352 // return gst.FlowError
353 // }
354 // }
355 // }
356 // }
357
358 // return gst.FlowOK
359 // },
360 // })
361 // }(i)
362 // }
363
364 // go func() {
365 // media.HandleBusMessages(ctx, pipeline)
366 // cancel()
367 // }()
368
369 // if err = pipeline.SetState(gst.StatePlaying); err != nil {
370 // return err
371 // }
372 // select {
373 // case err := <-errCh:
374 // return err
375 // case <-ctx.Done():
376 // return ctx.Err()
377 // }
378
379 return nil
380}