An experimental pub/sub client and server project.

Compare changes

Choose any two refs to compare.

+37
README.md
··· 1 + # Message Broker 2 + 3 + Message broker is my attempt at building client / server pub/sub system written in Go using plain `net.Conn`. 4 + 5 + I decided to try to build one to further my understanding of how such a common tool is used in software engineering. 6 + 7 + I realised that I took for granted how complicated other pub / sub message brokers were such as NATs, RabbitMQ and Kafka. By creating my own, I hope to dive into and understand more of how message brokers work. 8 + 9 + ## Concept 10 + 11 + The server accepts TCP connections. When a connection is first established, the client will send an "action" message which determines what type of client it is; subscribe or publish. 12 + 13 + ### Subscribing 14 + Once a connection has declared itself as a subscribing connection, it will need to then send the list of topics it wishes to initally subscribe to. After that the connection will enter a loop where it can then send a new action; subscribe to new topic(s) or unsubscribe from topic(s). 15 + 16 + ### Publishing 17 + Once a subscription has declared itself as a publisher, it will enter a loop where it can then send a message for a topic. Once a message has been received, the server will then send it to all connections that are subscribed to that topic. 18 + 19 + ### Sending data via a connection 20 + 21 + When sending a message representing an action (subscribe, publish etc) then a uint16 binary message is sent. 22 + 23 + When sending any other data, the length of the data is to be sent first using a binary uint32 and then the actual data sent afterwards. 24 + 25 + ## Running the server 26 + 27 + There is a server that can be run using `docker-compose up message-server`. This will start a server running listening on port 3000. 28 + 29 + ## Example clients 30 + There is an example application that implements the subscriber and publishers in the `example` directory. 31 + 32 + Run `go build .` to build the file. 33 + 34 + When running the example there are the following flags: 35 + 36 + `publish` : settings this to true will allow messages to be sent every 500ms as well as consuming 37 + `consume-from` : this allows you to specify what message to start from. If you don't set this or set it to be -1, you will start consuming from the next sent message.
+23
client/message.go
··· 1 + package client 2 + 3 + // Message represents a message that can be published or consumed 4 + type Message struct { 5 + Topic string `json:"topic"` 6 + Data []byte `json:"data"` 7 + 8 + ack chan bool 9 + } 10 + 11 + // NewMessage creates a new message 12 + func NewMessage(topic string, data []byte) *Message { 13 + return &Message{ 14 + Topic: topic, 15 + Data: data, 16 + ack: make(chan bool), 17 + } 18 + } 19 + 20 + // Ack will send the provided value of the ack to the server 21 + func (m *Message) Ack(ack bool) { 22 + m.ack <- ack 23 + }
+113
client/publisher.go
··· 1 + package client 2 + 3 + import ( 4 + "encoding/binary" 5 + "errors" 6 + "fmt" 7 + "log/slog" 8 + "net" 9 + "sync" 10 + "syscall" 11 + 12 + "github.com/willdot/messagebroker/internal/server" 13 + ) 14 + 15 + // Publisher allows messages to be published to a server 16 + type Publisher struct { 17 + conn net.Conn 18 + connMu sync.Mutex 19 + addr string 20 + } 21 + 22 + // NewPublisher connects to the server at the given address and registers as a publisher 23 + func NewPublisher(addr string) (*Publisher, error) { 24 + conn, err := connect(addr) 25 + if err != nil { 26 + return nil, fmt.Errorf("failed to connect to server: %w", err) 27 + } 28 + 29 + return &Publisher{ 30 + conn: conn, 31 + addr: addr, 32 + }, nil 33 + } 34 + 35 + func connect(addr string) (net.Conn, error) { 36 + conn, err := net.Dial("tcp", addr) 37 + if err != nil { 38 + return nil, fmt.Errorf("failed to dial: %w", err) 39 + } 40 + 41 + err = binary.Write(conn, binary.BigEndian, server.Publish) 42 + if err != nil { 43 + conn.Close() 44 + return nil, fmt.Errorf("failed to register publish to server: %w", err) 45 + } 46 + return conn, nil 47 + } 48 + 49 + // Close cleanly shuts down the publisher 50 + func (p *Publisher) Close() error { 51 + return p.conn.Close() 52 + } 53 + 54 + // Publish will publish the given message to the server 55 + func (p *Publisher) PublishMessage(message *Message) error { 56 + return p.publishMessageWithRetry(message, 0) 57 + } 58 + 59 + func (p *Publisher) publishMessageWithRetry(message *Message, attempt int) error { 60 + op := func(conn net.Conn) error { 61 + // send topic first 62 + topic := fmt.Sprintf("topic:%s", message.Topic) 63 + 64 + topicLenB := make([]byte, 2) 65 + binary.BigEndian.PutUint16(topicLenB, uint16(len(topic))) 66 + 67 + headers := append(topicLenB, []byte(topic)...) 68 + 69 + messageLenB := make([]byte, 4) 70 + binary.BigEndian.PutUint32(messageLenB, uint32(len(message.Data))) 71 + headers = append(headers, messageLenB...) 72 + 73 + _, err := conn.Write(append(headers, message.Data...)) 74 + if err != nil { 75 + return fmt.Errorf("failed to publish data to server: %w", err) 76 + } 77 + return nil 78 + } 79 + 80 + err := p.connOperation(op) 81 + if err == nil { 82 + return nil 83 + } 84 + 85 + // we can handle a broken pipe by trying to reconnect, but if it's a different error return it 86 + if !errors.Is(err, syscall.EPIPE) { 87 + return err 88 + } 89 + 90 + slog.Info("error is broken pipe") 91 + 92 + if attempt >= 5 { 93 + return fmt.Errorf("failed to publish message after max attempts to reconnect (%d): %w", attempt, err) 94 + } 95 + 96 + slog.Error("failed to publish message", "error", err) 97 + 98 + conn, connectErr := connect(p.addr) 99 + if connectErr != nil { 100 + return fmt.Errorf("failed to reconnect after failing to publish message: %w", connectErr) 101 + } 102 + 103 + p.conn = conn 104 + 105 + return p.publishMessageWithRetry(message, attempt+1) 106 + } 107 + 108 + func (p *Publisher) connOperation(op connOpp) error { 109 + p.connMu.Lock() 110 + defer p.connMu.Unlock() 111 + 112 + return op(p.conn) 113 + }
+371
client/subscriber.go
··· 1 + package client 2 + 3 + import ( 4 + "context" 5 + "encoding/binary" 6 + "encoding/json" 7 + "errors" 8 + "fmt" 9 + "io" 10 + "log/slog" 11 + "net" 12 + "sync" 13 + "syscall" 14 + "time" 15 + 16 + "github.com/willdot/messagebroker/internal/server" 17 + ) 18 + 19 + type connOpp func(conn net.Conn) error 20 + 21 + // Subscriber allows subscriptions to a server and the consumption of messages 22 + type Subscriber struct { 23 + conn net.Conn 24 + connMu sync.Mutex 25 + subscribedTopics []string 26 + addr string 27 + } 28 + 29 + // NewSubscriber will connect to the server at the given address 30 + func NewSubscriber(addr string) (*Subscriber, error) { 31 + conn, err := net.Dial("tcp", addr) 32 + if err != nil { 33 + return nil, fmt.Errorf("failed to dial: %w", err) 34 + } 35 + 36 + return &Subscriber{ 37 + conn: conn, 38 + addr: addr, 39 + }, nil 40 + } 41 + 42 + func (s *Subscriber) reconnect() error { 43 + conn, err := net.Dial("tcp", s.addr) 44 + if err != nil { 45 + return fmt.Errorf("failed to dial: %w", err) 46 + } 47 + 48 + s.conn = conn 49 + return nil 50 + } 51 + 52 + // Close cleanly shuts down the subscriber 53 + func (s *Subscriber) Close() error { 54 + return s.conn.Close() 55 + } 56 + 57 + // SubscribeToTopics will subscribe to the provided topics 58 + func (s *Subscriber) SubscribeToTopics(topicNames []string, startAtType server.StartAtType, startAtIndex int) error { 59 + op := func(conn net.Conn) error { 60 + return subscribeToTopics(conn, topicNames, startAtType, startAtIndex) 61 + } 62 + 63 + err := s.connOperation(op) 64 + if err != nil { 65 + return fmt.Errorf("failed to subscribe to topics: %w", err) 66 + } 67 + 68 + s.addToSubscribedTopics(topicNames) 69 + 70 + return nil 71 + } 72 + 73 + func (s *Subscriber) addToSubscribedTopics(topics []string) { 74 + existingSubs := make(map[string]struct{}) 75 + for _, topic := range s.subscribedTopics { 76 + existingSubs[topic] = struct{}{} 77 + } 78 + 79 + for _, topic := range topics { 80 + existingSubs[topic] = struct{}{} 81 + } 82 + 83 + subs := make([]string, 0, len(existingSubs)) 84 + for topic := range existingSubs { 85 + subs = append(subs, topic) 86 + } 87 + 88 + s.subscribedTopics = subs 89 + } 90 + 91 + func (s *Subscriber) removeTopicsFromSubscription(topics []string) { 92 + existingSubs := make(map[string]struct{}) 93 + for _, topic := range s.subscribedTopics { 94 + existingSubs[topic] = struct{}{} 95 + } 96 + 97 + for _, topic := range topics { 98 + delete(existingSubs, topic) 99 + } 100 + 101 + subs := make([]string, 0, len(existingSubs)) 102 + for topic := range existingSubs { 103 + subs = append(subs, topic) 104 + } 105 + 106 + s.subscribedTopics = subs 107 + } 108 + 109 + // UnsubscribeToTopics will unsubscribe to the provided topics 110 + func (s *Subscriber) UnsubscribeToTopics(topicNames []string) error { 111 + op := func(conn net.Conn) error { 112 + return unsubscribeToTopics(conn, topicNames) 113 + } 114 + 115 + err := s.connOperation(op) 116 + if err != nil { 117 + return fmt.Errorf("failed to unsubscribe to topics: %w", err) 118 + } 119 + 120 + s.removeTopicsFromSubscription(topicNames) 121 + 122 + return nil 123 + } 124 + 125 + func subscribeToTopics(conn net.Conn, topicNames []string, startAtType server.StartAtType, startAtIndex int) error { 126 + actionB := make([]byte, 2) 127 + binary.BigEndian.PutUint16(actionB, uint16(server.Subscribe)) 128 + headers := actionB 129 + 130 + b, err := json.Marshal(topicNames) 131 + if err != nil { 132 + return fmt.Errorf("failed to marshal topic names: %w", err) 133 + } 134 + 135 + topicNamesB := make([]byte, 4) 136 + binary.BigEndian.PutUint32(topicNamesB, uint32(len(b))) 137 + headers = append(headers, topicNamesB...) 138 + headers = append(headers, b...) 139 + 140 + startAtTypeB := make([]byte, 2) 141 + binary.BigEndian.PutUint16(startAtTypeB, uint16(startAtType)) 142 + headers = append(headers, startAtTypeB...) 143 + 144 + if startAtType == server.From { 145 + fromB := make([]byte, 2) 146 + binary.BigEndian.PutUint16(fromB, uint16(startAtIndex)) 147 + headers = append(headers, fromB...) 148 + } 149 + 150 + _, err = conn.Write(headers) 151 + if err != nil { 152 + return fmt.Errorf("failed to subscribe to topics: %w", err) 153 + } 154 + 155 + var resp server.Status 156 + err = binary.Read(conn, binary.BigEndian, &resp) 157 + if err != nil { 158 + return fmt.Errorf("failed to read confirmation of subscribe: %w", err) 159 + } 160 + 161 + if resp == server.Subscribed { 162 + return nil 163 + } 164 + 165 + var dataLen uint16 166 + err = binary.Read(conn, binary.BigEndian, &dataLen) 167 + if err != nil { 168 + return fmt.Errorf("received status %s:", resp) 169 + } 170 + 171 + buf := make([]byte, dataLen) 172 + _, err = conn.Read(buf) 173 + if err != nil { 174 + return fmt.Errorf("received status %s:", resp) 175 + } 176 + 177 + return fmt.Errorf("received status %s - %s", resp, buf) 178 + } 179 + 180 + func unsubscribeToTopics(conn net.Conn, topicNames []string) error { 181 + actionB := make([]byte, 2) 182 + binary.BigEndian.PutUint16(actionB, uint16(server.Unsubscribe)) 183 + headers := actionB 184 + 185 + b, err := json.Marshal(topicNames) 186 + if err != nil { 187 + return fmt.Errorf("failed to marshal topic names: %w", err) 188 + } 189 + 190 + topicNamesB := make([]byte, 4) 191 + binary.BigEndian.PutUint32(topicNamesB, uint32(len(b))) 192 + headers = append(headers, topicNamesB...) 193 + 194 + _, err = conn.Write(append(headers, b...)) 195 + if err != nil { 196 + return fmt.Errorf("failed to unsubscribe to topics: %w", err) 197 + } 198 + 199 + var resp server.Status 200 + err = binary.Read(conn, binary.BigEndian, &resp) 201 + if err != nil { 202 + return fmt.Errorf("failed to read confirmation of unsubscribe: %w", err) 203 + } 204 + 205 + if resp == server.Unsubscribed { 206 + return nil 207 + } 208 + 209 + var dataLen uint16 210 + err = binary.Read(conn, binary.BigEndian, &dataLen) 211 + if err != nil { 212 + return fmt.Errorf("received status %s:", resp) 213 + } 214 + 215 + buf := make([]byte, dataLen) 216 + _, err = conn.Read(buf) 217 + if err != nil { 218 + return fmt.Errorf("received status %s:", resp) 219 + } 220 + 221 + return fmt.Errorf("received status %s - %s", resp, buf) 222 + } 223 + 224 + // Consumer allows the consumption of messages. If during the consumer receiving messages from the 225 + // server an error occurs, it will be stored in Err 226 + type Consumer struct { 227 + msgs chan *Message 228 + // TODO: better error handling? Maybe a channel of errors? 229 + Err error 230 + } 231 + 232 + // Messages returns a channel in which this consumer will put messages onto. It is safe to range over the channel since it will be closed once 233 + // the consumer has finished either due to an error or from being cancelled. 234 + func (c *Consumer) Messages() <-chan *Message { 235 + return c.msgs 236 + } 237 + 238 + // Consume will create a consumer and start it running in a go routine. You can then use the Msgs channel of the consumer 239 + // to read the messages 240 + func (s *Subscriber) Consume(ctx context.Context) *Consumer { 241 + consumer := &Consumer{ 242 + msgs: make(chan *Message), 243 + } 244 + 245 + go s.consume(ctx, consumer) 246 + 247 + return consumer 248 + } 249 + 250 + func (s *Subscriber) consume(ctx context.Context, consumer *Consumer) { 251 + defer close(consumer.msgs) 252 + for { 253 + if ctx.Err() != nil { 254 + return 255 + } 256 + 257 + err := s.readMessage(ctx, consumer.msgs) 258 + if err == nil { 259 + continue 260 + } 261 + 262 + // if we couldn't connect to the server, attempt to reconnect 263 + if !errors.Is(err, syscall.EPIPE) && !errors.Is(err, io.EOF) { 264 + slog.Error("failed to read message", "error", err) 265 + consumer.Err = err 266 + return 267 + } 268 + 269 + slog.Info("attempting to reconnect") 270 + 271 + for i := 0; i < 5; i++ { 272 + time.Sleep(time.Millisecond * 500) 273 + err = s.reconnect() 274 + if err == nil { 275 + break 276 + } 277 + 278 + slog.Error("Failed to reconnect", "error", err, "attempt", i) 279 + } 280 + 281 + slog.Info("attempting to resubscribe") 282 + 283 + err = s.SubscribeToTopics(s.subscribedTopics, server.Current, 0) 284 + if err != nil { 285 + consumer.Err = fmt.Errorf("failed to subscribe to topics after reconnecting: %w", err) 286 + return 287 + } 288 + 289 + } 290 + } 291 + 292 + func (s *Subscriber) readMessage(ctx context.Context, msgChan chan *Message) error { 293 + op := func(conn net.Conn) error { 294 + err := s.conn.SetReadDeadline(time.Now().Add(time.Millisecond * 300)) 295 + if err != nil { 296 + return err 297 + } 298 + 299 + var topicLen uint16 300 + err = binary.Read(s.conn, binary.BigEndian, &topicLen) 301 + if err != nil { 302 + // TODO: check if this is needed elsewhere. I'm not sure where the read deadline resets.... 303 + if neterr, ok := err.(net.Error); ok && neterr.Timeout() { 304 + return nil 305 + } 306 + return err 307 + } 308 + 309 + topicBuf := make([]byte, topicLen) 310 + _, err = s.conn.Read(topicBuf) 311 + if err != nil { 312 + return err 313 + } 314 + 315 + var dataLen uint64 316 + err = binary.Read(s.conn, binary.BigEndian, &dataLen) 317 + if err != nil { 318 + return err 319 + } 320 + 321 + if dataLen <= 0 { 322 + return nil 323 + } 324 + 325 + dataBuf := make([]byte, dataLen) 326 + _, err = s.conn.Read(dataBuf) 327 + if err != nil { 328 + return err 329 + } 330 + 331 + msg := NewMessage(string(topicBuf), dataBuf) 332 + 333 + msgChan <- msg 334 + 335 + var ack bool 336 + select { 337 + case <-ctx.Done(): 338 + return ctx.Err() 339 + case ack = <-msg.ack: 340 + } 341 + ackMessage := server.Nack 342 + if ack { 343 + ackMessage = server.Ack 344 + } 345 + 346 + err = binary.Write(s.conn, binary.BigEndian, ackMessage) 347 + if err != nil { 348 + return fmt.Errorf("failed to ack/nack message: %w", err) 349 + } 350 + 351 + return nil 352 + } 353 + 354 + err := s.connOperation(op) 355 + if err != nil { 356 + var neterr net.Error 357 + if errors.As(err, &neterr) && neterr.Timeout() { 358 + return nil 359 + } 360 + return err 361 + } 362 + 363 + return err 364 + } 365 + 366 + func (s *Subscriber) connOperation(op connOpp) error { 367 + s.connMu.Lock() 368 + defer s.connMu.Unlock() 369 + 370 + return op(s.conn) 371 + }
+273
client/subscriber_test.go
··· 1 + package client 2 + 3 + import ( 4 + "context" 5 + "fmt" 6 + "testing" 7 + "time" 8 + 9 + "github.com/stretchr/testify/assert" 10 + "github.com/stretchr/testify/require" 11 + "github.com/willdot/messagebroker/internal/server" 12 + ) 13 + 14 + const ( 15 + serverAddr = ":9999" 16 + topicA = "topic a" 17 + topicB = "topic b" 18 + ) 19 + 20 + func createServer(t *testing.T) { 21 + server, err := server.New(serverAddr, time.Millisecond*100, time.Millisecond*100) 22 + require.NoError(t, err) 23 + 24 + t.Cleanup(func() { 25 + _ = server.Shutdown() 26 + }) 27 + } 28 + 29 + func TestNewSubscriber(t *testing.T) { 30 + createServer(t) 31 + 32 + sub, err := NewSubscriber(serverAddr) 33 + require.NoError(t, err) 34 + 35 + t.Cleanup(func() { 36 + sub.Close() 37 + }) 38 + } 39 + 40 + func TestNewSubscriberInvalidServerAddr(t *testing.T) { 41 + createServer(t) 42 + 43 + _, err := NewSubscriber(":123456") 44 + require.Error(t, err) 45 + } 46 + 47 + func TestNewPublisher(t *testing.T) { 48 + createServer(t) 49 + 50 + sub, err := NewPublisher(serverAddr) 51 + require.NoError(t, err) 52 + 53 + t.Cleanup(func() { 54 + sub.Close() 55 + }) 56 + } 57 + 58 + func TestNewPublisherInvalidServerAddr(t *testing.T) { 59 + createServer(t) 60 + 61 + _, err := NewPublisher(":123456") 62 + require.Error(t, err) 63 + } 64 + 65 + func TestSubscribeToTopics(t *testing.T) { 66 + createServer(t) 67 + 68 + sub, err := NewSubscriber(serverAddr) 69 + require.NoError(t, err) 70 + 71 + t.Cleanup(func() { 72 + sub.Close() 73 + }) 74 + 75 + topics := []string{topicA, topicB} 76 + 77 + err = sub.SubscribeToTopics(topics, server.Current, 0) 78 + require.NoError(t, err) 79 + } 80 + 81 + func TestUnsubscribesFromTopic(t *testing.T) { 82 + createServer(t) 83 + 84 + sub, err := NewSubscriber(serverAddr) 85 + require.NoError(t, err) 86 + 87 + t.Cleanup(func() { 88 + sub.Close() 89 + }) 90 + 91 + topics := []string{topicA, topicB} 92 + 93 + err = sub.SubscribeToTopics(topics, server.Current, 0) 94 + require.NoError(t, err) 95 + 96 + err = sub.UnsubscribeToTopics([]string{topicA}) 97 + require.NoError(t, err) 98 + 99 + ctx, cancel := context.WithCancel(context.Background()) 100 + t.Cleanup(func() { 101 + cancel() 102 + }) 103 + 104 + consumer := sub.Consume(ctx) 105 + require.NoError(t, err) 106 + 107 + var receivedMessages []*Message 108 + consumerFinCh := make(chan struct{}) 109 + go func() { 110 + for msg := range consumer.Messages() { 111 + msg.Ack(true) 112 + receivedMessages = append(receivedMessages, msg) 113 + } 114 + 115 + consumerFinCh <- struct{}{} 116 + }() 117 + 118 + // publish a message to both topics and check the subscriber only gets the message from the 1 topic 119 + // and not the unsubscribed topic 120 + 121 + publisher, err := NewPublisher("localhost:9999") 122 + require.NoError(t, err) 123 + t.Cleanup(func() { 124 + publisher.Close() 125 + }) 126 + 127 + msg := NewMessage(topicA, []byte("hello world")) 128 + 129 + err = publisher.PublishMessage(msg) 130 + require.NoError(t, err) 131 + 132 + msg.Topic = topicB 133 + err = publisher.PublishMessage(msg) 134 + require.NoError(t, err) 135 + 136 + // give the consumer some time to read the messages -- TODO: make better! 137 + time.Sleep(time.Millisecond * 300) 138 + cancel() 139 + 140 + select { 141 + case <-consumerFinCh: 142 + break 143 + case <-time.After(time.Second): 144 + t.Fatal("timed out waiting for consumer to read messages") 145 + } 146 + 147 + assert.Len(t, receivedMessages, 1) 148 + assert.Equal(t, topicB, receivedMessages[0].Topic) 149 + } 150 + 151 + func TestPublishAndSubscribe(t *testing.T) { 152 + consumer, cancel := setupConsumer(t) 153 + 154 + var receivedMessages []*Message 155 + 156 + consumerFinCh := make(chan struct{}) 157 + go func() { 158 + for msg := range consumer.Messages() { 159 + msg.Ack(true) 160 + receivedMessages = append(receivedMessages, msg) 161 + } 162 + 163 + consumerFinCh <- struct{}{} 164 + }() 165 + 166 + publisher, err := NewPublisher("localhost:9999") 167 + require.NoError(t, err) 168 + t.Cleanup(func() { 169 + publisher.Close() 170 + }) 171 + 172 + // send some messages 173 + sentMessages := make([]*Message, 0, 10) 174 + for i := 0; i < 10; i++ { 175 + msg := NewMessage(topicA, []byte(fmt.Sprintf("message %d", i))) 176 + 177 + sentMessages = append(sentMessages, msg) 178 + 179 + err = publisher.PublishMessage(msg) 180 + require.NoError(t, err) 181 + } 182 + 183 + // give the consumer some time to read the messages -- TODO: make better! 184 + time.Sleep(time.Millisecond * 300) 185 + cancel() 186 + 187 + select { 188 + case <-consumerFinCh: 189 + break 190 + case <-time.After(time.Second * 5): 191 + t.Fatal("timed out waiting for consumer to read messages") 192 + } 193 + 194 + // THIS IS SO HACKY 195 + for _, msg := range receivedMessages { 196 + msg.ack = nil 197 + } 198 + 199 + for _, msg := range sentMessages { 200 + msg.ack = nil 201 + } 202 + 203 + assert.ElementsMatch(t, receivedMessages, sentMessages) 204 + } 205 + 206 + func TestPublishAndSubscribeNackMessage(t *testing.T) { 207 + consumer, cancel := setupConsumer(t) 208 + 209 + var receivedMessages []*Message 210 + 211 + consumerFinCh := make(chan struct{}) 212 + timesMsgWasReceived := 0 213 + go func() { 214 + for msg := range consumer.Messages() { 215 + msg.Ack(false) 216 + timesMsgWasReceived++ 217 + } 218 + 219 + consumerFinCh <- struct{}{} 220 + }() 221 + 222 + publisher, err := NewPublisher("localhost:9999") 223 + require.NoError(t, err) 224 + t.Cleanup(func() { 225 + publisher.Close() 226 + }) 227 + 228 + // send a message 229 + msg := NewMessage(topicA, []byte("hello world")) 230 + 231 + err = publisher.PublishMessage(msg) 232 + require.NoError(t, err) 233 + 234 + // give the consumer some time to read the messages -- TODO: make better! 235 + time.Sleep(time.Second) 236 + cancel() 237 + 238 + select { 239 + case <-consumerFinCh: 240 + break 241 + case <-time.After(time.Second * 5): 242 + t.Fatal("timed out waiting for consumer to read messages") 243 + } 244 + 245 + assert.Empty(t, receivedMessages) 246 + assert.Equal(t, 5, timesMsgWasReceived) 247 + } 248 + 249 + func setupConsumer(t *testing.T) (*Consumer, context.CancelFunc) { 250 + createServer(t) 251 + 252 + sub, err := NewSubscriber(serverAddr) 253 + require.NoError(t, err) 254 + 255 + t.Cleanup(func() { 256 + sub.Close() 257 + }) 258 + 259 + topics := []string{topicA, topicB} 260 + 261 + err = sub.SubscribeToTopics(topics, server.Current, 0) 262 + require.NoError(t, err) 263 + 264 + ctx, cancel := context.WithCancel(context.Background()) 265 + t.Cleanup(func() { 266 + cancel() 267 + }) 268 + 269 + consumer := sub.Consume(ctx) 270 + require.NoError(t, err) 271 + 272 + return consumer, cancel 273 + }
+27
cmd/server/main.go
··· 1 + package main 2 + 3 + import ( 4 + "log" 5 + "os" 6 + "os/signal" 7 + "syscall" 8 + "time" 9 + 10 + "github.com/willdot/messagebroker/internal/server" 11 + ) 12 + 13 + func main() { 14 + srv, err := server.New(":3000", time.Second, time.Second*2) 15 + if err != nil { 16 + log.Fatal(err) 17 + } 18 + 19 + defer func() { 20 + _ = srv.Shutdown() 21 + }() 22 + 23 + signals := make(chan os.Signal, 1) 24 + signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) 25 + 26 + <-signals 27 + }
+7
docker-compose.yaml
··· 1 + version: "3.7" 2 + services: 3 + message-server: 4 + build: 5 + context: . 6 + dockerfile: dockerfile.server 7 + ports: [ "3000:3000" ]
-20
dockerfile.example-server
··· 1 - FROM golang:latest as builder 2 - 3 - WORKDIR /app 4 - 5 - COPY go.mod go.sum ./ 6 - COPY example/server/ ./ 7 - RUN go mod download 8 - 9 - COPY . . 10 - 11 - RUN CGO_ENABLED=0 go build -o message-broker-server . 12 - 13 - FROM alpine:latest 14 - 15 - RUN apk --no-cache add ca-certificates 16 - 17 - WORKDIR /root/ 18 - COPY --from=builder /app/message-broker-server . 19 - 20 - CMD ["./message-broker-server"]
+20
dockerfile.server
··· 1 + FROM golang:latest as builder 2 + 3 + WORKDIR /app 4 + 5 + COPY go.mod go.sum ./ 6 + COPY cmd/server/ ./ 7 + RUN go mod download 8 + 9 + COPY . . 10 + 11 + RUN CGO_ENABLED=0 go build -o message-broker-server . 12 + 13 + FROM alpine:latest 14 + 15 + RUN apk --no-cache add ca-certificates 16 + 17 + WORKDIR /root/ 18 + COPY --from=builder /app/message-broker-server . 19 + 20 + CMD ["./message-broker-server"]
+36 -13
example/main.go
··· 7 7 "log/slog" 8 8 "time" 9 9 10 - "github.com/willdot/messagebroker/pubsub" 10 + "github.com/willdot/messagebroker/client" 11 + "github.com/willdot/messagebroker/internal/server" 11 12 ) 12 13 13 - var consumeOnly *bool 14 + // var publish *bool 15 + // var consume *bool 16 + var consumeFrom *int 17 + var clientType *string 18 + 19 + const ( 20 + topic = "topic-a" 21 + ) 14 22 15 23 func main() { 16 - consumeOnly = flag.Bool("consume-only", false, "just consumes (doesn't start server and doesn't publish)") 24 + clientType = flag.String("client-type", "consume", "consume or publish (default consume)") 25 + // publish = flag.Bool("publish", false, "will publish messages every 500ms until client is stopped") 26 + // consume = flag.Bool("consume", false, "will consume messages until client is stopped") 27 + consumeFrom = flag.Int("consume-from", -1, "index of message to start consuming from. If not set it will consume from the most recent") 17 28 flag.Parse() 18 29 19 - if !*consumeOnly { 20 - go sendMessages() 30 + switch *clientType { 31 + case "consume": 32 + consume() 33 + case "publish": 34 + sendMessages() 35 + default: 36 + fmt.Println("unknown client type") 21 37 } 38 + } 22 39 23 - sub, err := pubsub.NewSubscriber(":3000") 40 + func consume() { 41 + sub, err := client.NewSubscriber(":3000") 24 42 if err != nil { 25 43 panic(err) 26 44 } ··· 28 46 defer func() { 29 47 _ = sub.Close() 30 48 }() 49 + startAt := 0 50 + startAtType := server.Current 51 + if *consumeFrom > -1 { 52 + startAtType = server.From 53 + startAt = *consumeFrom 54 + } 31 55 32 - err = sub.SubscribeToTopics([]string{"topic a"}) 56 + err = sub.SubscribeToTopics([]string{topic}, startAtType, startAt) 33 57 if err != nil { 34 58 panic(err) 35 59 } ··· 41 65 42 66 for msg := range consumer.Messages() { 43 67 slog.Info("received message", "message", string(msg.Data)) 68 + msg.Ack(true) 44 69 } 45 - 46 70 } 47 71 48 72 func sendMessages() { 49 - publisher, err := pubsub.NewPublisher("localhost:3000") 73 + publisher, err := client.NewPublisher("localhost:3000") 50 74 if err != nil { 51 75 panic(err) 52 76 } ··· 59 83 i := 0 60 84 for { 61 85 i++ 62 - msg := pubsub.Message{ 63 - Topic: "topic a", 64 - Data: []byte(fmt.Sprintf("message %d", i)), 65 - } 86 + msg := client.NewMessage(topic, []byte(fmt.Sprintf("message %d", i))) 66 87 67 88 err = publisher.PublishMessage(msg) 68 89 if err != nil { 69 90 slog.Error("failed to publish message", "error", err) 70 91 continue 71 92 } 93 + 94 + slog.Info("message sent") 72 95 73 96 time.Sleep(time.Millisecond * 500) 74 97 }
-26
example/server/main.go
··· 1 - package main 2 - 3 - import ( 4 - "log" 5 - "os" 6 - "os/signal" 7 - "syscall" 8 - 9 - "github.com/willdot/messagebroker/server" 10 - ) 11 - 12 - func main() { 13 - srv, err := server.New(":3000") 14 - if err != nil { 15 - log.Fatal(err) 16 - } 17 - 18 - defer func() { 19 - _ = srv.Shutdown() 20 - }() 21 - 22 - signals := make(chan os.Signal, 1) 23 - signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) 24 - 25 - <-signals 26 - }
+1 -5
go.mod
··· 2 2 3 3 go 1.21.0 4 4 5 - require ( 6 - github.com/docker/distribution v2.8.3+incompatible 7 - github.com/google/uuid v1.4.0 8 - github.com/stretchr/testify v1.8.4 9 - ) 5 + require github.com/stretchr/testify v1.8.4 10 6 11 7 require ( 12 8 github.com/davecgh/go-spew v1.1.1 // indirect
-4
go.sum
··· 1 1 github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 2 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 - github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBirtxJnzDrHLEKxTAYk= 4 - github.com/docker/distribution v2.8.3+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= 5 - github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= 6 - github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 7 3 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 8 4 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 9 5 github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
+47
internal/messagestore/memory_store.go
··· 1 + package messagestore 2 + 3 + import ( 4 + "sync" 5 + 6 + "github.com/willdot/messagebroker/internal" 7 + ) 8 + 9 + // MemoryStore allows messages to be stored in memory 10 + type MemoryStore struct { 11 + mu sync.Mutex 12 + msgs map[int]internal.Message 13 + nextOffset int 14 + } 15 + 16 + // NewMemoryStore initializes a new in memory store 17 + func NewMemoryStore() *MemoryStore { 18 + return &MemoryStore{ 19 + msgs: make(map[int]internal.Message), 20 + } 21 + } 22 + 23 + // Write will write the provided message to the in memory store 24 + func (m *MemoryStore) Write(msg internal.Message) error { 25 + m.mu.Lock() 26 + defer m.mu.Unlock() 27 + 28 + m.msgs[m.nextOffset] = msg 29 + 30 + m.nextOffset++ 31 + 32 + return nil 33 + } 34 + 35 + // ReadFrom will read messages from (and including) the provided offset and pass them to the provided handler 36 + func (m *MemoryStore) ReadFrom(offset int, handleFunc func(msg internal.Message)) { 37 + if offset < 0 || offset >= m.nextOffset { 38 + return 39 + } 40 + 41 + m.mu.Lock() 42 + defer m.mu.Unlock() 43 + 44 + for i := offset; i < len(m.msgs); i++ { 45 + handleFunc(m.msgs[i]) 46 + } 47 + }
+12
internal/messge.go
··· 1 + package internal 2 + 3 + // Message represents a message that can be sent / received 4 + type Message struct { 5 + Data []byte 6 + DeliveryCount int 7 + } 8 + 9 + // NewMessage intializes a new message 10 + func NewMessage(data []byte) Message { 11 + return Message{Data: data, DeliveryCount: 1} 12 + }
+36
internal/server/peer.go
··· 1 + package server 2 + 3 + import ( 4 + "net" 5 + "sync" 6 + ) 7 + 8 + // Peer represents a remote connection to the server such as a publisher or subscriber. 9 + type Peer struct { 10 + conn net.Conn 11 + connMu sync.Mutex 12 + } 13 + 14 + // New returns a new peer. 15 + func NewPeer(conn net.Conn) *Peer { 16 + return &Peer{ 17 + conn: conn, 18 + } 19 + } 20 + 21 + // Addr returns the peers connections address. 22 + func (p *Peer) Addr() net.Addr { 23 + return p.conn.RemoteAddr() 24 + } 25 + 26 + // ConnOpp represents a set of actions on a connection that can be used synchrnously. 27 + type ConnOpp func(conn net.Conn) error 28 + 29 + // RunConnOperation will run the provided operation. It ensures that it is the only operation that is being 30 + // run on the connection to ensure any other operations don't get mixed up. 31 + func (p *Peer) RunConnOperation(op ConnOpp) error { 32 + p.connMu.Lock() 33 + defer p.connMu.Unlock() 34 + 35 + return op(p.conn) 36 + }
+513
internal/server/server.go
··· 1 + package server 2 + 3 + import ( 4 + "encoding/binary" 5 + "encoding/json" 6 + "errors" 7 + "fmt" 8 + "io" 9 + "log/slog" 10 + "net" 11 + "strings" 12 + "sync" 13 + "syscall" 14 + "time" 15 + 16 + "github.com/willdot/messagebroker/internal" 17 + ) 18 + 19 + // Action represents the type of action that a peer requests to do 20 + type Action uint16 21 + 22 + const ( 23 + Subscribe Action = 1 24 + Unsubscribe Action = 2 25 + Publish Action = 3 26 + Ack Action = 4 27 + Nack Action = 5 28 + ) 29 + 30 + // Status represents the status of a request 31 + type Status uint16 32 + 33 + const ( 34 + Subscribed Status = 1 35 + Unsubscribed Status = 2 36 + Error Status = 3 37 + ) 38 + 39 + func (s Status) String() string { 40 + switch s { 41 + case Subscribed: 42 + return "subscribed" 43 + case Unsubscribed: 44 + return "unsubscribed" 45 + case Error: 46 + return "error" 47 + } 48 + 49 + return "" 50 + } 51 + 52 + // StartAtType represents where the subcriber wishes to start subscribing to a topic from 53 + type StartAtType uint16 54 + 55 + const ( 56 + Beginning StartAtType = 0 57 + Current StartAtType = 1 58 + From StartAtType = 2 59 + ) 60 + 61 + // Server accepts subscribe and publish connections and passes messages around 62 + type Server struct { 63 + Addr string 64 + lis net.Listener 65 + 66 + mu sync.Mutex 67 + topics map[string]*topic 68 + 69 + ackDelay time.Duration 70 + ackTimeout time.Duration 71 + } 72 + 73 + // New creates and starts a new server 74 + func New(Addr string, ackDelay, ackTimeout time.Duration) (*Server, error) { 75 + lis, err := net.Listen("tcp", Addr) 76 + if err != nil { 77 + return nil, fmt.Errorf("failed to listen: %w", err) 78 + } 79 + 80 + srv := &Server{ 81 + lis: lis, 82 + topics: map[string]*topic{}, 83 + ackDelay: ackDelay, 84 + ackTimeout: ackTimeout, 85 + } 86 + 87 + go srv.start() 88 + 89 + return srv, nil 90 + } 91 + 92 + // Shutdown will cleanly shutdown the server 93 + func (s *Server) Shutdown() error { 94 + return s.lis.Close() 95 + } 96 + 97 + func (s *Server) start() { 98 + for { 99 + conn, err := s.lis.Accept() 100 + if err != nil { 101 + if errors.Is(err, net.ErrClosed) { 102 + slog.Info("listener closed") 103 + return 104 + } 105 + slog.Error("listener failed to accept", "error", err) 106 + continue 107 + } 108 + 109 + go s.handleConn(conn) 110 + } 111 + } 112 + 113 + func (s *Server) handleConn(conn net.Conn) { 114 + peer := NewPeer(conn) 115 + 116 + slog.Info("handling connection", "peer", peer.Addr()) 117 + defer slog.Info("ending connection", "peer", peer.Addr()) 118 + 119 + action, err := readAction(peer, 0) 120 + if err != nil { 121 + if !errors.Is(err, io.EOF) { 122 + slog.Error("failed to read action from peer", "error", err, "peer", peer.Addr()) 123 + } 124 + return 125 + } 126 + 127 + switch action { 128 + case Subscribe: 129 + s.handleSubscribe(peer) 130 + case Unsubscribe: 131 + s.handleUnsubscribe(peer) 132 + case Publish: 133 + s.handlePublish(peer) 134 + default: 135 + slog.Error("unknown action", "action", action, "peer", peer.Addr()) 136 + writeInvalidAction(peer) 137 + } 138 + } 139 + 140 + func (s *Server) handleSubscribe(peer *Peer) { 141 + slog.Info("handling subscriber", "peer", peer.Addr()) 142 + // subscribe the peer to the topic 143 + s.subscribePeerToTopic(peer) 144 + 145 + s.waitForPeerAction(peer) 146 + } 147 + 148 + func (s *Server) waitForPeerAction(peer *Peer) { 149 + // keep handling the peers connection, getting the action from the peer when it wishes to do something else. 150 + // once the peers connection ends, it will be unsubscribed from all topics and returned 151 + for { 152 + action, err := readAction(peer, time.Millisecond*100) 153 + if err != nil { 154 + // if the error is a timeout, it means the peer hasn't sent an action indicating it wishes to do something so sleep 155 + // for a little bit to allow for other actions to happen on the connection 156 + var neterr net.Error 157 + if errors.As(err, &neterr) && neterr.Timeout() { 158 + time.Sleep(time.Millisecond * 500) 159 + continue 160 + } 161 + 162 + if !errors.Is(err, io.EOF) { 163 + slog.Error("failed to read action from subscriber", "error", err, "peer", peer.Addr()) 164 + } 165 + 166 + s.unsubscribePeerFromAllTopics(peer) 167 + 168 + return 169 + } 170 + 171 + switch action { 172 + case Subscribe: 173 + s.subscribePeerToTopic(peer) 174 + case Unsubscribe: 175 + s.handleUnsubscribe(peer) 176 + default: 177 + slog.Error("unknown action for subscriber", "action", action, "peer", peer.Addr()) 178 + writeInvalidAction(peer) 179 + continue 180 + } 181 + } 182 + } 183 + 184 + func (s *Server) subscribePeerToTopic(peer *Peer) { 185 + op := func(conn net.Conn) error { 186 + // get the topics the peer wishes to subscribe to 187 + dataLen, err := dataLengthUint32(conn) 188 + if err != nil { 189 + slog.Error(err.Error(), "peer", peer.Addr()) 190 + writeStatus(Error, "invalid data length of topics provided", conn) 191 + return nil 192 + } 193 + if dataLen == 0 { 194 + writeStatus(Error, "data length of topics is 0", conn) 195 + return nil 196 + } 197 + 198 + buf := make([]byte, dataLen) 199 + _, err = conn.Read(buf) 200 + if err != nil { 201 + slog.Error("failed to read subscibers topic data", "error", err, "peer", peer.Addr()) 202 + writeStatus(Error, "failed to read topic data", conn) 203 + return nil 204 + } 205 + 206 + var topics []string 207 + err = json.Unmarshal(buf, &topics) 208 + if err != nil { 209 + slog.Error("failed to unmarshal subscibers topic data", "error", err, "peer", peer.Addr()) 210 + writeStatus(Error, "invalid topic data provided", conn) 211 + return nil 212 + } 213 + 214 + var startAtType StartAtType 215 + err = binary.Read(conn, binary.BigEndian, &startAtType) 216 + if err != nil { 217 + slog.Error(err.Error(), "peer", peer.Addr()) 218 + writeStatus(Error, "invalid start at type provided", conn) 219 + return nil 220 + } 221 + var startAt int 222 + switch startAtType { 223 + case From: 224 + var s uint16 225 + err = binary.Read(conn, binary.BigEndian, &s) 226 + if err != nil { 227 + slog.Error(err.Error(), "peer", peer.Addr()) 228 + writeStatus(Error, "invalid start at value provided", conn) 229 + return nil 230 + } 231 + startAt = int(s) 232 + case Beginning: 233 + startAt = 0 234 + case Current: 235 + startAt = -1 236 + default: 237 + slog.Error("invalid start up type provided", "start up type", startAtType) 238 + writeStatus(Error, "invalid start up type provided", conn) 239 + return nil 240 + } 241 + 242 + s.subscribeToTopics(peer, topics, startAt) 243 + writeStatus(Subscribed, "", conn) 244 + 245 + return nil 246 + } 247 + 248 + _ = peer.RunConnOperation(op) 249 + } 250 + 251 + func (s *Server) handleUnsubscribe(peer *Peer) { 252 + slog.Info("handling unsubscriber", "peer", peer.Addr()) 253 + op := func(conn net.Conn) error { 254 + // get the topics the peer wishes to unsubscribe from 255 + dataLen, err := dataLengthUint32(conn) 256 + if err != nil { 257 + slog.Error(err.Error(), "peer", peer.Addr()) 258 + writeStatus(Error, "invalid data length of topics provided", conn) 259 + return nil 260 + } 261 + if dataLen == 0 { 262 + writeStatus(Error, "data length of topics is 0", conn) 263 + return nil 264 + } 265 + 266 + buf := make([]byte, dataLen) 267 + _, err = conn.Read(buf) 268 + if err != nil { 269 + slog.Error("failed to read subscibers topic data", "error", err, "peer", peer.Addr()) 270 + writeStatus(Error, "failed to read topic data", conn) 271 + return nil 272 + } 273 + 274 + var topics []string 275 + err = json.Unmarshal(buf, &topics) 276 + if err != nil { 277 + slog.Error("failed to unmarshal subscibers topic data", "error", err, "peer", peer.Addr()) 278 + writeStatus(Error, "invalid topic data provided", conn) 279 + return nil 280 + } 281 + 282 + s.unsubscribeToTopics(peer, topics) 283 + writeStatus(Unsubscribed, "", conn) 284 + 285 + return nil 286 + } 287 + 288 + _ = peer.RunConnOperation(op) 289 + } 290 + 291 + func (s *Server) handlePublish(peer *Peer) { 292 + slog.Info("handling publisher", "peer", peer.Addr()) 293 + for { 294 + op := func(conn net.Conn) error { 295 + topicDataLen, err := dataLengthUint16(conn) 296 + if err != nil { 297 + if errors.Is(err, io.EOF) { 298 + return nil 299 + } 300 + slog.Error("failed to read data length", "error", err, "peer", peer.Addr()) 301 + writeStatus(Error, "invalid data length of data provided", conn) 302 + return nil 303 + } 304 + if topicDataLen == 0 { 305 + return nil 306 + } 307 + topicBuf := make([]byte, topicDataLen) 308 + _, err = conn.Read(topicBuf) 309 + if err != nil { 310 + slog.Error("failed to read topic from peer", "error", err, "peer", peer.Addr()) 311 + writeStatus(Error, "failed to read topic", conn) 312 + return nil 313 + } 314 + 315 + topicStr := string(topicBuf) 316 + if !strings.HasPrefix(topicStr, "topic:") { 317 + slog.Error("topic data does not contain topic prefix", "peer", peer.Addr()) 318 + writeStatus(Error, "topic data does not contain 'topic:' prefix", conn) 319 + return nil 320 + } 321 + topicStr = strings.TrimPrefix(topicStr, "topic:") 322 + 323 + msgDataLen, err := dataLengthUint32(conn) 324 + if err != nil { 325 + slog.Error(err.Error(), "peer", peer.Addr()) 326 + writeStatus(Error, "invalid data length of data provided", conn) 327 + return nil 328 + } 329 + if msgDataLen == 0 { 330 + return nil 331 + } 332 + 333 + dataBuf := make([]byte, msgDataLen) 334 + _, err = conn.Read(dataBuf) 335 + if err != nil { 336 + slog.Error("failed to read data from peer", "error", err, "peer", peer.Addr()) 337 + writeStatus(Error, "failed to read data", conn) 338 + return nil 339 + } 340 + 341 + topic := s.getTopic(topicStr) 342 + if topic == nil { 343 + topic = newTopic(topicStr) 344 + s.topics[topicStr] = topic 345 + } 346 + 347 + message := internal.NewMessage(dataBuf) 348 + 349 + err = topic.sendMessageToSubscribers(message) 350 + if err != nil { 351 + slog.Error("failed to send message to subscribers", "error", err, "peer", peer.Addr()) 352 + writeStatus(Error, "failed to send message to subscribers", conn) 353 + return nil 354 + } 355 + 356 + return nil 357 + } 358 + 359 + _ = peer.RunConnOperation(op) 360 + } 361 + } 362 + 363 + func (s *Server) subscribeToTopics(peer *Peer, topics []string, startAt int) { 364 + slog.Info("subscribing peer to topics", "topics", topics, "peer", peer.Addr()) 365 + for _, topic := range topics { 366 + s.addSubsciberToTopic(topic, peer, startAt) 367 + } 368 + } 369 + 370 + func (s *Server) addSubsciberToTopic(topicName string, peer *Peer, startAt int) { 371 + s.mu.Lock() 372 + defer s.mu.Unlock() 373 + 374 + t, ok := s.topics[topicName] 375 + if !ok { 376 + t = newTopic(topicName) 377 + } 378 + 379 + t.mu.Lock() 380 + t.subscriptions[peer.Addr()] = newSubscriber(peer, t, s.ackDelay, s.ackTimeout, startAt) 381 + t.mu.Unlock() 382 + 383 + s.topics[topicName] = t 384 + } 385 + 386 + func (s *Server) unsubscribeToTopics(peer *Peer, topics []string) { 387 + slog.Info("unsubscribing peer from topics", "topics", topics, "peer", peer.Addr()) 388 + for _, topic := range topics { 389 + s.removeSubsciberFromTopic(topic, peer) 390 + } 391 + } 392 + 393 + func (s *Server) removeSubsciberFromTopic(topicName string, peer *Peer) { 394 + s.mu.Lock() 395 + defer s.mu.Unlock() 396 + 397 + t, ok := s.topics[topicName] 398 + if !ok { 399 + return 400 + } 401 + 402 + sub := t.findSubscription(peer.Addr()) 403 + if sub == nil { 404 + return 405 + } 406 + 407 + sub.unsubscribe() 408 + t.removeSubscription(peer.Addr()) 409 + } 410 + 411 + func (s *Server) unsubscribePeerFromAllTopics(peer *Peer) { 412 + s.mu.Lock() 413 + defer s.mu.Unlock() 414 + 415 + for _, t := range s.topics { 416 + sub := t.findSubscription(peer.Addr()) 417 + if sub == nil { 418 + return 419 + } 420 + 421 + sub.unsubscribe() 422 + t.removeSubscription(peer.Addr()) 423 + } 424 + } 425 + 426 + func (s *Server) getTopic(topicName string) *topic { 427 + s.mu.Lock() 428 + defer s.mu.Unlock() 429 + 430 + if topic, ok := s.topics[topicName]; ok { 431 + return topic 432 + } 433 + 434 + return nil 435 + } 436 + 437 + func readAction(peer *Peer, timeout time.Duration) (Action, error) { 438 + var action Action 439 + op := func(conn net.Conn) error { 440 + if timeout > 0 { 441 + if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil { 442 + slog.Error("failed to set connection read deadline", "error", err, "peer", peer.Addr()) 443 + } 444 + defer func() { 445 + if err := conn.SetReadDeadline(time.Time{}); err != nil { 446 + slog.Error("failed to reset connection read deadline", "error", err, "peer", peer.Addr()) 447 + } 448 + }() 449 + } 450 + 451 + err := binary.Read(conn, binary.BigEndian, &action) 452 + if err != nil { 453 + return err 454 + } 455 + return nil 456 + } 457 + 458 + err := peer.RunConnOperation(op) 459 + if err != nil { 460 + return 0, fmt.Errorf("failed to read action from peer: %w", err) 461 + } 462 + 463 + return action, nil 464 + } 465 + 466 + func writeInvalidAction(peer *Peer) { 467 + op := func(conn net.Conn) error { 468 + writeStatus(Error, "unknown action", conn) 469 + return nil 470 + } 471 + 472 + _ = peer.RunConnOperation(op) 473 + } 474 + 475 + func dataLengthUint32(conn net.Conn) (uint32, error) { 476 + var dataLen uint32 477 + err := binary.Read(conn, binary.BigEndian, &dataLen) 478 + if err != nil { 479 + return 0, err 480 + } 481 + return dataLen, nil 482 + } 483 + 484 + func dataLengthUint16(conn net.Conn) (uint16, error) { 485 + var dataLen uint16 486 + err := binary.Read(conn, binary.BigEndian, &dataLen) 487 + if err != nil { 488 + return 0, err 489 + } 490 + return dataLen, nil 491 + } 492 + 493 + func writeStatus(status Status, message string, conn net.Conn) { 494 + statusB := make([]byte, 2) 495 + binary.BigEndian.PutUint16(statusB, uint16(status)) 496 + 497 + headers := statusB 498 + 499 + if len(message) > 0 { 500 + sizeB := make([]byte, 2) 501 + binary.BigEndian.PutUint16(sizeB, uint16(len(message))) 502 + headers = append(headers, sizeB...) 503 + } 504 + 505 + msgBytes := []byte(message) 506 + _, err := conn.Write(append(headers, msgBytes...)) 507 + if err != nil { 508 + if !errors.Is(err, syscall.EPIPE) { 509 + slog.Error("failed to write status to peers connection", "error", err, "peer", conn.RemoteAddr()) 510 + } 511 + return 512 + } 513 + }
+632
internal/server/server_test.go
··· 1 + package server 2 + 3 + import ( 4 + "encoding/binary" 5 + "encoding/json" 6 + "fmt" 7 + "net" 8 + "testing" 9 + "time" 10 + 11 + "github.com/stretchr/testify/assert" 12 + "github.com/stretchr/testify/require" 13 + "github.com/willdot/messagebroker/internal/messagestore" 14 + ) 15 + 16 + const ( 17 + topicA = "topic a" 18 + topicB = "topic b" 19 + topicC = "topic c" 20 + 21 + serverAddr = ":6666" 22 + 23 + ackDelay = time.Millisecond * 100 24 + ackTimeout = time.Millisecond * 100 25 + ) 26 + 27 + func createServer(t *testing.T) *Server { 28 + srv, err := New(serverAddr, ackDelay, ackTimeout) 29 + require.NoError(t, err) 30 + 31 + t.Cleanup(func() { 32 + _ = srv.Shutdown() 33 + }) 34 + 35 + return srv 36 + } 37 + 38 + func createServerWithExistingTopic(t *testing.T, topicName string) *Server { 39 + srv := createServer(t) 40 + srv.topics[topicName] = &topic{ 41 + name: topicName, 42 + subscriptions: make(map[net.Addr]*subscriber), 43 + messageStore: messagestore.NewMemoryStore(), 44 + } 45 + 46 + return srv 47 + } 48 + 49 + func createConnectionAndSubscribe(t *testing.T, topics []string, startAtType StartAtType, startAtIndex int) net.Conn { 50 + conn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 51 + require.NoError(t, err) 52 + 53 + subscribeToTopics(t, conn, topics, startAtType, startAtIndex) 54 + 55 + expectedRes := Subscribed 56 + 57 + var resp Status 58 + err = binary.Read(conn, binary.BigEndian, &resp) 59 + require.NoError(t, err) 60 + 61 + assert.Equal(t, expectedRes, resp) 62 + 63 + return conn 64 + } 65 + 66 + func sendMessage(t *testing.T, conn net.Conn, topic string, message []byte) { 67 + topicLenB := make([]byte, 4) 68 + binary.BigEndian.PutUint32(topicLenB, uint32(len(topic))) 69 + 70 + headers := topicLenB 71 + headers = append(headers, []byte(topic)...) 72 + 73 + messageLenB := make([]byte, 4) 74 + binary.BigEndian.PutUint32(messageLenB, uint32(len(message))) 75 + headers = append(headers, messageLenB...) 76 + 77 + _, err := conn.Write(append(headers, message...)) 78 + require.NoError(t, err) 79 + } 80 + 81 + func subscribeToTopics(t *testing.T, conn net.Conn, topics []string, startAtType StartAtType, startAtIndex int) { 82 + actionB := make([]byte, 2) 83 + binary.BigEndian.PutUint16(actionB, uint16(Subscribe)) 84 + headers := actionB 85 + 86 + b, err := json.Marshal(topics) 87 + require.NoError(t, err) 88 + 89 + topicNamesB := make([]byte, 4) 90 + binary.BigEndian.PutUint32(topicNamesB, uint32(len(b))) 91 + headers = append(headers, topicNamesB...) 92 + headers = append(headers, b...) 93 + 94 + startAtTypeB := make([]byte, 2) 95 + binary.BigEndian.PutUint16(startAtTypeB, uint16(startAtType)) 96 + headers = append(headers, startAtTypeB...) 97 + 98 + if startAtType == From { 99 + fromB := make([]byte, 2) 100 + binary.BigEndian.PutUint16(fromB, uint16(startAtIndex)) 101 + headers = append(headers, fromB...) 102 + } 103 + 104 + _, err = conn.Write(headers) 105 + require.NoError(t, err) 106 + } 107 + 108 + func unsubscribetoTopics(t *testing.T, conn net.Conn, topics []string) { 109 + actionB := make([]byte, 2) 110 + binary.BigEndian.PutUint16(actionB, uint16(Unsubscribe)) 111 + headers := actionB 112 + 113 + b, err := json.Marshal(topics) 114 + require.NoError(t, err) 115 + 116 + topicNamesB := make([]byte, 4) 117 + binary.BigEndian.PutUint32(topicNamesB, uint32(len(b))) 118 + headers = append(headers, topicNamesB...) 119 + 120 + _, err = conn.Write(append(headers, b...)) 121 + require.NoError(t, err) 122 + } 123 + 124 + func TestSubscribeToTopics(t *testing.T) { 125 + // create a server with an existing topic so we can test subscribing to a new and 126 + // existing topic 127 + srv := createServerWithExistingTopic(t, topicA) 128 + 129 + _ = createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0) 130 + 131 + srv.mu.Lock() 132 + defer srv.mu.Unlock() 133 + assert.Len(t, srv.topics, 2) 134 + assert.Len(t, srv.topics[topicA].subscriptions, 1) 135 + assert.Len(t, srv.topics[topicB].subscriptions, 1) 136 + } 137 + 138 + func TestUnsubscribesFromTopic(t *testing.T) { 139 + srv := createServerWithExistingTopic(t, topicA) 140 + 141 + conn := createConnectionAndSubscribe(t, []string{topicA, topicB, topicC}, Current, 0) 142 + 143 + assert.Len(t, srv.topics, 3) 144 + 145 + srv.mu.Lock() 146 + assert.Len(t, srv.topics[topicA].subscriptions, 1) 147 + assert.Len(t, srv.topics[topicB].subscriptions, 1) 148 + assert.Len(t, srv.topics[topicC].subscriptions, 1) 149 + srv.mu.Unlock() 150 + 151 + topics := []string{topicA, topicB} 152 + 153 + unsubscribetoTopics(t, conn, topics) 154 + 155 + expectedRes := Unsubscribed 156 + 157 + var resp Status 158 + err := binary.Read(conn, binary.BigEndian, &resp) 159 + require.NoError(t, err) 160 + 161 + assert.Equal(t, expectedRes, resp) 162 + 163 + assert.Len(t, srv.topics, 3) 164 + 165 + srv.mu.Lock() 166 + assert.Len(t, srv.topics[topicA].subscriptions, 0) 167 + assert.Len(t, srv.topics[topicB].subscriptions, 0) 168 + assert.Len(t, srv.topics[topicC].subscriptions, 1) 169 + srv.mu.Unlock() 170 + } 171 + 172 + func TestSubscriberClosesWithoutUnsubscribing(t *testing.T) { 173 + srv := createServer(t) 174 + 175 + conn := createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0) 176 + 177 + assert.Len(t, srv.topics, 2) 178 + 179 + srv.mu.Lock() 180 + assert.Len(t, srv.topics[topicA].subscriptions, 1) 181 + assert.Len(t, srv.topics[topicB].subscriptions, 1) 182 + srv.mu.Unlock() 183 + 184 + // close the conn 185 + err := conn.Close() 186 + require.NoError(t, err) 187 + 188 + publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 189 + require.NoError(t, err) 190 + 191 + err = binary.Write(publisherConn, binary.BigEndian, Publish) 192 + require.NoError(t, err) 193 + 194 + data := []byte("hello world") 195 + 196 + sendMessage(t, publisherConn, topicA, data) 197 + 198 + // the timeout for a connection is 100 milliseconds, so we should wait at least this long before checking the unsubscribe 199 + // TODO: see if theres a better way, but without this, the test is flakey 200 + time.Sleep(time.Millisecond * 100) 201 + 202 + assert.Len(t, srv.topics, 2) 203 + 204 + srv.mu.Lock() 205 + assert.Len(t, srv.topics[topicA].subscriptions, 0) 206 + assert.Len(t, srv.topics[topicB].subscriptions, 0) 207 + srv.mu.Unlock() 208 + } 209 + 210 + func TestInvalidAction(t *testing.T) { 211 + _ = createServer(t) 212 + 213 + conn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 214 + require.NoError(t, err) 215 + 216 + err = binary.Write(conn, binary.BigEndian, uint16(99)) 217 + require.NoError(t, err) 218 + 219 + expectedRes := Error 220 + 221 + var resp Status 222 + err = binary.Read(conn, binary.BigEndian, &resp) 223 + require.NoError(t, err) 224 + 225 + assert.Equal(t, expectedRes, resp) 226 + 227 + expectedMessage := "unknown action" 228 + 229 + var dataLen uint16 230 + err = binary.Read(conn, binary.BigEndian, &dataLen) 231 + require.NoError(t, err) 232 + assert.Equal(t, len(expectedMessage), int(dataLen)) 233 + 234 + buf := make([]byte, dataLen) 235 + _, err = conn.Read(buf) 236 + require.NoError(t, err) 237 + 238 + assert.Equal(t, expectedMessage, string(buf)) 239 + } 240 + 241 + func TestInvalidTopicDataPublished(t *testing.T) { 242 + _ = createServer(t) 243 + 244 + publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 245 + require.NoError(t, err) 246 + 247 + err = binary.Write(publisherConn, binary.BigEndian, Publish) 248 + require.NoError(t, err) 249 + 250 + // send topic 251 + topic := topicA 252 + err = binary.Write(publisherConn, binary.BigEndian, uint32(len(topic))) 253 + require.NoError(t, err) 254 + _, err = publisherConn.Write([]byte(topic)) 255 + require.NoError(t, err) 256 + 257 + expectedRes := Error 258 + 259 + var resp Status 260 + err = binary.Read(publisherConn, binary.BigEndian, &resp) 261 + require.NoError(t, err) 262 + 263 + assert.Equal(t, expectedRes, resp) 264 + 265 + expectedMessage := "topic data does not contain 'topic:' prefix" 266 + 267 + var dataLen uint16 268 + err = binary.Read(publisherConn, binary.BigEndian, &dataLen) 269 + require.NoError(t, err) 270 + assert.Equal(t, len(expectedMessage), int(dataLen)) 271 + 272 + buf := make([]byte, dataLen) 273 + _, err = publisherConn.Read(buf) 274 + require.NoError(t, err) 275 + 276 + assert.Equal(t, expectedMessage, string(buf)) 277 + } 278 + 279 + func TestSendsDataToTopicSubscribers(t *testing.T) { 280 + _ = createServer(t) 281 + 282 + subscribers := make([]net.Conn, 0, 10) 283 + for i := 0; i < 10; i++ { 284 + subscriberConn := createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0) 285 + 286 + subscribers = append(subscribers, subscriberConn) 287 + } 288 + 289 + publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 290 + require.NoError(t, err) 291 + 292 + err = binary.Write(publisherConn, binary.BigEndian, Publish) 293 + require.NoError(t, err) 294 + 295 + topic := fmt.Sprintf("topic:%s", topicA) 296 + messageData := "hello world" 297 + 298 + sendMessage(t, publisherConn, topic, []byte(messageData)) 299 + 300 + // check the subsribers got the data 301 + for _, conn := range subscribers { 302 + msg := readMessage(t, conn) 303 + assert.Equal(t, messageData, string(msg)) 304 + } 305 + } 306 + 307 + func TestPublishMultipleTimes(t *testing.T) { 308 + _ = createServer(t) 309 + 310 + publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 311 + require.NoError(t, err) 312 + 313 + err = binary.Write(publisherConn, binary.BigEndian, Publish) 314 + require.NoError(t, err) 315 + 316 + messages := make([]string, 0, 10) 317 + for i := 0; i < 10; i++ { 318 + messages = append(messages, fmt.Sprintf("message %d", i)) 319 + } 320 + 321 + subscribeFinCh := make(chan struct{}) 322 + // create a subscriber that will read messages 323 + subscriberConn := createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0) 324 + go func() { 325 + // check subscriber got all messages 326 + results := make([]string, 0, len(messages)) 327 + for i := 0; i < len(messages); i++ { 328 + msg := readMessage(t, subscriberConn) 329 + results = append(results, string(msg)) 330 + } 331 + 332 + assert.ElementsMatch(t, results, messages) 333 + 334 + subscribeFinCh <- struct{}{} 335 + }() 336 + 337 + topic := fmt.Sprintf("topic:%s", topicA) 338 + 339 + // send multiple messages 340 + for _, msg := range messages { 341 + sendMessage(t, publisherConn, topic, []byte(msg)) 342 + } 343 + 344 + select { 345 + case <-subscribeFinCh: 346 + break 347 + case <-time.After(time.Second): 348 + t.Fatal(fmt.Errorf("timed out waiting for subscriber to read messages")) 349 + } 350 + } 351 + 352 + func TestSendsDataToTopicSubscriberNacksThenAcks(t *testing.T) { 353 + _ = createServer(t) 354 + 355 + subscriberConn := createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0) 356 + 357 + publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 358 + require.NoError(t, err) 359 + 360 + err = binary.Write(publisherConn, binary.BigEndian, Publish) 361 + require.NoError(t, err) 362 + 363 + topic := fmt.Sprintf("topic:%s", topicA) 364 + messageData := "hello world" 365 + 366 + sendMessage(t, publisherConn, topic, []byte(messageData)) 367 + 368 + // check the subsribers got the data 369 + readMessage := func(conn net.Conn, ack Action) { 370 + var topicLen uint16 371 + err = binary.Read(conn, binary.BigEndian, &topicLen) 372 + require.NoError(t, err) 373 + 374 + topicBuf := make([]byte, topicLen) 375 + _, err = conn.Read(topicBuf) 376 + require.NoError(t, err) 377 + assert.Equal(t, topicA, string(topicBuf)) 378 + 379 + var dataLen uint64 380 + err = binary.Read(conn, binary.BigEndian, &dataLen) 381 + require.NoError(t, err) 382 + 383 + buf := make([]byte, dataLen) 384 + n, err := conn.Read(buf) 385 + require.NoError(t, err) 386 + 387 + require.Equal(t, int(dataLen), n) 388 + 389 + assert.Equal(t, messageData, string(buf)) 390 + 391 + err = binary.Write(conn, binary.BigEndian, ack) 392 + require.NoError(t, err) 393 + } 394 + 395 + // NACK the message and then ack it 396 + readMessage(subscriberConn, Nack) 397 + readMessage(subscriberConn, Ack) 398 + // reading for another message should now timeout but give enough time for the ack delay to kick in 399 + // should the second read of the message not have been ack'd properly 400 + var topicLen uint16 401 + _ = subscriberConn.SetReadDeadline(time.Now().Add(ackDelay + time.Millisecond*100)) 402 + err = binary.Read(subscriberConn, binary.BigEndian, &topicLen) 403 + require.Error(t, err) 404 + } 405 + 406 + func TestSendsDataToTopicSubscriberDoesntAckMessage(t *testing.T) { 407 + _ = createServer(t) 408 + 409 + subscriberConn := createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0) 410 + 411 + publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 412 + require.NoError(t, err) 413 + 414 + err = binary.Write(publisherConn, binary.BigEndian, Publish) 415 + require.NoError(t, err) 416 + 417 + topic := fmt.Sprintf("topic:%s", topicA) 418 + messageData := "hello world" 419 + 420 + sendMessage(t, publisherConn, topic, []byte(messageData)) 421 + 422 + // check the subsribers got the data 423 + readMessage := func(conn net.Conn, ack bool) { 424 + var topicLen uint16 425 + err = binary.Read(conn, binary.BigEndian, &topicLen) 426 + require.NoError(t, err) 427 + 428 + topicBuf := make([]byte, topicLen) 429 + _, err = conn.Read(topicBuf) 430 + require.NoError(t, err) 431 + assert.Equal(t, topicA, string(topicBuf)) 432 + 433 + var dataLen uint64 434 + err = binary.Read(conn, binary.BigEndian, &dataLen) 435 + require.NoError(t, err) 436 + 437 + buf := make([]byte, dataLen) 438 + n, err := conn.Read(buf) 439 + require.NoError(t, err) 440 + 441 + require.Equal(t, int(dataLen), n) 442 + 443 + assert.Equal(t, messageData, string(buf)) 444 + 445 + if ack { 446 + err = binary.Write(conn, binary.BigEndian, Ack) 447 + require.NoError(t, err) 448 + return 449 + } 450 + } 451 + 452 + // don't send ack or nack and then ack on the second attempt 453 + readMessage(subscriberConn, false) 454 + readMessage(subscriberConn, true) 455 + 456 + // reading for another message should now timeout but give enough time for the ack delay to kick in 457 + // should the second read of the message not have been ack'd properly 458 + var topicLen uint16 459 + _ = subscriberConn.SetReadDeadline(time.Now().Add(ackDelay + time.Millisecond*100)) 460 + err = binary.Read(subscriberConn, binary.BigEndian, &topicLen) 461 + require.Error(t, err) 462 + } 463 + 464 + func TestSendsDataToTopicSubscriberDeliveryCountTooHighWithNoAck(t *testing.T) { 465 + _ = createServer(t) 466 + 467 + subscriberConn := createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0) 468 + 469 + publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 470 + require.NoError(t, err) 471 + 472 + err = binary.Write(publisherConn, binary.BigEndian, Publish) 473 + require.NoError(t, err) 474 + 475 + topic := fmt.Sprintf("topic:%s", topicA) 476 + messageData := "hello world" 477 + 478 + sendMessage(t, publisherConn, topic, []byte(messageData)) 479 + 480 + // check the subsribers got the data 481 + readMessage := func(conn net.Conn, ack bool) { 482 + var topicLen uint16 483 + err = binary.Read(conn, binary.BigEndian, &topicLen) 484 + require.NoError(t, err) 485 + 486 + topicBuf := make([]byte, topicLen) 487 + _, err = conn.Read(topicBuf) 488 + require.NoError(t, err) 489 + assert.Equal(t, topicA, string(topicBuf)) 490 + 491 + var dataLen uint64 492 + err = binary.Read(conn, binary.BigEndian, &dataLen) 493 + require.NoError(t, err) 494 + 495 + buf := make([]byte, dataLen) 496 + n, err := conn.Read(buf) 497 + require.NoError(t, err) 498 + 499 + require.Equal(t, int(dataLen), n) 500 + 501 + assert.Equal(t, messageData, string(buf)) 502 + 503 + if ack { 504 + err = binary.Write(conn, binary.BigEndian, Ack) 505 + require.NoError(t, err) 506 + return 507 + } 508 + } 509 + 510 + // nack the message 5 times 511 + readMessage(subscriberConn, false) 512 + readMessage(subscriberConn, false) 513 + readMessage(subscriberConn, false) 514 + readMessage(subscriberConn, false) 515 + readMessage(subscriberConn, false) 516 + 517 + // reading for the message should now timeout as we have nack'd the message too many times 518 + var topicLen uint16 519 + _ = subscriberConn.SetReadDeadline(time.Now().Add(ackDelay + time.Millisecond*100)) 520 + err = binary.Read(subscriberConn, binary.BigEndian, &topicLen) 521 + require.Error(t, err) 522 + } 523 + 524 + func TestSubscribeAndReplaysFromStart(t *testing.T) { 525 + _ = createServer(t) 526 + 527 + publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 528 + require.NoError(t, err) 529 + 530 + err = binary.Write(publisherConn, binary.BigEndian, Publish) 531 + require.NoError(t, err) 532 + 533 + messages := make([]string, 0, 10) 534 + for i := 0; i < 10; i++ { 535 + messages = append(messages, fmt.Sprintf("message %d", i)) 536 + } 537 + 538 + topic := fmt.Sprintf("topic:%s", topicA) 539 + 540 + for _, msg := range messages { 541 + sendMessage(t, publisherConn, topic, []byte(msg)) 542 + } 543 + 544 + // send some messages for topic B as well 545 + sendMessage(t, publisherConn, fmt.Sprintf("topic:%s", topicB), []byte("topic b message 1")) 546 + sendMessage(t, publisherConn, fmt.Sprintf("topic:%s", topicB), []byte("topic b message 2")) 547 + sendMessage(t, publisherConn, fmt.Sprintf("topic:%s", topicB), []byte("topic b message 3")) 548 + 549 + subscriberConn := createConnectionAndSubscribe(t, []string{topicA}, From, 0) 550 + results := make([]string, 0, len(messages)) 551 + for i := 0; i < len(messages); i++ { 552 + msg := readMessage(t, subscriberConn) 553 + results = append(results, string(msg)) 554 + } 555 + assert.ElementsMatch(t, results, messages) 556 + } 557 + 558 + func TestSubscribeAndReplaysFromIndex(t *testing.T) { 559 + _ = createServer(t) 560 + 561 + publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 562 + require.NoError(t, err) 563 + 564 + err = binary.Write(publisherConn, binary.BigEndian, Publish) 565 + require.NoError(t, err) 566 + 567 + messages := make([]string, 0, 10) 568 + for i := 0; i < 10; i++ { 569 + messages = append(messages, fmt.Sprintf("message %d", i)) 570 + } 571 + 572 + topic := fmt.Sprintf("topic:%s", topicA) 573 + 574 + // send multiple messages 575 + for _, msg := range messages { 576 + sendMessage(t, publisherConn, topic, []byte(msg)) 577 + } 578 + 579 + // send some messages for topic B as well 580 + sendMessage(t, publisherConn, fmt.Sprintf("topic:%s", topicB), []byte("topic b message 1")) 581 + sendMessage(t, publisherConn, fmt.Sprintf("topic:%s", topicB), []byte("topic b message 2")) 582 + sendMessage(t, publisherConn, fmt.Sprintf("topic:%s", topicB), []byte("topic b message 3")) 583 + 584 + subscriberConn := createConnectionAndSubscribe(t, []string{topicA}, From, 3) 585 + 586 + // now that the subscriber has subecribed send another message that should arrive after all the other messages were consumed 587 + sendMessage(t, publisherConn, topic, []byte("hello there")) 588 + 589 + results := make([]string, 0, len(messages)) 590 + for i := 0; i < len(messages)-3; i++ { 591 + msg := readMessage(t, subscriberConn) 592 + results = append(results, string(msg)) 593 + } 594 + require.Len(t, results, 7) 595 + expMessages := make([]string, 0, 7) 596 + for i, msg := range messages { 597 + if i < 3 { 598 + continue 599 + } 600 + expMessages = append(expMessages, msg) 601 + } 602 + assert.Equal(t, expMessages, results) 603 + 604 + // now check we can get the message that was sent after the subscription was created 605 + msg := readMessage(t, subscriberConn) 606 + assert.Equal(t, "hello there", string(msg)) 607 + } 608 + 609 + func readMessage(t *testing.T, subscriberConn net.Conn) []byte { 610 + var topicLen uint16 611 + err := binary.Read(subscriberConn, binary.BigEndian, &topicLen) 612 + require.NoError(t, err) 613 + 614 + topicBuf := make([]byte, topicLen) 615 + _, err = subscriberConn.Read(topicBuf) 616 + require.NoError(t, err) 617 + assert.Equal(t, topicA, string(topicBuf)) 618 + 619 + var dataLen uint64 620 + err = binary.Read(subscriberConn, binary.BigEndian, &dataLen) 621 + require.NoError(t, err) 622 + 623 + buf := make([]byte, dataLen) 624 + n, err := subscriberConn.Read(buf) 625 + require.NoError(t, err) 626 + require.Equal(t, int(dataLen), n) 627 + 628 + err = binary.Write(subscriberConn, binary.BigEndian, Ack) 629 + require.NoError(t, err) 630 + 631 + return buf 632 + }
+136
internal/server/subscriber.go
··· 1 + package server 2 + 3 + import ( 4 + "encoding/binary" 5 + "fmt" 6 + "log/slog" 7 + "net" 8 + "time" 9 + 10 + "github.com/willdot/messagebroker/internal" 11 + ) 12 + 13 + type subscriber struct { 14 + peer *Peer 15 + topic string 16 + messages chan internal.Message 17 + unsubscribeCh chan struct{} 18 + 19 + ackDelay time.Duration 20 + ackTimeout time.Duration 21 + } 22 + 23 + func newSubscriber(peer *Peer, topic *topic, ackDelay, ackTimeout time.Duration, startAt int) *subscriber { 24 + s := &subscriber{ 25 + peer: peer, 26 + topic: topic.name, 27 + messages: make(chan internal.Message), 28 + ackDelay: ackDelay, 29 + ackTimeout: ackTimeout, 30 + unsubscribeCh: make(chan struct{}, 1), 31 + } 32 + 33 + go s.sendMessages() 34 + 35 + go func() { 36 + topic.messageStore.ReadFrom(startAt, func(msg internal.Message) { 37 + select { 38 + case s.messages <- msg: 39 + return 40 + case <-s.unsubscribeCh: 41 + return 42 + } 43 + }) 44 + }() 45 + 46 + return s 47 + } 48 + 49 + func (s *subscriber) sendMessages() { 50 + for { 51 + select { 52 + case <-s.unsubscribeCh: 53 + return 54 + case msg := <-s.messages: 55 + ack, err := s.sendMessage(s.topic, msg) 56 + if err != nil { 57 + slog.Error("failed to send to message", "error", err, "peer", s.peer.Addr()) 58 + } 59 + 60 + if ack { 61 + continue 62 + } 63 + 64 + if msg.DeliveryCount >= 5 { 65 + slog.Error("max delivery count for message. Dropping", "peer", s.peer.Addr()) 66 + continue 67 + } 68 + 69 + msg.DeliveryCount++ 70 + s.addMessage(msg, s.ackDelay) 71 + } 72 + } 73 + } 74 + 75 + func (s *subscriber) addMessage(msg internal.Message, delay time.Duration) { 76 + go func() { 77 + timer := time.NewTimer(delay) 78 + defer timer.Stop() 79 + 80 + select { 81 + case <-s.unsubscribeCh: 82 + return 83 + case <-timer.C: 84 + s.messages <- msg 85 + } 86 + }() 87 + } 88 + 89 + func (s *subscriber) sendMessage(topic string, msg internal.Message) (bool, error) { 90 + var ack bool 91 + op := func(conn net.Conn) error { 92 + topicB := make([]byte, 2) 93 + binary.BigEndian.PutUint16(topicB, uint16(len(topic))) 94 + 95 + headers := topicB 96 + headers = append(headers, []byte(topic)...) 97 + 98 + // TODO: if message is empty, return error? 99 + dataLenB := make([]byte, 8) 100 + binary.BigEndian.PutUint64(dataLenB, uint64(len(msg.Data))) 101 + headers = append(headers, dataLenB...) 102 + 103 + _, err := conn.Write(append(headers, msg.Data...)) 104 + if err != nil { 105 + return fmt.Errorf("failed to write to peer: %w", err) 106 + } 107 + 108 + if err := conn.SetReadDeadline(time.Now().Add(s.ackTimeout)); err != nil { 109 + slog.Error("failed to set connection read deadline", "error", err, "peer", s.peer.Addr()) 110 + } 111 + defer func() { 112 + if err := conn.SetReadDeadline(time.Time{}); err != nil { 113 + slog.Error("failed to reset connection read deadline", "error", err, "peer", s.peer.Addr()) 114 + } 115 + }() 116 + var ackRes Action 117 + err = binary.Read(conn, binary.BigEndian, &ackRes) 118 + if err != nil { 119 + return fmt.Errorf("failed to read ack from peer: %w", err) 120 + } 121 + 122 + if ackRes == Ack { 123 + ack = true 124 + } 125 + 126 + return nil 127 + } 128 + 129 + err := s.peer.RunConnOperation(op) 130 + 131 + return ack, err 132 + } 133 + 134 + func (s *subscriber) unsubscribe() { 135 + close(s.unsubscribeCh) 136 + }
+62
internal/server/topic.go
··· 1 + package server 2 + 3 + import ( 4 + "fmt" 5 + "net" 6 + "sync" 7 + 8 + "github.com/willdot/messagebroker/internal" 9 + "github.com/willdot/messagebroker/internal/messagestore" 10 + ) 11 + 12 + type Store interface { 13 + Write(msg internal.Message) error 14 + ReadFrom(offset int, handleFunc func(msg internal.Message)) 15 + } 16 + 17 + type topic struct { 18 + name string 19 + subscriptions map[net.Addr]*subscriber 20 + mu sync.Mutex 21 + messageStore Store 22 + } 23 + 24 + func newTopic(name string) *topic { 25 + messageStore := messagestore.NewMemoryStore() 26 + return &topic{ 27 + name: name, 28 + subscriptions: make(map[net.Addr]*subscriber), 29 + messageStore: messageStore, 30 + } 31 + } 32 + 33 + func (t *topic) sendMessageToSubscribers(msg internal.Message) error { 34 + err := t.messageStore.Write(msg) 35 + if err != nil { 36 + return fmt.Errorf("failed to write message to store: %w", err) 37 + } 38 + 39 + t.mu.Lock() 40 + subscribers := t.subscriptions 41 + t.mu.Unlock() 42 + 43 + for _, subscriber := range subscribers { 44 + subscriber.addMessage(msg, 0) 45 + } 46 + 47 + return nil 48 + } 49 + 50 + func (t *topic) findSubscription(addr net.Addr) *subscriber { 51 + t.mu.Lock() 52 + defer t.mu.Unlock() 53 + 54 + return t.subscriptions[addr] 55 + } 56 + 57 + func (t *topic) removeSubscription(addr net.Addr) { 58 + t.mu.Lock() 59 + defer t.mu.Unlock() 60 + 61 + delete(t.subscriptions, addr) 62 + }
-7
pubsub/message.go
··· 1 - package pubsub 2 - 3 - // Message represents a message that can be published or consumed 4 - type Message struct { 5 - Topic string `json:"topic"` 6 - Data []byte `json:"data"` 7 - }
-76
pubsub/publisher.go
··· 1 - package pubsub 2 - 3 - import ( 4 - "encoding/binary" 5 - "fmt" 6 - "net" 7 - "sync" 8 - 9 - "github.com/willdot/messagebroker/server" 10 - ) 11 - 12 - // Publisher allows messages to be published to a server 13 - type Publisher struct { 14 - conn net.Conn 15 - connMu sync.Mutex 16 - } 17 - 18 - // NewPublisher connects to the server at the given address and registers as a publisher 19 - func NewPublisher(addr string) (*Publisher, error) { 20 - conn, err := net.Dial("tcp", addr) 21 - if err != nil { 22 - return nil, fmt.Errorf("failed to dial: %w", err) 23 - } 24 - 25 - err = binary.Write(conn, binary.BigEndian, server.Publish) 26 - if err != nil { 27 - conn.Close() 28 - return nil, fmt.Errorf("failed to register publish to server: %w", err) 29 - } 30 - 31 - return &Publisher{ 32 - conn: conn, 33 - }, nil 34 - } 35 - 36 - // Close cleanly shuts down the publisher 37 - func (p *Publisher) Close() error { 38 - return p.conn.Close() 39 - } 40 - 41 - // Publish will publish the given message to the server 42 - func (p *Publisher) PublishMessage(message Message) error { 43 - op := func(conn net.Conn) error { 44 - // send topic first 45 - topic := fmt.Sprintf("topic:%s", message.Topic) 46 - err := binary.Write(p.conn, binary.BigEndian, uint32(len(topic))) 47 - if err != nil { 48 - return fmt.Errorf("failed to write topic size to server") 49 - } 50 - 51 - _, err = p.conn.Write([]byte(topic)) 52 - if err != nil { 53 - return fmt.Errorf("failed to write topic to server") 54 - } 55 - 56 - err = binary.Write(p.conn, binary.BigEndian, uint32(len(message.Data))) 57 - if err != nil { 58 - return fmt.Errorf("failed to write message size to server") 59 - } 60 - 61 - _, err = p.conn.Write(message.Data) 62 - if err != nil { 63 - return fmt.Errorf("failed to publish data to server") 64 - } 65 - return nil 66 - } 67 - 68 - return p.connOperation(op) 69 - } 70 - 71 - func (p *Publisher) connOperation(op connOpp) error { 72 - p.connMu.Lock() 73 - defer p.connMu.Unlock() 74 - 75 - return op(p.conn) 76 - }
-254
pubsub/subscriber.go
··· 1 - package pubsub 2 - 3 - import ( 4 - "context" 5 - "encoding/binary" 6 - "encoding/json" 7 - "errors" 8 - "fmt" 9 - "net" 10 - "sync" 11 - "time" 12 - 13 - "github.com/willdot/messagebroker/server" 14 - ) 15 - 16 - type connOpp func(conn net.Conn) error 17 - 18 - // Subscriber allows subscriptions to a server and the consumption of messages 19 - type Subscriber struct { 20 - conn net.Conn 21 - connMu sync.Mutex 22 - } 23 - 24 - // NewSubscriber will connect to the server at the given address 25 - func NewSubscriber(addr string) (*Subscriber, error) { 26 - conn, err := net.Dial("tcp", addr) 27 - if err != nil { 28 - return nil, fmt.Errorf("failed to dial: %w", err) 29 - } 30 - 31 - return &Subscriber{ 32 - conn: conn, 33 - }, nil 34 - } 35 - 36 - // Close cleanly shuts down the subscriber 37 - func (s *Subscriber) Close() error { 38 - return s.conn.Close() 39 - } 40 - 41 - // SubscribeToTopics will subscribe to the provided topics 42 - func (s *Subscriber) SubscribeToTopics(topicNames []string) error { 43 - op := func(conn net.Conn) error { 44 - err := binary.Write(conn, binary.BigEndian, server.Subscribe) 45 - if err != nil { 46 - return fmt.Errorf("failed to subscribe: %w", err) 47 - } 48 - 49 - b, err := json.Marshal(topicNames) 50 - if err != nil { 51 - return fmt.Errorf("failed to marshal topic names: %w", err) 52 - } 53 - 54 - err = binary.Write(conn, binary.BigEndian, uint32(len(b))) 55 - if err != nil { 56 - return fmt.Errorf("failed to write topic data length: %w", err) 57 - } 58 - 59 - _, err = conn.Write(b) 60 - if err != nil { 61 - return fmt.Errorf("failed to subscribe to topics: %w", err) 62 - } 63 - 64 - var resp server.Status 65 - err = binary.Read(conn, binary.BigEndian, &resp) 66 - if err != nil { 67 - return fmt.Errorf("failed to read confirmation of subscription: %w", err) 68 - } 69 - 70 - if resp == server.Subscribed { 71 - return nil 72 - } 73 - 74 - var dataLen uint32 75 - err = binary.Read(conn, binary.BigEndian, &dataLen) 76 - if err != nil { 77 - return fmt.Errorf("received status %s:", resp) 78 - } 79 - 80 - buf := make([]byte, dataLen) 81 - _, err = conn.Read(buf) 82 - if err != nil { 83 - return fmt.Errorf("received status %s:", resp) 84 - } 85 - 86 - return fmt.Errorf("received status %s - %s", resp, buf) 87 - } 88 - 89 - return s.connOperation(op) 90 - } 91 - 92 - // UnsubscribeToTopics will unsubscribe to the provided topics 93 - func (s *Subscriber) UnsubscribeToTopics(topicNames []string) error { 94 - op := func(conn net.Conn) error { 95 - err := binary.Write(conn, binary.BigEndian, server.Unsubscribe) 96 - if err != nil { 97 - return fmt.Errorf("failed to unsubscribe: %w", err) 98 - } 99 - 100 - b, err := json.Marshal(topicNames) 101 - if err != nil { 102 - return fmt.Errorf("failed to marshal topic names: %w", err) 103 - } 104 - 105 - err = binary.Write(conn, binary.BigEndian, uint32(len(b))) 106 - if err != nil { 107 - return fmt.Errorf("failed to write topic data length: %w", err) 108 - } 109 - 110 - _, err = conn.Write(b) 111 - if err != nil { 112 - return fmt.Errorf("failed to unsubscribe to topics: %w", err) 113 - } 114 - 115 - var resp server.Status 116 - err = binary.Read(conn, binary.BigEndian, &resp) 117 - if err != nil { 118 - return fmt.Errorf("failed to read confirmation of unsubscription: %w", err) 119 - } 120 - 121 - if resp == server.Unsubscribed { 122 - return nil 123 - } 124 - 125 - var dataLen uint32 126 - err = binary.Read(conn, binary.BigEndian, &dataLen) 127 - if err != nil { 128 - return fmt.Errorf("received status %s:", resp) 129 - } 130 - 131 - buf := make([]byte, dataLen) 132 - _, err = conn.Read(buf) 133 - if err != nil { 134 - return fmt.Errorf("received status %s:", resp) 135 - } 136 - 137 - return fmt.Errorf("received status %s - %s", resp, buf) 138 - } 139 - 140 - return s.connOperation(op) 141 - } 142 - 143 - // Consumer allows the consumption of messages. If during the consumer receiving messages from the 144 - // server an error occurs, it will be stored in Err 145 - type Consumer struct { 146 - msgs chan Message 147 - // TODO: better error handling? Maybe a channel of errors? 148 - Err error 149 - } 150 - 151 - // Messages returns a channel in which this consumer will put messages onto. It is safe to range over the channel since it will be closed once 152 - // the consumer has finished either due to an error or from being cancelled. 153 - func (c *Consumer) Messages() <-chan Message { 154 - return c.msgs 155 - } 156 - 157 - // Consume will create a consumer and start it running in a go routine. You can then use the Msgs channel of the consumer 158 - // to read the messages 159 - func (s *Subscriber) Consume(ctx context.Context) *Consumer { 160 - consumer := &Consumer{ 161 - msgs: make(chan Message), 162 - } 163 - 164 - go s.consume(ctx, consumer) 165 - 166 - return consumer 167 - } 168 - 169 - func (s *Subscriber) consume(ctx context.Context, consumer *Consumer) { 170 - defer close(consumer.msgs) 171 - for { 172 - if ctx.Err() != nil { 173 - return 174 - } 175 - 176 - msg, err := s.readMessage() 177 - if err != nil { 178 - consumer.Err = err 179 - return 180 - } 181 - 182 - if msg != nil { 183 - consumer.msgs <- *msg 184 - } 185 - } 186 - } 187 - 188 - func (s *Subscriber) readMessage() (*Message, error) { 189 - var msg *Message 190 - op := func(conn net.Conn) error { 191 - err := s.conn.SetReadDeadline(time.Now().Add(time.Second)) 192 - if err != nil { 193 - return err 194 - } 195 - 196 - var topicLen uint64 197 - err = binary.Read(s.conn, binary.BigEndian, &topicLen) 198 - if err != nil { 199 - // TODO: check if this is needed elsewhere. I'm not sure where the read deadline resets.... 200 - if neterr, ok := err.(net.Error); ok && neterr.Timeout() { 201 - return nil 202 - } 203 - return err 204 - } 205 - 206 - topicBuf := make([]byte, topicLen) 207 - _, err = s.conn.Read(topicBuf) 208 - if err != nil { 209 - return err 210 - } 211 - 212 - var dataLen uint64 213 - err = binary.Read(s.conn, binary.BigEndian, &dataLen) 214 - if err != nil { 215 - return err 216 - } 217 - 218 - if dataLen <= 0 { 219 - return nil 220 - } 221 - 222 - dataBuf := make([]byte, dataLen) 223 - _, err = s.conn.Read(dataBuf) 224 - if err != nil { 225 - return err 226 - } 227 - 228 - msg = &Message{ 229 - Data: dataBuf, 230 - Topic: string(topicBuf), 231 - } 232 - 233 - return nil 234 - 235 - } 236 - 237 - err := s.connOperation(op) 238 - if err != nil { 239 - var neterr net.Error 240 - if errors.As(err, &neterr) && neterr.Timeout() { 241 - return nil, nil 242 - } 243 - return nil, err 244 - } 245 - 246 - return msg, err 247 - } 248 - 249 - func (s *Subscriber) connOperation(op connOpp) error { 250 - s.connMu.Lock() 251 - defer s.connMu.Unlock() 252 - 253 - return op(s.conn) 254 - }
-220
pubsub/subscriber_test.go
··· 1 - package pubsub 2 - 3 - import ( 4 - "context" 5 - "fmt" 6 - "testing" 7 - "time" 8 - 9 - "github.com/stretchr/testify/assert" 10 - "github.com/stretchr/testify/require" 11 - 12 - "github.com/willdot/messagebroker/server" 13 - ) 14 - 15 - const ( 16 - serverAddr = ":9999" 17 - topicA = "topic a" 18 - topicB = "topic b" 19 - ) 20 - 21 - func createServer(t *testing.T) { 22 - server, err := server.New(serverAddr) 23 - require.NoError(t, err) 24 - 25 - t.Cleanup(func() { 26 - _ = server.Shutdown() 27 - }) 28 - } 29 - 30 - func TestNewSubscriber(t *testing.T) { 31 - createServer(t) 32 - 33 - sub, err := NewSubscriber(serverAddr) 34 - require.NoError(t, err) 35 - 36 - t.Cleanup(func() { 37 - sub.Close() 38 - }) 39 - } 40 - 41 - func TestNewSubscriberInvalidServerAddr(t *testing.T) { 42 - createServer(t) 43 - 44 - _, err := NewSubscriber(":123456") 45 - require.Error(t, err) 46 - } 47 - 48 - func TestNewPublisher(t *testing.T) { 49 - createServer(t) 50 - 51 - sub, err := NewPublisher(serverAddr) 52 - require.NoError(t, err) 53 - 54 - t.Cleanup(func() { 55 - sub.Close() 56 - }) 57 - } 58 - 59 - func TestNewPublisherInvalidServerAddr(t *testing.T) { 60 - createServer(t) 61 - 62 - _, err := NewPublisher(":123456") 63 - require.Error(t, err) 64 - } 65 - 66 - func TestSubscribeToTopics(t *testing.T) { 67 - createServer(t) 68 - 69 - sub, err := NewSubscriber(serverAddr) 70 - require.NoError(t, err) 71 - 72 - t.Cleanup(func() { 73 - sub.Close() 74 - }) 75 - 76 - topics := []string{topicA, topicB} 77 - 78 - err = sub.SubscribeToTopics(topics) 79 - require.NoError(t, err) 80 - } 81 - 82 - func TestUnsubscribesFromTopic(t *testing.T) { 83 - createServer(t) 84 - 85 - sub, err := NewSubscriber(serverAddr) 86 - require.NoError(t, err) 87 - 88 - t.Cleanup(func() { 89 - sub.Close() 90 - }) 91 - 92 - topics := []string{topicA, topicB} 93 - 94 - err = sub.SubscribeToTopics(topics) 95 - require.NoError(t, err) 96 - 97 - err = sub.UnsubscribeToTopics([]string{topicA}) 98 - require.NoError(t, err) 99 - 100 - ctx, cancel := context.WithCancel(context.Background()) 101 - t.Cleanup(func() { 102 - cancel() 103 - }) 104 - 105 - consumer := sub.Consume(ctx) 106 - require.NoError(t, err) 107 - 108 - var receivedMessages []Message 109 - consumerFinCh := make(chan struct{}) 110 - go func() { 111 - for msg := range consumer.Messages() { 112 - receivedMessages = append(receivedMessages, msg) 113 - } 114 - 115 - require.NoError(t, err) 116 - consumerFinCh <- struct{}{} 117 - }() 118 - 119 - // publish a message to both topics and check the subscriber only gets the message from the 1 topic 120 - // and not the unsubscribed topic 121 - 122 - publisher, err := NewPublisher("localhost:9999") 123 - require.NoError(t, err) 124 - t.Cleanup(func() { 125 - publisher.Close() 126 - }) 127 - 128 - msg := Message{ 129 - Topic: topicA, 130 - Data: []byte("hello world"), 131 - } 132 - 133 - err = publisher.PublishMessage(msg) 134 - require.NoError(t, err) 135 - 136 - msg.Topic = topicB 137 - err = publisher.PublishMessage(msg) 138 - require.NoError(t, err) 139 - 140 - cancel() 141 - 142 - select { 143 - case <-consumerFinCh: 144 - break 145 - case <-time.After(time.Second): 146 - t.Fatal("timed out waiting for consumer to read messages") 147 - } 148 - 149 - assert.Len(t, receivedMessages, 1) 150 - assert.Equal(t, topicB, receivedMessages[0].Topic) 151 - } 152 - 153 - func TestPublishAndSubscribe(t *testing.T) { 154 - createServer(t) 155 - 156 - sub, err := NewSubscriber(serverAddr) 157 - require.NoError(t, err) 158 - 159 - t.Cleanup(func() { 160 - sub.Close() 161 - }) 162 - 163 - topics := []string{topicA, topicB} 164 - 165 - err = sub.SubscribeToTopics(topics) 166 - require.NoError(t, err) 167 - 168 - ctx, cancel := context.WithCancel(context.Background()) 169 - t.Cleanup(func() { 170 - cancel() 171 - }) 172 - 173 - consumer := sub.Consume(ctx) 174 - require.NoError(t, err) 175 - 176 - var receivedMessages []Message 177 - 178 - consumerFinCh := make(chan struct{}) 179 - go func() { 180 - for msg := range consumer.Messages() { 181 - receivedMessages = append(receivedMessages, msg) 182 - } 183 - 184 - require.NoError(t, err) 185 - consumerFinCh <- struct{}{} 186 - }() 187 - 188 - publisher, err := NewPublisher("localhost:9999") 189 - require.NoError(t, err) 190 - t.Cleanup(func() { 191 - publisher.Close() 192 - }) 193 - 194 - // send some messages 195 - sentMessages := make([]Message, 0, 10) 196 - for i := 0; i < 10; i++ { 197 - msg := Message{ 198 - Topic: topicA, 199 - Data: []byte(fmt.Sprintf("message %d", i)), 200 - } 201 - 202 - sentMessages = append(sentMessages, msg) 203 - 204 - err = publisher.PublishMessage(msg) 205 - require.NoError(t, err) 206 - } 207 - 208 - // give the consumer some time to read the messages -- TODO: make better! 209 - time.Sleep(time.Millisecond * 500) 210 - cancel() 211 - 212 - select { 213 - case <-consumerFinCh: 214 - break 215 - case <-time.After(time.Second): 216 - t.Fatal("timed out waiting for consumer to read messages") 217 - } 218 - 219 - assert.ElementsMatch(t, receivedMessages, sentMessages) 220 - }
-36
server/peer/peer.go
··· 1 - package peer 2 - 3 - import ( 4 - "net" 5 - "sync" 6 - ) 7 - 8 - // Peer represents a remote connection to the server such as a publisher or subscriber. 9 - type Peer struct { 10 - conn net.Conn 11 - connMu sync.Mutex 12 - } 13 - 14 - // New returns a new peer. 15 - func New(conn net.Conn) *Peer { 16 - return &Peer{ 17 - conn: conn, 18 - } 19 - } 20 - 21 - // Addr returns the peers connections address. 22 - func (p *Peer) Addr() net.Addr { 23 - return p.conn.RemoteAddr() 24 - } 25 - 26 - // ConnOpp represents a set of actions on a connection that can be used synchrnously. 27 - type ConnOpp func(conn net.Conn) error 28 - 29 - // RunConnOperation will run the provided operation. It ensures that it is the only operation that is being 30 - // run on the connection to ensure any other operations don't get mixed up. 31 - func (p *Peer) RunConnOperation(op ConnOpp) error { 32 - p.connMu.Lock() 33 - defer p.connMu.Unlock() 34 - 35 - return op(p.conn) 36 - }
-443
server/server.go
··· 1 - package server 2 - 3 - import ( 4 - "encoding/binary" 5 - "encoding/json" 6 - "errors" 7 - "fmt" 8 - "io" 9 - "log/slog" 10 - "net" 11 - "strings" 12 - "sync" 13 - "syscall" 14 - "time" 15 - 16 - "github.com/willdot/messagebroker/server/peer" 17 - ) 18 - 19 - // Action represents the type of action that a peer requests to do 20 - type Action uint8 21 - 22 - const ( 23 - Subscribe Action = 1 24 - Unsubscribe Action = 2 25 - Publish Action = 3 26 - ) 27 - 28 - // Status represents the status of a request 29 - type Status uint8 30 - 31 - const ( 32 - Subscribed = 1 33 - Unsubscribed = 2 34 - Error = 3 35 - ) 36 - 37 - func (s Status) String() string { 38 - switch s { 39 - case Subscribed: 40 - return "subsribed" 41 - case Unsubscribed: 42 - return "unsubscribed" 43 - case Error: 44 - return "error" 45 - } 46 - 47 - return "" 48 - } 49 - 50 - // Server accepts subscribe and publish connections and passes messages around 51 - type Server struct { 52 - Addr string 53 - lis net.Listener 54 - 55 - mu sync.Mutex 56 - topics map[string]*topic 57 - } 58 - 59 - // New creates and starts a new server 60 - func New(Addr string) (*Server, error) { 61 - lis, err := net.Listen("tcp", Addr) 62 - if err != nil { 63 - return nil, fmt.Errorf("failed to listen: %w", err) 64 - } 65 - 66 - srv := &Server{ 67 - lis: lis, 68 - topics: map[string]*topic{}, 69 - } 70 - 71 - go srv.start() 72 - 73 - return srv, nil 74 - } 75 - 76 - // Shutdown will cleanly shutdown the server 77 - func (s *Server) Shutdown() error { 78 - return s.lis.Close() 79 - } 80 - 81 - func (s *Server) start() { 82 - for { 83 - conn, err := s.lis.Accept() 84 - if err != nil { 85 - if errors.Is(err, net.ErrClosed) { 86 - slog.Info("listener closed") 87 - return 88 - } 89 - slog.Error("listener failed to accept", "error", err) 90 - continue 91 - } 92 - 93 - go s.handleConn(conn) 94 - } 95 - } 96 - 97 - func (s *Server) handleConn(conn net.Conn) { 98 - peer := peer.New(conn) 99 - 100 - action, err := readAction(peer, 0) 101 - if err != nil { 102 - if !errors.Is(err, io.EOF) { 103 - slog.Error("failed to read action from peer", "error", err, "peer", peer.Addr()) 104 - } 105 - return 106 - } 107 - 108 - switch action { 109 - case Subscribe: 110 - s.handleSubscribe(peer) 111 - case Unsubscribe: 112 - s.handleUnsubscribe(peer) 113 - case Publish: 114 - s.handlePublish(peer) 115 - default: 116 - slog.Error("unknown action", "action", action, "peer", peer.Addr()) 117 - writeInvalidAction(peer) 118 - } 119 - } 120 - 121 - func (s *Server) handleSubscribe(peer *peer.Peer) { 122 - // subscribe the peer to the topic 123 - s.subscribePeerToTopic(peer) 124 - 125 - // keep handling the peers connection, getting the action from the peer when it wishes to do something else. 126 - // once the peers connection ends, it will be unsubscribed from all topics and returned 127 - for { 128 - action, err := readAction(peer, time.Millisecond*100) 129 - if err != nil { 130 - var neterr net.Error 131 - if errors.As(err, &neterr) && neterr.Timeout() { 132 - time.Sleep(time.Second) 133 - continue 134 - } 135 - 136 - if !errors.Is(err, io.EOF) { 137 - slog.Error("failed to read action from subscriber", "error", err, "peer", peer.Addr()) 138 - } 139 - 140 - s.unsubscribePeerFromAllTopics(peer) 141 - 142 - return 143 - } 144 - 145 - switch action { 146 - case Subscribe: 147 - s.subscribePeerToTopic(peer) 148 - case Unsubscribe: 149 - s.handleUnsubscribe(peer) 150 - default: 151 - slog.Error("unknown action for subscriber", "action", action, "peer", peer.Addr()) 152 - writeInvalidAction(peer) 153 - continue 154 - } 155 - } 156 - } 157 - 158 - func (s *Server) subscribePeerToTopic(peer *peer.Peer) { 159 - op := func(conn net.Conn) error { 160 - // get the topics the peer wishes to subscribe to 161 - dataLen, err := dataLength(conn) 162 - if err != nil { 163 - slog.Error(err.Error(), "peer", peer.Addr()) 164 - writeStatus(Error, "invalid data length of topics provided", conn) 165 - return nil 166 - } 167 - if dataLen == 0 { 168 - writeStatus(Error, "data length of topics is 0", conn) 169 - return nil 170 - } 171 - 172 - buf := make([]byte, dataLen) 173 - _, err = conn.Read(buf) 174 - if err != nil { 175 - slog.Error("failed to read subscibers topic data", "error", err, "peer", peer.Addr()) 176 - writeStatus(Error, "failed to read topic data", conn) 177 - return nil 178 - } 179 - 180 - var topics []string 181 - err = json.Unmarshal(buf, &topics) 182 - if err != nil { 183 - slog.Error("failed to unmarshal subscibers topic data", "error", err, "peer", peer.Addr()) 184 - writeStatus(Error, "invalid topic data provided", conn) 185 - return nil 186 - } 187 - 188 - s.subscribeToTopics(peer, topics) 189 - writeStatus(Subscribed, "", conn) 190 - 191 - return nil 192 - } 193 - 194 - _ = peer.RunConnOperation(op) 195 - } 196 - 197 - func (s *Server) handleUnsubscribe(peer *peer.Peer) { 198 - op := func(conn net.Conn) error { 199 - // get the topics the peer wishes to unsubscribe from 200 - dataLen, err := dataLength(conn) 201 - if err != nil { 202 - slog.Error(err.Error(), "peer", peer.Addr()) 203 - writeStatus(Error, "invalid data length of topics provided", conn) 204 - return nil 205 - } 206 - if dataLen == 0 { 207 - writeStatus(Error, "data length of topics is 0", conn) 208 - return nil 209 - } 210 - 211 - buf := make([]byte, dataLen) 212 - _, err = conn.Read(buf) 213 - if err != nil { 214 - slog.Error("failed to read subscibers topic data", "error", err, "peer", peer.Addr()) 215 - writeStatus(Error, "failed to read topic data", conn) 216 - return nil 217 - } 218 - 219 - var topics []string 220 - err = json.Unmarshal(buf, &topics) 221 - if err != nil { 222 - slog.Error("failed to unmarshal subscibers topic data", "error", err, "peer", peer.Addr()) 223 - writeStatus(Error, "invalid topic data provided", conn) 224 - return nil 225 - } 226 - 227 - s.unsubscribeToTopics(peer, topics) 228 - writeStatus(Unsubscribed, "", conn) 229 - 230 - return nil 231 - } 232 - 233 - _ = peer.RunConnOperation(op) 234 - } 235 - 236 - type messageToSend struct { 237 - topic string 238 - data []byte 239 - } 240 - 241 - func (s *Server) handlePublish(peer *peer.Peer) { 242 - for { 243 - var message *messageToSend 244 - 245 - op := func(conn net.Conn) error { 246 - dataLen, err := dataLength(conn) 247 - if err != nil { 248 - if errors.Is(err, io.EOF) { 249 - return nil 250 - } 251 - slog.Error("failed to read data length", "error", err, "peer", peer.Addr()) 252 - writeStatus(Error, "invalid data length of data provided", conn) 253 - return nil 254 - } 255 - if dataLen == 0 { 256 - return nil 257 - } 258 - topicBuf := make([]byte, dataLen) 259 - _, err = conn.Read(topicBuf) 260 - if err != nil { 261 - slog.Error("failed to read topic from peer", "error", err, "peer", peer.Addr()) 262 - writeStatus(Error, "failed to read topic", conn) 263 - return nil 264 - } 265 - 266 - topicStr := string(topicBuf) 267 - if !strings.HasPrefix(topicStr, "topic:") { 268 - slog.Error("topic data does not contain topic prefix", "peer", peer.Addr()) 269 - writeStatus(Error, "topic data does not contain 'topic:' prefix", conn) 270 - return nil 271 - } 272 - topicStr = strings.TrimPrefix(topicStr, "topic:") 273 - 274 - dataLen, err = dataLength(conn) 275 - if err != nil { 276 - slog.Error(err.Error(), "peer", peer.Addr()) 277 - writeStatus(Error, "invalid data length of data provided", conn) 278 - return nil 279 - } 280 - if dataLen == 0 { 281 - return nil 282 - } 283 - 284 - dataBuf := make([]byte, dataLen) 285 - _, err = conn.Read(dataBuf) 286 - if err != nil { 287 - slog.Error("failed to read data from peer", "error", err, "peer", peer.Addr()) 288 - writeStatus(Error, "failed to read data", conn) 289 - return nil 290 - } 291 - 292 - message = &messageToSend{ 293 - topic: topicStr, 294 - data: dataBuf, 295 - } 296 - return nil 297 - } 298 - 299 - _ = peer.RunConnOperation(op) 300 - 301 - if message == nil { 302 - continue 303 - } 304 - // TODO: this can be done in a go routine because once we've got the message from the publisher, the publisher 305 - // doesn't need to wait for us to send the message to all peers 306 - 307 - topic := s.getTopic(message.topic) 308 - if topic != nil { 309 - topic.sendMessageToSubscribers(message.data) 310 - } 311 - } 312 - } 313 - 314 - func (s *Server) subscribeToTopics(peer *peer.Peer, topics []string) { 315 - for _, topic := range topics { 316 - s.addSubsciberToTopic(topic, peer) 317 - } 318 - } 319 - 320 - func (s *Server) addSubsciberToTopic(topicName string, peer *peer.Peer) { 321 - s.mu.Lock() 322 - defer s.mu.Unlock() 323 - 324 - t, ok := s.topics[topicName] 325 - if !ok { 326 - t = newTopic(topicName) 327 - } 328 - 329 - t.subscriptions[peer.Addr()] = subscriber{ 330 - peer: peer, 331 - currentOffset: 0, 332 - } 333 - 334 - s.topics[topicName] = t 335 - } 336 - 337 - func (s *Server) unsubscribeToTopics(peer *peer.Peer, topics []string) { 338 - for _, topic := range topics { 339 - s.removeSubsciberFromTopic(topic, peer) 340 - } 341 - } 342 - 343 - func (s *Server) removeSubsciberFromTopic(topicName string, peer *peer.Peer) { 344 - s.mu.Lock() 345 - defer s.mu.Unlock() 346 - 347 - t, ok := s.topics[topicName] 348 - if !ok { 349 - return 350 - } 351 - 352 - delete(t.subscriptions, peer.Addr()) 353 - } 354 - 355 - func (s *Server) unsubscribePeerFromAllTopics(peer *peer.Peer) { 356 - s.mu.Lock() 357 - defer s.mu.Unlock() 358 - 359 - for _, topic := range s.topics { 360 - delete(topic.subscriptions, peer.Addr()) 361 - } 362 - } 363 - 364 - func (s *Server) getTopic(topicName string) *topic { 365 - s.mu.Lock() 366 - defer s.mu.Unlock() 367 - 368 - if topic, ok := s.topics[topicName]; ok { 369 - return topic 370 - } 371 - 372 - return nil 373 - } 374 - 375 - func readAction(peer *peer.Peer, timeout time.Duration) (Action, error) { 376 - var action Action 377 - op := func(conn net.Conn) error { 378 - if timeout > 0 { 379 - err := conn.SetReadDeadline(time.Now().Add(timeout)) 380 - if err != nil { 381 - slog.Error("failed to set connection read deadline", "error", err, "peer", peer.Addr()) 382 - } 383 - } 384 - 385 - err := binary.Read(conn, binary.BigEndian, &action) 386 - if err != nil { 387 - return err 388 - } 389 - return nil 390 - } 391 - 392 - err := peer.RunConnOperation(op) 393 - if err != nil { 394 - return 0, fmt.Errorf("failed to read action from peer: %w", err) 395 - } 396 - 397 - return action, nil 398 - } 399 - 400 - func writeInvalidAction(peer *peer.Peer) { 401 - op := func(conn net.Conn) error { 402 - writeStatus(Error, "unknown action", conn) 403 - return nil 404 - } 405 - 406 - _ = peer.RunConnOperation(op) 407 - } 408 - 409 - func dataLength(conn net.Conn) (uint32, error) { 410 - var dataLen uint32 411 - err := binary.Read(conn, binary.BigEndian, &dataLen) 412 - if err != nil { 413 - return 0, err 414 - } 415 - return dataLen, nil 416 - } 417 - 418 - func writeStatus(status Status, message string, conn net.Conn) { 419 - err := binary.Write(conn, binary.BigEndian, status) 420 - if err != nil { 421 - if !errors.Is(err, syscall.EPIPE) { 422 - slog.Error("failed to write status to peers connection", "error", err, "peer", conn.RemoteAddr()) 423 - } 424 - return 425 - } 426 - 427 - if message == "" { 428 - return 429 - } 430 - 431 - msgBytes := []byte(message) 432 - err = binary.Write(conn, binary.BigEndian, uint32(len(msgBytes))) 433 - if err != nil { 434 - slog.Error("failed to write message length to peers connection", "error", err, "peer", conn.RemoteAddr()) 435 - return 436 - } 437 - 438 - _, err = conn.Write(msgBytes) 439 - if err != nil { 440 - slog.Error("failed to write message to peers connection", "error", err, "peer", conn.RemoteAddr()) 441 - return 442 - } 443 - }
-345
server/server_test.go
··· 1 - package server 2 - 3 - import ( 4 - "encoding/binary" 5 - "encoding/json" 6 - "fmt" 7 - "net" 8 - "testing" 9 - "time" 10 - 11 - "github.com/stretchr/testify/assert" 12 - "github.com/stretchr/testify/require" 13 - ) 14 - 15 - const ( 16 - topicA = "topic a" 17 - topicB = "topic b" 18 - topicC = "topic c" 19 - 20 - serverAddr = ":6666" 21 - ) 22 - 23 - func createServer(t *testing.T) *Server { 24 - srv, err := New(serverAddr) 25 - require.NoError(t, err) 26 - 27 - t.Cleanup(func() { 28 - _ = srv.Shutdown() 29 - }) 30 - 31 - return srv 32 - } 33 - 34 - func createServerWithExistingTopic(t *testing.T, topicName string) *Server { 35 - srv := createServer(t) 36 - srv.topics[topicName] = &topic{ 37 - name: topicName, 38 - subscriptions: make(map[net.Addr]subscriber), 39 - } 40 - 41 - return srv 42 - } 43 - 44 - func createConnectionAndSubscribe(t *testing.T, topics []string) net.Conn { 45 - conn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 46 - require.NoError(t, err) 47 - 48 - err = binary.Write(conn, binary.BigEndian, Subscribe) 49 - require.NoError(t, err) 50 - 51 - rawTopics, err := json.Marshal(topics) 52 - require.NoError(t, err) 53 - 54 - err = binary.Write(conn, binary.BigEndian, uint32(len(rawTopics))) 55 - require.NoError(t, err) 56 - 57 - _, err = conn.Write(rawTopics) 58 - require.NoError(t, err) 59 - 60 - expectedRes := Subscribed 61 - 62 - var resp Status 63 - err = binary.Read(conn, binary.BigEndian, &resp) 64 - require.NoError(t, err) 65 - 66 - assert.Equal(t, expectedRes, int(resp)) 67 - 68 - return conn 69 - } 70 - 71 - func TestSubscribeToTopics(t *testing.T) { 72 - // create a server with an existing topic so we can test subscribing to a new and 73 - // existing topic 74 - srv := createServerWithExistingTopic(t, topicA) 75 - 76 - _ = createConnectionAndSubscribe(t, []string{topicA, topicB}) 77 - 78 - assert.Len(t, srv.topics, 2) 79 - assert.Len(t, srv.topics[topicA].subscriptions, 1) 80 - assert.Len(t, srv.topics[topicB].subscriptions, 1) 81 - } 82 - 83 - func TestUnsubscribesFromTopic(t *testing.T) { 84 - srv := createServerWithExistingTopic(t, topicA) 85 - 86 - conn := createConnectionAndSubscribe(t, []string{topicA, topicB, topicC}) 87 - 88 - assert.Len(t, srv.topics, 3) 89 - assert.Len(t, srv.topics[topicA].subscriptions, 1) 90 - assert.Len(t, srv.topics[topicB].subscriptions, 1) 91 - assert.Len(t, srv.topics[topicC].subscriptions, 1) 92 - 93 - err := binary.Write(conn, binary.BigEndian, Unsubscribe) 94 - require.NoError(t, err) 95 - 96 - topics := []string{topicA, topicB} 97 - rawTopics, err := json.Marshal(topics) 98 - require.NoError(t, err) 99 - 100 - err = binary.Write(conn, binary.BigEndian, uint32(len(rawTopics))) 101 - require.NoError(t, err) 102 - 103 - _, err = conn.Write(rawTopics) 104 - require.NoError(t, err) 105 - 106 - expectedRes := Unsubscribed 107 - 108 - var resp Status 109 - err = binary.Read(conn, binary.BigEndian, &resp) 110 - require.NoError(t, err) 111 - 112 - assert.Equal(t, expectedRes, int(resp)) 113 - 114 - assert.Len(t, srv.topics, 3) 115 - assert.Len(t, srv.topics[topicA].subscriptions, 0) 116 - assert.Len(t, srv.topics[topicB].subscriptions, 0) 117 - assert.Len(t, srv.topics[topicC].subscriptions, 1) 118 - } 119 - 120 - func TestSubscriberClosesWithoutUnsubscribing(t *testing.T) { 121 - srv := createServer(t) 122 - 123 - conn := createConnectionAndSubscribe(t, []string{topicA, topicB}) 124 - 125 - assert.Len(t, srv.topics, 2) 126 - assert.Len(t, srv.topics[topicA].subscriptions, 1) 127 - assert.Len(t, srv.topics[topicB].subscriptions, 1) 128 - 129 - // close the conn 130 - err := conn.Close() 131 - require.NoError(t, err) 132 - 133 - publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 134 - require.NoError(t, err) 135 - 136 - err = binary.Write(publisherConn, binary.BigEndian, Publish) 137 - require.NoError(t, err) 138 - 139 - data := []byte("hello world") 140 - // send data length first 141 - err = binary.Write(publisherConn, binary.BigEndian, uint32(len(data))) 142 - require.NoError(t, err) 143 - n, err := publisherConn.Write(data) 144 - require.NoError(t, err) 145 - require.Equal(t, len(data), n) 146 - 147 - assert.Len(t, srv.topics, 2) 148 - assert.Len(t, srv.topics[topicA].subscriptions, 0) 149 - assert.Len(t, srv.topics[topicB].subscriptions, 0) 150 - } 151 - 152 - func TestInvalidAction(t *testing.T) { 153 - _ = createServer(t) 154 - 155 - conn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 156 - require.NoError(t, err) 157 - 158 - err = binary.Write(conn, binary.BigEndian, uint8(99)) 159 - require.NoError(t, err) 160 - 161 - expectedRes := Error 162 - 163 - var resp Status 164 - err = binary.Read(conn, binary.BigEndian, &resp) 165 - require.NoError(t, err) 166 - 167 - assert.Equal(t, expectedRes, int(resp)) 168 - 169 - expectedMessage := "unknown action" 170 - 171 - var dataLen uint32 172 - err = binary.Read(conn, binary.BigEndian, &dataLen) 173 - require.NoError(t, err) 174 - assert.Equal(t, len(expectedMessage), int(dataLen)) 175 - 176 - buf := make([]byte, dataLen) 177 - _, err = conn.Read(buf) 178 - require.NoError(t, err) 179 - 180 - assert.Equal(t, expectedMessage, string(buf)) 181 - } 182 - 183 - func TestInvalidTopicDataPublished(t *testing.T) { 184 - _ = createServer(t) 185 - 186 - publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 187 - require.NoError(t, err) 188 - 189 - err = binary.Write(publisherConn, binary.BigEndian, Publish) 190 - require.NoError(t, err) 191 - 192 - // send topic 193 - topic := topicA 194 - err = binary.Write(publisherConn, binary.BigEndian, uint32(len(topic))) 195 - require.NoError(t, err) 196 - _, err = publisherConn.Write([]byte(topic)) 197 - require.NoError(t, err) 198 - 199 - expectedRes := Error 200 - 201 - var resp Status 202 - err = binary.Read(publisherConn, binary.BigEndian, &resp) 203 - require.NoError(t, err) 204 - 205 - assert.Equal(t, expectedRes, int(resp)) 206 - 207 - expectedMessage := "topic data does not contain 'topic:' prefix" 208 - 209 - var dataLen uint32 210 - err = binary.Read(publisherConn, binary.BigEndian, &dataLen) 211 - require.NoError(t, err) 212 - assert.Equal(t, len(expectedMessage), int(dataLen)) 213 - 214 - buf := make([]byte, dataLen) 215 - _, err = publisherConn.Read(buf) 216 - require.NoError(t, err) 217 - 218 - assert.Equal(t, expectedMessage, string(buf)) 219 - } 220 - 221 - func TestSendsDataToTopicSubscribers(t *testing.T) { 222 - _ = createServer(t) 223 - 224 - subscribers := make([]net.Conn, 0, 10) 225 - for i := 0; i < 10; i++ { 226 - subscriberConn := createConnectionAndSubscribe(t, []string{topicA, topicB}) 227 - 228 - subscribers = append(subscribers, subscriberConn) 229 - } 230 - 231 - publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 232 - require.NoError(t, err) 233 - 234 - err = binary.Write(publisherConn, binary.BigEndian, Publish) 235 - require.NoError(t, err) 236 - 237 - topic := fmt.Sprintf("topic:%s", topicA) 238 - messageData := "hello world" 239 - 240 - // send topic first 241 - err = binary.Write(publisherConn, binary.BigEndian, uint32(len(topic))) 242 - require.NoError(t, err) 243 - _, err = publisherConn.Write([]byte(topic)) 244 - require.NoError(t, err) 245 - 246 - // now send the data 247 - err = binary.Write(publisherConn, binary.BigEndian, uint32(len(messageData))) 248 - require.NoError(t, err) 249 - n, err := publisherConn.Write([]byte(messageData)) 250 - require.NoError(t, err) 251 - require.Equal(t, len(messageData), n) 252 - 253 - // check the subsribers got the data 254 - for _, conn := range subscribers { 255 - var topicLen uint64 256 - err = binary.Read(conn, binary.BigEndian, &topicLen) 257 - require.NoError(t, err) 258 - 259 - topicBuf := make([]byte, topicLen) 260 - _, err = conn.Read(topicBuf) 261 - require.NoError(t, err) 262 - assert.Equal(t, topicA, string(topicBuf)) 263 - 264 - var dataLen uint64 265 - err = binary.Read(conn, binary.BigEndian, &dataLen) 266 - require.NoError(t, err) 267 - 268 - buf := make([]byte, dataLen) 269 - n, err := conn.Read(buf) 270 - require.NoError(t, err) 271 - require.Equal(t, int(dataLen), n) 272 - 273 - assert.Equal(t, messageData, string(buf)) 274 - } 275 - } 276 - 277 - func TestPublishMultipleTimes(t *testing.T) { 278 - _ = createServer(t) 279 - 280 - publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 281 - require.NoError(t, err) 282 - 283 - err = binary.Write(publisherConn, binary.BigEndian, Publish) 284 - require.NoError(t, err) 285 - 286 - messages := make([][]byte, 0, 10) 287 - for i := 0; i < 10; i++ { 288 - messages = append(messages, []byte(fmt.Sprintf("message %d", i))) 289 - } 290 - 291 - subscribeFinCh := make(chan struct{}) 292 - // create a subscriber that will read messages 293 - subscriberConn := createConnectionAndSubscribe(t, []string{topicA, topicB}) 294 - go func() { 295 - // check subscriber got all messages 296 - for _, msg := range messages { 297 - var topicLen uint64 298 - err = binary.Read(subscriberConn, binary.BigEndian, &topicLen) 299 - require.NoError(t, err) 300 - 301 - topicBuf := make([]byte, topicLen) 302 - _, err = subscriberConn.Read(topicBuf) 303 - require.NoError(t, err) 304 - assert.Equal(t, topicA, string(topicBuf)) 305 - 306 - var dataLen uint64 307 - err = binary.Read(subscriberConn, binary.BigEndian, &dataLen) 308 - require.NoError(t, err) 309 - 310 - buf := make([]byte, dataLen) 311 - n, err := subscriberConn.Read(buf) 312 - require.NoError(t, err) 313 - require.Equal(t, int(dataLen), n) 314 - 315 - assert.Equal(t, msg, buf) 316 - } 317 - 318 - subscribeFinCh <- struct{}{} 319 - }() 320 - 321 - topic := fmt.Sprintf("topic:%s", topicA) 322 - 323 - // send multiple messages 324 - for _, msg := range messages { 325 - // send topic first 326 - err = binary.Write(publisherConn, binary.BigEndian, uint32(len(topic))) 327 - require.NoError(t, err) 328 - _, err = publisherConn.Write([]byte(topic)) 329 - require.NoError(t, err) 330 - 331 - // now send the data 332 - err = binary.Write(publisherConn, binary.BigEndian, uint32(len(msg))) 333 - require.NoError(t, err) 334 - n, err := publisherConn.Write([]byte(msg)) 335 - require.NoError(t, err) 336 - require.Equal(t, len(msg), n) 337 - } 338 - 339 - select { 340 - case <-subscribeFinCh: 341 - break 342 - case <-time.After(time.Second): 343 - t.Fatal(fmt.Errorf("timed out waiting for subscriber to read messages")) 344 - } 345 - }
-70
server/topic.go
··· 1 - package server 2 - 3 - import ( 4 - "encoding/binary" 5 - "fmt" 6 - "log/slog" 7 - "net" 8 - "sync" 9 - 10 - "github.com/willdot/messagebroker/server/peer" 11 - ) 12 - 13 - type topic struct { 14 - name string 15 - subscriptions map[net.Addr]subscriber 16 - mu sync.Mutex 17 - } 18 - 19 - type subscriber struct { 20 - peer *peer.Peer 21 - currentOffset int 22 - } 23 - 24 - func newTopic(name string) *topic { 25 - return &topic{ 26 - name: name, 27 - subscriptions: make(map[net.Addr]subscriber), 28 - } 29 - } 30 - 31 - func (t *topic) sendMessageToSubscribers(msgData []byte) { 32 - t.mu.Lock() 33 - subscribers := t.subscriptions 34 - t.mu.Unlock() 35 - 36 - for addr, subscriber := range subscribers { 37 - err := subscriber.peer.RunConnOperation(sendMessageOp(t.name, msgData)) 38 - if err != nil { 39 - slog.Error("failed to send to message", "error", err, "peer", addr) 40 - return 41 - } 42 - } 43 - } 44 - 45 - func sendMessageOp(topic string, data []byte) peer.ConnOpp { 46 - return func(conn net.Conn) error { 47 - topicLen := uint64(len(topic)) 48 - err := binary.Write(conn, binary.BigEndian, topicLen) 49 - if err != nil { 50 - return fmt.Errorf("failed to send topic length: %w", err) 51 - } 52 - _, err = conn.Write([]byte(topic)) 53 - if err != nil { 54 - return fmt.Errorf("failed to send topic: %w", err) 55 - } 56 - 57 - dataLen := uint64(len(data)) 58 - 59 - err = binary.Write(conn, binary.BigEndian, dataLen) 60 - if err != nil { 61 - return fmt.Errorf("failed to send data length: %w", err) 62 - } 63 - 64 - _, err = conn.Write(data) 65 - if err != nil { 66 - return fmt.Errorf("failed to write to peer: %w", err) 67 - } 68 - return nil 69 - } 70 - }