websocket-based lrcproto server
1package lrcd
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "github.com/gorilla/websocket"
8 "github.com/rachel-mp4/lrcproto/gen/go"
9 "google.golang.org/protobuf/proto"
10 "log"
11 "net"
12 "net/http"
13 "sync"
14 "time"
15 "unicode/utf16"
16)
17
18type Server struct {
19 secret string
20 uri string
21 eventBus chan clientEvent
22 ctx context.Context
23 cancel context.CancelFunc
24 clients map[*client]bool
25 clientsMu sync.Mutex
26 idmapsMu sync.Mutex
27 idToClient map[uint32]*client
28 lastID uint32
29 logger *log.Logger
30 debugLogger *log.Logger
31 welcomeEvt []byte
32 pongEvt []byte
33 initChan chan InitChanMsg
34 mediainitChan chan MediaInitChanMsg
35 resolver func(externalID string, ctx context.Context) *string
36 pubChan chan PubEvent
37}
38
39type PubEvent struct {
40 ID uint32
41 Body string
42}
43
44type client struct {
45 conn *websocket.Conn
46 dataChan chan []byte
47 ctx context.Context
48 cancel context.CancelFunc
49 muteMap map[*client]bool
50 mutedBy map[*client]bool
51 myIDs []uint32
52 textID *uint32
53 mediaID *uint32
54 post *string
55 nick *string
56 externID *string
57 resolvID *string
58 rcancel context.CancelFunc
59 color *uint32
60}
61
62type clientEvent struct {
63 client *client
64 event *lrcpb.Event
65}
66
67func NewServer(opts ...Option) (*Server, error) {
68 var options options
69 for _, opt := range opts {
70 err := opt(&options)
71 if err != nil {
72 return nil, err
73 }
74 }
75
76 s := Server{}
77
78 welcomeString := "Welcome to my lrc server!"
79 if options.welcome != nil {
80 welcomeString = *options.welcome
81 }
82 s.setDefaultEvents(welcomeString)
83
84 if options.writer != nil {
85 s.logger = log.New(*options.writer, "[log]", log.Ldate|log.Ltime)
86 if options.verbose {
87 s.debugLogger = log.New(*options.writer, "[debug]", log.Ldate|log.Ltime)
88 }
89 }
90
91 if options.initChan != nil {
92 s.initChan = options.initChan
93 }
94 if options.mediainitChan != nil {
95 s.mediainitChan = options.mediainitChan
96 }
97 if options.pubChan != nil {
98 s.pubChan = options.pubChan
99 }
100 if options.initialID != nil {
101 s.lastID = *options.initialID
102 }
103 s.uri = options.uri
104 s.secret = options.secret
105 s.resolver = options.resolver
106
107 s.clients = make(map[*client]bool)
108 s.clientsMu = sync.Mutex{}
109 s.idmapsMu = sync.Mutex{}
110 s.idToClient = make(map[uint32]*client)
111 s.eventBus = make(chan clientEvent, 100)
112 return &s, nil
113}
114
115func (s *Server) setDefaultEvents(welcome string) {
116 evt := &lrcpb.Event{Msg: &lrcpb.Event_Get{Get: &lrcpb.Get{Topic: &welcome}}}
117 we, _ := proto.Marshal(evt)
118 s.welcomeEvt = we
119
120 evt = &lrcpb.Event{Msg: &lrcpb.Event_Pong{Pong: &lrcpb.Pong{}}}
121 pe, _ := proto.Marshal(evt)
122 s.pongEvt = pe
123}
124
125// Start starts a server, and returns an error if it has ever been started before
126func (s *Server) Start() error {
127 if s.ctx != nil {
128 return errors.New("cannot start already started server")
129 }
130 s.ctx, s.cancel = context.WithCancel(context.Background())
131 go s.broadcaster()
132 s.logDebug("Hello, world!")
133 return nil
134}
135
136// Stop stops a server if it has started, and returns an error if it is already stopped
137func (s *Server) Stop() (uint32, error) {
138 if s.ctx == nil {
139 return s.lastID, nil
140 }
141 select {
142 case <-s.ctx.Done():
143 return s.lastID, errors.New("cannot stop already stopped server")
144 default:
145 s.cancel()
146 if s.initChan != nil {
147 close(s.initChan)
148 }
149 if s.pubChan != nil {
150 close(s.pubChan)
151 }
152 s.logDebug("Goodbye world :c")
153 return s.lastID, nil
154 }
155}
156
157// Connected returns how many clients are currently connected to the server
158func (s *Server) Connected() int {
159 return len(s.clients)
160}
161
162// StopIfEmpty stops the server if it is empty, returning true.
163func (s *Server) StopIfEmpty() bool {
164 if len(s.clients) == 0 {
165 s.Stop()
166 return true
167 }
168 return false
169}
170
171func (s *Server) WSHandler() http.HandlerFunc {
172 return func(w http.ResponseWriter, r *http.Request) {
173 upgrader := &websocket.Upgrader{
174 Subprotocols: []string{"lrc.v1"},
175 CheckOrigin: func(r *http.Request) bool {
176 return true
177 },
178 }
179 // initialize
180 conn, err := upgrader.Upgrade(w, r, nil)
181 if err != nil {
182 log.Println("Upgrade failed:", err)
183 return
184 }
185 defer conn.Close()
186
187 if netConn := conn.UnderlyingConn(); netConn != nil {
188 if tcpConn, ok := netConn.(*net.TCPConn); ok {
189 if err := tcpConn.SetNoDelay(true); err != nil {
190 log.Println("failed to denagle")
191 }
192 }
193 }
194
195 ctx, cancel := context.WithCancel(context.Background())
196 client := &client{
197 conn: conn,
198 dataChan: make(chan []byte, 100),
199 ctx: ctx,
200 cancel: cancel,
201 muteMap: make(map[*client]bool, 0),
202 mutedBy: make(map[*client]bool, 0),
203 myIDs: make([]uint32, 0),
204 }
205
206 s.clientsMu.Lock()
207 s.clients[client] = true
208 s.clientsMu.Unlock()
209
210 // lifetime
211 var wg sync.WaitGroup
212 wg.Add(2)
213 go func() { defer wg.Done(); s.wsWriter(client) }()
214 go func() { defer wg.Done(); s.listenToWS(client) }()
215 s.logDebug("new ws connection!")
216 wg.Wait()
217
218 // clean up
219 s.clientsMu.Lock()
220 delete(s.clients, client)
221 close(client.dataChan)
222 s.clientsMu.Unlock()
223 s.handlePub(client)
224
225 s.idmapsMu.Lock()
226 for _, id := range client.myIDs { // remove myself from the idToClient map
227 delete(s.idToClient, id)
228 }
229 for mutedClient := range client.muteMap { // remove myself from everyone that I muted's backreference map
230 delete(mutedClient.mutedBy, client)
231 }
232 for mutingClient := range client.mutedBy { // remove myself from everyone who muted me
233 delete(mutingClient.muteMap, client)
234 }
235 s.idmapsMu.Unlock()
236
237 conn.Close()
238 s.logDebug("closed ws connection")
239 }
240}
241
242func (s *Server) listenToWS(client *client) {
243 for {
244 select {
245 case <-client.ctx.Done():
246 s.logDebug("exiting listenToWS: client done")
247 return
248 case <-s.ctx.Done():
249 s.logDebug("exiting listenToWS: server done")
250 return
251 default:
252 _, data, err := client.conn.ReadMessage()
253 if err != nil {
254 s.logDebug("canceling client: read error")
255 client.cancel()
256 return
257 }
258 var event lrcpb.Event
259 err = proto.Unmarshal(data, &event)
260 if err != nil {
261 s.logDebug(err.Error())
262 client.cancel()
263 return
264 }
265 s.eventBus <- clientEvent{client: client, event: &event}
266 }
267 }
268}
269
270func (s *Server) wsWriter(client *client) {
271 ticker := time.NewTicker(15 * time.Second)
272 for {
273 select {
274 case <-ticker.C:
275 err := client.conn.WriteControl(websocket.PingMessage, nil, time.Now().Add(5*time.Second))
276 if err != nil {
277 client.cancel()
278 return
279 }
280 case <-client.ctx.Done():
281 s.logDebug("exiting wsWriter: client done")
282 return
283 case <-s.ctx.Done():
284 s.logDebug("exiting wsWriter: server done")
285 return
286 case data, ok := <-client.dataChan:
287 if !ok {
288 s.logDebug("canceling client: dataChan closed")
289 client.cancel()
290 return
291 }
292 err := client.conn.WriteMessage(websocket.BinaryMessage, data)
293 if err != nil {
294 s.logDebug(err.Error())
295 client.cancel()
296 return
297 }
298 }
299 }
300}
301
302// broadcaster takes an event from the events channel, and broadcasts it to all the connected clients individual event channels
303func (s *Server) broadcaster() {
304 for {
305 select {
306 case <-s.ctx.Done():
307 return
308 case ce := <-s.eventBus:
309 client := ce.client
310 event := ce.event
311 switch msg := event.Msg.(type) {
312 case *lrcpb.Event_Ping:
313 client.dataChan <- s.pongEvt
314 case *lrcpb.Event_Pong:
315 continue
316 case *lrcpb.Event_Init:
317 s.handleInit(msg, client)
318 case *lrcpb.Event_Mediainit:
319 s.handleMediainit(msg, client)
320 case *lrcpb.Event_Pub:
321 s.handlePub(client)
322 case *lrcpb.Event_Mediapub:
323 s.handleMediapub(msg, client)
324 case *lrcpb.Event_Insert:
325 s.handleInsert(msg, client)
326 case *lrcpb.Event_Delete:
327 s.handleDelete(msg, client)
328 case *lrcpb.Event_Mute:
329 s.handleMute(msg, client)
330 case *lrcpb.Event_Unmute:
331 s.handleUnmute(msg, client)
332 case *lrcpb.Event_Set:
333 s.handleSet(msg, client)
334 case *lrcpb.Event_Get:
335 s.handleGet(msg, client)
336 case *lrcpb.Event_Editbatch:
337 s.handleEditBatch(msg, client)
338 }
339
340 }
341 }
342}
343
344func (s *Server) handleInit(msg *lrcpb.Event_Init, client *client) {
345 curID := client.textID
346 if curID != nil {
347 return
348 }
349 s.idmapsMu.Lock()
350 newID := s.lastID + 1
351 s.lastID = newID
352 s.idToClient[newID] = client
353 s.idmapsMu.Unlock()
354 client.textID = &newID
355 client.myIDs = append(client.myIDs, newID)
356 newpost := ""
357 client.post = &newpost
358 msg.Init.Id = &newID
359 msg.Init.Nick = client.nick
360 msg.Init.ExternalID = client.externID
361 msg.Init.Color = client.color
362 echoed := false
363 msg.Init.Echoed = &echoed
364 msg.Init.Nonce = nil
365 if s.initChan != nil {
366 select {
367 case s.initChan <- InitChanMsg{*msg, client.resolvID}:
368 default:
369 s.log("initchan blocked, closing channel")
370 close(s.initChan)
371 s.initChan = nil
372 }
373 }
374 s.broadcastInit(msg, client)
375}
376
377func (s *Server) broadcastInit(msg *lrcpb.Event_Init, client *client) {
378 stdEvent := &lrcpb.Event{Msg: msg}
379 stdData, _ := proto.Marshal(stdEvent)
380 echoed := true
381 msg.Init.Echoed = &echoed
382 msg.Init.Nonce = GenerateNonce(*msg.Init.Id, s.uri, s.secret)
383 echoEvent := &lrcpb.Event{Msg: msg}
384 echoData, _ := proto.Marshal(echoEvent)
385 muteEvent := &lrcpb.Event{Msg: &lrcpb.Event_Mute{Mute: &lrcpb.Mute{Id: msg.Init.GetId()}}}
386 muteData, _ := proto.Marshal(muteEvent)
387 s.clientsMu.Lock()
388 defer s.clientsMu.Unlock()
389 for c := range s.clients {
390 var dts []byte
391 if c == client {
392 dts = echoData
393 } else if client.mutedBy[c] {
394 dts = muteData
395 } else {
396 dts = stdData
397 }
398 select {
399 case c.dataChan <- dts:
400 s.logDebug("b init")
401 default:
402 s.log("kicked client")
403 client.cancel()
404 }
405 }
406}
407func (s *Server) handleMediainit(msg *lrcpb.Event_Mediainit, client *client) {
408 s.logDebug("want to handle media init")
409 curId := client.mediaID
410 if curId != nil {
411 return
412 }
413 s.idmapsMu.Lock()
414 s.logDebug("handling media init")
415 newID := s.lastID + 1
416 s.lastID = newID
417 s.idToClient[newID] = client
418 s.idmapsMu.Unlock()
419 client.mediaID = &newID
420 client.myIDs = append(client.myIDs, newID)
421 msg.Mediainit.Id = &newID
422 msg.Mediainit.Nick = client.nick
423 msg.Mediainit.ExternalID = client.externID
424 msg.Mediainit.Color = client.color
425 echoed := false
426 msg.Mediainit.Echoed = &echoed
427 msg.Mediainit.Nonce = nil
428 if s.mediainitChan != nil {
429 select {
430 case s.mediainitChan <- MediaInitChanMsg{*msg, client.resolvID}:
431 default:
432 s.log("initchan blocked, closing channel")
433 close(s.mediainitChan)
434 s.mediainitChan = nil
435 }
436 }
437 s.broadcastMediainit(msg, client)
438}
439
440func (s *Server) broadcastMediainit(msg *lrcpb.Event_Mediainit, client *client) {
441 stdEvent := &lrcpb.Event{Msg: msg}
442 stdData, _ := proto.Marshal(stdEvent)
443 echoed := true
444 msg.Mediainit.Echoed = &echoed
445 msg.Mediainit.Nonce = GenerateNonce(*msg.Mediainit.Id, s.uri, s.secret)
446 echoEvent := &lrcpb.Event{Msg: msg}
447 echoData, _ := proto.Marshal(echoEvent)
448 muteEvent := &lrcpb.Event{Msg: &lrcpb.Event_Mute{Mute: &lrcpb.Mute{Id: msg.Mediainit.GetId()}}}
449 muteData, _ := proto.Marshal(muteEvent)
450 s.clientsMu.Lock()
451 defer s.clientsMu.Unlock()
452 for c := range s.clients {
453 var dts []byte
454 if c == client {
455 dts = echoData
456 } else if client.mutedBy[c] {
457 dts = muteData
458 } else {
459 dts = stdData
460 }
461 select {
462 case c.dataChan <- dts:
463 s.logDebug("b mediainit")
464 default:
465 s.log("kicked client")
466 client.cancel()
467 }
468 }
469}
470
471func (s *Server) handlePub(client *client) {
472 curID := client.textID
473 if curID == nil {
474 return
475 }
476 client.textID = nil
477 event := &lrcpb.Event{Msg: &lrcpb.Event_Pub{Pub: &lrcpb.Pub{Id: curID}}}
478 if s.pubChan != nil {
479 select {
480 case s.pubChan <- PubEvent{ID: *curID, Body: *client.post}:
481 default:
482 s.log("pubchan blocked, closing channel")
483 close(s.pubChan)
484 s.pubChan = nil
485 }
486 }
487 client.post = nil
488 s.broadcast(event, client)
489}
490
491func (s *Server) handleMediapub(msg *lrcpb.Event_Mediapub, client *client) {
492 curID := client.mediaID
493 if curID == nil {
494 return
495 }
496 client.mediaID = nil
497 msg.Mediapub.Id = curID
498 body := "external media."
499 if msg.Mediapub.Alt != nil {
500 body += fmt.Sprintf(" alt=%s.", *msg.Mediapub.Alt)
501 }
502 if msg.Mediapub.ContentAddress != nil {
503 body += fmt.Sprintf(" cid=%s.", *msg.Mediapub.ContentAddress)
504 }
505 if s.pubChan != nil {
506 select {
507 case s.pubChan <- PubEvent{ID: *curID, Body: body}:
508 default:
509 s.log("pubchan blocked, closing channel")
510 close(s.pubChan)
511 s.pubChan = nil
512 }
513 }
514 event := &lrcpb.Event{Msg: msg}
515 s.broadcast(event, client)
516}
517
518func (s *Server) handleInsert(msg *lrcpb.Event_Insert, client *client) {
519 curID := client.textID
520 if curID == nil {
521 return
522 }
523 newpost, err := insertAtUTF16Index(*client.post, msg.Insert.GetUtf16Index(), msg.Insert.GetBody())
524 if err != nil {
525 return
526 }
527 client.post = &newpost
528 msg.Insert.Id = curID
529 event := &lrcpb.Event{Msg: msg}
530 s.broadcast(event, client)
531}
532
533func insertAtUTF16Index(base string, index uint32, insert string) (string, error) {
534 runes := []rune(base)
535 baseUTF16Units := utf16.Encode(runes)
536 if uint32(len(baseUTF16Units)) < index {
537 return "", errors.New("index out of range")
538 }
539
540 insertRunes := []rune(insert)
541 insertUTF16Units := utf16.Encode(insertRunes)
542 result := make([]uint16, 0, len(baseUTF16Units)+len(insertUTF16Units))
543 result = append(result, baseUTF16Units[:index]...)
544 result = append(result, insertUTF16Units...)
545 result = append(result, baseUTF16Units[index:]...)
546 resultRunes := utf16.Decode(result)
547 return string(resultRunes), nil
548}
549
550func (s *Server) handleDelete(msg *lrcpb.Event_Delete, client *client) {
551 curID := client.textID
552 if curID == nil {
553 return
554 }
555 newPost, err := deleteBtwnUTF16Indices(*client.post, msg.Delete.GetUtf16Start(), msg.Delete.GetUtf16End())
556 if err != nil {
557 return
558 }
559 client.post = &newPost
560 msg.Delete.Id = curID
561 event := &lrcpb.Event{Msg: msg}
562 s.broadcast(event, client)
563}
564
565func deleteBtwnUTF16Indices(base string, start uint32, end uint32) (string, error) {
566 if end <= start {
567 return "", errors.New("end must come after start")
568 }
569 runes := []rune(base)
570 baseUTF16Units := utf16.Encode(runes)
571 if uint32(len(baseUTF16Units)) < end {
572 return "", errors.New("index out of range")
573 }
574 result := make([]uint16, 0, uint32(len(baseUTF16Units))+start-end)
575 result = append(result, baseUTF16Units[:start]...)
576 result = append(result, baseUTF16Units[end:]...)
577 resultRunes := utf16.Decode(result)
578 return string(resultRunes), nil
579}
580func (s *Server) broadcast(event *lrcpb.Event, client *client) {
581 data, _ := proto.Marshal(event)
582 s.clientsMu.Lock()
583 defer s.clientsMu.Unlock()
584 for c := range s.clients {
585 if client.mutedBy[c] {
586 continue
587 }
588 select {
589 case c.dataChan <- data:
590 s.logDebug("b")
591 default:
592 s.log("kicked client")
593 client.cancel()
594 }
595 }
596}
597
598func (s *Server) handleEditBatch(msg *lrcpb.Event_Editbatch, client *client) {
599 curID := client.textID
600 if curID == nil {
601 return
602 }
603 plorp := *client.post
604 var err error
605 for _, edit := range msg.Editbatch.Edits {
606 switch edit := edit.Edit.(type) {
607 case *lrcpb.Edit_Insert:
608 plorp, err = insertAtUTF16Index(plorp, edit.Insert.GetUtf16Index(), edit.Insert.GetBody())
609 if err != nil {
610 return
611 }
612 case *lrcpb.Edit_Delete:
613 plorp, err = deleteBtwnUTF16Indices(plorp, edit.Delete.GetUtf16Start(), edit.Delete.GetUtf16End())
614 if err != nil {
615 return
616 }
617 }
618 }
619 client.post = &plorp
620 event := &lrcpb.Event{Msg: msg, Id: curID}
621 data, _ := proto.Marshal(event)
622 s.clientsMu.Lock()
623 defer s.clientsMu.Unlock()
624 for c := range s.clients {
625 if client.mutedBy[c] {
626 continue
627 }
628 select {
629 case c.dataChan <- data:
630 s.logDebug("b")
631 default:
632 s.log("kicked client")
633 client.cancel()
634 }
635 }
636}
637
638func (s *Server) handleMute(msg *lrcpb.Event_Mute, client *client) {
639 toMute := msg.Mute.GetId()
640 s.idmapsMu.Lock()
641 defer s.idmapsMu.Unlock()
642 clientToMute, ok := s.idToClient[toMute]
643 if !ok {
644 return
645 }
646 if clientToMute == client {
647 return
648 }
649 clientToMute.mutedBy[client] = true
650 client.muteMap[clientToMute] = true
651
652}
653
654func (s *Server) handleUnmute(msg *lrcpb.Event_Unmute, client *client) {
655 toMute := msg.Unmute.GetId()
656 s.idmapsMu.Lock()
657 defer s.idmapsMu.Unlock()
658 clientToMute, ok := s.idToClient[toMute]
659 if !ok {
660 return
661 }
662 if clientToMute == client {
663 return
664 }
665 delete(clientToMute.mutedBy, client)
666 delete(client.muteMap, clientToMute)
667}
668
669func (s *Server) handleSet(msg *lrcpb.Event_Set, client *client) {
670 nick := msg.Set.Nick
671 if nick != nil {
672 nickname := *nick
673 if len(nickname) <= 16 {
674 client.nick = &nickname
675 }
676 }
677 externalId := msg.Set.ExternalID
678 if externalId != nil {
679 externid := *externalId
680 client.externID = &externid
681 client.rcancel()
682 if s.resolver != nil {
683 go func() {
684 ctx, cancel := context.WithCancel(client.ctx)
685 client.rcancel = cancel
686 resolvid := s.resolver(externid, ctx)
687 client.resolvID = resolvid
688 }()
689 }
690 }
691 color := msg.Set.Color
692 if color != nil {
693 c := *color
694 if c <= 0xffffff {
695 client.color = &c
696 }
697 }
698}
699
700func (s *Server) handleGet(msg *lrcpb.Event_Get, client *client) {
701 t := msg.Get.Topic
702 if t != nil {
703 client.dataChan <- s.welcomeEvt
704 }
705 c := msg.Get.Connected
706 if c != nil {
707 conncount := uint32(len(s.clients))
708 e := &lrcpb.Event{Msg: &lrcpb.Event_Get{Get: &lrcpb.Get{Connected: &conncount}}}
709 data, _ := proto.Marshal(e)
710 client.dataChan <- data
711 }
712}
713
714// logDebug debugs unless in production
715func (server *Server) logDebug(s string) {
716 if server.debugLogger != nil {
717 server.debugLogger.Println(s)
718 }
719}
720
721func (server *Server) log(s string) {
722 if server.logger != nil {
723 server.logger.Println(s)
724 }
725}