websocket-based lrcproto server
at main 17 kB view raw
1package lrcd 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "github.com/gorilla/websocket" 8 "github.com/rachel-mp4/lrcproto/gen/go" 9 "google.golang.org/protobuf/proto" 10 "log" 11 "net" 12 "net/http" 13 "sync" 14 "time" 15 "unicode/utf16" 16) 17 18type Server struct { 19 secret string 20 uri string 21 eventBus chan clientEvent 22 ctx context.Context 23 cancel context.CancelFunc 24 clients map[*client]bool 25 clientsMu sync.Mutex 26 idmapsMu sync.Mutex 27 idToClient map[uint32]*client 28 lastID uint32 29 logger *log.Logger 30 debugLogger *log.Logger 31 welcomeEvt []byte 32 pongEvt []byte 33 initChan chan InitChanMsg 34 mediainitChan chan MediaInitChanMsg 35 resolver func(externalID string, ctx context.Context) *string 36 pubChan chan PubEvent 37} 38 39type PubEvent struct { 40 ID uint32 41 Body string 42} 43 44type client struct { 45 conn *websocket.Conn 46 dataChan chan []byte 47 ctx context.Context 48 cancel context.CancelFunc 49 muteMap map[*client]bool 50 mutedBy map[*client]bool 51 myIDs []uint32 52 textID *uint32 53 mediaID *uint32 54 post *string 55 nick *string 56 externID *string 57 resolvID *string 58 rcancel context.CancelFunc 59 color *uint32 60} 61 62type clientEvent struct { 63 client *client 64 event *lrcpb.Event 65} 66 67func NewServer(opts ...Option) (*Server, error) { 68 var options options 69 for _, opt := range opts { 70 err := opt(&options) 71 if err != nil { 72 return nil, err 73 } 74 } 75 76 s := Server{} 77 78 welcomeString := "Welcome to my lrc server!" 79 if options.welcome != nil { 80 welcomeString = *options.welcome 81 } 82 s.setDefaultEvents(welcomeString) 83 84 if options.writer != nil { 85 s.logger = log.New(*options.writer, "[log]", log.Ldate|log.Ltime) 86 if options.verbose { 87 s.debugLogger = log.New(*options.writer, "[debug]", log.Ldate|log.Ltime) 88 } 89 } 90 91 if options.initChan != nil { 92 s.initChan = options.initChan 93 } 94 if options.mediainitChan != nil { 95 s.mediainitChan = options.mediainitChan 96 } 97 if options.pubChan != nil { 98 s.pubChan = options.pubChan 99 } 100 if options.initialID != nil { 101 s.lastID = *options.initialID 102 } 103 s.uri = options.uri 104 s.secret = options.secret 105 s.resolver = options.resolver 106 107 s.clients = make(map[*client]bool) 108 s.clientsMu = sync.Mutex{} 109 s.idmapsMu = sync.Mutex{} 110 s.idToClient = make(map[uint32]*client) 111 s.eventBus = make(chan clientEvent, 100) 112 return &s, nil 113} 114 115func (s *Server) setDefaultEvents(welcome string) { 116 evt := &lrcpb.Event{Msg: &lrcpb.Event_Get{Get: &lrcpb.Get{Topic: &welcome}}} 117 we, _ := proto.Marshal(evt) 118 s.welcomeEvt = we 119 120 evt = &lrcpb.Event{Msg: &lrcpb.Event_Pong{Pong: &lrcpb.Pong{}}} 121 pe, _ := proto.Marshal(evt) 122 s.pongEvt = pe 123} 124 125// Start starts a server, and returns an error if it has ever been started before 126func (s *Server) Start() error { 127 if s.ctx != nil { 128 return errors.New("cannot start already started server") 129 } 130 s.ctx, s.cancel = context.WithCancel(context.Background()) 131 go s.broadcaster() 132 s.logDebug("Hello, world!") 133 return nil 134} 135 136// Stop stops a server if it has started, and returns an error if it is already stopped 137func (s *Server) Stop() (uint32, error) { 138 if s.ctx == nil { 139 return s.lastID, nil 140 } 141 select { 142 case <-s.ctx.Done(): 143 return s.lastID, errors.New("cannot stop already stopped server") 144 default: 145 s.cancel() 146 if s.initChan != nil { 147 close(s.initChan) 148 } 149 if s.pubChan != nil { 150 close(s.pubChan) 151 } 152 s.logDebug("Goodbye world :c") 153 return s.lastID, nil 154 } 155} 156 157// Connected returns how many clients are currently connected to the server 158func (s *Server) Connected() int { 159 return len(s.clients) 160} 161 162// StopIfEmpty stops the server if it is empty, returning true. 163func (s *Server) StopIfEmpty() bool { 164 if len(s.clients) == 0 { 165 s.Stop() 166 return true 167 } 168 return false 169} 170 171func (s *Server) WSHandler() http.HandlerFunc { 172 return func(w http.ResponseWriter, r *http.Request) { 173 upgrader := &websocket.Upgrader{ 174 Subprotocols: []string{"lrc.v1"}, 175 CheckOrigin: func(r *http.Request) bool { 176 return true 177 }, 178 } 179 // initialize 180 conn, err := upgrader.Upgrade(w, r, nil) 181 if err != nil { 182 log.Println("Upgrade failed:", err) 183 return 184 } 185 defer conn.Close() 186 187 if netConn := conn.UnderlyingConn(); netConn != nil { 188 if tcpConn, ok := netConn.(*net.TCPConn); ok { 189 if err := tcpConn.SetNoDelay(true); err != nil { 190 log.Println("failed to denagle") 191 } 192 } 193 } 194 195 ctx, cancel := context.WithCancel(context.Background()) 196 client := &client{ 197 conn: conn, 198 dataChan: make(chan []byte, 100), 199 ctx: ctx, 200 cancel: cancel, 201 muteMap: make(map[*client]bool, 0), 202 mutedBy: make(map[*client]bool, 0), 203 myIDs: make([]uint32, 0), 204 } 205 206 s.clientsMu.Lock() 207 s.clients[client] = true 208 s.clientsMu.Unlock() 209 210 // lifetime 211 var wg sync.WaitGroup 212 wg.Add(2) 213 go func() { defer wg.Done(); s.wsWriter(client) }() 214 go func() { defer wg.Done(); s.listenToWS(client) }() 215 s.logDebug("new ws connection!") 216 wg.Wait() 217 218 // clean up 219 s.clientsMu.Lock() 220 delete(s.clients, client) 221 close(client.dataChan) 222 s.clientsMu.Unlock() 223 s.handlePub(client) 224 225 s.idmapsMu.Lock() 226 for _, id := range client.myIDs { // remove myself from the idToClient map 227 delete(s.idToClient, id) 228 } 229 for mutedClient := range client.muteMap { // remove myself from everyone that I muted's backreference map 230 delete(mutedClient.mutedBy, client) 231 } 232 for mutingClient := range client.mutedBy { // remove myself from everyone who muted me 233 delete(mutingClient.muteMap, client) 234 } 235 s.idmapsMu.Unlock() 236 237 conn.Close() 238 s.logDebug("closed ws connection") 239 } 240} 241 242func (s *Server) listenToWS(client *client) { 243 for { 244 select { 245 case <-client.ctx.Done(): 246 s.logDebug("exiting listenToWS: client done") 247 return 248 case <-s.ctx.Done(): 249 s.logDebug("exiting listenToWS: server done") 250 return 251 default: 252 _, data, err := client.conn.ReadMessage() 253 if err != nil { 254 s.logDebug("canceling client: read error") 255 client.cancel() 256 return 257 } 258 var event lrcpb.Event 259 err = proto.Unmarshal(data, &event) 260 if err != nil { 261 s.logDebug(err.Error()) 262 client.cancel() 263 return 264 } 265 s.eventBus <- clientEvent{client: client, event: &event} 266 } 267 } 268} 269 270func (s *Server) wsWriter(client *client) { 271 ticker := time.NewTicker(15 * time.Second) 272 for { 273 select { 274 case <-ticker.C: 275 err := client.conn.WriteControl(websocket.PingMessage, nil, time.Now().Add(5*time.Second)) 276 if err != nil { 277 client.cancel() 278 return 279 } 280 case <-client.ctx.Done(): 281 s.logDebug("exiting wsWriter: client done") 282 return 283 case <-s.ctx.Done(): 284 s.logDebug("exiting wsWriter: server done") 285 return 286 case data, ok := <-client.dataChan: 287 if !ok { 288 s.logDebug("canceling client: dataChan closed") 289 client.cancel() 290 return 291 } 292 err := client.conn.WriteMessage(websocket.BinaryMessage, data) 293 if err != nil { 294 s.logDebug(err.Error()) 295 client.cancel() 296 return 297 } 298 } 299 } 300} 301 302// broadcaster takes an event from the events channel, and broadcasts it to all the connected clients individual event channels 303func (s *Server) broadcaster() { 304 for { 305 select { 306 case <-s.ctx.Done(): 307 return 308 case ce := <-s.eventBus: 309 client := ce.client 310 event := ce.event 311 switch msg := event.Msg.(type) { 312 case *lrcpb.Event_Ping: 313 client.dataChan <- s.pongEvt 314 case *lrcpb.Event_Pong: 315 continue 316 case *lrcpb.Event_Init: 317 s.handleInit(msg, client) 318 case *lrcpb.Event_Mediainit: 319 s.handleMediainit(msg, client) 320 case *lrcpb.Event_Pub: 321 s.handlePub(client) 322 case *lrcpb.Event_Mediapub: 323 s.handleMediapub(msg, client) 324 case *lrcpb.Event_Insert: 325 s.handleInsert(msg, client) 326 case *lrcpb.Event_Delete: 327 s.handleDelete(msg, client) 328 case *lrcpb.Event_Mute: 329 s.handleMute(msg, client) 330 case *lrcpb.Event_Unmute: 331 s.handleUnmute(msg, client) 332 case *lrcpb.Event_Set: 333 s.handleSet(msg, client) 334 case *lrcpb.Event_Get: 335 s.handleGet(msg, client) 336 case *lrcpb.Event_Editbatch: 337 s.handleEditBatch(msg, client) 338 } 339 340 } 341 } 342} 343 344func (s *Server) handleInit(msg *lrcpb.Event_Init, client *client) { 345 curID := client.textID 346 if curID != nil { 347 return 348 } 349 s.idmapsMu.Lock() 350 newID := s.lastID + 1 351 s.lastID = newID 352 s.idToClient[newID] = client 353 s.idmapsMu.Unlock() 354 client.textID = &newID 355 client.myIDs = append(client.myIDs, newID) 356 newpost := "" 357 client.post = &newpost 358 msg.Init.Id = &newID 359 msg.Init.Nick = client.nick 360 msg.Init.ExternalID = client.externID 361 msg.Init.Color = client.color 362 echoed := false 363 msg.Init.Echoed = &echoed 364 msg.Init.Nonce = nil 365 if s.initChan != nil { 366 select { 367 case s.initChan <- InitChanMsg{*msg, client.resolvID}: 368 default: 369 s.log("initchan blocked, closing channel") 370 close(s.initChan) 371 s.initChan = nil 372 } 373 } 374 s.broadcastInit(msg, client) 375} 376 377func (s *Server) broadcastInit(msg *lrcpb.Event_Init, client *client) { 378 stdEvent := &lrcpb.Event{Msg: msg} 379 stdData, _ := proto.Marshal(stdEvent) 380 echoed := true 381 msg.Init.Echoed = &echoed 382 msg.Init.Nonce = GenerateNonce(*msg.Init.Id, s.uri, s.secret) 383 echoEvent := &lrcpb.Event{Msg: msg} 384 echoData, _ := proto.Marshal(echoEvent) 385 muteEvent := &lrcpb.Event{Msg: &lrcpb.Event_Mute{Mute: &lrcpb.Mute{Id: msg.Init.GetId()}}} 386 muteData, _ := proto.Marshal(muteEvent) 387 s.clientsMu.Lock() 388 defer s.clientsMu.Unlock() 389 for c := range s.clients { 390 var dts []byte 391 if c == client { 392 dts = echoData 393 } else if client.mutedBy[c] { 394 dts = muteData 395 } else { 396 dts = stdData 397 } 398 select { 399 case c.dataChan <- dts: 400 s.logDebug("b init") 401 default: 402 s.log("kicked client") 403 client.cancel() 404 } 405 } 406} 407func (s *Server) handleMediainit(msg *lrcpb.Event_Mediainit, client *client) { 408 s.logDebug("want to handle media init") 409 curId := client.mediaID 410 if curId != nil { 411 return 412 } 413 s.idmapsMu.Lock() 414 s.logDebug("handling media init") 415 newID := s.lastID + 1 416 s.lastID = newID 417 s.idToClient[newID] = client 418 s.idmapsMu.Unlock() 419 client.mediaID = &newID 420 client.myIDs = append(client.myIDs, newID) 421 msg.Mediainit.Id = &newID 422 msg.Mediainit.Nick = client.nick 423 msg.Mediainit.ExternalID = client.externID 424 msg.Mediainit.Color = client.color 425 echoed := false 426 msg.Mediainit.Echoed = &echoed 427 msg.Mediainit.Nonce = nil 428 if s.mediainitChan != nil { 429 select { 430 case s.mediainitChan <- MediaInitChanMsg{*msg, client.resolvID}: 431 default: 432 s.log("initchan blocked, closing channel") 433 close(s.mediainitChan) 434 s.mediainitChan = nil 435 } 436 } 437 s.broadcastMediainit(msg, client) 438} 439 440func (s *Server) broadcastMediainit(msg *lrcpb.Event_Mediainit, client *client) { 441 stdEvent := &lrcpb.Event{Msg: msg} 442 stdData, _ := proto.Marshal(stdEvent) 443 echoed := true 444 msg.Mediainit.Echoed = &echoed 445 msg.Mediainit.Nonce = GenerateNonce(*msg.Mediainit.Id, s.uri, s.secret) 446 echoEvent := &lrcpb.Event{Msg: msg} 447 echoData, _ := proto.Marshal(echoEvent) 448 muteEvent := &lrcpb.Event{Msg: &lrcpb.Event_Mute{Mute: &lrcpb.Mute{Id: msg.Mediainit.GetId()}}} 449 muteData, _ := proto.Marshal(muteEvent) 450 s.clientsMu.Lock() 451 defer s.clientsMu.Unlock() 452 for c := range s.clients { 453 var dts []byte 454 if c == client { 455 dts = echoData 456 } else if client.mutedBy[c] { 457 dts = muteData 458 } else { 459 dts = stdData 460 } 461 select { 462 case c.dataChan <- dts: 463 s.logDebug("b mediainit") 464 default: 465 s.log("kicked client") 466 client.cancel() 467 } 468 } 469} 470 471func (s *Server) handlePub(client *client) { 472 curID := client.textID 473 if curID == nil { 474 return 475 } 476 client.textID = nil 477 event := &lrcpb.Event{Msg: &lrcpb.Event_Pub{Pub: &lrcpb.Pub{Id: curID}}} 478 if s.pubChan != nil { 479 select { 480 case s.pubChan <- PubEvent{ID: *curID, Body: *client.post}: 481 default: 482 s.log("pubchan blocked, closing channel") 483 close(s.pubChan) 484 s.pubChan = nil 485 } 486 } 487 client.post = nil 488 s.broadcast(event, client) 489} 490 491func (s *Server) handleMediapub(msg *lrcpb.Event_Mediapub, client *client) { 492 curID := client.mediaID 493 if curID == nil { 494 return 495 } 496 client.mediaID = nil 497 msg.Mediapub.Id = curID 498 body := "external media." 499 if msg.Mediapub.Alt != nil { 500 body += fmt.Sprintf(" alt=%s.", *msg.Mediapub.Alt) 501 } 502 if msg.Mediapub.ContentAddress != nil { 503 body += fmt.Sprintf(" cid=%s.", *msg.Mediapub.ContentAddress) 504 } 505 if s.pubChan != nil { 506 select { 507 case s.pubChan <- PubEvent{ID: *curID, Body: body}: 508 default: 509 s.log("pubchan blocked, closing channel") 510 close(s.pubChan) 511 s.pubChan = nil 512 } 513 } 514 event := &lrcpb.Event{Msg: msg} 515 s.broadcast(event, client) 516} 517 518func (s *Server) handleInsert(msg *lrcpb.Event_Insert, client *client) { 519 curID := client.textID 520 if curID == nil { 521 return 522 } 523 newpost, err := insertAtUTF16Index(*client.post, msg.Insert.GetUtf16Index(), msg.Insert.GetBody()) 524 if err != nil { 525 return 526 } 527 client.post = &newpost 528 msg.Insert.Id = curID 529 event := &lrcpb.Event{Msg: msg} 530 s.broadcast(event, client) 531} 532 533func insertAtUTF16Index(base string, index uint32, insert string) (string, error) { 534 runes := []rune(base) 535 baseUTF16Units := utf16.Encode(runes) 536 if uint32(len(baseUTF16Units)) < index { 537 return "", errors.New("index out of range") 538 } 539 540 insertRunes := []rune(insert) 541 insertUTF16Units := utf16.Encode(insertRunes) 542 result := make([]uint16, 0, len(baseUTF16Units)+len(insertUTF16Units)) 543 result = append(result, baseUTF16Units[:index]...) 544 result = append(result, insertUTF16Units...) 545 result = append(result, baseUTF16Units[index:]...) 546 resultRunes := utf16.Decode(result) 547 return string(resultRunes), nil 548} 549 550func (s *Server) handleDelete(msg *lrcpb.Event_Delete, client *client) { 551 curID := client.textID 552 if curID == nil { 553 return 554 } 555 newPost, err := deleteBtwnUTF16Indices(*client.post, msg.Delete.GetUtf16Start(), msg.Delete.GetUtf16End()) 556 if err != nil { 557 return 558 } 559 client.post = &newPost 560 msg.Delete.Id = curID 561 event := &lrcpb.Event{Msg: msg} 562 s.broadcast(event, client) 563} 564 565func deleteBtwnUTF16Indices(base string, start uint32, end uint32) (string, error) { 566 if end <= start { 567 return "", errors.New("end must come after start") 568 } 569 runes := []rune(base) 570 baseUTF16Units := utf16.Encode(runes) 571 if uint32(len(baseUTF16Units)) < end { 572 return "", errors.New("index out of range") 573 } 574 result := make([]uint16, 0, uint32(len(baseUTF16Units))+start-end) 575 result = append(result, baseUTF16Units[:start]...) 576 result = append(result, baseUTF16Units[end:]...) 577 resultRunes := utf16.Decode(result) 578 return string(resultRunes), nil 579} 580func (s *Server) broadcast(event *lrcpb.Event, client *client) { 581 data, _ := proto.Marshal(event) 582 s.clientsMu.Lock() 583 defer s.clientsMu.Unlock() 584 for c := range s.clients { 585 if client.mutedBy[c] { 586 continue 587 } 588 select { 589 case c.dataChan <- data: 590 s.logDebug("b") 591 default: 592 s.log("kicked client") 593 client.cancel() 594 } 595 } 596} 597 598func (s *Server) handleEditBatch(msg *lrcpb.Event_Editbatch, client *client) { 599 curID := client.textID 600 if curID == nil { 601 return 602 } 603 plorp := *client.post 604 var err error 605 for _, edit := range msg.Editbatch.Edits { 606 switch edit := edit.Edit.(type) { 607 case *lrcpb.Edit_Insert: 608 plorp, err = insertAtUTF16Index(plorp, edit.Insert.GetUtf16Index(), edit.Insert.GetBody()) 609 if err != nil { 610 return 611 } 612 case *lrcpb.Edit_Delete: 613 plorp, err = deleteBtwnUTF16Indices(plorp, edit.Delete.GetUtf16Start(), edit.Delete.GetUtf16End()) 614 if err != nil { 615 return 616 } 617 } 618 } 619 client.post = &plorp 620 event := &lrcpb.Event{Msg: msg, Id: curID} 621 data, _ := proto.Marshal(event) 622 s.clientsMu.Lock() 623 defer s.clientsMu.Unlock() 624 for c := range s.clients { 625 if client.mutedBy[c] { 626 continue 627 } 628 select { 629 case c.dataChan <- data: 630 s.logDebug("b") 631 default: 632 s.log("kicked client") 633 client.cancel() 634 } 635 } 636} 637 638func (s *Server) handleMute(msg *lrcpb.Event_Mute, client *client) { 639 toMute := msg.Mute.GetId() 640 s.idmapsMu.Lock() 641 defer s.idmapsMu.Unlock() 642 clientToMute, ok := s.idToClient[toMute] 643 if !ok { 644 return 645 } 646 if clientToMute == client { 647 return 648 } 649 clientToMute.mutedBy[client] = true 650 client.muteMap[clientToMute] = true 651 652} 653 654func (s *Server) handleUnmute(msg *lrcpb.Event_Unmute, client *client) { 655 toMute := msg.Unmute.GetId() 656 s.idmapsMu.Lock() 657 defer s.idmapsMu.Unlock() 658 clientToMute, ok := s.idToClient[toMute] 659 if !ok { 660 return 661 } 662 if clientToMute == client { 663 return 664 } 665 delete(clientToMute.mutedBy, client) 666 delete(client.muteMap, clientToMute) 667} 668 669func (s *Server) handleSet(msg *lrcpb.Event_Set, client *client) { 670 nick := msg.Set.Nick 671 if nick != nil { 672 nickname := *nick 673 if len(nickname) <= 16 { 674 client.nick = &nickname 675 } 676 } 677 externalId := msg.Set.ExternalID 678 if externalId != nil { 679 externid := *externalId 680 client.externID = &externid 681 client.rcancel() 682 if s.resolver != nil { 683 go func() { 684 ctx, cancel := context.WithCancel(client.ctx) 685 client.rcancel = cancel 686 resolvid := s.resolver(externid, ctx) 687 client.resolvID = resolvid 688 }() 689 } 690 } 691 color := msg.Set.Color 692 if color != nil { 693 c := *color 694 if c <= 0xffffff { 695 client.color = &c 696 } 697 } 698} 699 700func (s *Server) handleGet(msg *lrcpb.Event_Get, client *client) { 701 t := msg.Get.Topic 702 if t != nil { 703 client.dataChan <- s.welcomeEvt 704 } 705 c := msg.Get.Connected 706 if c != nil { 707 conncount := uint32(len(s.clients)) 708 e := &lrcpb.Event{Msg: &lrcpb.Event_Get{Get: &lrcpb.Get{Connected: &conncount}}} 709 data, _ := proto.Marshal(e) 710 client.dataChan <- data 711 } 712} 713 714// logDebug debugs unless in production 715func (server *Server) logDebug(s string) { 716 if server.debugLogger != nil { 717 server.debugLogger.Println(s) 718 } 719} 720 721func (server *Server) log(s string) { 722 if server.logger != nil { 723 server.logger.Println(s) 724 } 725}