An experimental pub/sub client and server project.

Compare changes

Choose any two refs to compare.

+26
.github/workflows/workflow.yaml
··· 1 + name: Go package 2 + 3 + on: [push] 4 + 5 + jobs: 6 + build: 7 + runs-on: ubuntu-latest 8 + steps: 9 + - uses: actions/checkout@v3 10 + 11 + - name: Set up Go 12 + uses: actions/setup-go@v4 13 + with: 14 + go-version: '1.21' 15 + 16 + - name: golangci-lint 17 + run: curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin 18 + 19 + - name: Build 20 + run: go build -v ./... 21 + 22 + - name: lint 23 + run: golangci-lint run 24 + 25 + - name: Test 26 + run: go test ./... -p 1 -count=1 -v
+2 -1
.gitignore
··· 1 - .DS_STORE 1 + .DS_STORE 2 + example/example
+37
README.md
··· 1 + # Message Broker 2 + 3 + Message broker is my attempt at building client / server pub/sub system written in Go using plain `net.Conn`. 4 + 5 + I decided to try to build one to further my understanding of how such a common tool is used in software engineering. 6 + 7 + I realised that I took for granted how complicated other pub / sub message brokers were such as NATs, RabbitMQ and Kafka. By creating my own, I hope to dive into and understand more of how message brokers work. 8 + 9 + ## Concept 10 + 11 + The server accepts TCP connections. When a connection is first established, the client will send an "action" message which determines what type of client it is; subscribe or publish. 12 + 13 + ### Subscribing 14 + Once a connection has declared itself as a subscribing connection, it will need to then send the list of topics it wishes to initally subscribe to. After that the connection will enter a loop where it can then send a new action; subscribe to new topic(s) or unsubscribe from topic(s). 15 + 16 + ### Publishing 17 + Once a subscription has declared itself as a publisher, it will enter a loop where it can then send a message for a topic. Once a message has been received, the server will then send it to all connections that are subscribed to that topic. 18 + 19 + ### Sending data via a connection 20 + 21 + When sending a message representing an action (subscribe, publish etc) then a uint16 binary message is sent. 22 + 23 + When sending any other data, the length of the data is to be sent first using a binary uint32 and then the actual data sent afterwards. 24 + 25 + ## Running the server 26 + 27 + There is a server that can be run using `docker-compose up message-server`. This will start a server running listening on port 3000. 28 + 29 + ## Example clients 30 + There is an example application that implements the subscriber and publishers in the `example` directory. 31 + 32 + Run `go build .` to build the file. 33 + 34 + When running the example there are the following flags: 35 + 36 + `publish` : settings this to true will allow messages to be sent every 500ms as well as consuming 37 + `consume-from` : this allows you to specify what message to start from. If you don't set this or set it to be -1, you will start consuming from the next sent message.
+23
client/message.go
··· 1 + package client 2 + 3 + // Message represents a message that can be published or consumed 4 + type Message struct { 5 + Topic string `json:"topic"` 6 + Data []byte `json:"data"` 7 + 8 + ack chan bool 9 + } 10 + 11 + // NewMessage creates a new message 12 + func NewMessage(topic string, data []byte) *Message { 13 + return &Message{ 14 + Topic: topic, 15 + Data: data, 16 + ack: make(chan bool), 17 + } 18 + } 19 + 20 + // Ack will send the provided value of the ack to the server 21 + func (m *Message) Ack(ack bool) { 22 + m.ack <- ack 23 + }
+113
client/publisher.go
··· 1 + package client 2 + 3 + import ( 4 + "encoding/binary" 5 + "errors" 6 + "fmt" 7 + "log/slog" 8 + "net" 9 + "sync" 10 + "syscall" 11 + 12 + "github.com/willdot/messagebroker/internal/server" 13 + ) 14 + 15 + // Publisher allows messages to be published to a server 16 + type Publisher struct { 17 + conn net.Conn 18 + connMu sync.Mutex 19 + addr string 20 + } 21 + 22 + // NewPublisher connects to the server at the given address and registers as a publisher 23 + func NewPublisher(addr string) (*Publisher, error) { 24 + conn, err := connect(addr) 25 + if err != nil { 26 + return nil, fmt.Errorf("failed to connect to server: %w", err) 27 + } 28 + 29 + return &Publisher{ 30 + conn: conn, 31 + addr: addr, 32 + }, nil 33 + } 34 + 35 + func connect(addr string) (net.Conn, error) { 36 + conn, err := net.Dial("tcp", addr) 37 + if err != nil { 38 + return nil, fmt.Errorf("failed to dial: %w", err) 39 + } 40 + 41 + err = binary.Write(conn, binary.BigEndian, server.Publish) 42 + if err != nil { 43 + conn.Close() 44 + return nil, fmt.Errorf("failed to register publish to server: %w", err) 45 + } 46 + return conn, nil 47 + } 48 + 49 + // Close cleanly shuts down the publisher 50 + func (p *Publisher) Close() error { 51 + return p.conn.Close() 52 + } 53 + 54 + // Publish will publish the given message to the server 55 + func (p *Publisher) PublishMessage(message *Message) error { 56 + return p.publishMessageWithRetry(message, 0) 57 + } 58 + 59 + func (p *Publisher) publishMessageWithRetry(message *Message, attempt int) error { 60 + op := func(conn net.Conn) error { 61 + // send topic first 62 + topic := fmt.Sprintf("topic:%s", message.Topic) 63 + 64 + topicLenB := make([]byte, 2) 65 + binary.BigEndian.PutUint16(topicLenB, uint16(len(topic))) 66 + 67 + headers := append(topicLenB, []byte(topic)...) 68 + 69 + messageLenB := make([]byte, 4) 70 + binary.BigEndian.PutUint32(messageLenB, uint32(len(message.Data))) 71 + headers = append(headers, messageLenB...) 72 + 73 + _, err := conn.Write(append(headers, message.Data...)) 74 + if err != nil { 75 + return fmt.Errorf("failed to publish data to server: %w", err) 76 + } 77 + return nil 78 + } 79 + 80 + err := p.connOperation(op) 81 + if err == nil { 82 + return nil 83 + } 84 + 85 + // we can handle a broken pipe by trying to reconnect, but if it's a different error return it 86 + if !errors.Is(err, syscall.EPIPE) { 87 + return err 88 + } 89 + 90 + slog.Info("error is broken pipe") 91 + 92 + if attempt >= 5 { 93 + return fmt.Errorf("failed to publish message after max attempts to reconnect (%d): %w", attempt, err) 94 + } 95 + 96 + slog.Error("failed to publish message", "error", err) 97 + 98 + conn, connectErr := connect(p.addr) 99 + if connectErr != nil { 100 + return fmt.Errorf("failed to reconnect after failing to publish message: %w", connectErr) 101 + } 102 + 103 + p.conn = conn 104 + 105 + return p.publishMessageWithRetry(message, attempt+1) 106 + } 107 + 108 + func (p *Publisher) connOperation(op connOpp) error { 109 + p.connMu.Lock() 110 + defer p.connMu.Unlock() 111 + 112 + return op(p.conn) 113 + }
+371
client/subscriber.go
··· 1 + package client 2 + 3 + import ( 4 + "context" 5 + "encoding/binary" 6 + "encoding/json" 7 + "errors" 8 + "fmt" 9 + "io" 10 + "log/slog" 11 + "net" 12 + "sync" 13 + "syscall" 14 + "time" 15 + 16 + "github.com/willdot/messagebroker/internal/server" 17 + ) 18 + 19 + type connOpp func(conn net.Conn) error 20 + 21 + // Subscriber allows subscriptions to a server and the consumption of messages 22 + type Subscriber struct { 23 + conn net.Conn 24 + connMu sync.Mutex 25 + subscribedTopics []string 26 + addr string 27 + } 28 + 29 + // NewSubscriber will connect to the server at the given address 30 + func NewSubscriber(addr string) (*Subscriber, error) { 31 + conn, err := net.Dial("tcp", addr) 32 + if err != nil { 33 + return nil, fmt.Errorf("failed to dial: %w", err) 34 + } 35 + 36 + return &Subscriber{ 37 + conn: conn, 38 + addr: addr, 39 + }, nil 40 + } 41 + 42 + func (s *Subscriber) reconnect() error { 43 + conn, err := net.Dial("tcp", s.addr) 44 + if err != nil { 45 + return fmt.Errorf("failed to dial: %w", err) 46 + } 47 + 48 + s.conn = conn 49 + return nil 50 + } 51 + 52 + // Close cleanly shuts down the subscriber 53 + func (s *Subscriber) Close() error { 54 + return s.conn.Close() 55 + } 56 + 57 + // SubscribeToTopics will subscribe to the provided topics 58 + func (s *Subscriber) SubscribeToTopics(topicNames []string, startAtType server.StartAtType, startAtIndex int) error { 59 + op := func(conn net.Conn) error { 60 + return subscribeToTopics(conn, topicNames, startAtType, startAtIndex) 61 + } 62 + 63 + err := s.connOperation(op) 64 + if err != nil { 65 + return fmt.Errorf("failed to subscribe to topics: %w", err) 66 + } 67 + 68 + s.addToSubscribedTopics(topicNames) 69 + 70 + return nil 71 + } 72 + 73 + func (s *Subscriber) addToSubscribedTopics(topics []string) { 74 + existingSubs := make(map[string]struct{}) 75 + for _, topic := range s.subscribedTopics { 76 + existingSubs[topic] = struct{}{} 77 + } 78 + 79 + for _, topic := range topics { 80 + existingSubs[topic] = struct{}{} 81 + } 82 + 83 + subs := make([]string, 0, len(existingSubs)) 84 + for topic := range existingSubs { 85 + subs = append(subs, topic) 86 + } 87 + 88 + s.subscribedTopics = subs 89 + } 90 + 91 + func (s *Subscriber) removeTopicsFromSubscription(topics []string) { 92 + existingSubs := make(map[string]struct{}) 93 + for _, topic := range s.subscribedTopics { 94 + existingSubs[topic] = struct{}{} 95 + } 96 + 97 + for _, topic := range topics { 98 + delete(existingSubs, topic) 99 + } 100 + 101 + subs := make([]string, 0, len(existingSubs)) 102 + for topic := range existingSubs { 103 + subs = append(subs, topic) 104 + } 105 + 106 + s.subscribedTopics = subs 107 + } 108 + 109 + // UnsubscribeToTopics will unsubscribe to the provided topics 110 + func (s *Subscriber) UnsubscribeToTopics(topicNames []string) error { 111 + op := func(conn net.Conn) error { 112 + return unsubscribeToTopics(conn, topicNames) 113 + } 114 + 115 + err := s.connOperation(op) 116 + if err != nil { 117 + return fmt.Errorf("failed to unsubscribe to topics: %w", err) 118 + } 119 + 120 + s.removeTopicsFromSubscription(topicNames) 121 + 122 + return nil 123 + } 124 + 125 + func subscribeToTopics(conn net.Conn, topicNames []string, startAtType server.StartAtType, startAtIndex int) error { 126 + actionB := make([]byte, 2) 127 + binary.BigEndian.PutUint16(actionB, uint16(server.Subscribe)) 128 + headers := actionB 129 + 130 + b, err := json.Marshal(topicNames) 131 + if err != nil { 132 + return fmt.Errorf("failed to marshal topic names: %w", err) 133 + } 134 + 135 + topicNamesB := make([]byte, 4) 136 + binary.BigEndian.PutUint32(topicNamesB, uint32(len(b))) 137 + headers = append(headers, topicNamesB...) 138 + headers = append(headers, b...) 139 + 140 + startAtTypeB := make([]byte, 2) 141 + binary.BigEndian.PutUint16(startAtTypeB, uint16(startAtType)) 142 + headers = append(headers, startAtTypeB...) 143 + 144 + if startAtType == server.From { 145 + fromB := make([]byte, 2) 146 + binary.BigEndian.PutUint16(fromB, uint16(startAtIndex)) 147 + headers = append(headers, fromB...) 148 + } 149 + 150 + _, err = conn.Write(headers) 151 + if err != nil { 152 + return fmt.Errorf("failed to subscribe to topics: %w", err) 153 + } 154 + 155 + var resp server.Status 156 + err = binary.Read(conn, binary.BigEndian, &resp) 157 + if err != nil { 158 + return fmt.Errorf("failed to read confirmation of subscribe: %w", err) 159 + } 160 + 161 + if resp == server.Subscribed { 162 + return nil 163 + } 164 + 165 + var dataLen uint16 166 + err = binary.Read(conn, binary.BigEndian, &dataLen) 167 + if err != nil { 168 + return fmt.Errorf("received status %s:", resp) 169 + } 170 + 171 + buf := make([]byte, dataLen) 172 + _, err = conn.Read(buf) 173 + if err != nil { 174 + return fmt.Errorf("received status %s:", resp) 175 + } 176 + 177 + return fmt.Errorf("received status %s - %s", resp, buf) 178 + } 179 + 180 + func unsubscribeToTopics(conn net.Conn, topicNames []string) error { 181 + actionB := make([]byte, 2) 182 + binary.BigEndian.PutUint16(actionB, uint16(server.Unsubscribe)) 183 + headers := actionB 184 + 185 + b, err := json.Marshal(topicNames) 186 + if err != nil { 187 + return fmt.Errorf("failed to marshal topic names: %w", err) 188 + } 189 + 190 + topicNamesB := make([]byte, 4) 191 + binary.BigEndian.PutUint32(topicNamesB, uint32(len(b))) 192 + headers = append(headers, topicNamesB...) 193 + 194 + _, err = conn.Write(append(headers, b...)) 195 + if err != nil { 196 + return fmt.Errorf("failed to unsubscribe to topics: %w", err) 197 + } 198 + 199 + var resp server.Status 200 + err = binary.Read(conn, binary.BigEndian, &resp) 201 + if err != nil { 202 + return fmt.Errorf("failed to read confirmation of unsubscribe: %w", err) 203 + } 204 + 205 + if resp == server.Unsubscribed { 206 + return nil 207 + } 208 + 209 + var dataLen uint16 210 + err = binary.Read(conn, binary.BigEndian, &dataLen) 211 + if err != nil { 212 + return fmt.Errorf("received status %s:", resp) 213 + } 214 + 215 + buf := make([]byte, dataLen) 216 + _, err = conn.Read(buf) 217 + if err != nil { 218 + return fmt.Errorf("received status %s:", resp) 219 + } 220 + 221 + return fmt.Errorf("received status %s - %s", resp, buf) 222 + } 223 + 224 + // Consumer allows the consumption of messages. If during the consumer receiving messages from the 225 + // server an error occurs, it will be stored in Err 226 + type Consumer struct { 227 + msgs chan *Message 228 + // TODO: better error handling? Maybe a channel of errors? 229 + Err error 230 + } 231 + 232 + // Messages returns a channel in which this consumer will put messages onto. It is safe to range over the channel since it will be closed once 233 + // the consumer has finished either due to an error or from being cancelled. 234 + func (c *Consumer) Messages() <-chan *Message { 235 + return c.msgs 236 + } 237 + 238 + // Consume will create a consumer and start it running in a go routine. You can then use the Msgs channel of the consumer 239 + // to read the messages 240 + func (s *Subscriber) Consume(ctx context.Context) *Consumer { 241 + consumer := &Consumer{ 242 + msgs: make(chan *Message), 243 + } 244 + 245 + go s.consume(ctx, consumer) 246 + 247 + return consumer 248 + } 249 + 250 + func (s *Subscriber) consume(ctx context.Context, consumer *Consumer) { 251 + defer close(consumer.msgs) 252 + for { 253 + if ctx.Err() != nil { 254 + return 255 + } 256 + 257 + err := s.readMessage(ctx, consumer.msgs) 258 + if err == nil { 259 + continue 260 + } 261 + 262 + // if we couldn't connect to the server, attempt to reconnect 263 + if !errors.Is(err, syscall.EPIPE) && !errors.Is(err, io.EOF) { 264 + slog.Error("failed to read message", "error", err) 265 + consumer.Err = err 266 + return 267 + } 268 + 269 + slog.Info("attempting to reconnect") 270 + 271 + for i := 0; i < 5; i++ { 272 + time.Sleep(time.Millisecond * 500) 273 + err = s.reconnect() 274 + if err == nil { 275 + break 276 + } 277 + 278 + slog.Error("Failed to reconnect", "error", err, "attempt", i) 279 + } 280 + 281 + slog.Info("attempting to resubscribe") 282 + 283 + err = s.SubscribeToTopics(s.subscribedTopics, server.Current, 0) 284 + if err != nil { 285 + consumer.Err = fmt.Errorf("failed to subscribe to topics after reconnecting: %w", err) 286 + return 287 + } 288 + 289 + } 290 + } 291 + 292 + func (s *Subscriber) readMessage(ctx context.Context, msgChan chan *Message) error { 293 + op := func(conn net.Conn) error { 294 + err := s.conn.SetReadDeadline(time.Now().Add(time.Millisecond * 300)) 295 + if err != nil { 296 + return err 297 + } 298 + 299 + var topicLen uint16 300 + err = binary.Read(s.conn, binary.BigEndian, &topicLen) 301 + if err != nil { 302 + // TODO: check if this is needed elsewhere. I'm not sure where the read deadline resets.... 303 + if neterr, ok := err.(net.Error); ok && neterr.Timeout() { 304 + return nil 305 + } 306 + return err 307 + } 308 + 309 + topicBuf := make([]byte, topicLen) 310 + _, err = s.conn.Read(topicBuf) 311 + if err != nil { 312 + return err 313 + } 314 + 315 + var dataLen uint64 316 + err = binary.Read(s.conn, binary.BigEndian, &dataLen) 317 + if err != nil { 318 + return err 319 + } 320 + 321 + if dataLen <= 0 { 322 + return nil 323 + } 324 + 325 + dataBuf := make([]byte, dataLen) 326 + _, err = s.conn.Read(dataBuf) 327 + if err != nil { 328 + return err 329 + } 330 + 331 + msg := NewMessage(string(topicBuf), dataBuf) 332 + 333 + msgChan <- msg 334 + 335 + var ack bool 336 + select { 337 + case <-ctx.Done(): 338 + return ctx.Err() 339 + case ack = <-msg.ack: 340 + } 341 + ackMessage := server.Nack 342 + if ack { 343 + ackMessage = server.Ack 344 + } 345 + 346 + err = binary.Write(s.conn, binary.BigEndian, ackMessage) 347 + if err != nil { 348 + return fmt.Errorf("failed to ack/nack message: %w", err) 349 + } 350 + 351 + return nil 352 + } 353 + 354 + err := s.connOperation(op) 355 + if err != nil { 356 + var neterr net.Error 357 + if errors.As(err, &neterr) && neterr.Timeout() { 358 + return nil 359 + } 360 + return err 361 + } 362 + 363 + return err 364 + } 365 + 366 + func (s *Subscriber) connOperation(op connOpp) error { 367 + s.connMu.Lock() 368 + defer s.connMu.Unlock() 369 + 370 + return op(s.conn) 371 + }
+273
client/subscriber_test.go
··· 1 + package client 2 + 3 + import ( 4 + "context" 5 + "fmt" 6 + "testing" 7 + "time" 8 + 9 + "github.com/stretchr/testify/assert" 10 + "github.com/stretchr/testify/require" 11 + "github.com/willdot/messagebroker/internal/server" 12 + ) 13 + 14 + const ( 15 + serverAddr = ":9999" 16 + topicA = "topic a" 17 + topicB = "topic b" 18 + ) 19 + 20 + func createServer(t *testing.T) { 21 + server, err := server.New(serverAddr, time.Millisecond*100, time.Millisecond*100) 22 + require.NoError(t, err) 23 + 24 + t.Cleanup(func() { 25 + _ = server.Shutdown() 26 + }) 27 + } 28 + 29 + func TestNewSubscriber(t *testing.T) { 30 + createServer(t) 31 + 32 + sub, err := NewSubscriber(serverAddr) 33 + require.NoError(t, err) 34 + 35 + t.Cleanup(func() { 36 + sub.Close() 37 + }) 38 + } 39 + 40 + func TestNewSubscriberInvalidServerAddr(t *testing.T) { 41 + createServer(t) 42 + 43 + _, err := NewSubscriber(":123456") 44 + require.Error(t, err) 45 + } 46 + 47 + func TestNewPublisher(t *testing.T) { 48 + createServer(t) 49 + 50 + sub, err := NewPublisher(serverAddr) 51 + require.NoError(t, err) 52 + 53 + t.Cleanup(func() { 54 + sub.Close() 55 + }) 56 + } 57 + 58 + func TestNewPublisherInvalidServerAddr(t *testing.T) { 59 + createServer(t) 60 + 61 + _, err := NewPublisher(":123456") 62 + require.Error(t, err) 63 + } 64 + 65 + func TestSubscribeToTopics(t *testing.T) { 66 + createServer(t) 67 + 68 + sub, err := NewSubscriber(serverAddr) 69 + require.NoError(t, err) 70 + 71 + t.Cleanup(func() { 72 + sub.Close() 73 + }) 74 + 75 + topics := []string{topicA, topicB} 76 + 77 + err = sub.SubscribeToTopics(topics, server.Current, 0) 78 + require.NoError(t, err) 79 + } 80 + 81 + func TestUnsubscribesFromTopic(t *testing.T) { 82 + createServer(t) 83 + 84 + sub, err := NewSubscriber(serverAddr) 85 + require.NoError(t, err) 86 + 87 + t.Cleanup(func() { 88 + sub.Close() 89 + }) 90 + 91 + topics := []string{topicA, topicB} 92 + 93 + err = sub.SubscribeToTopics(topics, server.Current, 0) 94 + require.NoError(t, err) 95 + 96 + err = sub.UnsubscribeToTopics([]string{topicA}) 97 + require.NoError(t, err) 98 + 99 + ctx, cancel := context.WithCancel(context.Background()) 100 + t.Cleanup(func() { 101 + cancel() 102 + }) 103 + 104 + consumer := sub.Consume(ctx) 105 + require.NoError(t, err) 106 + 107 + var receivedMessages []*Message 108 + consumerFinCh := make(chan struct{}) 109 + go func() { 110 + for msg := range consumer.Messages() { 111 + msg.Ack(true) 112 + receivedMessages = append(receivedMessages, msg) 113 + } 114 + 115 + consumerFinCh <- struct{}{} 116 + }() 117 + 118 + // publish a message to both topics and check the subscriber only gets the message from the 1 topic 119 + // and not the unsubscribed topic 120 + 121 + publisher, err := NewPublisher("localhost:9999") 122 + require.NoError(t, err) 123 + t.Cleanup(func() { 124 + publisher.Close() 125 + }) 126 + 127 + msg := NewMessage(topicA, []byte("hello world")) 128 + 129 + err = publisher.PublishMessage(msg) 130 + require.NoError(t, err) 131 + 132 + msg.Topic = topicB 133 + err = publisher.PublishMessage(msg) 134 + require.NoError(t, err) 135 + 136 + // give the consumer some time to read the messages -- TODO: make better! 137 + time.Sleep(time.Millisecond * 300) 138 + cancel() 139 + 140 + select { 141 + case <-consumerFinCh: 142 + break 143 + case <-time.After(time.Second): 144 + t.Fatal("timed out waiting for consumer to read messages") 145 + } 146 + 147 + assert.Len(t, receivedMessages, 1) 148 + assert.Equal(t, topicB, receivedMessages[0].Topic) 149 + } 150 + 151 + func TestPublishAndSubscribe(t *testing.T) { 152 + consumer, cancel := setupConsumer(t) 153 + 154 + var receivedMessages []*Message 155 + 156 + consumerFinCh := make(chan struct{}) 157 + go func() { 158 + for msg := range consumer.Messages() { 159 + msg.Ack(true) 160 + receivedMessages = append(receivedMessages, msg) 161 + } 162 + 163 + consumerFinCh <- struct{}{} 164 + }() 165 + 166 + publisher, err := NewPublisher("localhost:9999") 167 + require.NoError(t, err) 168 + t.Cleanup(func() { 169 + publisher.Close() 170 + }) 171 + 172 + // send some messages 173 + sentMessages := make([]*Message, 0, 10) 174 + for i := 0; i < 10; i++ { 175 + msg := NewMessage(topicA, []byte(fmt.Sprintf("message %d", i))) 176 + 177 + sentMessages = append(sentMessages, msg) 178 + 179 + err = publisher.PublishMessage(msg) 180 + require.NoError(t, err) 181 + } 182 + 183 + // give the consumer some time to read the messages -- TODO: make better! 184 + time.Sleep(time.Millisecond * 300) 185 + cancel() 186 + 187 + select { 188 + case <-consumerFinCh: 189 + break 190 + case <-time.After(time.Second * 5): 191 + t.Fatal("timed out waiting for consumer to read messages") 192 + } 193 + 194 + // THIS IS SO HACKY 195 + for _, msg := range receivedMessages { 196 + msg.ack = nil 197 + } 198 + 199 + for _, msg := range sentMessages { 200 + msg.ack = nil 201 + } 202 + 203 + assert.ElementsMatch(t, receivedMessages, sentMessages) 204 + } 205 + 206 + func TestPublishAndSubscribeNackMessage(t *testing.T) { 207 + consumer, cancel := setupConsumer(t) 208 + 209 + var receivedMessages []*Message 210 + 211 + consumerFinCh := make(chan struct{}) 212 + timesMsgWasReceived := 0 213 + go func() { 214 + for msg := range consumer.Messages() { 215 + msg.Ack(false) 216 + timesMsgWasReceived++ 217 + } 218 + 219 + consumerFinCh <- struct{}{} 220 + }() 221 + 222 + publisher, err := NewPublisher("localhost:9999") 223 + require.NoError(t, err) 224 + t.Cleanup(func() { 225 + publisher.Close() 226 + }) 227 + 228 + // send a message 229 + msg := NewMessage(topicA, []byte("hello world")) 230 + 231 + err = publisher.PublishMessage(msg) 232 + require.NoError(t, err) 233 + 234 + // give the consumer some time to read the messages -- TODO: make better! 235 + time.Sleep(time.Second) 236 + cancel() 237 + 238 + select { 239 + case <-consumerFinCh: 240 + break 241 + case <-time.After(time.Second * 5): 242 + t.Fatal("timed out waiting for consumer to read messages") 243 + } 244 + 245 + assert.Empty(t, receivedMessages) 246 + assert.Equal(t, 5, timesMsgWasReceived) 247 + } 248 + 249 + func setupConsumer(t *testing.T) (*Consumer, context.CancelFunc) { 250 + createServer(t) 251 + 252 + sub, err := NewSubscriber(serverAddr) 253 + require.NoError(t, err) 254 + 255 + t.Cleanup(func() { 256 + sub.Close() 257 + }) 258 + 259 + topics := []string{topicA, topicB} 260 + 261 + err = sub.SubscribeToTopics(topics, server.Current, 0) 262 + require.NoError(t, err) 263 + 264 + ctx, cancel := context.WithCancel(context.Background()) 265 + t.Cleanup(func() { 266 + cancel() 267 + }) 268 + 269 + consumer := sub.Consume(ctx) 270 + require.NoError(t, err) 271 + 272 + return consumer, cancel 273 + }
+27
cmd/server/main.go
··· 1 + package main 2 + 3 + import ( 4 + "log" 5 + "os" 6 + "os/signal" 7 + "syscall" 8 + "time" 9 + 10 + "github.com/willdot/messagebroker/internal/server" 11 + ) 12 + 13 + func main() { 14 + srv, err := server.New(":3000", time.Second, time.Second*2) 15 + if err != nil { 16 + log.Fatal(err) 17 + } 18 + 19 + defer func() { 20 + _ = srv.Shutdown() 21 + }() 22 + 23 + signals := make(chan os.Signal, 1) 24 + signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) 25 + 26 + <-signals 27 + }
+7
docker-compose.yaml
··· 1 + version: "3.7" 2 + services: 3 + message-server: 4 + build: 5 + context: . 6 + dockerfile: dockerfile.server 7 + ports: [ "3000:3000" ]
+20
dockerfile.server
··· 1 + FROM golang:latest as builder 2 + 3 + WORKDIR /app 4 + 5 + COPY go.mod go.sum ./ 6 + COPY cmd/server/ ./ 7 + RUN go mod download 8 + 9 + COPY . . 10 + 11 + RUN CGO_ENABLED=0 go build -o message-broker-server . 12 + 13 + FROM alpine:latest 14 + 15 + RUN apk --no-cache add ca-certificates 16 + 17 + WORKDIR /root/ 18 + COPY --from=builder /app/message-broker-server . 19 + 20 + CMD ["./message-broker-server"]
+98
example/main.go
··· 1 + package main 2 + 3 + import ( 4 + "context" 5 + "flag" 6 + "fmt" 7 + "log/slog" 8 + "time" 9 + 10 + "github.com/willdot/messagebroker/client" 11 + "github.com/willdot/messagebroker/internal/server" 12 + ) 13 + 14 + // var publish *bool 15 + // var consume *bool 16 + var consumeFrom *int 17 + var clientType *string 18 + 19 + const ( 20 + topic = "topic-a" 21 + ) 22 + 23 + func main() { 24 + clientType = flag.String("client-type", "consume", "consume or publish (default consume)") 25 + // publish = flag.Bool("publish", false, "will publish messages every 500ms until client is stopped") 26 + // consume = flag.Bool("consume", false, "will consume messages until client is stopped") 27 + consumeFrom = flag.Int("consume-from", -1, "index of message to start consuming from. If not set it will consume from the most recent") 28 + flag.Parse() 29 + 30 + switch *clientType { 31 + case "consume": 32 + consume() 33 + case "publish": 34 + sendMessages() 35 + default: 36 + fmt.Println("unknown client type") 37 + } 38 + } 39 + 40 + func consume() { 41 + sub, err := client.NewSubscriber(":3000") 42 + if err != nil { 43 + panic(err) 44 + } 45 + 46 + defer func() { 47 + _ = sub.Close() 48 + }() 49 + startAt := 0 50 + startAtType := server.Current 51 + if *consumeFrom > -1 { 52 + startAtType = server.From 53 + startAt = *consumeFrom 54 + } 55 + 56 + err = sub.SubscribeToTopics([]string{topic}, startAtType, startAt) 57 + if err != nil { 58 + panic(err) 59 + } 60 + 61 + consumer := sub.Consume(context.Background()) 62 + if consumer.Err != nil { 63 + panic(err) 64 + } 65 + 66 + for msg := range consumer.Messages() { 67 + slog.Info("received message", "message", string(msg.Data)) 68 + msg.Ack(true) 69 + } 70 + } 71 + 72 + func sendMessages() { 73 + publisher, err := client.NewPublisher("localhost:3000") 74 + if err != nil { 75 + panic(err) 76 + } 77 + 78 + defer func() { 79 + _ = publisher.Close() 80 + }() 81 + 82 + // send some messages 83 + i := 0 84 + for { 85 + i++ 86 + msg := client.NewMessage(topic, []byte(fmt.Sprintf("message %d", i))) 87 + 88 + err = publisher.PublishMessage(msg) 89 + if err != nil { 90 + slog.Error("failed to publish message", "error", err) 91 + continue 92 + } 93 + 94 + slog.Info("message sent") 95 + 96 + time.Sleep(time.Millisecond * 500) 97 + } 98 + }
+1 -1
go.mod
··· 1 - module github.com/willdot/message-broker 1 + module github.com/willdot/messagebroker 2 2 3 3 go 1.21.0 4 4
+47
internal/messagestore/memory_store.go
··· 1 + package messagestore 2 + 3 + import ( 4 + "sync" 5 + 6 + "github.com/willdot/messagebroker/internal" 7 + ) 8 + 9 + // MemoryStore allows messages to be stored in memory 10 + type MemoryStore struct { 11 + mu sync.Mutex 12 + msgs map[int]internal.Message 13 + nextOffset int 14 + } 15 + 16 + // NewMemoryStore initializes a new in memory store 17 + func NewMemoryStore() *MemoryStore { 18 + return &MemoryStore{ 19 + msgs: make(map[int]internal.Message), 20 + } 21 + } 22 + 23 + // Write will write the provided message to the in memory store 24 + func (m *MemoryStore) Write(msg internal.Message) error { 25 + m.mu.Lock() 26 + defer m.mu.Unlock() 27 + 28 + m.msgs[m.nextOffset] = msg 29 + 30 + m.nextOffset++ 31 + 32 + return nil 33 + } 34 + 35 + // ReadFrom will read messages from (and including) the provided offset and pass them to the provided handler 36 + func (m *MemoryStore) ReadFrom(offset int, handleFunc func(msg internal.Message)) { 37 + if offset < 0 || offset >= m.nextOffset { 38 + return 39 + } 40 + 41 + m.mu.Lock() 42 + defer m.mu.Unlock() 43 + 44 + for i := offset; i < len(m.msgs); i++ { 45 + handleFunc(m.msgs[i]) 46 + } 47 + }
+12
internal/messge.go
··· 1 + package internal 2 + 3 + // Message represents a message that can be sent / received 4 + type Message struct { 5 + Data []byte 6 + DeliveryCount int 7 + } 8 + 9 + // NewMessage intializes a new message 10 + func NewMessage(data []byte) Message { 11 + return Message{Data: data, DeliveryCount: 1} 12 + }
+36
internal/server/peer.go
··· 1 + package server 2 + 3 + import ( 4 + "net" 5 + "sync" 6 + ) 7 + 8 + // Peer represents a remote connection to the server such as a publisher or subscriber. 9 + type Peer struct { 10 + conn net.Conn 11 + connMu sync.Mutex 12 + } 13 + 14 + // New returns a new peer. 15 + func NewPeer(conn net.Conn) *Peer { 16 + return &Peer{ 17 + conn: conn, 18 + } 19 + } 20 + 21 + // Addr returns the peers connections address. 22 + func (p *Peer) Addr() net.Addr { 23 + return p.conn.RemoteAddr() 24 + } 25 + 26 + // ConnOpp represents a set of actions on a connection that can be used synchrnously. 27 + type ConnOpp func(conn net.Conn) error 28 + 29 + // RunConnOperation will run the provided operation. It ensures that it is the only operation that is being 30 + // run on the connection to ensure any other operations don't get mixed up. 31 + func (p *Peer) RunConnOperation(op ConnOpp) error { 32 + p.connMu.Lock() 33 + defer p.connMu.Unlock() 34 + 35 + return op(p.conn) 36 + }
+513
internal/server/server.go
··· 1 + package server 2 + 3 + import ( 4 + "encoding/binary" 5 + "encoding/json" 6 + "errors" 7 + "fmt" 8 + "io" 9 + "log/slog" 10 + "net" 11 + "strings" 12 + "sync" 13 + "syscall" 14 + "time" 15 + 16 + "github.com/willdot/messagebroker/internal" 17 + ) 18 + 19 + // Action represents the type of action that a peer requests to do 20 + type Action uint16 21 + 22 + const ( 23 + Subscribe Action = 1 24 + Unsubscribe Action = 2 25 + Publish Action = 3 26 + Ack Action = 4 27 + Nack Action = 5 28 + ) 29 + 30 + // Status represents the status of a request 31 + type Status uint16 32 + 33 + const ( 34 + Subscribed Status = 1 35 + Unsubscribed Status = 2 36 + Error Status = 3 37 + ) 38 + 39 + func (s Status) String() string { 40 + switch s { 41 + case Subscribed: 42 + return "subscribed" 43 + case Unsubscribed: 44 + return "unsubscribed" 45 + case Error: 46 + return "error" 47 + } 48 + 49 + return "" 50 + } 51 + 52 + // StartAtType represents where the subcriber wishes to start subscribing to a topic from 53 + type StartAtType uint16 54 + 55 + const ( 56 + Beginning StartAtType = 0 57 + Current StartAtType = 1 58 + From StartAtType = 2 59 + ) 60 + 61 + // Server accepts subscribe and publish connections and passes messages around 62 + type Server struct { 63 + Addr string 64 + lis net.Listener 65 + 66 + mu sync.Mutex 67 + topics map[string]*topic 68 + 69 + ackDelay time.Duration 70 + ackTimeout time.Duration 71 + } 72 + 73 + // New creates and starts a new server 74 + func New(Addr string, ackDelay, ackTimeout time.Duration) (*Server, error) { 75 + lis, err := net.Listen("tcp", Addr) 76 + if err != nil { 77 + return nil, fmt.Errorf("failed to listen: %w", err) 78 + } 79 + 80 + srv := &Server{ 81 + lis: lis, 82 + topics: map[string]*topic{}, 83 + ackDelay: ackDelay, 84 + ackTimeout: ackTimeout, 85 + } 86 + 87 + go srv.start() 88 + 89 + return srv, nil 90 + } 91 + 92 + // Shutdown will cleanly shutdown the server 93 + func (s *Server) Shutdown() error { 94 + return s.lis.Close() 95 + } 96 + 97 + func (s *Server) start() { 98 + for { 99 + conn, err := s.lis.Accept() 100 + if err != nil { 101 + if errors.Is(err, net.ErrClosed) { 102 + slog.Info("listener closed") 103 + return 104 + } 105 + slog.Error("listener failed to accept", "error", err) 106 + continue 107 + } 108 + 109 + go s.handleConn(conn) 110 + } 111 + } 112 + 113 + func (s *Server) handleConn(conn net.Conn) { 114 + peer := NewPeer(conn) 115 + 116 + slog.Info("handling connection", "peer", peer.Addr()) 117 + defer slog.Info("ending connection", "peer", peer.Addr()) 118 + 119 + action, err := readAction(peer, 0) 120 + if err != nil { 121 + if !errors.Is(err, io.EOF) { 122 + slog.Error("failed to read action from peer", "error", err, "peer", peer.Addr()) 123 + } 124 + return 125 + } 126 + 127 + switch action { 128 + case Subscribe: 129 + s.handleSubscribe(peer) 130 + case Unsubscribe: 131 + s.handleUnsubscribe(peer) 132 + case Publish: 133 + s.handlePublish(peer) 134 + default: 135 + slog.Error("unknown action", "action", action, "peer", peer.Addr()) 136 + writeInvalidAction(peer) 137 + } 138 + } 139 + 140 + func (s *Server) handleSubscribe(peer *Peer) { 141 + slog.Info("handling subscriber", "peer", peer.Addr()) 142 + // subscribe the peer to the topic 143 + s.subscribePeerToTopic(peer) 144 + 145 + s.waitForPeerAction(peer) 146 + } 147 + 148 + func (s *Server) waitForPeerAction(peer *Peer) { 149 + // keep handling the peers connection, getting the action from the peer when it wishes to do something else. 150 + // once the peers connection ends, it will be unsubscribed from all topics and returned 151 + for { 152 + action, err := readAction(peer, time.Millisecond*100) 153 + if err != nil { 154 + // if the error is a timeout, it means the peer hasn't sent an action indicating it wishes to do something so sleep 155 + // for a little bit to allow for other actions to happen on the connection 156 + var neterr net.Error 157 + if errors.As(err, &neterr) && neterr.Timeout() { 158 + time.Sleep(time.Millisecond * 500) 159 + continue 160 + } 161 + 162 + if !errors.Is(err, io.EOF) { 163 + slog.Error("failed to read action from subscriber", "error", err, "peer", peer.Addr()) 164 + } 165 + 166 + s.unsubscribePeerFromAllTopics(peer) 167 + 168 + return 169 + } 170 + 171 + switch action { 172 + case Subscribe: 173 + s.subscribePeerToTopic(peer) 174 + case Unsubscribe: 175 + s.handleUnsubscribe(peer) 176 + default: 177 + slog.Error("unknown action for subscriber", "action", action, "peer", peer.Addr()) 178 + writeInvalidAction(peer) 179 + continue 180 + } 181 + } 182 + } 183 + 184 + func (s *Server) subscribePeerToTopic(peer *Peer) { 185 + op := func(conn net.Conn) error { 186 + // get the topics the peer wishes to subscribe to 187 + dataLen, err := dataLengthUint32(conn) 188 + if err != nil { 189 + slog.Error(err.Error(), "peer", peer.Addr()) 190 + writeStatus(Error, "invalid data length of topics provided", conn) 191 + return nil 192 + } 193 + if dataLen == 0 { 194 + writeStatus(Error, "data length of topics is 0", conn) 195 + return nil 196 + } 197 + 198 + buf := make([]byte, dataLen) 199 + _, err = conn.Read(buf) 200 + if err != nil { 201 + slog.Error("failed to read subscibers topic data", "error", err, "peer", peer.Addr()) 202 + writeStatus(Error, "failed to read topic data", conn) 203 + return nil 204 + } 205 + 206 + var topics []string 207 + err = json.Unmarshal(buf, &topics) 208 + if err != nil { 209 + slog.Error("failed to unmarshal subscibers topic data", "error", err, "peer", peer.Addr()) 210 + writeStatus(Error, "invalid topic data provided", conn) 211 + return nil 212 + } 213 + 214 + var startAtType StartAtType 215 + err = binary.Read(conn, binary.BigEndian, &startAtType) 216 + if err != nil { 217 + slog.Error(err.Error(), "peer", peer.Addr()) 218 + writeStatus(Error, "invalid start at type provided", conn) 219 + return nil 220 + } 221 + var startAt int 222 + switch startAtType { 223 + case From: 224 + var s uint16 225 + err = binary.Read(conn, binary.BigEndian, &s) 226 + if err != nil { 227 + slog.Error(err.Error(), "peer", peer.Addr()) 228 + writeStatus(Error, "invalid start at value provided", conn) 229 + return nil 230 + } 231 + startAt = int(s) 232 + case Beginning: 233 + startAt = 0 234 + case Current: 235 + startAt = -1 236 + default: 237 + slog.Error("invalid start up type provided", "start up type", startAtType) 238 + writeStatus(Error, "invalid start up type provided", conn) 239 + return nil 240 + } 241 + 242 + s.subscribeToTopics(peer, topics, startAt) 243 + writeStatus(Subscribed, "", conn) 244 + 245 + return nil 246 + } 247 + 248 + _ = peer.RunConnOperation(op) 249 + } 250 + 251 + func (s *Server) handleUnsubscribe(peer *Peer) { 252 + slog.Info("handling unsubscriber", "peer", peer.Addr()) 253 + op := func(conn net.Conn) error { 254 + // get the topics the peer wishes to unsubscribe from 255 + dataLen, err := dataLengthUint32(conn) 256 + if err != nil { 257 + slog.Error(err.Error(), "peer", peer.Addr()) 258 + writeStatus(Error, "invalid data length of topics provided", conn) 259 + return nil 260 + } 261 + if dataLen == 0 { 262 + writeStatus(Error, "data length of topics is 0", conn) 263 + return nil 264 + } 265 + 266 + buf := make([]byte, dataLen) 267 + _, err = conn.Read(buf) 268 + if err != nil { 269 + slog.Error("failed to read subscibers topic data", "error", err, "peer", peer.Addr()) 270 + writeStatus(Error, "failed to read topic data", conn) 271 + return nil 272 + } 273 + 274 + var topics []string 275 + err = json.Unmarshal(buf, &topics) 276 + if err != nil { 277 + slog.Error("failed to unmarshal subscibers topic data", "error", err, "peer", peer.Addr()) 278 + writeStatus(Error, "invalid topic data provided", conn) 279 + return nil 280 + } 281 + 282 + s.unsubscribeToTopics(peer, topics) 283 + writeStatus(Unsubscribed, "", conn) 284 + 285 + return nil 286 + } 287 + 288 + _ = peer.RunConnOperation(op) 289 + } 290 + 291 + func (s *Server) handlePublish(peer *Peer) { 292 + slog.Info("handling publisher", "peer", peer.Addr()) 293 + for { 294 + op := func(conn net.Conn) error { 295 + topicDataLen, err := dataLengthUint16(conn) 296 + if err != nil { 297 + if errors.Is(err, io.EOF) { 298 + return nil 299 + } 300 + slog.Error("failed to read data length", "error", err, "peer", peer.Addr()) 301 + writeStatus(Error, "invalid data length of data provided", conn) 302 + return nil 303 + } 304 + if topicDataLen == 0 { 305 + return nil 306 + } 307 + topicBuf := make([]byte, topicDataLen) 308 + _, err = conn.Read(topicBuf) 309 + if err != nil { 310 + slog.Error("failed to read topic from peer", "error", err, "peer", peer.Addr()) 311 + writeStatus(Error, "failed to read topic", conn) 312 + return nil 313 + } 314 + 315 + topicStr := string(topicBuf) 316 + if !strings.HasPrefix(topicStr, "topic:") { 317 + slog.Error("topic data does not contain topic prefix", "peer", peer.Addr()) 318 + writeStatus(Error, "topic data does not contain 'topic:' prefix", conn) 319 + return nil 320 + } 321 + topicStr = strings.TrimPrefix(topicStr, "topic:") 322 + 323 + msgDataLen, err := dataLengthUint32(conn) 324 + if err != nil { 325 + slog.Error(err.Error(), "peer", peer.Addr()) 326 + writeStatus(Error, "invalid data length of data provided", conn) 327 + return nil 328 + } 329 + if msgDataLen == 0 { 330 + return nil 331 + } 332 + 333 + dataBuf := make([]byte, msgDataLen) 334 + _, err = conn.Read(dataBuf) 335 + if err != nil { 336 + slog.Error("failed to read data from peer", "error", err, "peer", peer.Addr()) 337 + writeStatus(Error, "failed to read data", conn) 338 + return nil 339 + } 340 + 341 + topic := s.getTopic(topicStr) 342 + if topic == nil { 343 + topic = newTopic(topicStr) 344 + s.topics[topicStr] = topic 345 + } 346 + 347 + message := internal.NewMessage(dataBuf) 348 + 349 + err = topic.sendMessageToSubscribers(message) 350 + if err != nil { 351 + slog.Error("failed to send message to subscribers", "error", err, "peer", peer.Addr()) 352 + writeStatus(Error, "failed to send message to subscribers", conn) 353 + return nil 354 + } 355 + 356 + return nil 357 + } 358 + 359 + _ = peer.RunConnOperation(op) 360 + } 361 + } 362 + 363 + func (s *Server) subscribeToTopics(peer *Peer, topics []string, startAt int) { 364 + slog.Info("subscribing peer to topics", "topics", topics, "peer", peer.Addr()) 365 + for _, topic := range topics { 366 + s.addSubsciberToTopic(topic, peer, startAt) 367 + } 368 + } 369 + 370 + func (s *Server) addSubsciberToTopic(topicName string, peer *Peer, startAt int) { 371 + s.mu.Lock() 372 + defer s.mu.Unlock() 373 + 374 + t, ok := s.topics[topicName] 375 + if !ok { 376 + t = newTopic(topicName) 377 + } 378 + 379 + t.mu.Lock() 380 + t.subscriptions[peer.Addr()] = newSubscriber(peer, t, s.ackDelay, s.ackTimeout, startAt) 381 + t.mu.Unlock() 382 + 383 + s.topics[topicName] = t 384 + } 385 + 386 + func (s *Server) unsubscribeToTopics(peer *Peer, topics []string) { 387 + slog.Info("unsubscribing peer from topics", "topics", topics, "peer", peer.Addr()) 388 + for _, topic := range topics { 389 + s.removeSubsciberFromTopic(topic, peer) 390 + } 391 + } 392 + 393 + func (s *Server) removeSubsciberFromTopic(topicName string, peer *Peer) { 394 + s.mu.Lock() 395 + defer s.mu.Unlock() 396 + 397 + t, ok := s.topics[topicName] 398 + if !ok { 399 + return 400 + } 401 + 402 + sub := t.findSubscription(peer.Addr()) 403 + if sub == nil { 404 + return 405 + } 406 + 407 + sub.unsubscribe() 408 + t.removeSubscription(peer.Addr()) 409 + } 410 + 411 + func (s *Server) unsubscribePeerFromAllTopics(peer *Peer) { 412 + s.mu.Lock() 413 + defer s.mu.Unlock() 414 + 415 + for _, t := range s.topics { 416 + sub := t.findSubscription(peer.Addr()) 417 + if sub == nil { 418 + return 419 + } 420 + 421 + sub.unsubscribe() 422 + t.removeSubscription(peer.Addr()) 423 + } 424 + } 425 + 426 + func (s *Server) getTopic(topicName string) *topic { 427 + s.mu.Lock() 428 + defer s.mu.Unlock() 429 + 430 + if topic, ok := s.topics[topicName]; ok { 431 + return topic 432 + } 433 + 434 + return nil 435 + } 436 + 437 + func readAction(peer *Peer, timeout time.Duration) (Action, error) { 438 + var action Action 439 + op := func(conn net.Conn) error { 440 + if timeout > 0 { 441 + if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil { 442 + slog.Error("failed to set connection read deadline", "error", err, "peer", peer.Addr()) 443 + } 444 + defer func() { 445 + if err := conn.SetReadDeadline(time.Time{}); err != nil { 446 + slog.Error("failed to reset connection read deadline", "error", err, "peer", peer.Addr()) 447 + } 448 + }() 449 + } 450 + 451 + err := binary.Read(conn, binary.BigEndian, &action) 452 + if err != nil { 453 + return err 454 + } 455 + return nil 456 + } 457 + 458 + err := peer.RunConnOperation(op) 459 + if err != nil { 460 + return 0, fmt.Errorf("failed to read action from peer: %w", err) 461 + } 462 + 463 + return action, nil 464 + } 465 + 466 + func writeInvalidAction(peer *Peer) { 467 + op := func(conn net.Conn) error { 468 + writeStatus(Error, "unknown action", conn) 469 + return nil 470 + } 471 + 472 + _ = peer.RunConnOperation(op) 473 + } 474 + 475 + func dataLengthUint32(conn net.Conn) (uint32, error) { 476 + var dataLen uint32 477 + err := binary.Read(conn, binary.BigEndian, &dataLen) 478 + if err != nil { 479 + return 0, err 480 + } 481 + return dataLen, nil 482 + } 483 + 484 + func dataLengthUint16(conn net.Conn) (uint16, error) { 485 + var dataLen uint16 486 + err := binary.Read(conn, binary.BigEndian, &dataLen) 487 + if err != nil { 488 + return 0, err 489 + } 490 + return dataLen, nil 491 + } 492 + 493 + func writeStatus(status Status, message string, conn net.Conn) { 494 + statusB := make([]byte, 2) 495 + binary.BigEndian.PutUint16(statusB, uint16(status)) 496 + 497 + headers := statusB 498 + 499 + if len(message) > 0 { 500 + sizeB := make([]byte, 2) 501 + binary.BigEndian.PutUint16(sizeB, uint16(len(message))) 502 + headers = append(headers, sizeB...) 503 + } 504 + 505 + msgBytes := []byte(message) 506 + _, err := conn.Write(append(headers, msgBytes...)) 507 + if err != nil { 508 + if !errors.Is(err, syscall.EPIPE) { 509 + slog.Error("failed to write status to peers connection", "error", err, "peer", conn.RemoteAddr()) 510 + } 511 + return 512 + } 513 + }
+632
internal/server/server_test.go
··· 1 + package server 2 + 3 + import ( 4 + "encoding/binary" 5 + "encoding/json" 6 + "fmt" 7 + "net" 8 + "testing" 9 + "time" 10 + 11 + "github.com/stretchr/testify/assert" 12 + "github.com/stretchr/testify/require" 13 + "github.com/willdot/messagebroker/internal/messagestore" 14 + ) 15 + 16 + const ( 17 + topicA = "topic a" 18 + topicB = "topic b" 19 + topicC = "topic c" 20 + 21 + serverAddr = ":6666" 22 + 23 + ackDelay = time.Millisecond * 100 24 + ackTimeout = time.Millisecond * 100 25 + ) 26 + 27 + func createServer(t *testing.T) *Server { 28 + srv, err := New(serverAddr, ackDelay, ackTimeout) 29 + require.NoError(t, err) 30 + 31 + t.Cleanup(func() { 32 + _ = srv.Shutdown() 33 + }) 34 + 35 + return srv 36 + } 37 + 38 + func createServerWithExistingTopic(t *testing.T, topicName string) *Server { 39 + srv := createServer(t) 40 + srv.topics[topicName] = &topic{ 41 + name: topicName, 42 + subscriptions: make(map[net.Addr]*subscriber), 43 + messageStore: messagestore.NewMemoryStore(), 44 + } 45 + 46 + return srv 47 + } 48 + 49 + func createConnectionAndSubscribe(t *testing.T, topics []string, startAtType StartAtType, startAtIndex int) net.Conn { 50 + conn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 51 + require.NoError(t, err) 52 + 53 + subscribeToTopics(t, conn, topics, startAtType, startAtIndex) 54 + 55 + expectedRes := Subscribed 56 + 57 + var resp Status 58 + err = binary.Read(conn, binary.BigEndian, &resp) 59 + require.NoError(t, err) 60 + 61 + assert.Equal(t, expectedRes, resp) 62 + 63 + return conn 64 + } 65 + 66 + func sendMessage(t *testing.T, conn net.Conn, topic string, message []byte) { 67 + topicLenB := make([]byte, 4) 68 + binary.BigEndian.PutUint32(topicLenB, uint32(len(topic))) 69 + 70 + headers := topicLenB 71 + headers = append(headers, []byte(topic)...) 72 + 73 + messageLenB := make([]byte, 4) 74 + binary.BigEndian.PutUint32(messageLenB, uint32(len(message))) 75 + headers = append(headers, messageLenB...) 76 + 77 + _, err := conn.Write(append(headers, message...)) 78 + require.NoError(t, err) 79 + } 80 + 81 + func subscribeToTopics(t *testing.T, conn net.Conn, topics []string, startAtType StartAtType, startAtIndex int) { 82 + actionB := make([]byte, 2) 83 + binary.BigEndian.PutUint16(actionB, uint16(Subscribe)) 84 + headers := actionB 85 + 86 + b, err := json.Marshal(topics) 87 + require.NoError(t, err) 88 + 89 + topicNamesB := make([]byte, 4) 90 + binary.BigEndian.PutUint32(topicNamesB, uint32(len(b))) 91 + headers = append(headers, topicNamesB...) 92 + headers = append(headers, b...) 93 + 94 + startAtTypeB := make([]byte, 2) 95 + binary.BigEndian.PutUint16(startAtTypeB, uint16(startAtType)) 96 + headers = append(headers, startAtTypeB...) 97 + 98 + if startAtType == From { 99 + fromB := make([]byte, 2) 100 + binary.BigEndian.PutUint16(fromB, uint16(startAtIndex)) 101 + headers = append(headers, fromB...) 102 + } 103 + 104 + _, err = conn.Write(headers) 105 + require.NoError(t, err) 106 + } 107 + 108 + func unsubscribetoTopics(t *testing.T, conn net.Conn, topics []string) { 109 + actionB := make([]byte, 2) 110 + binary.BigEndian.PutUint16(actionB, uint16(Unsubscribe)) 111 + headers := actionB 112 + 113 + b, err := json.Marshal(topics) 114 + require.NoError(t, err) 115 + 116 + topicNamesB := make([]byte, 4) 117 + binary.BigEndian.PutUint32(topicNamesB, uint32(len(b))) 118 + headers = append(headers, topicNamesB...) 119 + 120 + _, err = conn.Write(append(headers, b...)) 121 + require.NoError(t, err) 122 + } 123 + 124 + func TestSubscribeToTopics(t *testing.T) { 125 + // create a server with an existing topic so we can test subscribing to a new and 126 + // existing topic 127 + srv := createServerWithExistingTopic(t, topicA) 128 + 129 + _ = createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0) 130 + 131 + srv.mu.Lock() 132 + defer srv.mu.Unlock() 133 + assert.Len(t, srv.topics, 2) 134 + assert.Len(t, srv.topics[topicA].subscriptions, 1) 135 + assert.Len(t, srv.topics[topicB].subscriptions, 1) 136 + } 137 + 138 + func TestUnsubscribesFromTopic(t *testing.T) { 139 + srv := createServerWithExistingTopic(t, topicA) 140 + 141 + conn := createConnectionAndSubscribe(t, []string{topicA, topicB, topicC}, Current, 0) 142 + 143 + assert.Len(t, srv.topics, 3) 144 + 145 + srv.mu.Lock() 146 + assert.Len(t, srv.topics[topicA].subscriptions, 1) 147 + assert.Len(t, srv.topics[topicB].subscriptions, 1) 148 + assert.Len(t, srv.topics[topicC].subscriptions, 1) 149 + srv.mu.Unlock() 150 + 151 + topics := []string{topicA, topicB} 152 + 153 + unsubscribetoTopics(t, conn, topics) 154 + 155 + expectedRes := Unsubscribed 156 + 157 + var resp Status 158 + err := binary.Read(conn, binary.BigEndian, &resp) 159 + require.NoError(t, err) 160 + 161 + assert.Equal(t, expectedRes, resp) 162 + 163 + assert.Len(t, srv.topics, 3) 164 + 165 + srv.mu.Lock() 166 + assert.Len(t, srv.topics[topicA].subscriptions, 0) 167 + assert.Len(t, srv.topics[topicB].subscriptions, 0) 168 + assert.Len(t, srv.topics[topicC].subscriptions, 1) 169 + srv.mu.Unlock() 170 + } 171 + 172 + func TestSubscriberClosesWithoutUnsubscribing(t *testing.T) { 173 + srv := createServer(t) 174 + 175 + conn := createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0) 176 + 177 + assert.Len(t, srv.topics, 2) 178 + 179 + srv.mu.Lock() 180 + assert.Len(t, srv.topics[topicA].subscriptions, 1) 181 + assert.Len(t, srv.topics[topicB].subscriptions, 1) 182 + srv.mu.Unlock() 183 + 184 + // close the conn 185 + err := conn.Close() 186 + require.NoError(t, err) 187 + 188 + publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 189 + require.NoError(t, err) 190 + 191 + err = binary.Write(publisherConn, binary.BigEndian, Publish) 192 + require.NoError(t, err) 193 + 194 + data := []byte("hello world") 195 + 196 + sendMessage(t, publisherConn, topicA, data) 197 + 198 + // the timeout for a connection is 100 milliseconds, so we should wait at least this long before checking the unsubscribe 199 + // TODO: see if theres a better way, but without this, the test is flakey 200 + time.Sleep(time.Millisecond * 100) 201 + 202 + assert.Len(t, srv.topics, 2) 203 + 204 + srv.mu.Lock() 205 + assert.Len(t, srv.topics[topicA].subscriptions, 0) 206 + assert.Len(t, srv.topics[topicB].subscriptions, 0) 207 + srv.mu.Unlock() 208 + } 209 + 210 + func TestInvalidAction(t *testing.T) { 211 + _ = createServer(t) 212 + 213 + conn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 214 + require.NoError(t, err) 215 + 216 + err = binary.Write(conn, binary.BigEndian, uint16(99)) 217 + require.NoError(t, err) 218 + 219 + expectedRes := Error 220 + 221 + var resp Status 222 + err = binary.Read(conn, binary.BigEndian, &resp) 223 + require.NoError(t, err) 224 + 225 + assert.Equal(t, expectedRes, resp) 226 + 227 + expectedMessage := "unknown action" 228 + 229 + var dataLen uint16 230 + err = binary.Read(conn, binary.BigEndian, &dataLen) 231 + require.NoError(t, err) 232 + assert.Equal(t, len(expectedMessage), int(dataLen)) 233 + 234 + buf := make([]byte, dataLen) 235 + _, err = conn.Read(buf) 236 + require.NoError(t, err) 237 + 238 + assert.Equal(t, expectedMessage, string(buf)) 239 + } 240 + 241 + func TestInvalidTopicDataPublished(t *testing.T) { 242 + _ = createServer(t) 243 + 244 + publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 245 + require.NoError(t, err) 246 + 247 + err = binary.Write(publisherConn, binary.BigEndian, Publish) 248 + require.NoError(t, err) 249 + 250 + // send topic 251 + topic := topicA 252 + err = binary.Write(publisherConn, binary.BigEndian, uint32(len(topic))) 253 + require.NoError(t, err) 254 + _, err = publisherConn.Write([]byte(topic)) 255 + require.NoError(t, err) 256 + 257 + expectedRes := Error 258 + 259 + var resp Status 260 + err = binary.Read(publisherConn, binary.BigEndian, &resp) 261 + require.NoError(t, err) 262 + 263 + assert.Equal(t, expectedRes, resp) 264 + 265 + expectedMessage := "topic data does not contain 'topic:' prefix" 266 + 267 + var dataLen uint16 268 + err = binary.Read(publisherConn, binary.BigEndian, &dataLen) 269 + require.NoError(t, err) 270 + assert.Equal(t, len(expectedMessage), int(dataLen)) 271 + 272 + buf := make([]byte, dataLen) 273 + _, err = publisherConn.Read(buf) 274 + require.NoError(t, err) 275 + 276 + assert.Equal(t, expectedMessage, string(buf)) 277 + } 278 + 279 + func TestSendsDataToTopicSubscribers(t *testing.T) { 280 + _ = createServer(t) 281 + 282 + subscribers := make([]net.Conn, 0, 10) 283 + for i := 0; i < 10; i++ { 284 + subscriberConn := createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0) 285 + 286 + subscribers = append(subscribers, subscriberConn) 287 + } 288 + 289 + publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 290 + require.NoError(t, err) 291 + 292 + err = binary.Write(publisherConn, binary.BigEndian, Publish) 293 + require.NoError(t, err) 294 + 295 + topic := fmt.Sprintf("topic:%s", topicA) 296 + messageData := "hello world" 297 + 298 + sendMessage(t, publisherConn, topic, []byte(messageData)) 299 + 300 + // check the subsribers got the data 301 + for _, conn := range subscribers { 302 + msg := readMessage(t, conn) 303 + assert.Equal(t, messageData, string(msg)) 304 + } 305 + } 306 + 307 + func TestPublishMultipleTimes(t *testing.T) { 308 + _ = createServer(t) 309 + 310 + publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 311 + require.NoError(t, err) 312 + 313 + err = binary.Write(publisherConn, binary.BigEndian, Publish) 314 + require.NoError(t, err) 315 + 316 + messages := make([]string, 0, 10) 317 + for i := 0; i < 10; i++ { 318 + messages = append(messages, fmt.Sprintf("message %d", i)) 319 + } 320 + 321 + subscribeFinCh := make(chan struct{}) 322 + // create a subscriber that will read messages 323 + subscriberConn := createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0) 324 + go func() { 325 + // check subscriber got all messages 326 + results := make([]string, 0, len(messages)) 327 + for i := 0; i < len(messages); i++ { 328 + msg := readMessage(t, subscriberConn) 329 + results = append(results, string(msg)) 330 + } 331 + 332 + assert.ElementsMatch(t, results, messages) 333 + 334 + subscribeFinCh <- struct{}{} 335 + }() 336 + 337 + topic := fmt.Sprintf("topic:%s", topicA) 338 + 339 + // send multiple messages 340 + for _, msg := range messages { 341 + sendMessage(t, publisherConn, topic, []byte(msg)) 342 + } 343 + 344 + select { 345 + case <-subscribeFinCh: 346 + break 347 + case <-time.After(time.Second): 348 + t.Fatal(fmt.Errorf("timed out waiting for subscriber to read messages")) 349 + } 350 + } 351 + 352 + func TestSendsDataToTopicSubscriberNacksThenAcks(t *testing.T) { 353 + _ = createServer(t) 354 + 355 + subscriberConn := createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0) 356 + 357 + publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 358 + require.NoError(t, err) 359 + 360 + err = binary.Write(publisherConn, binary.BigEndian, Publish) 361 + require.NoError(t, err) 362 + 363 + topic := fmt.Sprintf("topic:%s", topicA) 364 + messageData := "hello world" 365 + 366 + sendMessage(t, publisherConn, topic, []byte(messageData)) 367 + 368 + // check the subsribers got the data 369 + readMessage := func(conn net.Conn, ack Action) { 370 + var topicLen uint16 371 + err = binary.Read(conn, binary.BigEndian, &topicLen) 372 + require.NoError(t, err) 373 + 374 + topicBuf := make([]byte, topicLen) 375 + _, err = conn.Read(topicBuf) 376 + require.NoError(t, err) 377 + assert.Equal(t, topicA, string(topicBuf)) 378 + 379 + var dataLen uint64 380 + err = binary.Read(conn, binary.BigEndian, &dataLen) 381 + require.NoError(t, err) 382 + 383 + buf := make([]byte, dataLen) 384 + n, err := conn.Read(buf) 385 + require.NoError(t, err) 386 + 387 + require.Equal(t, int(dataLen), n) 388 + 389 + assert.Equal(t, messageData, string(buf)) 390 + 391 + err = binary.Write(conn, binary.BigEndian, ack) 392 + require.NoError(t, err) 393 + } 394 + 395 + // NACK the message and then ack it 396 + readMessage(subscriberConn, Nack) 397 + readMessage(subscriberConn, Ack) 398 + // reading for another message should now timeout but give enough time for the ack delay to kick in 399 + // should the second read of the message not have been ack'd properly 400 + var topicLen uint16 401 + _ = subscriberConn.SetReadDeadline(time.Now().Add(ackDelay + time.Millisecond*100)) 402 + err = binary.Read(subscriberConn, binary.BigEndian, &topicLen) 403 + require.Error(t, err) 404 + } 405 + 406 + func TestSendsDataToTopicSubscriberDoesntAckMessage(t *testing.T) { 407 + _ = createServer(t) 408 + 409 + subscriberConn := createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0) 410 + 411 + publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 412 + require.NoError(t, err) 413 + 414 + err = binary.Write(publisherConn, binary.BigEndian, Publish) 415 + require.NoError(t, err) 416 + 417 + topic := fmt.Sprintf("topic:%s", topicA) 418 + messageData := "hello world" 419 + 420 + sendMessage(t, publisherConn, topic, []byte(messageData)) 421 + 422 + // check the subsribers got the data 423 + readMessage := func(conn net.Conn, ack bool) { 424 + var topicLen uint16 425 + err = binary.Read(conn, binary.BigEndian, &topicLen) 426 + require.NoError(t, err) 427 + 428 + topicBuf := make([]byte, topicLen) 429 + _, err = conn.Read(topicBuf) 430 + require.NoError(t, err) 431 + assert.Equal(t, topicA, string(topicBuf)) 432 + 433 + var dataLen uint64 434 + err = binary.Read(conn, binary.BigEndian, &dataLen) 435 + require.NoError(t, err) 436 + 437 + buf := make([]byte, dataLen) 438 + n, err := conn.Read(buf) 439 + require.NoError(t, err) 440 + 441 + require.Equal(t, int(dataLen), n) 442 + 443 + assert.Equal(t, messageData, string(buf)) 444 + 445 + if ack { 446 + err = binary.Write(conn, binary.BigEndian, Ack) 447 + require.NoError(t, err) 448 + return 449 + } 450 + } 451 + 452 + // don't send ack or nack and then ack on the second attempt 453 + readMessage(subscriberConn, false) 454 + readMessage(subscriberConn, true) 455 + 456 + // reading for another message should now timeout but give enough time for the ack delay to kick in 457 + // should the second read of the message not have been ack'd properly 458 + var topicLen uint16 459 + _ = subscriberConn.SetReadDeadline(time.Now().Add(ackDelay + time.Millisecond*100)) 460 + err = binary.Read(subscriberConn, binary.BigEndian, &topicLen) 461 + require.Error(t, err) 462 + } 463 + 464 + func TestSendsDataToTopicSubscriberDeliveryCountTooHighWithNoAck(t *testing.T) { 465 + _ = createServer(t) 466 + 467 + subscriberConn := createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0) 468 + 469 + publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 470 + require.NoError(t, err) 471 + 472 + err = binary.Write(publisherConn, binary.BigEndian, Publish) 473 + require.NoError(t, err) 474 + 475 + topic := fmt.Sprintf("topic:%s", topicA) 476 + messageData := "hello world" 477 + 478 + sendMessage(t, publisherConn, topic, []byte(messageData)) 479 + 480 + // check the subsribers got the data 481 + readMessage := func(conn net.Conn, ack bool) { 482 + var topicLen uint16 483 + err = binary.Read(conn, binary.BigEndian, &topicLen) 484 + require.NoError(t, err) 485 + 486 + topicBuf := make([]byte, topicLen) 487 + _, err = conn.Read(topicBuf) 488 + require.NoError(t, err) 489 + assert.Equal(t, topicA, string(topicBuf)) 490 + 491 + var dataLen uint64 492 + err = binary.Read(conn, binary.BigEndian, &dataLen) 493 + require.NoError(t, err) 494 + 495 + buf := make([]byte, dataLen) 496 + n, err := conn.Read(buf) 497 + require.NoError(t, err) 498 + 499 + require.Equal(t, int(dataLen), n) 500 + 501 + assert.Equal(t, messageData, string(buf)) 502 + 503 + if ack { 504 + err = binary.Write(conn, binary.BigEndian, Ack) 505 + require.NoError(t, err) 506 + return 507 + } 508 + } 509 + 510 + // nack the message 5 times 511 + readMessage(subscriberConn, false) 512 + readMessage(subscriberConn, false) 513 + readMessage(subscriberConn, false) 514 + readMessage(subscriberConn, false) 515 + readMessage(subscriberConn, false) 516 + 517 + // reading for the message should now timeout as we have nack'd the message too many times 518 + var topicLen uint16 519 + _ = subscriberConn.SetReadDeadline(time.Now().Add(ackDelay + time.Millisecond*100)) 520 + err = binary.Read(subscriberConn, binary.BigEndian, &topicLen) 521 + require.Error(t, err) 522 + } 523 + 524 + func TestSubscribeAndReplaysFromStart(t *testing.T) { 525 + _ = createServer(t) 526 + 527 + publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 528 + require.NoError(t, err) 529 + 530 + err = binary.Write(publisherConn, binary.BigEndian, Publish) 531 + require.NoError(t, err) 532 + 533 + messages := make([]string, 0, 10) 534 + for i := 0; i < 10; i++ { 535 + messages = append(messages, fmt.Sprintf("message %d", i)) 536 + } 537 + 538 + topic := fmt.Sprintf("topic:%s", topicA) 539 + 540 + for _, msg := range messages { 541 + sendMessage(t, publisherConn, topic, []byte(msg)) 542 + } 543 + 544 + // send some messages for topic B as well 545 + sendMessage(t, publisherConn, fmt.Sprintf("topic:%s", topicB), []byte("topic b message 1")) 546 + sendMessage(t, publisherConn, fmt.Sprintf("topic:%s", topicB), []byte("topic b message 2")) 547 + sendMessage(t, publisherConn, fmt.Sprintf("topic:%s", topicB), []byte("topic b message 3")) 548 + 549 + subscriberConn := createConnectionAndSubscribe(t, []string{topicA}, From, 0) 550 + results := make([]string, 0, len(messages)) 551 + for i := 0; i < len(messages); i++ { 552 + msg := readMessage(t, subscriberConn) 553 + results = append(results, string(msg)) 554 + } 555 + assert.ElementsMatch(t, results, messages) 556 + } 557 + 558 + func TestSubscribeAndReplaysFromIndex(t *testing.T) { 559 + _ = createServer(t) 560 + 561 + publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 562 + require.NoError(t, err) 563 + 564 + err = binary.Write(publisherConn, binary.BigEndian, Publish) 565 + require.NoError(t, err) 566 + 567 + messages := make([]string, 0, 10) 568 + for i := 0; i < 10; i++ { 569 + messages = append(messages, fmt.Sprintf("message %d", i)) 570 + } 571 + 572 + topic := fmt.Sprintf("topic:%s", topicA) 573 + 574 + // send multiple messages 575 + for _, msg := range messages { 576 + sendMessage(t, publisherConn, topic, []byte(msg)) 577 + } 578 + 579 + // send some messages for topic B as well 580 + sendMessage(t, publisherConn, fmt.Sprintf("topic:%s", topicB), []byte("topic b message 1")) 581 + sendMessage(t, publisherConn, fmt.Sprintf("topic:%s", topicB), []byte("topic b message 2")) 582 + sendMessage(t, publisherConn, fmt.Sprintf("topic:%s", topicB), []byte("topic b message 3")) 583 + 584 + subscriberConn := createConnectionAndSubscribe(t, []string{topicA}, From, 3) 585 + 586 + // now that the subscriber has subecribed send another message that should arrive after all the other messages were consumed 587 + sendMessage(t, publisherConn, topic, []byte("hello there")) 588 + 589 + results := make([]string, 0, len(messages)) 590 + for i := 0; i < len(messages)-3; i++ { 591 + msg := readMessage(t, subscriberConn) 592 + results = append(results, string(msg)) 593 + } 594 + require.Len(t, results, 7) 595 + expMessages := make([]string, 0, 7) 596 + for i, msg := range messages { 597 + if i < 3 { 598 + continue 599 + } 600 + expMessages = append(expMessages, msg) 601 + } 602 + assert.Equal(t, expMessages, results) 603 + 604 + // now check we can get the message that was sent after the subscription was created 605 + msg := readMessage(t, subscriberConn) 606 + assert.Equal(t, "hello there", string(msg)) 607 + } 608 + 609 + func readMessage(t *testing.T, subscriberConn net.Conn) []byte { 610 + var topicLen uint16 611 + err := binary.Read(subscriberConn, binary.BigEndian, &topicLen) 612 + require.NoError(t, err) 613 + 614 + topicBuf := make([]byte, topicLen) 615 + _, err = subscriberConn.Read(topicBuf) 616 + require.NoError(t, err) 617 + assert.Equal(t, topicA, string(topicBuf)) 618 + 619 + var dataLen uint64 620 + err = binary.Read(subscriberConn, binary.BigEndian, &dataLen) 621 + require.NoError(t, err) 622 + 623 + buf := make([]byte, dataLen) 624 + n, err := subscriberConn.Read(buf) 625 + require.NoError(t, err) 626 + require.Equal(t, int(dataLen), n) 627 + 628 + err = binary.Write(subscriberConn, binary.BigEndian, Ack) 629 + require.NoError(t, err) 630 + 631 + return buf 632 + }
+136
internal/server/subscriber.go
··· 1 + package server 2 + 3 + import ( 4 + "encoding/binary" 5 + "fmt" 6 + "log/slog" 7 + "net" 8 + "time" 9 + 10 + "github.com/willdot/messagebroker/internal" 11 + ) 12 + 13 + type subscriber struct { 14 + peer *Peer 15 + topic string 16 + messages chan internal.Message 17 + unsubscribeCh chan struct{} 18 + 19 + ackDelay time.Duration 20 + ackTimeout time.Duration 21 + } 22 + 23 + func newSubscriber(peer *Peer, topic *topic, ackDelay, ackTimeout time.Duration, startAt int) *subscriber { 24 + s := &subscriber{ 25 + peer: peer, 26 + topic: topic.name, 27 + messages: make(chan internal.Message), 28 + ackDelay: ackDelay, 29 + ackTimeout: ackTimeout, 30 + unsubscribeCh: make(chan struct{}, 1), 31 + } 32 + 33 + go s.sendMessages() 34 + 35 + go func() { 36 + topic.messageStore.ReadFrom(startAt, func(msg internal.Message) { 37 + select { 38 + case s.messages <- msg: 39 + return 40 + case <-s.unsubscribeCh: 41 + return 42 + } 43 + }) 44 + }() 45 + 46 + return s 47 + } 48 + 49 + func (s *subscriber) sendMessages() { 50 + for { 51 + select { 52 + case <-s.unsubscribeCh: 53 + return 54 + case msg := <-s.messages: 55 + ack, err := s.sendMessage(s.topic, msg) 56 + if err != nil { 57 + slog.Error("failed to send to message", "error", err, "peer", s.peer.Addr()) 58 + } 59 + 60 + if ack { 61 + continue 62 + } 63 + 64 + if msg.DeliveryCount >= 5 { 65 + slog.Error("max delivery count for message. Dropping", "peer", s.peer.Addr()) 66 + continue 67 + } 68 + 69 + msg.DeliveryCount++ 70 + s.addMessage(msg, s.ackDelay) 71 + } 72 + } 73 + } 74 + 75 + func (s *subscriber) addMessage(msg internal.Message, delay time.Duration) { 76 + go func() { 77 + timer := time.NewTimer(delay) 78 + defer timer.Stop() 79 + 80 + select { 81 + case <-s.unsubscribeCh: 82 + return 83 + case <-timer.C: 84 + s.messages <- msg 85 + } 86 + }() 87 + } 88 + 89 + func (s *subscriber) sendMessage(topic string, msg internal.Message) (bool, error) { 90 + var ack bool 91 + op := func(conn net.Conn) error { 92 + topicB := make([]byte, 2) 93 + binary.BigEndian.PutUint16(topicB, uint16(len(topic))) 94 + 95 + headers := topicB 96 + headers = append(headers, []byte(topic)...) 97 + 98 + // TODO: if message is empty, return error? 99 + dataLenB := make([]byte, 8) 100 + binary.BigEndian.PutUint64(dataLenB, uint64(len(msg.Data))) 101 + headers = append(headers, dataLenB...) 102 + 103 + _, err := conn.Write(append(headers, msg.Data...)) 104 + if err != nil { 105 + return fmt.Errorf("failed to write to peer: %w", err) 106 + } 107 + 108 + if err := conn.SetReadDeadline(time.Now().Add(s.ackTimeout)); err != nil { 109 + slog.Error("failed to set connection read deadline", "error", err, "peer", s.peer.Addr()) 110 + } 111 + defer func() { 112 + if err := conn.SetReadDeadline(time.Time{}); err != nil { 113 + slog.Error("failed to reset connection read deadline", "error", err, "peer", s.peer.Addr()) 114 + } 115 + }() 116 + var ackRes Action 117 + err = binary.Read(conn, binary.BigEndian, &ackRes) 118 + if err != nil { 119 + return fmt.Errorf("failed to read ack from peer: %w", err) 120 + } 121 + 122 + if ackRes == Ack { 123 + ack = true 124 + } 125 + 126 + return nil 127 + } 128 + 129 + err := s.peer.RunConnOperation(op) 130 + 131 + return ack, err 132 + } 133 + 134 + func (s *subscriber) unsubscribe() { 135 + close(s.unsubscribeCh) 136 + }
+62
internal/server/topic.go
··· 1 + package server 2 + 3 + import ( 4 + "fmt" 5 + "net" 6 + "sync" 7 + 8 + "github.com/willdot/messagebroker/internal" 9 + "github.com/willdot/messagebroker/internal/messagestore" 10 + ) 11 + 12 + type Store interface { 13 + Write(msg internal.Message) error 14 + ReadFrom(offset int, handleFunc func(msg internal.Message)) 15 + } 16 + 17 + type topic struct { 18 + name string 19 + subscriptions map[net.Addr]*subscriber 20 + mu sync.Mutex 21 + messageStore Store 22 + } 23 + 24 + func newTopic(name string) *topic { 25 + messageStore := messagestore.NewMemoryStore() 26 + return &topic{ 27 + name: name, 28 + subscriptions: make(map[net.Addr]*subscriber), 29 + messageStore: messageStore, 30 + } 31 + } 32 + 33 + func (t *topic) sendMessageToSubscribers(msg internal.Message) error { 34 + err := t.messageStore.Write(msg) 35 + if err != nil { 36 + return fmt.Errorf("failed to write message to store: %w", err) 37 + } 38 + 39 + t.mu.Lock() 40 + subscribers := t.subscriptions 41 + t.mu.Unlock() 42 + 43 + for _, subscriber := range subscribers { 44 + subscriber.addMessage(msg, 0) 45 + } 46 + 47 + return nil 48 + } 49 + 50 + func (t *topic) findSubscription(addr net.Addr) *subscriber { 51 + t.mu.Lock() 52 + defer t.mu.Unlock() 53 + 54 + return t.subscriptions[addr] 55 + } 56 + 57 + func (t *topic) removeSubscription(addr net.Addr) { 58 + t.mu.Lock() 59 + defer t.mu.Unlock() 60 + 61 + delete(t.subscriptions, addr) 62 + }
-299
server.go
··· 1 - package main 2 - 3 - import ( 4 - "context" 5 - "encoding/binary" 6 - "encoding/json" 7 - "fmt" 8 - "log/slog" 9 - "net" 10 - "strings" 11 - "sync" 12 - ) 13 - 14 - // Action represents the type of action that a connection requests to do 15 - type Action uint8 16 - 17 - const ( 18 - Subscribe Action = 1 19 - Unsubscribe Action = 2 20 - Publish Action = 3 21 - ) 22 - 23 - type message struct { 24 - Topic string `json:"topic"` 25 - Data []byte `json:"data"` 26 - } 27 - 28 - type Server struct { 29 - addr string 30 - lis net.Listener 31 - 32 - mu sync.Mutex 33 - topics map[string]topic 34 - } 35 - 36 - func NewServer(ctx context.Context, addr string) (*Server, error) { 37 - lis, err := net.Listen("tcp", addr) 38 - if err != nil { 39 - return nil, fmt.Errorf("failed to listen: %w", err) 40 - } 41 - 42 - srv := &Server{ 43 - lis: lis, 44 - topics: map[string]topic{}, 45 - } 46 - 47 - go srv.start(ctx) 48 - 49 - return srv, nil 50 - } 51 - 52 - func (s *Server) Shutdown() error { 53 - return s.lis.Close() 54 - } 55 - 56 - func (s *Server) start(ctx context.Context) { 57 - for { 58 - conn, err := s.lis.Accept() 59 - if err != nil { 60 - slog.Error("listener failed to accept", "error", err) 61 - // TODO: see if there's a better way to check for this error 62 - if strings.Contains(err.Error(), "use of closed network connection") { 63 - return 64 - } 65 - } 66 - 67 - go s.handleConn(conn) 68 - } 69 - } 70 - 71 - func getActionFromConn(conn net.Conn) (Action, error) { 72 - var action Action 73 - err := binary.Read(conn, binary.BigEndian, &action) 74 - if err != nil { 75 - return 0, err 76 - } 77 - 78 - return action, nil 79 - } 80 - 81 - func getDataLengthFromConn(conn net.Conn) (uint32, error) { 82 - var dataLen uint32 83 - err := binary.Read(conn, binary.BigEndian, &dataLen) 84 - if err != nil { 85 - return 0, fmt.Errorf("failed to read data length from conn: %w", err) 86 - } 87 - 88 - return dataLen, nil 89 - } 90 - 91 - func (s *Server) handleConn(conn net.Conn) { 92 - action, err := getActionFromConn(conn) 93 - if err != nil { 94 - slog.Error("failed to read action from conn", "error", err, "conn", conn.LocalAddr()) 95 - return 96 - } 97 - 98 - switch action { 99 - case Subscribe: 100 - s.handleSubscribingConn(conn) 101 - case Unsubscribe: 102 - s.handleUnsubscribingConn(conn) 103 - case Publish: 104 - s.handlePublisherConn(conn) 105 - default: 106 - slog.Error("unknown action", "action", action, "conn", conn.LocalAddr()) 107 - _, _ = conn.Write([]byte("unknown action")) 108 - } 109 - } 110 - 111 - func (s *Server) handleSubscribingConn(conn net.Conn) { 112 - // subscribe the connection to the topic 113 - s.subscribeConnToTopic(conn) 114 - 115 - // keep handling the connection, getting the action from the conection when it wishes to do something else. 116 - // once the connection ends, it will be unsubscribed from all topics and returned 117 - for { 118 - action, err := getActionFromConn(conn) 119 - if err != nil { 120 - // TODO: see if there's a way to check if the connection has been ended etc 121 - slog.Error("failed to read action from subscriber", "error", err, "conn", conn.LocalAddr()) 122 - 123 - s.unsubscribeConnectionFromAllTopics(conn.LocalAddr()) 124 - 125 - return 126 - } 127 - 128 - switch action { 129 - case Subscribe: 130 - s.subscribeConnToTopic(conn) 131 - case Unsubscribe: 132 - s.handleUnsubscribingConn(conn) 133 - default: 134 - slog.Error("unknown action for subscriber", "action", action, "conn", conn.LocalAddr()) 135 - continue 136 - } 137 - } 138 - } 139 - 140 - func (s *Server) subscribeConnToTopic(conn net.Conn) { 141 - // get the topics the connection wishes to subscribe to 142 - dataLen, err := getDataLengthFromConn(conn) 143 - if err != nil { 144 - slog.Error(err.Error(), "conn", conn.LocalAddr()) 145 - _, _ = conn.Write([]byte("invalid data length of topics provided")) 146 - return 147 - } 148 - if dataLen == 0 { 149 - _, _ = conn.Write([]byte("data length of topics is 0")) 150 - return 151 - } 152 - 153 - buf := make([]byte, dataLen) 154 - _, err = conn.Read(buf) 155 - if err != nil { 156 - slog.Error("failed to read subscibers topic data", "error", err, "conn", conn.LocalAddr()) 157 - _, _ = conn.Write([]byte("failed to read topic data")) 158 - return 159 - } 160 - 161 - var topics []string 162 - err = json.Unmarshal(buf, &topics) 163 - if err != nil { 164 - slog.Error("failed to unmarshal subscibers topic data", "error", err, "conn", conn.LocalAddr()) 165 - _, _ = conn.Write([]byte("invalid topic data provided")) 166 - return 167 - } 168 - 169 - s.subscribeToTopics(conn, topics) 170 - _, _ = conn.Write([]byte("subscribed")) 171 - } 172 - 173 - func (s *Server) handleUnsubscribingConn(conn net.Conn) { 174 - // get the topics the connection wishes to unsubscribe from 175 - dataLen, err := getDataLengthFromConn(conn) 176 - if err != nil { 177 - slog.Error(err.Error(), "conn", conn.LocalAddr()) 178 - _, _ = conn.Write([]byte("invalid data length of topics provided")) 179 - return 180 - } 181 - if dataLen == 0 { 182 - _, _ = conn.Write([]byte("data length of topics is 0")) 183 - return 184 - } 185 - 186 - buf := make([]byte, dataLen) 187 - _, err = conn.Read(buf) 188 - if err != nil { 189 - slog.Error("failed to read subscibers topic data", "error", err, "conn", conn.LocalAddr()) 190 - _, _ = conn.Write([]byte("failed to read topic data")) 191 - return 192 - } 193 - 194 - var topics []string 195 - err = json.Unmarshal(buf, &topics) 196 - if err != nil { 197 - slog.Error("failed to unmarshal subscibers topic data", "error", err, "conn", conn.LocalAddr()) 198 - _, _ = conn.Write([]byte("invalid topic data provided")) 199 - return 200 - } 201 - 202 - s.unsubscribeToTopics(conn, topics) 203 - 204 - _, _ = conn.Write([]byte("unsubscribed")) 205 - } 206 - 207 - func (s *Server) handlePublisherConn(conn net.Conn) { 208 - dataLen, err := getDataLengthFromConn(conn) 209 - if err != nil { 210 - slog.Error(err.Error(), "conn", conn.LocalAddr()) 211 - _, _ = conn.Write([]byte("invalid data length of data provided")) 212 - return 213 - } 214 - if dataLen == 0 { 215 - return 216 - } 217 - 218 - buf := make([]byte, dataLen) 219 - _, err = conn.Read(buf) 220 - if err != nil { 221 - _, _ = conn.Write([]byte("failed to read data")) 222 - slog.Error("failed to read data from conn", "error", err, "conn", conn.LocalAddr()) 223 - return 224 - } 225 - 226 - var msg message 227 - err = json.Unmarshal(buf, &msg) 228 - if err != nil { 229 - _, _ = conn.Write([]byte("invalid message")) 230 - slog.Error("failed to unmarshal data to message", "error", err, "conn", conn.LocalAddr()) 231 - return 232 - } 233 - 234 - topic := s.getTopic(msg.Topic) 235 - if topic != nil { 236 - topic.sendMessageToSubscribers(msg) 237 - } 238 - } 239 - 240 - func (s *Server) subscribeToTopics(conn net.Conn, topics []string) { 241 - for _, topic := range topics { 242 - s.addSubsciberToTopic(topic, conn) 243 - } 244 - } 245 - 246 - func (s *Server) addSubsciberToTopic(topicName string, conn net.Conn) { 247 - s.mu.Lock() 248 - defer s.mu.Unlock() 249 - 250 - t, ok := s.topics[topicName] 251 - if !ok { 252 - t = newTopic(topicName) 253 - } 254 - 255 - t.subscriptions[conn.LocalAddr()] = Subscriber{ 256 - conn: conn, 257 - currentOffset: 0, 258 - } 259 - 260 - s.topics[topicName] = t 261 - } 262 - 263 - func (s *Server) unsubscribeToTopics(conn net.Conn, topics []string) { 264 - for _, topic := range topics { 265 - s.removeSubsciberFromTopic(topic, conn) 266 - } 267 - } 268 - 269 - func (s *Server) removeSubsciberFromTopic(topicName string, conn net.Conn) { 270 - s.mu.Lock() 271 - defer s.mu.Unlock() 272 - 273 - t, ok := s.topics[topicName] 274 - if !ok { 275 - return 276 - } 277 - 278 - delete(t.subscriptions, conn.LocalAddr()) 279 - } 280 - 281 - func (s *Server) unsubscribeConnectionFromAllTopics(addr net.Addr) { 282 - s.mu.Lock() 283 - defer s.mu.Unlock() 284 - 285 - for _, topic := range s.topics { 286 - delete(topic.subscriptions, addr) 287 - } 288 - } 289 - 290 - func (s *Server) getTopic(topicName string) *topic { 291 - s.mu.Lock() 292 - defer s.mu.Unlock() 293 - 294 - if topic, ok := s.topics[topicName]; ok { 295 - return &topic 296 - } 297 - 298 - return nil 299 - }
-231
server_test.go
··· 1 - package main 2 - 3 - import ( 4 - "context" 5 - "encoding/binary" 6 - "encoding/json" 7 - "net" 8 - "testing" 9 - 10 - "github.com/stretchr/testify/assert" 11 - "github.com/stretchr/testify/require" 12 - ) 13 - 14 - func createServer(t *testing.T) *Server { 15 - srv, err := NewServer(context.Background(), ":3000") 16 - require.NoError(t, err) 17 - 18 - t.Cleanup(func() { 19 - srv.Shutdown() 20 - }) 21 - 22 - return srv 23 - } 24 - 25 - func createServerWithExistingTopic(t *testing.T, topicName string) *Server { 26 - srv := createServer(t) 27 - srv.topics[topicName] = topic{ 28 - name: topicName, 29 - subscriptions: make(map[net.Addr]Subscriber), 30 - } 31 - 32 - return srv 33 - } 34 - 35 - func createConnectionAndSubscribe(t *testing.T, topics []string) net.Conn { 36 - conn, err := net.Dial("tcp", "localhost:3000") 37 - require.NoError(t, err) 38 - 39 - err = binary.Write(conn, binary.BigEndian, Subscribe) 40 - require.NoError(t, err) 41 - 42 - rawTopics, err := json.Marshal(topics) 43 - require.NoError(t, err) 44 - 45 - err = binary.Write(conn, binary.BigEndian, uint32(len(rawTopics))) 46 - require.NoError(t, err) 47 - 48 - _, err = conn.Write(rawTopics) 49 - require.NoError(t, err) 50 - 51 - expectedRes := "subscribed" 52 - 53 - buf := make([]byte, len(expectedRes)) 54 - n, err := conn.Read(buf) 55 - require.NoError(t, err) 56 - require.Equal(t, len(expectedRes), n) 57 - 58 - assert.Equal(t, expectedRes, string(buf)) 59 - 60 - return conn 61 - } 62 - 63 - func TestSubscribeToTopics(t *testing.T) { 64 - // create a server with an existing topic so we can test subscribing to a new and 65 - // existing topic 66 - srv := createServerWithExistingTopic(t, "topic a") 67 - 68 - _ = createConnectionAndSubscribe(t, []string{"topic a", "topic b"}) 69 - 70 - assert.Len(t, srv.topics, 2) 71 - assert.Len(t, srv.topics["topic a"].subscriptions, 1) 72 - assert.Len(t, srv.topics["topic b"].subscriptions, 1) 73 - } 74 - 75 - func TestUnsubscribesFromTopic(t *testing.T) { 76 - srv := createServerWithExistingTopic(t, "topic a") 77 - 78 - conn := createConnectionAndSubscribe(t, []string{"topic a", "topic b", "topic c"}) 79 - 80 - assert.Len(t, srv.topics, 3) 81 - assert.Len(t, srv.topics["topic a"].subscriptions, 1) 82 - assert.Len(t, srv.topics["topic b"].subscriptions, 1) 83 - assert.Len(t, srv.topics["topic c"].subscriptions, 1) 84 - 85 - err := binary.Write(conn, binary.BigEndian, Unsubscribe) 86 - require.NoError(t, err) 87 - 88 - topics := []string{"topic a", "topic b"} 89 - rawTopics, err := json.Marshal(topics) 90 - require.NoError(t, err) 91 - 92 - err = binary.Write(conn, binary.BigEndian, uint32(len(rawTopics))) 93 - require.NoError(t, err) 94 - 95 - _, err = conn.Write(rawTopics) 96 - require.NoError(t, err) 97 - 98 - expectedRes := "unsubscribed" 99 - 100 - buf := make([]byte, len(expectedRes)) 101 - n, err := conn.Read(buf) 102 - require.NoError(t, err) 103 - require.Equal(t, len(expectedRes), n) 104 - 105 - assert.Equal(t, expectedRes, string(buf)) 106 - 107 - assert.Len(t, srv.topics, 3) 108 - assert.Len(t, srv.topics["topic a"].subscriptions, 0) 109 - assert.Len(t, srv.topics["topic b"].subscriptions, 0) 110 - assert.Len(t, srv.topics["topic c"].subscriptions, 1) 111 - } 112 - 113 - func TestSubscriberClosesWithoutUnsubscribing(t *testing.T) { 114 - srv := createServer(t) 115 - 116 - conn := createConnectionAndSubscribe(t, []string{"topic a", "topic b"}) 117 - 118 - assert.Len(t, srv.topics, 2) 119 - assert.Len(t, srv.topics["topic a"].subscriptions, 1) 120 - assert.Len(t, srv.topics["topic b"].subscriptions, 1) 121 - 122 - // close the conn 123 - err := conn.Close() 124 - require.NoError(t, err) 125 - 126 - publisherConn, err := net.Dial("tcp", "localhost:3000") 127 - require.NoError(t, err) 128 - 129 - err = binary.Write(publisherConn, binary.BigEndian, Publish) 130 - require.NoError(t, err) 131 - 132 - data := []byte("hello world") 133 - // send data length first 134 - err = binary.Write(publisherConn, binary.BigEndian, uint32(len(data))) 135 - require.NoError(t, err) 136 - n, err := publisherConn.Write(data) 137 - require.NoError(t, err) 138 - require.Equal(t, len(data), n) 139 - 140 - assert.Len(t, srv.topics, 2) 141 - assert.Len(t, srv.topics["topic a"].subscriptions, 0) 142 - assert.Len(t, srv.topics["topic b"].subscriptions, 0) 143 - } 144 - 145 - func TestInvalidAction(t *testing.T) { 146 - _ = createServer(t) 147 - 148 - conn, err := net.Dial("tcp", "localhost:3000") 149 - require.NoError(t, err) 150 - 151 - err = binary.Write(conn, binary.BigEndian, uint8(99)) 152 - require.NoError(t, err) 153 - 154 - expectedRes := "unknown action" 155 - 156 - buf := make([]byte, len(expectedRes)) 157 - n, err := conn.Read(buf) 158 - require.NoError(t, err) 159 - require.Equal(t, len(expectedRes), n) 160 - 161 - assert.Equal(t, expectedRes, string(buf)) 162 - } 163 - 164 - func TestInvalidMessagePublished(t *testing.T) { 165 - _ = createServer(t) 166 - 167 - publisherConn, err := net.Dial("tcp", "localhost:3000") 168 - require.NoError(t, err) 169 - 170 - err = binary.Write(publisherConn, binary.BigEndian, Publish) 171 - require.NoError(t, err) 172 - 173 - // send some data 174 - data := []byte("this isn't wrapped in a message type") 175 - 176 - // send data length first 177 - err = binary.Write(publisherConn, binary.BigEndian, uint32(len(data))) 178 - require.NoError(t, err) 179 - n, err := publisherConn.Write(data) 180 - require.NoError(t, err) 181 - require.Equal(t, len(data), n) 182 - 183 - buf := make([]byte, 15) 184 - _, err = publisherConn.Read(buf) 185 - require.NoError(t, err) 186 - assert.Equal(t, "invalid message", string(buf)) 187 - } 188 - 189 - func TestSendsDataToTopicSubscribers(t *testing.T) { 190 - _ = createServer(t) 191 - 192 - subscribers := make([]net.Conn, 0, 5) 193 - for i := 0; i < 5; i++ { 194 - subscriberConn := createConnectionAndSubscribe(t, []string{"topic a", "topic b"}) 195 - 196 - subscribers = append(subscribers, subscriberConn) 197 - } 198 - 199 - publisherConn, err := net.Dial("tcp", "localhost:3000") 200 - require.NoError(t, err) 201 - 202 - err = binary.Write(publisherConn, binary.BigEndian, Publish) 203 - require.NoError(t, err) 204 - 205 - // send some data 206 - data := []byte("hello world") 207 - msg := message{ 208 - Topic: "topic a", 209 - Data: data, 210 - } 211 - 212 - rawMsg, err := json.Marshal(msg) 213 - require.NoError(t, err) 214 - 215 - // send data length first 216 - err = binary.Write(publisherConn, binary.BigEndian, uint32(len(rawMsg))) 217 - require.NoError(t, err) 218 - n, err := publisherConn.Write(rawMsg) 219 - require.NoError(t, err) 220 - require.Equal(t, len(rawMsg), n) 221 - 222 - // check the subsribers got the data 223 - for _, conn := range subscribers { 224 - buf := make([]byte, len(data)) 225 - n, err := conn.Read(buf) 226 - require.NoError(t, err) 227 - require.Equal(t, len(data), n) 228 - 229 - assert.Equal(t, data, buf) 230 - } 231 - }
-19
subscriber.go
··· 1 - package main 2 - 3 - import ( 4 - "fmt" 5 - "net" 6 - ) 7 - 8 - type Subscriber struct { 9 - conn net.Conn 10 - currentOffset int 11 - } 12 - 13 - func (s *Subscriber) SendMessage(data []byte) error { 14 - _, err := s.conn.Write(data) 15 - if err != nil { 16 - return fmt.Errorf("failed to write to connection: %w", err) 17 - } 18 - return nil 19 - }
-50
topic.go
··· 1 - package main 2 - 3 - import ( 4 - "log/slog" 5 - "net" 6 - "sync" 7 - ) 8 - 9 - type topic struct { 10 - name string 11 - subscriptions map[net.Addr]Subscriber 12 - mu sync.Mutex 13 - } 14 - 15 - func newTopic(name string) topic { 16 - return topic{ 17 - name: name, 18 - subscriptions: make(map[net.Addr]Subscriber), 19 - } 20 - } 21 - 22 - func (t *topic) addSubscriber(conn net.Conn) { 23 - t.mu.Lock() 24 - defer t.mu.Unlock() 25 - 26 - slog.Info("adding subscriber", "conn", conn.LocalAddr()) 27 - t.subscriptions[conn.LocalAddr()] = Subscriber{conn: conn} 28 - } 29 - 30 - func (t *topic) removeSubscriber(addr net.Addr) { 31 - t.mu.Lock() 32 - defer t.mu.Unlock() 33 - 34 - slog.Info("removing subscriber", "conn", addr) 35 - delete(t.subscriptions, addr) 36 - } 37 - 38 - func (t *topic) sendMessageToSubscribers(msg message) { 39 - t.mu.Lock() 40 - subscribers := t.subscriptions 41 - t.mu.Unlock() 42 - 43 - for addr, subscriber := range subscribers { 44 - err := subscriber.SendMessage(msg.Data) 45 - if err != nil { 46 - slog.Error("failed to send to message", "error", err, "conn", addr) 47 - continue 48 - } 49 - } 50 - }