Live video on the AT Protocol
at eli/oatproxy-http-client 380 lines 9.4 kB view raw
1package cmd 2 3import ( 4 "context" 5 "flag" 6 "fmt" 7 "io" 8 "net/http" 9 "strings" 10 "sync" 11 "time" 12 13 "github.com/pion/webrtc/v4" 14 "golang.org/x/sync/errgroup" 15 "stream.place/streamplace/pkg/log" 16) 17 18func WHEP(args []string) error { 19 fs := flag.NewFlagSet("whep", flag.ExitOnError) 20 count := fs.Int("count", 1, "number of concurrent streams (for load testing)") 21 duration := fs.Duration("duration", 0, "stop after this long") 22 endpoint := fs.String("endpoint", "", "endpoint to send the WHEP request to") 23 err := fs.Parse(args) 24 25 if err != nil { 26 return err 27 } 28 29 ctx := context.Background() 30 if *duration > 0 { 31 var cancel context.CancelFunc 32 ctx, cancel = context.WithTimeout(ctx, *duration) 33 defer cancel() 34 } 35 36 w := &WHEPClient{ 37 Endpoint: *endpoint, 38 Count: *count, 39 } 40 41 return w.WHEP(ctx) 42} 43 44type WHEPClient struct { 45 StreamKey string 46 File string 47 Endpoint string 48 Count int 49 FreezeAfter time.Duration 50 Stats []map[string]*TrackStats 51} 52 53type WHEPConnection struct { 54 peerConnection *webrtc.PeerConnection 55 audioTrack *webrtc.TrackLocalStaticSample 56 videoTrack *webrtc.TrackLocalStaticSample 57 did string 58 Done func() <-chan struct{} 59} 60 61type TrackStats struct { 62 Total int 63 lastTotal int 64 lastUpdate time.Time 65 mu sync.Mutex 66} 67 68func (w *WHEPClient) StartWHEPConnection(ctx context.Context, stats map[string]*TrackStats) (*WHEPConnection, error) { 69 70 // Prepare the configuration 71 config := webrtc.Configuration{} 72 73 // Create a new RTCPeerConnection 74 peerConnection, err := webrtc.NewPeerConnection(config) 75 if err != nil { 76 return nil, err 77 } 78 79 // Create a ticker to print combined bitrate every 5 seconds 80 ticker := time.NewTicker(5 * time.Second) 81 82 // Start a goroutine to print combined bitrate 83 go func() { 84 for { 85 select { 86 case <-ticker.C: 87 currentTime := time.Now() 88 89 // Lock both stats to get a consistent snapshot 90 for _, s := range stats { 91 s.mu.Lock() 92 } 93 94 videoStats := stats["video"] 95 audioStats := stats["audio"] 96 97 videoElapsed := currentTime.Sub(videoStats.lastUpdate).Seconds() 98 audioElapsed := currentTime.Sub(audioStats.lastUpdate).Seconds() 99 100 videoBytes := videoStats.Total - videoStats.lastTotal 101 audioBytes := audioStats.Total - audioStats.lastTotal 102 103 videoBitrate := float64(videoBytes) * 8 / videoElapsed / 1000 // kbps 104 audioBitrate := float64(audioBytes) * 8 / audioElapsed / 1000 // kbps 105 106 log.Log(ctx, "bitrate stats", 107 "video", fmt.Sprintf("%.2f kbps (%.2f KB)", videoBitrate, float64(videoBytes)/1000), 108 "audio", fmt.Sprintf("%.2f kbps (%.2f KB)", audioBitrate, float64(audioBytes)/1000), 109 "total", fmt.Sprintf("%.2f kbps", videoBitrate+audioBitrate)) 110 111 // Update last values 112 videoStats.lastTotal = videoStats.Total 113 videoStats.lastUpdate = currentTime 114 audioStats.lastTotal = audioStats.Total 115 audioStats.lastUpdate = currentTime 116 117 // Unlock stats 118 for _, s := range stats { 119 s.mu.Unlock() 120 } 121 122 case <-ctx.Done(): 123 ticker.Stop() 124 return 125 } 126 } 127 }() 128 129 go func() { 130 ctx, cancel := context.WithCancel(ctx) 131 peerConnection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { 132 log.Log(ctx, "track received", "track", track.ID()) 133 134 // Determine track type 135 trackType := "video" 136 if track.Kind() == webrtc.RTPCodecTypeAudio { 137 trackType = "audio" 138 } 139 140 trackStat := stats[trackType] 141 142 for { 143 if ctx.Err() != nil { 144 return 145 } 146 rtp, _, err := track.ReadRTP() 147 if err != nil { 148 log.Log(ctx, "error reading RTP", "error", err) 149 cancel() 150 return 151 } 152 153 trackStat.mu.Lock() 154 trackStat.Total += len(rtp.Payload) 155 trackStat.mu.Unlock() 156 } 157 }) 158 peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) { 159 log.Log(ctx, "WHEP connection State has changed", "state", connectionState.String()) 160 for _, state := range failureStates { 161 if connectionState == state { 162 log.Log(ctx, "connection failed, cancelling") 163 cancel() 164 } 165 } 166 }) 167 168 <-ctx.Done() 169 peerConnection.Close() 170 }() 171 peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) { 172 log.Log(ctx, "ICE candidate", "candidate", candidate) 173 }) 174 if _, err := peerConnection.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, webrtc.RTPTransceiverInit{ 175 Direction: webrtc.RTPTransceiverDirectionRecvonly, 176 }); err != nil { 177 return nil, fmt.Errorf("failed to add video transceiver: %w", err) 178 } 179 if _, err := peerConnection.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RTPTransceiverInit{ 180 Direction: webrtc.RTPTransceiverDirectionRecvonly, 181 }); err != nil { 182 return nil, fmt.Errorf("failed to add audio transceiver: %w", err) 183 } 184 185 // Create an offer 186 offer, err := peerConnection.CreateOffer(nil) 187 if err != nil { 188 return nil, err 189 } 190 191 // Set the generated offer as our LocalDescription 192 err = peerConnection.SetLocalDescription(offer) 193 if err != nil { 194 return nil, err 195 } 196 197 // Wait for ICE gathering to complete 198 // gatherComplete := webrtc.GatheringCompletePromise(peerConnection) 199 // <-gatherComplete 200 201 // Create HTTP client and prepare the request 202 client := &http.Client{} 203 204 // Send the WHIP request to the server 205 req, err := http.NewRequest("POST", w.Endpoint, strings.NewReader(offer.SDP)) 206 if err != nil { 207 return nil, err 208 } 209 req.Header.Set("Content-Type", "application/sdp") 210 211 // Execute the request 212 resp, err := client.Do(req) 213 if err != nil { 214 return nil, err 215 } 216 defer resp.Body.Close() 217 if resp.StatusCode != 201 { 218 return nil, fmt.Errorf("status code: %d", resp.StatusCode) 219 } 220 221 // Read and process the answer 222 answerBytes, err := io.ReadAll(resp.Body) 223 if err != nil { 224 return nil, err 225 } 226 227 // Parse the SDP answer 228 var answer webrtc.SessionDescription 229 answer.Type = webrtc.SDPTypeAnswer 230 answer.SDP = string(answerBytes) 231 232 // Apply the answer as remote description 233 err = peerConnection.SetRemoteDescription(answer) 234 if err != nil { 235 return nil, err 236 } 237 238 gatherComplete := webrtc.GatheringCompletePromise(peerConnection) 239 <-gatherComplete 240 241 conn := &WHEPConnection{ 242 peerConnection: peerConnection, 243 Done: ctx.Done, 244 } 245 246 return conn, nil 247} 248 249func (w *WHEPClient) WHEP(ctx context.Context) error { 250 w.Stats = []map[string]*TrackStats{} 251 ctx, cancel := context.WithCancel(ctx) 252 defer cancel() 253 254 conns := make([]*WHEPConnection, w.Count) 255 g := &errgroup.Group{} 256 for i := 0; i < w.Count; i++ { 257 stats := map[string]*TrackStats{ 258 "video": {lastUpdate: time.Now()}, 259 "audio": {lastUpdate: time.Now()}, 260 } 261 w.Stats = append(w.Stats, stats) 262 g.Go(func() error { 263 conn, err := w.StartWHEPConnection(ctx, stats) 264 if err != nil { 265 return err 266 } 267 conns[i] = conn 268 269 <-conn.Done() 270 271 return nil 272 }) 273 } 274 275 err := g.Wait() 276 if err != nil { 277 return err 278 } 279 // if err := g.Wait(); err != nil { 280 // if err := g.Wait(); err != nil { 281 // return err 282 // } 283 // // Start a ticker to print elapsed duration every second 284 // go func() { 285 // ticker := time.NewTicker(time.Second) 286 // defer ticker.Stop() 287 288 // for { 289 // <-ticker.C 290 // for i, duration := range accumulators { 291 // trackType := "video" 292 // if i == 1 { 293 // trackType = "audio" 294 // } 295 // target := startTime.Add(time.Duration(accumulators[i])) 296 // diff := time.Since(target) 297 // log.Debug(ctx, "elapsed duration", "track", trackType, "duration", duration, "diff", diff) 298 // } 299 // } 300 // }() 301 302 // errCh := make(chan error, 1) 303 304 // for i, _ := range sinks { 305 // func(i int) { 306 // sink := sinks[i] 307 // trackType := "video" 308 // if i == 1 { 309 // trackType = "audio" 310 // } 311 312 // sink.SetCallbacks(&app.SinkCallbacks{ 313 // NewSampleFunc: func(sink *app.Sink) gst.FlowReturn { 314 315 // sample := sink.PullSample() 316 // if sample == nil { 317 // return gst.FlowEOS 318 // } 319 320 // buffer := sample.GetBuffer() 321 // if buffer == nil { 322 // return gst.FlowError 323 // } 324 325 // samples := buffer.Map(gst.MapRead).Bytes() 326 // defer buffer.Unmap() 327 328 // durationPtr := buffer.Duration().AsDuration() 329 // var duration time.Duration 330 // if durationPtr == nil { 331 // errCh <- fmt.Errorf("%v duration: nil", trackType) 332 // return gst.FlowError 333 // } else { 334 // // fmt.Printf("%v duration: %v\n", trackType, *durationPtr) 335 // duration = *durationPtr 336 // } 337 338 // accumulators[i] += duration 339 340 // if w.FreezeAfter == 0 || time.Since(startTime) < w.FreezeAfter { 341 // for _, conn := range conns { 342 // if trackType == "video" { 343 // if err := conn.videoTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil { 344 // log.Log(ctx, "error writing video sample", "error", err) 345 // errCh <- err 346 // return gst.FlowError 347 // } 348 // } else { 349 // if err := conn.audioTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil { 350 // log.Log(ctx, "error writing video sample", "error", err) 351 // errCh <- err 352 // return gst.FlowError 353 // } 354 // } 355 // } 356 // } 357 358 // return gst.FlowOK 359 // }, 360 // }) 361 // }(i) 362 // } 363 364 // go func() { 365 // media.HandleBusMessages(ctx, pipeline) 366 // cancel() 367 // }() 368 369 // if err = pipeline.SetState(gst.StatePlaying); err != nil { 370 // return err 371 // } 372 // select { 373 // case err := <-errCh: 374 // return err 375 // case <-ctx.Done(): 376 // return ctx.Err() 377 // } 378 379 return nil 380}