Live video on the AT Protocol
1package cmd
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "io"
8 "net/http"
9 "strings"
10 "sync"
11 "time"
12
13 "github.com/pion/webrtc/v4"
14 "golang.org/x/sync/errgroup"
15 "stream.place/streamplace/pkg/log"
16)
17
18func WHEP(args []string) error {
19 fs := flag.NewFlagSet("whep", flag.ExitOnError)
20 count := fs.Int("count", 1, "number of concurrent streams (for load testing)")
21 duration := fs.Duration("duration", 0, "stop after this long")
22 endpoint := fs.String("endpoint", "", "endpoint to send the WHEP request to")
23 err := fs.Parse(args)
24
25 if err != nil {
26 return err
27 }
28
29 ctx := context.Background()
30 if *duration > 0 {
31 var cancel context.CancelFunc
32 ctx, cancel = context.WithTimeout(ctx, *duration)
33 defer cancel()
34 }
35
36 w := &WHEPClient{
37 Endpoint: *endpoint,
38 Count: *count,
39 }
40
41 return w.WHEP(ctx)
42}
43
44type WHEPClient struct {
45 StreamKey string
46 File string
47 Endpoint string
48 Count int
49 FreezeAfter time.Duration
50}
51
52type WHEPConnection struct {
53 peerConnection *webrtc.PeerConnection
54 audioTrack *webrtc.TrackLocalStaticSample
55 videoTrack *webrtc.TrackLocalStaticSample
56 did string
57 Done func() <-chan struct{}
58}
59
60func (w *WHEPClient) StartWHEPConnection(ctx context.Context) (*WHEPConnection, error) {
61
62 ctx, cancel := context.WithCancel(ctx)
63
64 // Prepare the configuration
65 config := webrtc.Configuration{}
66
67 // Create a new RTCPeerConnection
68 peerConnection, err := webrtc.NewPeerConnection(config)
69 if err != nil {
70 return nil, err
71 }
72
73 // Track statistics
74 type trackStats struct {
75 total int
76 lastTotal int
77 lastUpdate time.Time
78 mu sync.Mutex
79 }
80
81 stats := map[string]*trackStats{
82 "video": {lastUpdate: time.Now()},
83 "audio": {lastUpdate: time.Now()},
84 }
85
86 // Create a ticker to print combined bitrate every 5 seconds
87 ticker := time.NewTicker(5 * time.Second)
88
89 // Start a goroutine to print combined bitrate
90 go func() {
91 for {
92 select {
93 case <-ticker.C:
94 currentTime := time.Now()
95
96 // Lock both stats to get a consistent snapshot
97 for _, s := range stats {
98 s.mu.Lock()
99 }
100
101 videoStats := stats["video"]
102 audioStats := stats["audio"]
103
104 videoElapsed := currentTime.Sub(videoStats.lastUpdate).Seconds()
105 audioElapsed := currentTime.Sub(audioStats.lastUpdate).Seconds()
106
107 videoBytes := videoStats.total - videoStats.lastTotal
108 audioBytes := audioStats.total - audioStats.lastTotal
109
110 videoBitrate := float64(videoBytes) * 8 / videoElapsed / 1000 // kbps
111 audioBitrate := float64(audioBytes) * 8 / audioElapsed / 1000 // kbps
112
113 log.Log(ctx, "bitrate stats",
114 "video", fmt.Sprintf("%.2f kbps (%.2f KB)", videoBitrate, float64(videoBytes)/1000),
115 "audio", fmt.Sprintf("%.2f kbps (%.2f KB)", audioBitrate, float64(audioBytes)/1000),
116 "total", fmt.Sprintf("%.2f kbps", videoBitrate+audioBitrate))
117
118 // Update last values
119 videoStats.lastTotal = videoStats.total
120 videoStats.lastUpdate = currentTime
121 audioStats.lastTotal = audioStats.total
122 audioStats.lastUpdate = currentTime
123
124 // Unlock stats
125 for _, s := range stats {
126 s.mu.Unlock()
127 }
128
129 case <-ctx.Done():
130 ticker.Stop()
131 return
132 }
133 }
134 }()
135
136 peerConnection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
137 log.Log(ctx, "track received", "track", track.ID())
138
139 // Determine track type
140 trackType := "video"
141 if track.Kind() == webrtc.RTPCodecTypeAudio {
142 trackType = "audio"
143 }
144
145 trackStat := stats[trackType]
146
147 for {
148 if ctx.Err() != nil {
149 return
150 }
151 rtp, _, err := track.ReadRTP()
152 if err != nil {
153 log.Log(ctx, "error reading RTP", "error", err)
154 cancel()
155 return
156 }
157
158 trackStat.mu.Lock()
159 trackStat.total += len(rtp.Payload)
160 trackStat.mu.Unlock()
161 }
162 })
163 peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) {
164 log.Log(ctx, "WHEP connection State has changed", "state", connectionState.String())
165 for _, state := range failureStates {
166 if connectionState == state {
167 log.Log(ctx, "connection failed, cancelling")
168 cancel()
169 }
170 }
171 })
172 peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
173 log.Log(ctx, "ICE candidate", "candidate", candidate)
174 })
175 peerConnection.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, webrtc.RTPTransceiverInit{
176 Direction: webrtc.RTPTransceiverDirectionRecvonly,
177 })
178 peerConnection.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RTPTransceiverInit{
179 Direction: webrtc.RTPTransceiverDirectionRecvonly,
180 })
181
182 // Create an offer
183 offer, err := peerConnection.CreateOffer(nil)
184 if err != nil {
185 return nil, err
186 }
187
188 // Set the generated offer as our LocalDescription
189 err = peerConnection.SetLocalDescription(offer)
190 if err != nil {
191 return nil, err
192 }
193
194 // Wait for ICE gathering to complete
195 // gatherComplete := webrtc.GatheringCompletePromise(peerConnection)
196 // <-gatherComplete
197
198 // Create HTTP client and prepare the request
199 client := &http.Client{}
200
201 // Send the WHIP request to the server
202 req, err := http.NewRequest("POST", w.Endpoint, strings.NewReader(offer.SDP))
203 if err != nil {
204 return nil, err
205 }
206 req.Header.Set("Content-Type", "application/sdp")
207
208 // Execute the request
209 resp, err := client.Do(req)
210 if err != nil {
211 return nil, err
212 }
213 defer resp.Body.Close()
214 if resp.StatusCode != 201 {
215 return nil, fmt.Errorf("status code: %d", resp.StatusCode)
216 }
217
218 // Read and process the answer
219 answerBytes, err := io.ReadAll(resp.Body)
220 if err != nil {
221 return nil, err
222 }
223
224 // Parse the SDP answer
225 var answer webrtc.SessionDescription
226 answer.Type = webrtc.SDPTypeAnswer
227 answer.SDP = string(answerBytes)
228
229 // Apply the answer as remote description
230 err = peerConnection.SetRemoteDescription(answer)
231 if err != nil {
232 return nil, err
233 }
234
235 gatherComplete := webrtc.GatheringCompletePromise(peerConnection)
236 <-gatherComplete
237
238 conn := &WHEPConnection{
239 peerConnection: peerConnection,
240 Done: ctx.Done,
241 }
242
243 return conn, nil
244}
245
246func (w *WHEPClient) WHEP(ctx context.Context) error {
247 ctx, cancel := context.WithCancel(ctx)
248 defer cancel()
249
250 conns := make([]*WHEPConnection, w.Count)
251 g := &errgroup.Group{}
252 for i := 0; i < w.Count; i++ {
253 g.Go(func() error {
254 conn, err := w.StartWHEPConnection(ctx)
255 if err != nil {
256 return err
257 }
258 conns[i] = conn
259
260 <-conn.Done()
261
262 return nil
263 })
264 }
265
266 err := g.Wait()
267 if err != nil {
268 return err
269 }
270 // if err := g.Wait(); err != nil {
271 // if err := g.Wait(); err != nil {
272 // return err
273 // }
274 // // Start a ticker to print elapsed duration every second
275 // go func() {
276 // ticker := time.NewTicker(time.Second)
277 // defer ticker.Stop()
278
279 // for {
280 // <-ticker.C
281 // for i, duration := range accumulators {
282 // trackType := "video"
283 // if i == 1 {
284 // trackType = "audio"
285 // }
286 // target := startTime.Add(time.Duration(accumulators[i]))
287 // diff := time.Since(target)
288 // log.Debug(ctx, "elapsed duration", "track", trackType, "duration", duration, "diff", diff)
289 // }
290 // }
291 // }()
292
293 // errCh := make(chan error, 1)
294
295 // for i, _ := range sinks {
296 // func(i int) {
297 // sink := sinks[i]
298 // trackType := "video"
299 // if i == 1 {
300 // trackType = "audio"
301 // }
302
303 // sink.SetCallbacks(&app.SinkCallbacks{
304 // NewSampleFunc: func(sink *app.Sink) gst.FlowReturn {
305
306 // sample := sink.PullSample()
307 // if sample == nil {
308 // return gst.FlowEOS
309 // }
310
311 // buffer := sample.GetBuffer()
312 // if buffer == nil {
313 // return gst.FlowError
314 // }
315
316 // samples := buffer.Map(gst.MapRead).Bytes()
317 // defer buffer.Unmap()
318
319 // durationPtr := buffer.Duration().AsDuration()
320 // var duration time.Duration
321 // if durationPtr == nil {
322 // errCh <- fmt.Errorf("%v duration: nil", trackType)
323 // return gst.FlowError
324 // } else {
325 // // fmt.Printf("%v duration: %v\n", trackType, *durationPtr)
326 // duration = *durationPtr
327 // }
328
329 // accumulators[i] += duration
330
331 // if w.FreezeAfter == 0 || time.Since(startTime) < w.FreezeAfter {
332 // for _, conn := range conns {
333 // if trackType == "video" {
334 // if err := conn.videoTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil {
335 // log.Log(ctx, "error writing video sample", "error", err)
336 // errCh <- err
337 // return gst.FlowError
338 // }
339 // } else {
340 // if err := conn.audioTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil {
341 // log.Log(ctx, "error writing video sample", "error", err)
342 // errCh <- err
343 // return gst.FlowError
344 // }
345 // }
346 // }
347 // }
348
349 // return gst.FlowOK
350 // },
351 // })
352 // }(i)
353 // }
354
355 // go func() {
356 // media.HandleBusMessages(ctx, pipeline)
357 // cancel()
358 // }()
359
360 // if err = pipeline.SetState(gst.StatePlaying); err != nil {
361 // return err
362 // }
363 // select {
364 // case err := <-errCh:
365 // return err
366 // case <-ctx.Done():
367 // return ctx.Err()
368 // }
369
370 return nil
371}