Live video on the AT Protocol
1package cmd
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "io"
8 "net/http"
9 "strings"
10 "time"
11
12 "github.com/go-gst/go-gst/gst"
13 "github.com/go-gst/go-gst/gst/app"
14 "github.com/pion/webrtc/v4"
15 pionmedia "github.com/pion/webrtc/v4/pkg/media"
16 "golang.org/x/sync/errgroup"
17 "stream.place/streamplace/pkg/crypto/spkey"
18 "stream.place/streamplace/pkg/gstinit"
19 "stream.place/streamplace/pkg/log"
20 "stream.place/streamplace/pkg/media"
21)
22
23func WHIP(args []string) error {
24 fs := flag.NewFlagSet("whip", flag.ExitOnError)
25 streamKey := fs.String("stream-key", "", "stream key")
26 count := fs.Int("count", 1, "number of concurrent streams (for load testing)")
27 viewers := fs.Int("viewers", 0, "number of viewers to simulate per stream")
28 duration := fs.Duration("duration", 0, "duration of the stream")
29 file := fs.String("file", "", "file to stream (needs to be an MP4 containing H264 video and Opus audio)")
30 endpoint := fs.String("endpoint", "http://127.0.0.1:38080", "endpoint to send the WHIP request to")
31 freezeAfter := fs.Duration("freeze-after", 0, "freeze the stream after the given duration")
32 err := fs.Parse(args)
33 if *file == "" {
34 return fmt.Errorf("file is required")
35 }
36 if err != nil {
37 return err
38 }
39 gstinit.InitGST()
40
41 ctx := context.Background()
42 if *duration > 0 {
43 var cancel context.CancelFunc
44 ctx, cancel = context.WithTimeout(ctx, *duration)
45 defer cancel()
46 }
47
48 w := &WHIPClient{
49 StreamKey: *streamKey,
50 File: *file,
51 Endpoint: *endpoint,
52 Count: *count,
53 FreezeAfter: *freezeAfter,
54 Viewers: *viewers,
55 }
56
57 return w.WHIP(ctx)
58}
59
60type WHIPClient struct {
61 StreamKey string
62 File string
63 Endpoint string
64 Count int
65 FreezeAfter time.Duration
66 Viewers int
67}
68
69var failureStates = []webrtc.ICEConnectionState{
70 webrtc.ICEConnectionStateFailed,
71 webrtc.ICEConnectionStateDisconnected,
72 webrtc.ICEConnectionStateClosed,
73 webrtc.ICEConnectionStateCompleted,
74}
75
76type WHIPConnection struct {
77 peerConnection *webrtc.PeerConnection
78 audioTrack *webrtc.TrackLocalStaticSample
79 videoTrack *webrtc.TrackLocalStaticSample
80 did string
81}
82
83func (w *WHIPClient) WHIP(ctx context.Context) error {
84 ctx, cancel := context.WithCancel(ctx)
85 defer cancel()
86
87 pipelineSlice := []string{
88 "filesrc name=filesrc ! qtdemux name=demux",
89 "demux.video_0 ! tee name=video_tee",
90 "demux.audio_0 ! tee name=audio_tee",
91 "video_tee. ! queue ! h264parse config-interval=-1 ! video/x-h264,stream-format=byte-stream ! appsink name=videoappsink",
92 "audio_tee. ! queue ! opusparse ! appsink name=audioappsink",
93 // "matroskamux name=mux ! fakesink name=fakesink sync=true",
94 // "video_tee. ! mux.video_0",
95 // "audio_tee. ! mux.audio_0",
96 }
97
98 pipeline, err := gst.NewPipelineFromString(strings.Join(pipelineSlice, "\n"))
99 if err != nil {
100 return err
101 }
102
103 fileSrc, err := pipeline.GetElementByName("filesrc")
104 if err != nil {
105 return err
106 }
107
108 if err := fileSrc.Set("location", w.File); err != nil {
109 return err
110 }
111
112 videoSink, err := pipeline.GetElementByName("videoappsink")
113 if err != nil {
114 return err
115 }
116
117 audioSink, err := pipeline.GetElementByName("audioappsink")
118 if err != nil {
119 return err
120 }
121
122 startTime := time.Now()
123 sinks := []*app.Sink{
124 app.SinkFromElement(videoSink),
125 app.SinkFromElement(audioSink),
126 }
127 // Create accumulators for tracking elapsed duration
128 accumulators := make([]time.Duration, len(sinks))
129
130 conns := make([]*WHIPConnection, w.Count)
131 g := &errgroup.Group{}
132 for i := 0; i < w.Count; i++ {
133 ctx := ctx
134 // var streamKey string
135 var did string
136 var streamKey string
137 if w.StreamKey != "" {
138 streamKey = w.StreamKey
139 } else {
140 priv, pub, err := spkey.GenerateStreamKey()
141 if err != nil {
142 return err
143 }
144
145 did = pub.DIDKey()
146 ctx = log.WithLogValues(ctx, "did", did)
147 streamKey = priv.Multibase()
148 }
149
150 g.Go(func() error {
151 conn, err := w.StartWHIPConnection(ctx, streamKey, did)
152 if err != nil {
153 return err
154 }
155 conns[i] = conn
156 ctx := log.WithLogValues(ctx, "did", did)
157 conn.peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) {
158 log.Log(ctx, "WHIP connection State has changed", "state", connectionState.String())
159 for _, state := range failureStates {
160 if connectionState == state {
161 log.Log(ctx, "connection failed, cancelling")
162 cancel()
163 }
164 }
165 })
166 go func() {
167 <-ctx.Done()
168 if conn.peerConnection != nil {
169 conn.peerConnection.Close()
170 }
171 }()
172 return nil
173 })
174 }
175
176 if err := g.Wait(); err != nil {
177 return err
178 }
179
180 // Start a ticker to print elapsed duration every second
181 go func() {
182 ticker := time.NewTicker(time.Second)
183 defer ticker.Stop()
184
185 for {
186 select {
187 case <-ctx.Done():
188 return
189 case <-ticker.C:
190 for i, duration := range accumulators {
191 trackType := "video"
192 if i == 1 {
193 trackType = "audio"
194 }
195 target := startTime.Add(time.Duration(accumulators[i]))
196 diff := time.Since(target)
197 log.Debug(ctx, "elapsed duration", "track", trackType, "duration", duration, "diff", diff)
198 }
199 }
200 }
201 }()
202
203 errCh := make(chan error, 1)
204
205 for i := range sinks {
206 func(i int) {
207 sink := sinks[i]
208 trackType := "video"
209 if i == 1 {
210 trackType = "audio"
211 }
212
213 sink.SetCallbacks(&app.SinkCallbacks{
214 NewSampleFunc: func(sink *app.Sink) gst.FlowReturn {
215
216 sample := sink.PullSample()
217 if sample == nil {
218 return gst.FlowEOS
219 }
220
221 buffer := sample.GetBuffer()
222 if buffer == nil {
223 return gst.FlowError
224 }
225
226 samples := buffer.Map(gst.MapRead).Bytes()
227 defer buffer.Unmap()
228
229 durationPtr := buffer.Duration().AsDuration()
230 var duration time.Duration
231 if durationPtr == nil {
232 errCh <- fmt.Errorf("%v duration: nil", trackType)
233 return gst.FlowError
234 } else {
235 // fmt.Printf("%v duration: %v\n", trackType, *durationPtr)
236 duration = *durationPtr
237 }
238
239 accumulators[i] += duration
240
241 if w.FreezeAfter == 0 || time.Since(startTime) < w.FreezeAfter {
242 for _, conn := range conns {
243 if trackType == "video" {
244 if err := conn.videoTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil {
245 log.Log(ctx, "error writing video sample", "error", err)
246 errCh <- err
247 return gst.FlowError
248 }
249 } else {
250 if err := conn.audioTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil {
251 log.Log(ctx, "error writing video sample", "error", err)
252 errCh <- err
253 return gst.FlowError
254 }
255 }
256 }
257 }
258
259 return gst.FlowOK
260 },
261 })
262 }(i)
263 }
264
265 go func() {
266 if err := media.HandleBusMessages(ctx, pipeline); err != nil {
267 log.Log(ctx, "pipeline error", "error", err)
268 }
269 cancel()
270 }()
271
272 if err = pipeline.SetState(gst.StatePlaying); err != nil {
273 return err
274 }
275 if w.Viewers > 0 {
276 whepG, ctx := errgroup.WithContext(ctx)
277 for i := 0; i < w.Count; i++ {
278 did := conns[i].did
279 w := &WHEPClient{
280 Endpoint: fmt.Sprintf("%s/api/playback/%s/webrtc", w.Endpoint, did),
281 Count: w.Viewers,
282 }
283 whepG.Go(func() error {
284 return w.WHEP(ctx)
285 })
286 }
287 if err := whepG.Wait(); err != nil {
288 return err
289 }
290 }
291
292 <-ctx.Done()
293 err = pipeline.BlockSetState(gst.StateNull)
294 if err != nil {
295 return err
296 }
297
298 select {
299 case err := <-errCh:
300 return err
301 case <-ctx.Done():
302 return ctx.Err()
303 }
304}
305
306func (w *WHIPClient) StartWHIPConnection(ctx context.Context, streamKey string, did string) (*WHIPConnection, error) {
307
308 // Prepare the configuration
309 config := webrtc.Configuration{}
310
311 // Create a new RTCPeerConnection
312 peerConnection, err := webrtc.NewPeerConnection(config)
313 if err != nil {
314 return nil, err
315 }
316
317 // Create a audio track
318 audioTrack, err := webrtc.NewTrackLocalStaticSample(webrtc.RTPCodecCapability{MimeType: "audio/opus"}, "audio", "pion1")
319 if err != nil {
320 return nil, err
321 }
322 _, err = peerConnection.AddTrack(audioTrack)
323 if err != nil {
324 return nil, err
325 }
326
327 // Create a video track
328 videoTrack, err := webrtc.NewTrackLocalStaticSample(webrtc.RTPCodecCapability{MimeType: "video/h264"}, "video", "pion2")
329 if err != nil {
330 return nil, err
331 }
332 _, err = peerConnection.AddTrack(videoTrack)
333 if err != nil {
334 return nil, err
335 }
336
337 // Create an offer
338 offer, err := peerConnection.CreateOffer(nil)
339 if err != nil {
340 return nil, err
341 }
342
343 // Set the generated offer as our LocalDescription
344 err = peerConnection.SetLocalDescription(offer)
345 if err != nil {
346 return nil, err
347 }
348
349 // Wait for ICE gathering to complete
350 // gatherComplete := webrtc.GatheringCompletePromise(peerConnection)
351 // <-gatherComplete
352
353 // Create HTTP client and prepare the request
354 client := &http.Client{}
355
356 // Send the WHIP request to the server
357 req, err := http.NewRequest("POST", w.Endpoint, strings.NewReader(offer.SDP))
358 if err != nil {
359 return nil, err
360 }
361 req.Header.Set("Authorization", "Bearer "+streamKey)
362 req.Header.Set("Content-Type", "application/sdp")
363
364 // Execute the request
365 resp, err := client.Do(req)
366 if err != nil {
367 return nil, err
368 }
369 defer resp.Body.Close()
370
371 // Read and process the answer
372 answerBytes, err := io.ReadAll(resp.Body)
373 if err != nil {
374 return nil, err
375 }
376
377 // Parse the SDP answer
378 var answer webrtc.SessionDescription
379 answer.Type = webrtc.SDPTypeAnswer
380 answer.SDP = string(answerBytes)
381
382 // Apply the answer as remote description
383 err = peerConnection.SetRemoteDescription(answer)
384 if err != nil {
385 return nil, err
386 }
387
388 gatherComplete := webrtc.GatheringCompletePromise(peerConnection)
389 <-gatherComplete
390
391 conn := &WHIPConnection{
392 peerConnection: peerConnection,
393 audioTrack: audioTrack,
394 videoTrack: videoTrack,
395 did: did,
396 }
397
398 return conn, nil
399}