Live video on the AT Protocol
1package cmd
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "io"
8 "net/http"
9 "strings"
10 "sync"
11 "time"
12
13 "github.com/pion/webrtc/v4"
14 "golang.org/x/sync/errgroup"
15 "stream.place/streamplace/pkg/log"
16)
17
18func WHEP(args []string) error {
19 fs := flag.NewFlagSet("whep", flag.ExitOnError)
20 count := fs.Int("count", 1, "number of concurrent streams (for load testing)")
21 duration := fs.Duration("duration", 0, "stop after this long")
22 endpoint := fs.String("endpoint", "", "endpoint to send the WHEP request to")
23 err := fs.Parse(args)
24
25 if err != nil {
26 return err
27 }
28
29 ctx := context.Background()
30 if *duration > 0 {
31 var cancel context.CancelFunc
32 ctx, cancel = context.WithTimeout(ctx, *duration)
33 defer cancel()
34 }
35
36 w := &WHEPClient{
37 Endpoint: *endpoint,
38 Count: *count,
39 }
40
41 return w.WHEP(ctx)
42}
43
44type WHEPClient struct {
45 StreamKey string
46 File string
47 Endpoint string
48 Count int
49 FreezeAfter time.Duration
50}
51
52type WHEPConnection struct {
53 peerConnection *webrtc.PeerConnection
54 audioTrack *webrtc.TrackLocalStaticSample
55 videoTrack *webrtc.TrackLocalStaticSample
56 did string
57 Done func() <-chan struct{}
58}
59
60func (w *WHEPClient) StartWHEPConnection(ctx context.Context) (*WHEPConnection, error) {
61
62 // Prepare the configuration
63 config := webrtc.Configuration{}
64
65 // Create a new RTCPeerConnection
66 peerConnection, err := webrtc.NewPeerConnection(config)
67 if err != nil {
68 return nil, err
69 }
70
71 // Track statistics
72 type trackStats struct {
73 total int
74 lastTotal int
75 lastUpdate time.Time
76 mu sync.Mutex
77 }
78
79 stats := map[string]*trackStats{
80 "video": {lastUpdate: time.Now()},
81 "audio": {lastUpdate: time.Now()},
82 }
83
84 // Create a ticker to print combined bitrate every 5 seconds
85 ticker := time.NewTicker(5 * time.Second)
86
87 // Start a goroutine to print combined bitrate
88 go func() {
89 for {
90 select {
91 case <-ticker.C:
92 currentTime := time.Now()
93
94 // Lock both stats to get a consistent snapshot
95 for _, s := range stats {
96 s.mu.Lock()
97 }
98
99 videoStats := stats["video"]
100 audioStats := stats["audio"]
101
102 videoElapsed := currentTime.Sub(videoStats.lastUpdate).Seconds()
103 audioElapsed := currentTime.Sub(audioStats.lastUpdate).Seconds()
104
105 videoBytes := videoStats.total - videoStats.lastTotal
106 audioBytes := audioStats.total - audioStats.lastTotal
107
108 videoBitrate := float64(videoBytes) * 8 / videoElapsed / 1000 // kbps
109 audioBitrate := float64(audioBytes) * 8 / audioElapsed / 1000 // kbps
110
111 log.Log(ctx, "bitrate stats",
112 "video", fmt.Sprintf("%.2f kbps (%.2f KB)", videoBitrate, float64(videoBytes)/1000),
113 "audio", fmt.Sprintf("%.2f kbps (%.2f KB)", audioBitrate, float64(audioBytes)/1000),
114 "total", fmt.Sprintf("%.2f kbps", videoBitrate+audioBitrate))
115
116 // Update last values
117 videoStats.lastTotal = videoStats.total
118 videoStats.lastUpdate = currentTime
119 audioStats.lastTotal = audioStats.total
120 audioStats.lastUpdate = currentTime
121
122 // Unlock stats
123 for _, s := range stats {
124 s.mu.Unlock()
125 }
126
127 case <-ctx.Done():
128 ticker.Stop()
129 return
130 }
131 }
132 }()
133
134 go func() {
135 ctx, cancel := context.WithCancel(ctx)
136 peerConnection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
137 log.Log(ctx, "track received", "track", track.ID())
138
139 // Determine track type
140 trackType := "video"
141 if track.Kind() == webrtc.RTPCodecTypeAudio {
142 trackType = "audio"
143 }
144
145 trackStat := stats[trackType]
146
147 for {
148 if ctx.Err() != nil {
149 return
150 }
151 rtp, _, err := track.ReadRTP()
152 if err != nil {
153 log.Log(ctx, "error reading RTP", "error", err)
154 cancel()
155 return
156 }
157
158 trackStat.mu.Lock()
159 trackStat.total += len(rtp.Payload)
160 trackStat.mu.Unlock()
161 }
162 })
163 peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) {
164 log.Log(ctx, "WHEP connection State has changed", "state", connectionState.String())
165 for _, state := range failureStates {
166 if connectionState == state {
167 log.Log(ctx, "connection failed, cancelling")
168 cancel()
169 }
170 }
171 })
172
173 <-ctx.Done()
174 peerConnection.Close()
175 }()
176 peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
177 log.Log(ctx, "ICE candidate", "candidate", candidate)
178 })
179 if _, err := peerConnection.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, webrtc.RTPTransceiverInit{
180 Direction: webrtc.RTPTransceiverDirectionRecvonly,
181 }); err != nil {
182 return nil, fmt.Errorf("failed to add video transceiver: %w", err)
183 }
184 if _, err := peerConnection.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RTPTransceiverInit{
185 Direction: webrtc.RTPTransceiverDirectionRecvonly,
186 }); err != nil {
187 return nil, fmt.Errorf("failed to add audio transceiver: %w", err)
188 }
189
190 // Create an offer
191 offer, err := peerConnection.CreateOffer(nil)
192 if err != nil {
193 return nil, err
194 }
195
196 // Set the generated offer as our LocalDescription
197 err = peerConnection.SetLocalDescription(offer)
198 if err != nil {
199 return nil, err
200 }
201
202 // Wait for ICE gathering to complete
203 // gatherComplete := webrtc.GatheringCompletePromise(peerConnection)
204 // <-gatherComplete
205
206 // Create HTTP client and prepare the request
207 client := &http.Client{}
208
209 // Send the WHIP request to the server
210 req, err := http.NewRequest("POST", w.Endpoint, strings.NewReader(offer.SDP))
211 if err != nil {
212 return nil, err
213 }
214 req.Header.Set("Content-Type", "application/sdp")
215
216 // Execute the request
217 resp, err := client.Do(req)
218 if err != nil {
219 return nil, err
220 }
221 defer resp.Body.Close()
222 if resp.StatusCode != 201 {
223 return nil, fmt.Errorf("status code: %d", resp.StatusCode)
224 }
225
226 // Read and process the answer
227 answerBytes, err := io.ReadAll(resp.Body)
228 if err != nil {
229 return nil, err
230 }
231
232 // Parse the SDP answer
233 var answer webrtc.SessionDescription
234 answer.Type = webrtc.SDPTypeAnswer
235 answer.SDP = string(answerBytes)
236
237 // Apply the answer as remote description
238 err = peerConnection.SetRemoteDescription(answer)
239 if err != nil {
240 return nil, err
241 }
242
243 gatherComplete := webrtc.GatheringCompletePromise(peerConnection)
244 <-gatherComplete
245
246 conn := &WHEPConnection{
247 peerConnection: peerConnection,
248 Done: ctx.Done,
249 }
250
251 return conn, nil
252}
253
254func (w *WHEPClient) WHEP(ctx context.Context) error {
255 ctx, cancel := context.WithCancel(ctx)
256 defer cancel()
257
258 conns := make([]*WHEPConnection, w.Count)
259 g := &errgroup.Group{}
260 for i := 0; i < w.Count; i++ {
261 g.Go(func() error {
262 conn, err := w.StartWHEPConnection(ctx)
263 if err != nil {
264 return err
265 }
266 conns[i] = conn
267
268 <-conn.Done()
269
270 return nil
271 })
272 }
273
274 err := g.Wait()
275 if err != nil {
276 return err
277 }
278 // if err := g.Wait(); err != nil {
279 // if err := g.Wait(); err != nil {
280 // return err
281 // }
282 // // Start a ticker to print elapsed duration every second
283 // go func() {
284 // ticker := time.NewTicker(time.Second)
285 // defer ticker.Stop()
286
287 // for {
288 // <-ticker.C
289 // for i, duration := range accumulators {
290 // trackType := "video"
291 // if i == 1 {
292 // trackType = "audio"
293 // }
294 // target := startTime.Add(time.Duration(accumulators[i]))
295 // diff := time.Since(target)
296 // log.Debug(ctx, "elapsed duration", "track", trackType, "duration", duration, "diff", diff)
297 // }
298 // }
299 // }()
300
301 // errCh := make(chan error, 1)
302
303 // for i, _ := range sinks {
304 // func(i int) {
305 // sink := sinks[i]
306 // trackType := "video"
307 // if i == 1 {
308 // trackType = "audio"
309 // }
310
311 // sink.SetCallbacks(&app.SinkCallbacks{
312 // NewSampleFunc: func(sink *app.Sink) gst.FlowReturn {
313
314 // sample := sink.PullSample()
315 // if sample == nil {
316 // return gst.FlowEOS
317 // }
318
319 // buffer := sample.GetBuffer()
320 // if buffer == nil {
321 // return gst.FlowError
322 // }
323
324 // samples := buffer.Map(gst.MapRead).Bytes()
325 // defer buffer.Unmap()
326
327 // durationPtr := buffer.Duration().AsDuration()
328 // var duration time.Duration
329 // if durationPtr == nil {
330 // errCh <- fmt.Errorf("%v duration: nil", trackType)
331 // return gst.FlowError
332 // } else {
333 // // fmt.Printf("%v duration: %v\n", trackType, *durationPtr)
334 // duration = *durationPtr
335 // }
336
337 // accumulators[i] += duration
338
339 // if w.FreezeAfter == 0 || time.Since(startTime) < w.FreezeAfter {
340 // for _, conn := range conns {
341 // if trackType == "video" {
342 // if err := conn.videoTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil {
343 // log.Log(ctx, "error writing video sample", "error", err)
344 // errCh <- err
345 // return gst.FlowError
346 // }
347 // } else {
348 // if err := conn.audioTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil {
349 // log.Log(ctx, "error writing video sample", "error", err)
350 // errCh <- err
351 // return gst.FlowError
352 // }
353 // }
354 // }
355 // }
356
357 // return gst.FlowOK
358 // },
359 // })
360 // }(i)
361 // }
362
363 // go func() {
364 // media.HandleBusMessages(ctx, pipeline)
365 // cancel()
366 // }()
367
368 // if err = pipeline.SetState(gst.StatePlaying); err != nil {
369 // return err
370 // }
371 // select {
372 // case err := <-errCh:
373 // return err
374 // case <-ctx.Done():
375 // return ctx.Err()
376 // }
377
378 return nil
379}