Live video on the AT Protocol
1package cmd
2
3import (
4 "context"
5 "fmt"
6 "io"
7 "net/http"
8 "strings"
9 "sync"
10 "time"
11
12 "github.com/pion/webrtc/v4"
13 "golang.org/x/sync/errgroup"
14 "stream.place/streamplace/pkg/log"
15)
16
17func WHEP(ctx context.Context, count int, duration time.Duration, endpoint string) error {
18 if duration > 0 {
19 var cancel context.CancelFunc
20 ctx, cancel = context.WithTimeout(ctx, duration)
21 defer cancel()
22 }
23
24 w := &WHEPClient{
25 Endpoint: endpoint,
26 Count: count,
27 }
28
29 return w.WHEP(ctx)
30}
31
32type WHEPClient struct {
33 StreamKey string
34 File string
35 Endpoint string
36 Count int
37 FreezeAfter time.Duration
38 Stats []map[string]*TrackStats
39}
40
41type WHEPConnection struct {
42 peerConnection *webrtc.PeerConnection
43 audioTrack *webrtc.TrackLocalStaticSample
44 videoTrack *webrtc.TrackLocalStaticSample
45 did string
46 Done func() <-chan struct{}
47}
48
49type TrackStats struct {
50 Total int
51 lastTotal int
52 lastUpdate time.Time
53 mu sync.Mutex
54}
55
56func (w *WHEPClient) StartWHEPConnection(ctx context.Context, stats map[string]*TrackStats) (*WHEPConnection, error) {
57
58 // Prepare the configuration
59 config := webrtc.Configuration{}
60
61 // Create a new RTCPeerConnection
62 peerConnection, err := webrtc.NewPeerConnection(config)
63 if err != nil {
64 return nil, err
65 }
66
67 // Create a ticker to print combined bitrate every 5 seconds
68 ticker := time.NewTicker(5 * time.Second)
69
70 // Start a goroutine to print combined bitrate
71 go func() {
72 for {
73 select {
74 case <-ticker.C:
75 currentTime := time.Now()
76
77 // Lock both stats to get a consistent snapshot
78 for _, s := range stats {
79 s.mu.Lock()
80 }
81
82 videoStats := stats["video"]
83 audioStats := stats["audio"]
84
85 videoElapsed := currentTime.Sub(videoStats.lastUpdate).Seconds()
86 audioElapsed := currentTime.Sub(audioStats.lastUpdate).Seconds()
87
88 videoBytes := videoStats.Total - videoStats.lastTotal
89 audioBytes := audioStats.Total - audioStats.lastTotal
90
91 videoBitrate := float64(videoBytes) * 8 / videoElapsed / 1000 // kbps
92 audioBitrate := float64(audioBytes) * 8 / audioElapsed / 1000 // kbps
93
94 log.Log(ctx, "bitrate stats",
95 "video", fmt.Sprintf("%.2f kbps (%.2f KB)", videoBitrate, float64(videoBytes)/1000),
96 "audio", fmt.Sprintf("%.2f kbps (%.2f KB)", audioBitrate, float64(audioBytes)/1000),
97 "total", fmt.Sprintf("%.2f kbps", videoBitrate+audioBitrate))
98
99 // Update last values
100 videoStats.lastTotal = videoStats.Total
101 videoStats.lastUpdate = currentTime
102 audioStats.lastTotal = audioStats.Total
103 audioStats.lastUpdate = currentTime
104
105 // Unlock stats
106 for _, s := range stats {
107 s.mu.Unlock()
108 }
109
110 case <-ctx.Done():
111 ticker.Stop()
112 return
113 }
114 }
115 }()
116
117 go func() {
118 ctx, cancel := context.WithCancel(ctx)
119 peerConnection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
120 log.Log(ctx, "track received", "track", track.ID())
121
122 // Determine track type
123 trackType := "video"
124 if track.Kind() == webrtc.RTPCodecTypeAudio {
125 trackType = "audio"
126 }
127
128 trackStat := stats[trackType]
129
130 for {
131 if ctx.Err() != nil {
132 return
133 }
134 rtp, _, err := track.ReadRTP()
135 if err != nil {
136 log.Log(ctx, "error reading RTP", "error", err)
137 cancel()
138 return
139 }
140
141 trackStat.mu.Lock()
142 trackStat.Total += len(rtp.Payload)
143 trackStat.mu.Unlock()
144 }
145 })
146 peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) {
147 log.Log(ctx, "WHEP connection State has changed", "state", connectionState.String())
148 for _, state := range failureStates {
149 if connectionState == state {
150 log.Log(ctx, "connection failed, cancelling")
151 cancel()
152 }
153 }
154 })
155
156 <-ctx.Done()
157 peerConnection.Close()
158 }()
159 peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) {
160 log.Log(ctx, "ICE candidate", "candidate", candidate)
161 })
162 if _, err := peerConnection.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, webrtc.RTPTransceiverInit{
163 Direction: webrtc.RTPTransceiverDirectionRecvonly,
164 }); err != nil {
165 return nil, fmt.Errorf("failed to add video transceiver: %w", err)
166 }
167 if _, err := peerConnection.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RTPTransceiverInit{
168 Direction: webrtc.RTPTransceiverDirectionRecvonly,
169 }); err != nil {
170 return nil, fmt.Errorf("failed to add audio transceiver: %w", err)
171 }
172
173 // Create an offer
174 offer, err := peerConnection.CreateOffer(nil)
175 if err != nil {
176 return nil, err
177 }
178
179 // Set the generated offer as our LocalDescription
180 err = peerConnection.SetLocalDescription(offer)
181 if err != nil {
182 return nil, err
183 }
184
185 // Wait for ICE gathering to complete
186 // gatherComplete := webrtc.GatheringCompletePromise(peerConnection)
187 // <-gatherComplete
188
189 // Create HTTP client and prepare the request
190 client := &http.Client{}
191
192 // Send the WHIP request to the server
193 req, err := http.NewRequest("POST", w.Endpoint, strings.NewReader(offer.SDP))
194 if err != nil {
195 return nil, err
196 }
197 req.Header.Set("Content-Type", "application/sdp")
198
199 // Execute the request
200 resp, err := client.Do(req)
201 if err != nil {
202 return nil, err
203 }
204 defer resp.Body.Close()
205 if resp.StatusCode != 201 {
206 return nil, fmt.Errorf("status code: %d", resp.StatusCode)
207 }
208
209 // Read and process the answer
210 answerBytes, err := io.ReadAll(resp.Body)
211 if err != nil {
212 return nil, err
213 }
214
215 // Parse the SDP answer
216 var answer webrtc.SessionDescription
217 answer.Type = webrtc.SDPTypeAnswer
218 answer.SDP = string(answerBytes)
219
220 // Apply the answer as remote description
221 err = peerConnection.SetRemoteDescription(answer)
222 if err != nil {
223 return nil, err
224 }
225
226 gatherComplete := webrtc.GatheringCompletePromise(peerConnection)
227 <-gatherComplete
228
229 conn := &WHEPConnection{
230 peerConnection: peerConnection,
231 Done: ctx.Done,
232 }
233
234 return conn, nil
235}
236
237func (w *WHEPClient) WHEP(ctx context.Context) error {
238 w.Stats = []map[string]*TrackStats{}
239 ctx, cancel := context.WithCancel(ctx)
240 defer cancel()
241
242 conns := make([]*WHEPConnection, w.Count)
243 g := &errgroup.Group{}
244 for i := 0; i < w.Count; i++ {
245 stats := map[string]*TrackStats{
246 "video": {lastUpdate: time.Now()},
247 "audio": {lastUpdate: time.Now()},
248 }
249 w.Stats = append(w.Stats, stats)
250 g.Go(func() error {
251 conn, err := w.StartWHEPConnection(ctx, stats)
252 if err != nil {
253 return err
254 }
255 conns[i] = conn
256
257 <-conn.Done()
258
259 return nil
260 })
261 }
262
263 err := g.Wait()
264 if err != nil {
265 return err
266 }
267 // if err := g.Wait(); err != nil {
268 // if err := g.Wait(); err != nil {
269 // return err
270 // }
271 // // Start a ticker to print elapsed duration every second
272 // go func() {
273 // ticker := time.NewTicker(time.Second)
274 // defer ticker.Stop()
275
276 // for {
277 // <-ticker.C
278 // for i, duration := range accumulators {
279 // trackType := "video"
280 // if i == 1 {
281 // trackType = "audio"
282 // }
283 // target := startTime.Add(time.Duration(accumulators[i]))
284 // diff := time.Since(target)
285 // log.Debug(ctx, "elapsed duration", "track", trackType, "duration", duration, "diff", diff)
286 // }
287 // }
288 // }()
289
290 // errCh := make(chan error, 1)
291
292 // for i, _ := range sinks {
293 // func(i int) {
294 // sink := sinks[i]
295 // trackType := "video"
296 // if i == 1 {
297 // trackType = "audio"
298 // }
299
300 // sink.SetCallbacks(&app.SinkCallbacks{
301 // NewSampleFunc: func(sink *app.Sink) gst.FlowReturn {
302
303 // sample := sink.PullSample()
304 // if sample == nil {
305 // return gst.FlowEOS
306 // }
307
308 // buffer := sample.GetBuffer()
309 // if buffer == nil {
310 // return gst.FlowError
311 // }
312
313 // samples := buffer.Map(gst.MapRead).Bytes()
314 // defer buffer.Unmap()
315
316 // durationPtr := buffer.Duration().AsDuration()
317 // var duration time.Duration
318 // if durationPtr == nil {
319 // errCh <- fmt.Errorf("%v duration: nil", trackType)
320 // return gst.FlowError
321 // } else {
322 // // fmt.Printf("%v duration: %v\n", trackType, *durationPtr)
323 // duration = *durationPtr
324 // }
325
326 // accumulators[i] += duration
327
328 // if w.FreezeAfter == 0 || time.Since(startTime) < w.FreezeAfter {
329 // for _, conn := range conns {
330 // if trackType == "video" {
331 // if err := conn.videoTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil {
332 // log.Log(ctx, "error writing video sample", "error", err)
333 // errCh <- err
334 // return gst.FlowError
335 // }
336 // } else {
337 // if err := conn.audioTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil {
338 // log.Log(ctx, "error writing video sample", "error", err)
339 // errCh <- err
340 // return gst.FlowError
341 // }
342 // }
343 // }
344 // }
345
346 // return gst.FlowOK
347 // },
348 // })
349 // }(i)
350 // }
351
352 // go func() {
353 // media.HandleBusMessages(ctx, pipeline)
354 // cancel()
355 // }()
356
357 // if err = pipeline.SetState(gst.StatePlaying); err != nil {
358 // return err
359 // }
360 // select {
361 // case err := <-errCh:
362 // return err
363 // case <-ctx.Done():
364 // return ctx.Err()
365 // }
366
367 return nil
368}