Live video on the AT Protocol
at eli/nonfatal-errors 368 lines 9.1 kB view raw
1package cmd 2 3import ( 4 "context" 5 "fmt" 6 "io" 7 "net/http" 8 "strings" 9 "sync" 10 "time" 11 12 "github.com/pion/webrtc/v4" 13 "golang.org/x/sync/errgroup" 14 "stream.place/streamplace/pkg/log" 15) 16 17func WHEP(ctx context.Context, count int, duration time.Duration, endpoint string) error { 18 if duration > 0 { 19 var cancel context.CancelFunc 20 ctx, cancel = context.WithTimeout(ctx, duration) 21 defer cancel() 22 } 23 24 w := &WHEPClient{ 25 Endpoint: endpoint, 26 Count: count, 27 } 28 29 return w.WHEP(ctx) 30} 31 32type WHEPClient struct { 33 StreamKey string 34 File string 35 Endpoint string 36 Count int 37 FreezeAfter time.Duration 38 Stats []map[string]*TrackStats 39} 40 41type WHEPConnection struct { 42 peerConnection *webrtc.PeerConnection 43 audioTrack *webrtc.TrackLocalStaticSample 44 videoTrack *webrtc.TrackLocalStaticSample 45 did string 46 Done func() <-chan struct{} 47} 48 49type TrackStats struct { 50 Total int 51 lastTotal int 52 lastUpdate time.Time 53 mu sync.Mutex 54} 55 56func (w *WHEPClient) StartWHEPConnection(ctx context.Context, stats map[string]*TrackStats) (*WHEPConnection, error) { 57 58 // Prepare the configuration 59 config := webrtc.Configuration{} 60 61 // Create a new RTCPeerConnection 62 peerConnection, err := webrtc.NewPeerConnection(config) 63 if err != nil { 64 return nil, err 65 } 66 67 // Create a ticker to print combined bitrate every 5 seconds 68 ticker := time.NewTicker(5 * time.Second) 69 70 // Start a goroutine to print combined bitrate 71 go func() { 72 for { 73 select { 74 case <-ticker.C: 75 currentTime := time.Now() 76 77 // Lock both stats to get a consistent snapshot 78 for _, s := range stats { 79 s.mu.Lock() 80 } 81 82 videoStats := stats["video"] 83 audioStats := stats["audio"] 84 85 videoElapsed := currentTime.Sub(videoStats.lastUpdate).Seconds() 86 audioElapsed := currentTime.Sub(audioStats.lastUpdate).Seconds() 87 88 videoBytes := videoStats.Total - videoStats.lastTotal 89 audioBytes := audioStats.Total - audioStats.lastTotal 90 91 videoBitrate := float64(videoBytes) * 8 / videoElapsed / 1000 // kbps 92 audioBitrate := float64(audioBytes) * 8 / audioElapsed / 1000 // kbps 93 94 log.Log(ctx, "bitrate stats", 95 "video", fmt.Sprintf("%.2f kbps (%.2f KB)", videoBitrate, float64(videoBytes)/1000), 96 "audio", fmt.Sprintf("%.2f kbps (%.2f KB)", audioBitrate, float64(audioBytes)/1000), 97 "total", fmt.Sprintf("%.2f kbps", videoBitrate+audioBitrate)) 98 99 // Update last values 100 videoStats.lastTotal = videoStats.Total 101 videoStats.lastUpdate = currentTime 102 audioStats.lastTotal = audioStats.Total 103 audioStats.lastUpdate = currentTime 104 105 // Unlock stats 106 for _, s := range stats { 107 s.mu.Unlock() 108 } 109 110 case <-ctx.Done(): 111 ticker.Stop() 112 return 113 } 114 } 115 }() 116 117 go func() { 118 ctx, cancel := context.WithCancel(ctx) 119 peerConnection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { 120 log.Log(ctx, "track received", "track", track.ID()) 121 122 // Determine track type 123 trackType := "video" 124 if track.Kind() == webrtc.RTPCodecTypeAudio { 125 trackType = "audio" 126 } 127 128 trackStat := stats[trackType] 129 130 for { 131 if ctx.Err() != nil { 132 return 133 } 134 rtp, _, err := track.ReadRTP() 135 if err != nil { 136 log.Log(ctx, "error reading RTP", "error", err) 137 cancel() 138 return 139 } 140 141 trackStat.mu.Lock() 142 trackStat.Total += len(rtp.Payload) 143 trackStat.mu.Unlock() 144 } 145 }) 146 peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) { 147 log.Log(ctx, "WHEP connection State has changed", "state", connectionState.String()) 148 for _, state := range failureStates { 149 if connectionState == state { 150 log.Log(ctx, "connection failed, cancelling") 151 cancel() 152 } 153 } 154 }) 155 156 <-ctx.Done() 157 peerConnection.Close() 158 }() 159 peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) { 160 log.Log(ctx, "ICE candidate", "candidate", candidate) 161 }) 162 if _, err := peerConnection.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, webrtc.RTPTransceiverInit{ 163 Direction: webrtc.RTPTransceiverDirectionRecvonly, 164 }); err != nil { 165 return nil, fmt.Errorf("failed to add video transceiver: %w", err) 166 } 167 if _, err := peerConnection.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RTPTransceiverInit{ 168 Direction: webrtc.RTPTransceiverDirectionRecvonly, 169 }); err != nil { 170 return nil, fmt.Errorf("failed to add audio transceiver: %w", err) 171 } 172 173 // Create an offer 174 offer, err := peerConnection.CreateOffer(nil) 175 if err != nil { 176 return nil, err 177 } 178 179 // Set the generated offer as our LocalDescription 180 err = peerConnection.SetLocalDescription(offer) 181 if err != nil { 182 return nil, err 183 } 184 185 // Wait for ICE gathering to complete 186 // gatherComplete := webrtc.GatheringCompletePromise(peerConnection) 187 // <-gatherComplete 188 189 // Create HTTP client and prepare the request 190 client := &http.Client{} 191 192 // Send the WHIP request to the server 193 req, err := http.NewRequest("POST", w.Endpoint, strings.NewReader(offer.SDP)) 194 if err != nil { 195 return nil, err 196 } 197 req.Header.Set("Content-Type", "application/sdp") 198 199 // Execute the request 200 resp, err := client.Do(req) 201 if err != nil { 202 return nil, err 203 } 204 defer resp.Body.Close() 205 if resp.StatusCode != 201 { 206 return nil, fmt.Errorf("status code: %d", resp.StatusCode) 207 } 208 209 // Read and process the answer 210 answerBytes, err := io.ReadAll(resp.Body) 211 if err != nil { 212 return nil, err 213 } 214 215 // Parse the SDP answer 216 var answer webrtc.SessionDescription 217 answer.Type = webrtc.SDPTypeAnswer 218 answer.SDP = string(answerBytes) 219 220 // Apply the answer as remote description 221 err = peerConnection.SetRemoteDescription(answer) 222 if err != nil { 223 return nil, err 224 } 225 226 gatherComplete := webrtc.GatheringCompletePromise(peerConnection) 227 <-gatherComplete 228 229 conn := &WHEPConnection{ 230 peerConnection: peerConnection, 231 Done: ctx.Done, 232 } 233 234 return conn, nil 235} 236 237func (w *WHEPClient) WHEP(ctx context.Context) error { 238 w.Stats = []map[string]*TrackStats{} 239 ctx, cancel := context.WithCancel(ctx) 240 defer cancel() 241 242 conns := make([]*WHEPConnection, w.Count) 243 g := &errgroup.Group{} 244 for i := 0; i < w.Count; i++ { 245 stats := map[string]*TrackStats{ 246 "video": {lastUpdate: time.Now()}, 247 "audio": {lastUpdate: time.Now()}, 248 } 249 w.Stats = append(w.Stats, stats) 250 g.Go(func() error { 251 conn, err := w.StartWHEPConnection(ctx, stats) 252 if err != nil { 253 return err 254 } 255 conns[i] = conn 256 257 <-conn.Done() 258 259 return nil 260 }) 261 } 262 263 err := g.Wait() 264 if err != nil { 265 return err 266 } 267 // if err := g.Wait(); err != nil { 268 // if err := g.Wait(); err != nil { 269 // return err 270 // } 271 // // Start a ticker to print elapsed duration every second 272 // go func() { 273 // ticker := time.NewTicker(time.Second) 274 // defer ticker.Stop() 275 276 // for { 277 // <-ticker.C 278 // for i, duration := range accumulators { 279 // trackType := "video" 280 // if i == 1 { 281 // trackType = "audio" 282 // } 283 // target := startTime.Add(time.Duration(accumulators[i])) 284 // diff := time.Since(target) 285 // log.Debug(ctx, "elapsed duration", "track", trackType, "duration", duration, "diff", diff) 286 // } 287 // } 288 // }() 289 290 // errCh := make(chan error, 1) 291 292 // for i, _ := range sinks { 293 // func(i int) { 294 // sink := sinks[i] 295 // trackType := "video" 296 // if i == 1 { 297 // trackType = "audio" 298 // } 299 300 // sink.SetCallbacks(&app.SinkCallbacks{ 301 // NewSampleFunc: func(sink *app.Sink) gst.FlowReturn { 302 303 // sample := sink.PullSample() 304 // if sample == nil { 305 // return gst.FlowEOS 306 // } 307 308 // buffer := sample.GetBuffer() 309 // if buffer == nil { 310 // return gst.FlowError 311 // } 312 313 // samples := buffer.Map(gst.MapRead).Bytes() 314 // defer buffer.Unmap() 315 316 // durationPtr := buffer.Duration().AsDuration() 317 // var duration time.Duration 318 // if durationPtr == nil { 319 // errCh <- fmt.Errorf("%v duration: nil", trackType) 320 // return gst.FlowError 321 // } else { 322 // // fmt.Printf("%v duration: %v\n", trackType, *durationPtr) 323 // duration = *durationPtr 324 // } 325 326 // accumulators[i] += duration 327 328 // if w.FreezeAfter == 0 || time.Since(startTime) < w.FreezeAfter { 329 // for _, conn := range conns { 330 // if trackType == "video" { 331 // if err := conn.videoTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil { 332 // log.Log(ctx, "error writing video sample", "error", err) 333 // errCh <- err 334 // return gst.FlowError 335 // } 336 // } else { 337 // if err := conn.audioTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil { 338 // log.Log(ctx, "error writing video sample", "error", err) 339 // errCh <- err 340 // return gst.FlowError 341 // } 342 // } 343 // } 344 // } 345 346 // return gst.FlowOK 347 // }, 348 // }) 349 // }(i) 350 // } 351 352 // go func() { 353 // media.HandleBusMessages(ctx, pipeline) 354 // cancel() 355 // }() 356 357 // if err = pipeline.SetState(gst.StatePlaying); err != nil { 358 // return err 359 // } 360 // select { 361 // case err := <-errCh: 362 // return err 363 // case <-ctx.Done(): 364 // return ctx.Err() 365 // } 366 367 return nil 368}