Live video on the AT Protocol
at natb/rust-testing 386 lines 9.6 kB view raw
1package cmd 2 3import ( 4 "context" 5 "flag" 6 "fmt" 7 "io" 8 "net/http" 9 "strings" 10 "sync" 11 "time" 12 13 "github.com/pion/webrtc/v4" 14 "golang.org/x/sync/errgroup" 15 "stream.place/streamplace/pkg/log" 16) 17 18func WHEP(args []string) error { 19 fs := flag.NewFlagSet("whep", flag.ExitOnError) 20 count := fs.Int("count", 1, "number of concurrent streams (for load testing)") 21 duration := fs.Duration("duration", 0, "stop after this long") 22 endpoint := fs.String("endpoint", "", "endpoint to send the WHEP request to") 23 err := fs.Parse(args) 24 25 if err != nil { 26 return err 27 } 28 29 ctx := context.Background() 30 if *duration > 0 { 31 var cancel context.CancelFunc 32 ctx, cancel = context.WithTimeout(ctx, *duration) 33 defer cancel() 34 } 35 36 w := &WHEPClient{ 37 Endpoint: *endpoint, 38 Count: *count, 39 } 40 41 return w.WHEP(ctx) 42} 43 44type WHEPClient struct { 45 StreamKey string 46 File string 47 Endpoint string 48 Count int 49 FreezeAfter time.Duration 50 Stats []map[string]*TrackStats 51} 52 53type WHEPConnection struct { 54 peerConnection *webrtc.PeerConnection 55 audioTrack *webrtc.TrackLocalStaticSample 56 videoTrack *webrtc.TrackLocalStaticSample 57 did string 58 Done func() <-chan struct{} 59} 60 61type TrackStats struct { 62 Total int 63 lastTotal int 64 lastUpdate time.Time 65 mu sync.Mutex 66} 67 68func (w *WHEPClient) StartWHEPConnection(ctx context.Context, stats map[string]*TrackStats) (*WHEPConnection, error) { 69 70 // Prepare the configuration with STUN servers for ICE connectivity 71 config := webrtc.Configuration{ 72 ICEServers: []webrtc.ICEServer{ 73 { 74 URLs: []string{"stun:stun.l.google.com:19302"}, 75 }, 76 }, 77 } 78 79 // Create a new RTCPeerConnection 80 peerConnection, err := webrtc.NewPeerConnection(config) 81 if err != nil { 82 return nil, err 83 } 84 85 // Create a ticker to print combined bitrate every 5 seconds 86 ticker := time.NewTicker(5 * time.Second) 87 88 // Start a goroutine to print combined bitrate 89 go func() { 90 for { 91 select { 92 case <-ticker.C: 93 currentTime := time.Now() 94 95 // Lock both stats to get a consistent snapshot 96 for _, s := range stats { 97 s.mu.Lock() 98 } 99 100 videoStats := stats["video"] 101 audioStats := stats["audio"] 102 103 videoElapsed := currentTime.Sub(videoStats.lastUpdate).Seconds() 104 audioElapsed := currentTime.Sub(audioStats.lastUpdate).Seconds() 105 106 videoBytes := videoStats.Total - videoStats.lastTotal 107 audioBytes := audioStats.Total - audioStats.lastTotal 108 109 videoBitrate := float64(videoBytes) * 8 / videoElapsed / 1000 // kbps 110 audioBitrate := float64(audioBytes) * 8 / audioElapsed / 1000 // kbps 111 112 log.Log(ctx, "bitrate stats", 113 "video", fmt.Sprintf("%.2f kbps (%.2f KB)", videoBitrate, float64(videoBytes)/1000), 114 "audio", fmt.Sprintf("%.2f kbps (%.2f KB)", audioBitrate, float64(audioBytes)/1000), 115 "total", fmt.Sprintf("%.2f kbps", videoBitrate+audioBitrate)) 116 117 // Update last values 118 videoStats.lastTotal = videoStats.Total 119 videoStats.lastUpdate = currentTime 120 audioStats.lastTotal = audioStats.Total 121 audioStats.lastUpdate = currentTime 122 123 // Unlock stats 124 for _, s := range stats { 125 s.mu.Unlock() 126 } 127 128 case <-ctx.Done(): 129 ticker.Stop() 130 return 131 } 132 } 133 }() 134 135 go func() { 136 ctx, cancel := context.WithCancel(ctx) 137 peerConnection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { 138 log.Log(ctx, "track received", "track", track.ID()) 139 140 // Determine track type 141 trackType := "video" 142 if track.Kind() == webrtc.RTPCodecTypeAudio { 143 trackType = "audio" 144 } 145 146 trackStat := stats[trackType] 147 148 for { 149 if ctx.Err() != nil { 150 return 151 } 152 rtp, _, err := track.ReadRTP() 153 if err != nil { 154 log.Log(ctx, "error reading RTP", "error", err) 155 cancel() 156 return 157 } 158 159 trackStat.mu.Lock() 160 trackStat.Total += len(rtp.Payload) 161 trackStat.mu.Unlock() 162 } 163 }) 164 peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) { 165 log.Log(ctx, "WHEP connection State has changed", "state", connectionState.String()) 166 for _, state := range failureStates { 167 if connectionState == state { 168 log.Log(ctx, "connection failed, cancelling") 169 cancel() 170 } 171 } 172 }) 173 174 <-ctx.Done() 175 peerConnection.Close() 176 }() 177 peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) { 178 log.Log(ctx, "ICE candidate", "candidate", candidate) 179 }) 180 if _, err := peerConnection.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, webrtc.RTPTransceiverInit{ 181 Direction: webrtc.RTPTransceiverDirectionRecvonly, 182 }); err != nil { 183 return nil, fmt.Errorf("failed to add video transceiver: %w", err) 184 } 185 if _, err := peerConnection.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RTPTransceiverInit{ 186 Direction: webrtc.RTPTransceiverDirectionRecvonly, 187 }); err != nil { 188 return nil, fmt.Errorf("failed to add audio transceiver: %w", err) 189 } 190 191 // Create an offer 192 offer, err := peerConnection.CreateOffer(nil) 193 if err != nil { 194 return nil, err 195 } 196 197 // Set the generated offer as our LocalDescription 198 err = peerConnection.SetLocalDescription(offer) 199 if err != nil { 200 return nil, err 201 } 202 203 // Wait for ICE gathering to complete 204 // gatherComplete := webrtc.GatheringCompletePromise(peerConnection) 205 // <-gatherComplete 206 207 // Create HTTP client and prepare the request 208 client := &http.Client{} 209 210 // Send the WHIP request to the server 211 req, err := http.NewRequest("POST", w.Endpoint, strings.NewReader(offer.SDP)) 212 if err != nil { 213 return nil, err 214 } 215 req.Header.Set("Content-Type", "application/sdp") 216 217 // Execute the request 218 resp, err := client.Do(req) 219 if err != nil { 220 return nil, err 221 } 222 defer resp.Body.Close() 223 if resp.StatusCode != 201 { 224 return nil, fmt.Errorf("status code: %d", resp.StatusCode) 225 } 226 227 // Read and process the answer 228 answerBytes, err := io.ReadAll(resp.Body) 229 if err != nil { 230 return nil, err 231 } 232 233 // Parse the SDP answer 234 var answer webrtc.SessionDescription 235 answer.Type = webrtc.SDPTypeAnswer 236 answer.SDP = string(answerBytes) 237 238 // Apply the answer as remote description 239 err = peerConnection.SetRemoteDescription(answer) 240 if err != nil { 241 return nil, err 242 } 243 244 gatherComplete := webrtc.GatheringCompletePromise(peerConnection) 245 <-gatherComplete 246 247 conn := &WHEPConnection{ 248 peerConnection: peerConnection, 249 Done: ctx.Done, 250 } 251 252 return conn, nil 253} 254 255func (w *WHEPClient) WHEP(ctx context.Context) error { 256 w.Stats = []map[string]*TrackStats{} 257 ctx, cancel := context.WithCancel(ctx) 258 defer cancel() 259 260 conns := make([]*WHEPConnection, w.Count) 261 g := &errgroup.Group{} 262 for i := 0; i < w.Count; i++ { 263 stats := map[string]*TrackStats{ 264 "video": {lastUpdate: time.Now()}, 265 "audio": {lastUpdate: time.Now()}, 266 } 267 w.Stats = append(w.Stats, stats) 268 g.Go(func() error { 269 conn, err := w.StartWHEPConnection(ctx, stats) 270 if err != nil { 271 return err 272 } 273 conns[i] = conn 274 275 <-conn.Done() 276 277 return nil 278 }) 279 } 280 281 err := g.Wait() 282 if err != nil { 283 return err 284 } 285 // if err := g.Wait(); err != nil { 286 // if err := g.Wait(); err != nil { 287 // return err 288 // } 289 // // Start a ticker to print elapsed duration every second 290 // go func() { 291 // ticker := time.NewTicker(time.Second) 292 // defer ticker.Stop() 293 294 // for { 295 // <-ticker.C 296 // for i, duration := range accumulators { 297 // trackType := "video" 298 // if i == 1 { 299 // trackType = "audio" 300 // } 301 // target := startTime.Add(time.Duration(accumulators[i])) 302 // diff := time.Since(target) 303 // log.Debug(ctx, "elapsed duration", "track", trackType, "duration", duration, "diff", diff) 304 // } 305 // } 306 // }() 307 308 // errCh := make(chan error, 1) 309 310 // for i, _ := range sinks { 311 // func(i int) { 312 // sink := sinks[i] 313 // trackType := "video" 314 // if i == 1 { 315 // trackType = "audio" 316 // } 317 318 // sink.SetCallbacks(&app.SinkCallbacks{ 319 // NewSampleFunc: func(sink *app.Sink) gst.FlowReturn { 320 321 // sample := sink.PullSample() 322 // if sample == nil { 323 // return gst.FlowEOS 324 // } 325 326 // buffer := sample.GetBuffer() 327 // if buffer == nil { 328 // return gst.FlowError 329 // } 330 331 // samples := buffer.Map(gst.MapRead).Bytes() 332 // defer buffer.Unmap() 333 334 // durationPtr := buffer.Duration().AsDuration() 335 // var duration time.Duration 336 // if durationPtr == nil { 337 // errCh <- fmt.Errorf("%v duration: nil", trackType) 338 // return gst.FlowError 339 // } else { 340 // // fmt.Printf("%v duration: %v\n", trackType, *durationPtr) 341 // duration = *durationPtr 342 // } 343 344 // accumulators[i] += duration 345 346 // if w.FreezeAfter == 0 || time.Since(startTime) < w.FreezeAfter { 347 // for _, conn := range conns { 348 // if trackType == "video" { 349 // if err := conn.videoTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil { 350 // log.Log(ctx, "error writing video sample", "error", err) 351 // errCh <- err 352 // return gst.FlowError 353 // } 354 // } else { 355 // if err := conn.audioTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil { 356 // log.Log(ctx, "error writing video sample", "error", err) 357 // errCh <- err 358 // return gst.FlowError 359 // } 360 // } 361 // } 362 // } 363 364 // return gst.FlowOK 365 // }, 366 // }) 367 // }(i) 368 // } 369 370 // go func() { 371 // media.HandleBusMessages(ctx, pipeline) 372 // cancel() 373 // }() 374 375 // if err = pipeline.SetState(gst.StatePlaying); err != nil { 376 // return err 377 // } 378 // select { 379 // case err := <-errCh: 380 // return err 381 // case <-ctx.Done(): 382 // return ctx.Err() 383 // } 384 385 return nil 386}