Live video on the AT Protocol
1package cmd
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "io"
8 "net/http"
9 "strings"
10 "sync"
11 "time"
12
13 "github.com/pion/webrtc/v4"
14 "golang.org/x/sync/errgroup"
15 "stream.place/streamplace/pkg/log"
16)
17
18func WHEP(args []string) error {
19 fs := flag.NewFlagSet("whep", flag.ExitOnError)
20 count := fs.Int("count", 1, "number of concurrent streams (for load testing)")
21 duration := fs.Duration("duration", 0, "stop after this long")
22 endpoint := fs.String("endpoint", "", "endpoint to send the WHEP request to")
23 err := fs.Parse(args)
24
25 if err != nil {
26 return err
27 }
28
29 ctx := context.Background()
30 if *duration > 0 {
31 var cancel context.CancelFunc
32 ctx, cancel = context.WithTimeout(ctx, *duration)
33 defer cancel()
34 }
35
36 w := &WHEPClient{
37 Endpoint: *endpoint,
38 Count: *count,
39 }
40
41 return w.WHEP(ctx)
42}
43
44type WHEPClient struct {
45 StreamKey string
46 File string
47 Endpoint string
48 Count int
49 FreezeAfter time.Duration
50 Stats []map[string]*TrackStats
51}
52
53type WHEPConnection struct {
54 peerConnection *webrtc.PeerConnection
55 audioTrack *webrtc.TrackLocalStaticSample
56 videoTrack *webrtc.TrackLocalStaticSample
57 did string
58 Done func() <-chan struct{}
59}
60
61type TrackStats struct {
62 Total int
63 lastTotal int
64 lastUpdate time.Time
65 mu sync.Mutex
66}
67
68func (w *WHEPClient) StartWHEPConnection(ctx context.Context, stats map[string]*TrackStats) (*WHEPConnection, error) {
69
70 // Prepare the configuration with STUN servers for ICE connectivity
71 config := webrtc.Configuration{
72 ICEServers: []webrtc.ICEServer{
73 {
74 URLs: []string{"stun:stun.l.google.com:19302"},
75 },
76 },
77 }
78
79 // Create a new RTCPeerConnection
80 peerConnection, err := webrtc.NewPeerConnection(config)
81 if err != nil {
82 return nil, err
83 }
84
85 // Create a ticker to print combined bitrate every 5 seconds
86 ticker := time.NewTicker(5 * time.Second)
87
88 // Start a goroutine to print combined bitrate
89 go func() {
90 for {
91 select {
92 case <-ticker.C:
93 currentTime := time.Now()
94
95 // Lock both stats to get a consistent snapshot
96 for _, s := range stats {
97 s.mu.Lock()
98 }
99
100 videoStats := stats["video"]
101 audioStats := stats["audio"]
102
103 videoElapsed := currentTime.Sub(videoStats.lastUpdate).Seconds()
104 audioElapsed := currentTime.Sub(audioStats.lastUpdate).Seconds()
105
106 videoBytes := videoStats.Total - videoStats.lastTotal
107 audioBytes := audioStats.Total - audioStats.lastTotal
108
109 videoBitrate := float64(videoBytes) * 8 / videoElapsed / 1000 // kbps
110 audioBitrate := float64(audioBytes) * 8 / audioElapsed / 1000 // kbps
111
112 log.Log(ctx, "bitrate stats",
113 "video", fmt.Sprintf("%.2f kbps (%.2f KB)", videoBitrate, float64(videoBytes)/1000),
114 "audio", fmt.Sprintf("%.2f kbps (%.2f KB)", audioBitrate, float64(audioBytes)/1000),
115 "total", fmt.Sprintf("%.2f kbps", videoBitrate+audioBitrate))
116
117 // Update last values
118 videoStats.lastTotal = videoStats.Total
119 videoStats.lastUpdate = currentTime
120 audioStats.lastTotal = audioStats.Total
121 audioStats.lastUpdate = currentTime
122
123 // Unlock stats
124 for _, s := range stats {
125 s.mu.Unlock()
126 }
127
128 case <-ctx.Done():
129 ticker.Stop()
130 return
131 }
132 }
133 }()
134
135 go func() {
136 ctx, cancel := context.WithCancel(ctx)
137 peerConnection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
138 log.Log(ctx, "track received", "track", track.ID())
139
140 // Determine track type
141 trackType := "video"
142 if track.Kind() == webrtc.RTPCodecTypeAudio {
143 trackType = "audio"
144 }
145
146 trackStat := stats[trackType]
147
148 for {
149 if ctx.Err() != nil {
150 return
151 }
152 rtp, _, err := track.ReadRTP()
153 if err != nil {
154 log.Log(ctx, "error reading RTP", "error", err)
155 cancel()
156 return
157 }
158
159 trackStat.mu.Lock()
160 trackStat.Total += len(rtp.Payload)
161 trackStat.mu.Unlock()
162 }
163 })
164 peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) {
165 log.Log(ctx, "WHEP connection State has changed", "state", connectionState.String())
166 for _, state := range failureStates {
167 if connectionState == state {
168 log.Log(ctx, "connection failed, cancelling")
169 cancel()
170 }
171 }
172 })
173
174 <-ctx.Done()
175 peerConnection.Close()
176 }()
177 peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
178 log.Log(ctx, "ICE candidate", "candidate", candidate)
179 })
180 if _, err := peerConnection.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, webrtc.RTPTransceiverInit{
181 Direction: webrtc.RTPTransceiverDirectionRecvonly,
182 }); err != nil {
183 return nil, fmt.Errorf("failed to add video transceiver: %w", err)
184 }
185 if _, err := peerConnection.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RTPTransceiverInit{
186 Direction: webrtc.RTPTransceiverDirectionRecvonly,
187 }); err != nil {
188 return nil, fmt.Errorf("failed to add audio transceiver: %w", err)
189 }
190
191 // Create an offer
192 offer, err := peerConnection.CreateOffer(nil)
193 if err != nil {
194 return nil, err
195 }
196
197 // Set the generated offer as our LocalDescription
198 err = peerConnection.SetLocalDescription(offer)
199 if err != nil {
200 return nil, err
201 }
202
203 // Wait for ICE gathering to complete
204 // gatherComplete := webrtc.GatheringCompletePromise(peerConnection)
205 // <-gatherComplete
206
207 // Create HTTP client and prepare the request
208 client := &http.Client{}
209
210 // Send the WHIP request to the server
211 req, err := http.NewRequest("POST", w.Endpoint, strings.NewReader(offer.SDP))
212 if err != nil {
213 return nil, err
214 }
215 req.Header.Set("Content-Type", "application/sdp")
216
217 // Execute the request
218 resp, err := client.Do(req)
219 if err != nil {
220 return nil, err
221 }
222 defer resp.Body.Close()
223 if resp.StatusCode != 201 {
224 return nil, fmt.Errorf("status code: %d", resp.StatusCode)
225 }
226
227 // Read and process the answer
228 answerBytes, err := io.ReadAll(resp.Body)
229 if err != nil {
230 return nil, err
231 }
232
233 // Parse the SDP answer
234 var answer webrtc.SessionDescription
235 answer.Type = webrtc.SDPTypeAnswer
236 answer.SDP = string(answerBytes)
237
238 // Apply the answer as remote description
239 err = peerConnection.SetRemoteDescription(answer)
240 if err != nil {
241 return nil, err
242 }
243
244 gatherComplete := webrtc.GatheringCompletePromise(peerConnection)
245 <-gatherComplete
246
247 conn := &WHEPConnection{
248 peerConnection: peerConnection,
249 Done: ctx.Done,
250 }
251
252 return conn, nil
253}
254
255func (w *WHEPClient) WHEP(ctx context.Context) error {
256 w.Stats = []map[string]*TrackStats{}
257 ctx, cancel := context.WithCancel(ctx)
258 defer cancel()
259
260 conns := make([]*WHEPConnection, w.Count)
261 g := &errgroup.Group{}
262 for i := 0; i < w.Count; i++ {
263 stats := map[string]*TrackStats{
264 "video": {lastUpdate: time.Now()},
265 "audio": {lastUpdate: time.Now()},
266 }
267 w.Stats = append(w.Stats, stats)
268 g.Go(func() error {
269 conn, err := w.StartWHEPConnection(ctx, stats)
270 if err != nil {
271 return err
272 }
273 conns[i] = conn
274
275 <-conn.Done()
276
277 return nil
278 })
279 }
280
281 err := g.Wait()
282 if err != nil {
283 return err
284 }
285 // if err := g.Wait(); err != nil {
286 // if err := g.Wait(); err != nil {
287 // return err
288 // }
289 // // Start a ticker to print elapsed duration every second
290 // go func() {
291 // ticker := time.NewTicker(time.Second)
292 // defer ticker.Stop()
293
294 // for {
295 // <-ticker.C
296 // for i, duration := range accumulators {
297 // trackType := "video"
298 // if i == 1 {
299 // trackType = "audio"
300 // }
301 // target := startTime.Add(time.Duration(accumulators[i]))
302 // diff := time.Since(target)
303 // log.Debug(ctx, "elapsed duration", "track", trackType, "duration", duration, "diff", diff)
304 // }
305 // }
306 // }()
307
308 // errCh := make(chan error, 1)
309
310 // for i, _ := range sinks {
311 // func(i int) {
312 // sink := sinks[i]
313 // trackType := "video"
314 // if i == 1 {
315 // trackType = "audio"
316 // }
317
318 // sink.SetCallbacks(&app.SinkCallbacks{
319 // NewSampleFunc: func(sink *app.Sink) gst.FlowReturn {
320
321 // sample := sink.PullSample()
322 // if sample == nil {
323 // return gst.FlowEOS
324 // }
325
326 // buffer := sample.GetBuffer()
327 // if buffer == nil {
328 // return gst.FlowError
329 // }
330
331 // samples := buffer.Map(gst.MapRead).Bytes()
332 // defer buffer.Unmap()
333
334 // durationPtr := buffer.Duration().AsDuration()
335 // var duration time.Duration
336 // if durationPtr == nil {
337 // errCh <- fmt.Errorf("%v duration: nil", trackType)
338 // return gst.FlowError
339 // } else {
340 // // fmt.Printf("%v duration: %v\n", trackType, *durationPtr)
341 // duration = *durationPtr
342 // }
343
344 // accumulators[i] += duration
345
346 // if w.FreezeAfter == 0 || time.Since(startTime) < w.FreezeAfter {
347 // for _, conn := range conns {
348 // if trackType == "video" {
349 // if err := conn.videoTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil {
350 // log.Log(ctx, "error writing video sample", "error", err)
351 // errCh <- err
352 // return gst.FlowError
353 // }
354 // } else {
355 // if err := conn.audioTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil {
356 // log.Log(ctx, "error writing video sample", "error", err)
357 // errCh <- err
358 // return gst.FlowError
359 // }
360 // }
361 // }
362 // }
363
364 // return gst.FlowOK
365 // },
366 // })
367 // }(i)
368 // }
369
370 // go func() {
371 // media.HandleBusMessages(ctx, pipeline)
372 // cancel()
373 // }()
374
375 // if err = pipeline.SetState(gst.StatePlaying); err != nil {
376 // return err
377 // }
378 // select {
379 // case err := <-errCh:
380 // return err
381 // case <-ctx.Done():
382 // return ctx.Err()
383 // }
384
385 return nil
386}