+26
.github/workflows/workflow.yaml
+26
.github/workflows/workflow.yaml
···
1
+
name: Go package
2
+
3
+
on: [push]
4
+
5
+
jobs:
6
+
build:
7
+
runs-on: ubuntu-latest
8
+
steps:
9
+
- uses: actions/checkout@v3
10
+
11
+
- name: Set up Go
12
+
uses: actions/setup-go@v4
13
+
with:
14
+
go-version: '1.21'
15
+
16
+
- name: golangci-lint
17
+
run: curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin
18
+
19
+
- name: Build
20
+
run: go build -v ./...
21
+
22
+
- name: lint
23
+
run: golangci-lint run
24
+
25
+
- name: Test
26
+
run: go test ./... -p 1 -count=1 -v
+37
README.md
+37
README.md
···
1
+
# Message Broker
2
+
3
+
Message broker is my attempt at building client / server pub/sub system written in Go using plain `net.Conn`.
4
+
5
+
I decided to try to build one to further my understanding of how such a common tool is used in software engineering.
6
+
7
+
I realised that I took for granted how complicated other pub / sub message brokers were such as NATs, RabbitMQ and Kafka. By creating my own, I hope to dive into and understand more of how message brokers work.
8
+
9
+
## Concept
10
+
11
+
The server accepts TCP connections. When a connection is first established, the client will send an "action" message which determines what type of client it is; subscribe or publish.
12
+
13
+
### Subscribing
14
+
Once a connection has declared itself as a subscribing connection, it will need to then send the list of topics it wishes to initally subscribe to. After that the connection will enter a loop where it can then send a new action; subscribe to new topic(s) or unsubscribe from topic(s).
15
+
16
+
### Publishing
17
+
Once a subscription has declared itself as a publisher, it will enter a loop where it can then send a message for a topic. Once a message has been received, the server will then send it to all connections that are subscribed to that topic.
18
+
19
+
### Sending data via a connection
20
+
21
+
When sending a message representing an action (subscribe, publish etc) then a uint16 binary message is sent.
22
+
23
+
When sending any other data, the length of the data is to be sent first using a binary uint32 and then the actual data sent afterwards.
24
+
25
+
## Running the server
26
+
27
+
There is a server that can be run using `docker-compose up message-server`. This will start a server running listening on port 3000.
28
+
29
+
## Example clients
30
+
There is an example application that implements the subscriber and publishers in the `example` directory.
31
+
32
+
Run `go build .` to build the file.
33
+
34
+
When running the example there are the following flags:
35
+
36
+
`publish` : settings this to true will allow messages to be sent every 500ms as well as consuming
37
+
`consume-from` : this allows you to specify what message to start from. If you don't set this or set it to be -1, you will start consuming from the next sent message.
+23
client/message.go
+23
client/message.go
···
1
+
package client
2
+
3
+
// Message represents a message that can be published or consumed
4
+
type Message struct {
5
+
Topic string `json:"topic"`
6
+
Data []byte `json:"data"`
7
+
8
+
ack chan bool
9
+
}
10
+
11
+
// NewMessage creates a new message
12
+
func NewMessage(topic string, data []byte) *Message {
13
+
return &Message{
14
+
Topic: topic,
15
+
Data: data,
16
+
ack: make(chan bool),
17
+
}
18
+
}
19
+
20
+
// Ack will send the provided value of the ack to the server
21
+
func (m *Message) Ack(ack bool) {
22
+
m.ack <- ack
23
+
}
+113
client/publisher.go
+113
client/publisher.go
···
1
+
package client
2
+
3
+
import (
4
+
"encoding/binary"
5
+
"errors"
6
+
"fmt"
7
+
"log/slog"
8
+
"net"
9
+
"sync"
10
+
"syscall"
11
+
12
+
"github.com/willdot/messagebroker/internal/server"
13
+
)
14
+
15
+
// Publisher allows messages to be published to a server
16
+
type Publisher struct {
17
+
conn net.Conn
18
+
connMu sync.Mutex
19
+
addr string
20
+
}
21
+
22
+
// NewPublisher connects to the server at the given address and registers as a publisher
23
+
func NewPublisher(addr string) (*Publisher, error) {
24
+
conn, err := connect(addr)
25
+
if err != nil {
26
+
return nil, fmt.Errorf("failed to connect to server: %w", err)
27
+
}
28
+
29
+
return &Publisher{
30
+
conn: conn,
31
+
addr: addr,
32
+
}, nil
33
+
}
34
+
35
+
func connect(addr string) (net.Conn, error) {
36
+
conn, err := net.Dial("tcp", addr)
37
+
if err != nil {
38
+
return nil, fmt.Errorf("failed to dial: %w", err)
39
+
}
40
+
41
+
err = binary.Write(conn, binary.BigEndian, server.Publish)
42
+
if err != nil {
43
+
conn.Close()
44
+
return nil, fmt.Errorf("failed to register publish to server: %w", err)
45
+
}
46
+
return conn, nil
47
+
}
48
+
49
+
// Close cleanly shuts down the publisher
50
+
func (p *Publisher) Close() error {
51
+
return p.conn.Close()
52
+
}
53
+
54
+
// Publish will publish the given message to the server
55
+
func (p *Publisher) PublishMessage(message *Message) error {
56
+
return p.publishMessageWithRetry(message, 0)
57
+
}
58
+
59
+
func (p *Publisher) publishMessageWithRetry(message *Message, attempt int) error {
60
+
op := func(conn net.Conn) error {
61
+
// send topic first
62
+
topic := fmt.Sprintf("topic:%s", message.Topic)
63
+
64
+
topicLenB := make([]byte, 2)
65
+
binary.BigEndian.PutUint16(topicLenB, uint16(len(topic)))
66
+
67
+
headers := append(topicLenB, []byte(topic)...)
68
+
69
+
messageLenB := make([]byte, 4)
70
+
binary.BigEndian.PutUint32(messageLenB, uint32(len(message.Data)))
71
+
headers = append(headers, messageLenB...)
72
+
73
+
_, err := conn.Write(append(headers, message.Data...))
74
+
if err != nil {
75
+
return fmt.Errorf("failed to publish data to server: %w", err)
76
+
}
77
+
return nil
78
+
}
79
+
80
+
err := p.connOperation(op)
81
+
if err == nil {
82
+
return nil
83
+
}
84
+
85
+
// we can handle a broken pipe by trying to reconnect, but if it's a different error return it
86
+
if !errors.Is(err, syscall.EPIPE) {
87
+
return err
88
+
}
89
+
90
+
slog.Info("error is broken pipe")
91
+
92
+
if attempt >= 5 {
93
+
return fmt.Errorf("failed to publish message after max attempts to reconnect (%d): %w", attempt, err)
94
+
}
95
+
96
+
slog.Error("failed to publish message", "error", err)
97
+
98
+
conn, connectErr := connect(p.addr)
99
+
if connectErr != nil {
100
+
return fmt.Errorf("failed to reconnect after failing to publish message: %w", connectErr)
101
+
}
102
+
103
+
p.conn = conn
104
+
105
+
return p.publishMessageWithRetry(message, attempt+1)
106
+
}
107
+
108
+
func (p *Publisher) connOperation(op connOpp) error {
109
+
p.connMu.Lock()
110
+
defer p.connMu.Unlock()
111
+
112
+
return op(p.conn)
113
+
}
+371
client/subscriber.go
+371
client/subscriber.go
···
1
+
package client
2
+
3
+
import (
4
+
"context"
5
+
"encoding/binary"
6
+
"encoding/json"
7
+
"errors"
8
+
"fmt"
9
+
"io"
10
+
"log/slog"
11
+
"net"
12
+
"sync"
13
+
"syscall"
14
+
"time"
15
+
16
+
"github.com/willdot/messagebroker/internal/server"
17
+
)
18
+
19
+
type connOpp func(conn net.Conn) error
20
+
21
+
// Subscriber allows subscriptions to a server and the consumption of messages
22
+
type Subscriber struct {
23
+
conn net.Conn
24
+
connMu sync.Mutex
25
+
subscribedTopics []string
26
+
addr string
27
+
}
28
+
29
+
// NewSubscriber will connect to the server at the given address
30
+
func NewSubscriber(addr string) (*Subscriber, error) {
31
+
conn, err := net.Dial("tcp", addr)
32
+
if err != nil {
33
+
return nil, fmt.Errorf("failed to dial: %w", err)
34
+
}
35
+
36
+
return &Subscriber{
37
+
conn: conn,
38
+
addr: addr,
39
+
}, nil
40
+
}
41
+
42
+
func (s *Subscriber) reconnect() error {
43
+
conn, err := net.Dial("tcp", s.addr)
44
+
if err != nil {
45
+
return fmt.Errorf("failed to dial: %w", err)
46
+
}
47
+
48
+
s.conn = conn
49
+
return nil
50
+
}
51
+
52
+
// Close cleanly shuts down the subscriber
53
+
func (s *Subscriber) Close() error {
54
+
return s.conn.Close()
55
+
}
56
+
57
+
// SubscribeToTopics will subscribe to the provided topics
58
+
func (s *Subscriber) SubscribeToTopics(topicNames []string, startAtType server.StartAtType, startAtIndex int) error {
59
+
op := func(conn net.Conn) error {
60
+
return subscribeToTopics(conn, topicNames, startAtType, startAtIndex)
61
+
}
62
+
63
+
err := s.connOperation(op)
64
+
if err != nil {
65
+
return fmt.Errorf("failed to subscribe to topics: %w", err)
66
+
}
67
+
68
+
s.addToSubscribedTopics(topicNames)
69
+
70
+
return nil
71
+
}
72
+
73
+
func (s *Subscriber) addToSubscribedTopics(topics []string) {
74
+
existingSubs := make(map[string]struct{})
75
+
for _, topic := range s.subscribedTopics {
76
+
existingSubs[topic] = struct{}{}
77
+
}
78
+
79
+
for _, topic := range topics {
80
+
existingSubs[topic] = struct{}{}
81
+
}
82
+
83
+
subs := make([]string, 0, len(existingSubs))
84
+
for topic := range existingSubs {
85
+
subs = append(subs, topic)
86
+
}
87
+
88
+
s.subscribedTopics = subs
89
+
}
90
+
91
+
func (s *Subscriber) removeTopicsFromSubscription(topics []string) {
92
+
existingSubs := make(map[string]struct{})
93
+
for _, topic := range s.subscribedTopics {
94
+
existingSubs[topic] = struct{}{}
95
+
}
96
+
97
+
for _, topic := range topics {
98
+
delete(existingSubs, topic)
99
+
}
100
+
101
+
subs := make([]string, 0, len(existingSubs))
102
+
for topic := range existingSubs {
103
+
subs = append(subs, topic)
104
+
}
105
+
106
+
s.subscribedTopics = subs
107
+
}
108
+
109
+
// UnsubscribeToTopics will unsubscribe to the provided topics
110
+
func (s *Subscriber) UnsubscribeToTopics(topicNames []string) error {
111
+
op := func(conn net.Conn) error {
112
+
return unsubscribeToTopics(conn, topicNames)
113
+
}
114
+
115
+
err := s.connOperation(op)
116
+
if err != nil {
117
+
return fmt.Errorf("failed to unsubscribe to topics: %w", err)
118
+
}
119
+
120
+
s.removeTopicsFromSubscription(topicNames)
121
+
122
+
return nil
123
+
}
124
+
125
+
func subscribeToTopics(conn net.Conn, topicNames []string, startAtType server.StartAtType, startAtIndex int) error {
126
+
actionB := make([]byte, 2)
127
+
binary.BigEndian.PutUint16(actionB, uint16(server.Subscribe))
128
+
headers := actionB
129
+
130
+
b, err := json.Marshal(topicNames)
131
+
if err != nil {
132
+
return fmt.Errorf("failed to marshal topic names: %w", err)
133
+
}
134
+
135
+
topicNamesB := make([]byte, 4)
136
+
binary.BigEndian.PutUint32(topicNamesB, uint32(len(b)))
137
+
headers = append(headers, topicNamesB...)
138
+
headers = append(headers, b...)
139
+
140
+
startAtTypeB := make([]byte, 2)
141
+
binary.BigEndian.PutUint16(startAtTypeB, uint16(startAtType))
142
+
headers = append(headers, startAtTypeB...)
143
+
144
+
if startAtType == server.From {
145
+
fromB := make([]byte, 2)
146
+
binary.BigEndian.PutUint16(fromB, uint16(startAtIndex))
147
+
headers = append(headers, fromB...)
148
+
}
149
+
150
+
_, err = conn.Write(headers)
151
+
if err != nil {
152
+
return fmt.Errorf("failed to subscribe to topics: %w", err)
153
+
}
154
+
155
+
var resp server.Status
156
+
err = binary.Read(conn, binary.BigEndian, &resp)
157
+
if err != nil {
158
+
return fmt.Errorf("failed to read confirmation of subscribe: %w", err)
159
+
}
160
+
161
+
if resp == server.Subscribed {
162
+
return nil
163
+
}
164
+
165
+
var dataLen uint16
166
+
err = binary.Read(conn, binary.BigEndian, &dataLen)
167
+
if err != nil {
168
+
return fmt.Errorf("received status %s:", resp)
169
+
}
170
+
171
+
buf := make([]byte, dataLen)
172
+
_, err = conn.Read(buf)
173
+
if err != nil {
174
+
return fmt.Errorf("received status %s:", resp)
175
+
}
176
+
177
+
return fmt.Errorf("received status %s - %s", resp, buf)
178
+
}
179
+
180
+
func unsubscribeToTopics(conn net.Conn, topicNames []string) error {
181
+
actionB := make([]byte, 2)
182
+
binary.BigEndian.PutUint16(actionB, uint16(server.Unsubscribe))
183
+
headers := actionB
184
+
185
+
b, err := json.Marshal(topicNames)
186
+
if err != nil {
187
+
return fmt.Errorf("failed to marshal topic names: %w", err)
188
+
}
189
+
190
+
topicNamesB := make([]byte, 4)
191
+
binary.BigEndian.PutUint32(topicNamesB, uint32(len(b)))
192
+
headers = append(headers, topicNamesB...)
193
+
194
+
_, err = conn.Write(append(headers, b...))
195
+
if err != nil {
196
+
return fmt.Errorf("failed to unsubscribe to topics: %w", err)
197
+
}
198
+
199
+
var resp server.Status
200
+
err = binary.Read(conn, binary.BigEndian, &resp)
201
+
if err != nil {
202
+
return fmt.Errorf("failed to read confirmation of unsubscribe: %w", err)
203
+
}
204
+
205
+
if resp == server.Unsubscribed {
206
+
return nil
207
+
}
208
+
209
+
var dataLen uint16
210
+
err = binary.Read(conn, binary.BigEndian, &dataLen)
211
+
if err != nil {
212
+
return fmt.Errorf("received status %s:", resp)
213
+
}
214
+
215
+
buf := make([]byte, dataLen)
216
+
_, err = conn.Read(buf)
217
+
if err != nil {
218
+
return fmt.Errorf("received status %s:", resp)
219
+
}
220
+
221
+
return fmt.Errorf("received status %s - %s", resp, buf)
222
+
}
223
+
224
+
// Consumer allows the consumption of messages. If during the consumer receiving messages from the
225
+
// server an error occurs, it will be stored in Err
226
+
type Consumer struct {
227
+
msgs chan *Message
228
+
// TODO: better error handling? Maybe a channel of errors?
229
+
Err error
230
+
}
231
+
232
+
// Messages returns a channel in which this consumer will put messages onto. It is safe to range over the channel since it will be closed once
233
+
// the consumer has finished either due to an error or from being cancelled.
234
+
func (c *Consumer) Messages() <-chan *Message {
235
+
return c.msgs
236
+
}
237
+
238
+
// Consume will create a consumer and start it running in a go routine. You can then use the Msgs channel of the consumer
239
+
// to read the messages
240
+
func (s *Subscriber) Consume(ctx context.Context) *Consumer {
241
+
consumer := &Consumer{
242
+
msgs: make(chan *Message),
243
+
}
244
+
245
+
go s.consume(ctx, consumer)
246
+
247
+
return consumer
248
+
}
249
+
250
+
func (s *Subscriber) consume(ctx context.Context, consumer *Consumer) {
251
+
defer close(consumer.msgs)
252
+
for {
253
+
if ctx.Err() != nil {
254
+
return
255
+
}
256
+
257
+
err := s.readMessage(ctx, consumer.msgs)
258
+
if err == nil {
259
+
continue
260
+
}
261
+
262
+
// if we couldn't connect to the server, attempt to reconnect
263
+
if !errors.Is(err, syscall.EPIPE) && !errors.Is(err, io.EOF) {
264
+
slog.Error("failed to read message", "error", err)
265
+
consumer.Err = err
266
+
return
267
+
}
268
+
269
+
slog.Info("attempting to reconnect")
270
+
271
+
for i := 0; i < 5; i++ {
272
+
time.Sleep(time.Millisecond * 500)
273
+
err = s.reconnect()
274
+
if err == nil {
275
+
break
276
+
}
277
+
278
+
slog.Error("Failed to reconnect", "error", err, "attempt", i)
279
+
}
280
+
281
+
slog.Info("attempting to resubscribe")
282
+
283
+
err = s.SubscribeToTopics(s.subscribedTopics, server.Current, 0)
284
+
if err != nil {
285
+
consumer.Err = fmt.Errorf("failed to subscribe to topics after reconnecting: %w", err)
286
+
return
287
+
}
288
+
289
+
}
290
+
}
291
+
292
+
func (s *Subscriber) readMessage(ctx context.Context, msgChan chan *Message) error {
293
+
op := func(conn net.Conn) error {
294
+
err := s.conn.SetReadDeadline(time.Now().Add(time.Millisecond * 300))
295
+
if err != nil {
296
+
return err
297
+
}
298
+
299
+
var topicLen uint16
300
+
err = binary.Read(s.conn, binary.BigEndian, &topicLen)
301
+
if err != nil {
302
+
// TODO: check if this is needed elsewhere. I'm not sure where the read deadline resets....
303
+
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
304
+
return nil
305
+
}
306
+
return err
307
+
}
308
+
309
+
topicBuf := make([]byte, topicLen)
310
+
_, err = s.conn.Read(topicBuf)
311
+
if err != nil {
312
+
return err
313
+
}
314
+
315
+
var dataLen uint64
316
+
err = binary.Read(s.conn, binary.BigEndian, &dataLen)
317
+
if err != nil {
318
+
return err
319
+
}
320
+
321
+
if dataLen <= 0 {
322
+
return nil
323
+
}
324
+
325
+
dataBuf := make([]byte, dataLen)
326
+
_, err = s.conn.Read(dataBuf)
327
+
if err != nil {
328
+
return err
329
+
}
330
+
331
+
msg := NewMessage(string(topicBuf), dataBuf)
332
+
333
+
msgChan <- msg
334
+
335
+
var ack bool
336
+
select {
337
+
case <-ctx.Done():
338
+
return ctx.Err()
339
+
case ack = <-msg.ack:
340
+
}
341
+
ackMessage := server.Nack
342
+
if ack {
343
+
ackMessage = server.Ack
344
+
}
345
+
346
+
err = binary.Write(s.conn, binary.BigEndian, ackMessage)
347
+
if err != nil {
348
+
return fmt.Errorf("failed to ack/nack message: %w", err)
349
+
}
350
+
351
+
return nil
352
+
}
353
+
354
+
err := s.connOperation(op)
355
+
if err != nil {
356
+
var neterr net.Error
357
+
if errors.As(err, &neterr) && neterr.Timeout() {
358
+
return nil
359
+
}
360
+
return err
361
+
}
362
+
363
+
return err
364
+
}
365
+
366
+
func (s *Subscriber) connOperation(op connOpp) error {
367
+
s.connMu.Lock()
368
+
defer s.connMu.Unlock()
369
+
370
+
return op(s.conn)
371
+
}
+273
client/subscriber_test.go
+273
client/subscriber_test.go
···
1
+
package client
2
+
3
+
import (
4
+
"context"
5
+
"fmt"
6
+
"testing"
7
+
"time"
8
+
9
+
"github.com/stretchr/testify/assert"
10
+
"github.com/stretchr/testify/require"
11
+
"github.com/willdot/messagebroker/internal/server"
12
+
)
13
+
14
+
const (
15
+
serverAddr = ":9999"
16
+
topicA = "topic a"
17
+
topicB = "topic b"
18
+
)
19
+
20
+
func createServer(t *testing.T) {
21
+
server, err := server.New(serverAddr, time.Millisecond*100, time.Millisecond*100)
22
+
require.NoError(t, err)
23
+
24
+
t.Cleanup(func() {
25
+
_ = server.Shutdown()
26
+
})
27
+
}
28
+
29
+
func TestNewSubscriber(t *testing.T) {
30
+
createServer(t)
31
+
32
+
sub, err := NewSubscriber(serverAddr)
33
+
require.NoError(t, err)
34
+
35
+
t.Cleanup(func() {
36
+
sub.Close()
37
+
})
38
+
}
39
+
40
+
func TestNewSubscriberInvalidServerAddr(t *testing.T) {
41
+
createServer(t)
42
+
43
+
_, err := NewSubscriber(":123456")
44
+
require.Error(t, err)
45
+
}
46
+
47
+
func TestNewPublisher(t *testing.T) {
48
+
createServer(t)
49
+
50
+
sub, err := NewPublisher(serverAddr)
51
+
require.NoError(t, err)
52
+
53
+
t.Cleanup(func() {
54
+
sub.Close()
55
+
})
56
+
}
57
+
58
+
func TestNewPublisherInvalidServerAddr(t *testing.T) {
59
+
createServer(t)
60
+
61
+
_, err := NewPublisher(":123456")
62
+
require.Error(t, err)
63
+
}
64
+
65
+
func TestSubscribeToTopics(t *testing.T) {
66
+
createServer(t)
67
+
68
+
sub, err := NewSubscriber(serverAddr)
69
+
require.NoError(t, err)
70
+
71
+
t.Cleanup(func() {
72
+
sub.Close()
73
+
})
74
+
75
+
topics := []string{topicA, topicB}
76
+
77
+
err = sub.SubscribeToTopics(topics, server.Current, 0)
78
+
require.NoError(t, err)
79
+
}
80
+
81
+
func TestUnsubscribesFromTopic(t *testing.T) {
82
+
createServer(t)
83
+
84
+
sub, err := NewSubscriber(serverAddr)
85
+
require.NoError(t, err)
86
+
87
+
t.Cleanup(func() {
88
+
sub.Close()
89
+
})
90
+
91
+
topics := []string{topicA, topicB}
92
+
93
+
err = sub.SubscribeToTopics(topics, server.Current, 0)
94
+
require.NoError(t, err)
95
+
96
+
err = sub.UnsubscribeToTopics([]string{topicA})
97
+
require.NoError(t, err)
98
+
99
+
ctx, cancel := context.WithCancel(context.Background())
100
+
t.Cleanup(func() {
101
+
cancel()
102
+
})
103
+
104
+
consumer := sub.Consume(ctx)
105
+
require.NoError(t, err)
106
+
107
+
var receivedMessages []*Message
108
+
consumerFinCh := make(chan struct{})
109
+
go func() {
110
+
for msg := range consumer.Messages() {
111
+
msg.Ack(true)
112
+
receivedMessages = append(receivedMessages, msg)
113
+
}
114
+
115
+
consumerFinCh <- struct{}{}
116
+
}()
117
+
118
+
// publish a message to both topics and check the subscriber only gets the message from the 1 topic
119
+
// and not the unsubscribed topic
120
+
121
+
publisher, err := NewPublisher("localhost:9999")
122
+
require.NoError(t, err)
123
+
t.Cleanup(func() {
124
+
publisher.Close()
125
+
})
126
+
127
+
msg := NewMessage(topicA, []byte("hello world"))
128
+
129
+
err = publisher.PublishMessage(msg)
130
+
require.NoError(t, err)
131
+
132
+
msg.Topic = topicB
133
+
err = publisher.PublishMessage(msg)
134
+
require.NoError(t, err)
135
+
136
+
// give the consumer some time to read the messages -- TODO: make better!
137
+
time.Sleep(time.Millisecond * 300)
138
+
cancel()
139
+
140
+
select {
141
+
case <-consumerFinCh:
142
+
break
143
+
case <-time.After(time.Second):
144
+
t.Fatal("timed out waiting for consumer to read messages")
145
+
}
146
+
147
+
assert.Len(t, receivedMessages, 1)
148
+
assert.Equal(t, topicB, receivedMessages[0].Topic)
149
+
}
150
+
151
+
func TestPublishAndSubscribe(t *testing.T) {
152
+
consumer, cancel := setupConsumer(t)
153
+
154
+
var receivedMessages []*Message
155
+
156
+
consumerFinCh := make(chan struct{})
157
+
go func() {
158
+
for msg := range consumer.Messages() {
159
+
msg.Ack(true)
160
+
receivedMessages = append(receivedMessages, msg)
161
+
}
162
+
163
+
consumerFinCh <- struct{}{}
164
+
}()
165
+
166
+
publisher, err := NewPublisher("localhost:9999")
167
+
require.NoError(t, err)
168
+
t.Cleanup(func() {
169
+
publisher.Close()
170
+
})
171
+
172
+
// send some messages
173
+
sentMessages := make([]*Message, 0, 10)
174
+
for i := 0; i < 10; i++ {
175
+
msg := NewMessage(topicA, []byte(fmt.Sprintf("message %d", i)))
176
+
177
+
sentMessages = append(sentMessages, msg)
178
+
179
+
err = publisher.PublishMessage(msg)
180
+
require.NoError(t, err)
181
+
}
182
+
183
+
// give the consumer some time to read the messages -- TODO: make better!
184
+
time.Sleep(time.Millisecond * 300)
185
+
cancel()
186
+
187
+
select {
188
+
case <-consumerFinCh:
189
+
break
190
+
case <-time.After(time.Second * 5):
191
+
t.Fatal("timed out waiting for consumer to read messages")
192
+
}
193
+
194
+
// THIS IS SO HACKY
195
+
for _, msg := range receivedMessages {
196
+
msg.ack = nil
197
+
}
198
+
199
+
for _, msg := range sentMessages {
200
+
msg.ack = nil
201
+
}
202
+
203
+
assert.ElementsMatch(t, receivedMessages, sentMessages)
204
+
}
205
+
206
+
func TestPublishAndSubscribeNackMessage(t *testing.T) {
207
+
consumer, cancel := setupConsumer(t)
208
+
209
+
var receivedMessages []*Message
210
+
211
+
consumerFinCh := make(chan struct{})
212
+
timesMsgWasReceived := 0
213
+
go func() {
214
+
for msg := range consumer.Messages() {
215
+
msg.Ack(false)
216
+
timesMsgWasReceived++
217
+
}
218
+
219
+
consumerFinCh <- struct{}{}
220
+
}()
221
+
222
+
publisher, err := NewPublisher("localhost:9999")
223
+
require.NoError(t, err)
224
+
t.Cleanup(func() {
225
+
publisher.Close()
226
+
})
227
+
228
+
// send a message
229
+
msg := NewMessage(topicA, []byte("hello world"))
230
+
231
+
err = publisher.PublishMessage(msg)
232
+
require.NoError(t, err)
233
+
234
+
// give the consumer some time to read the messages -- TODO: make better!
235
+
time.Sleep(time.Second)
236
+
cancel()
237
+
238
+
select {
239
+
case <-consumerFinCh:
240
+
break
241
+
case <-time.After(time.Second * 5):
242
+
t.Fatal("timed out waiting for consumer to read messages")
243
+
}
244
+
245
+
assert.Empty(t, receivedMessages)
246
+
assert.Equal(t, 5, timesMsgWasReceived)
247
+
}
248
+
249
+
func setupConsumer(t *testing.T) (*Consumer, context.CancelFunc) {
250
+
createServer(t)
251
+
252
+
sub, err := NewSubscriber(serverAddr)
253
+
require.NoError(t, err)
254
+
255
+
t.Cleanup(func() {
256
+
sub.Close()
257
+
})
258
+
259
+
topics := []string{topicA, topicB}
260
+
261
+
err = sub.SubscribeToTopics(topics, server.Current, 0)
262
+
require.NoError(t, err)
263
+
264
+
ctx, cancel := context.WithCancel(context.Background())
265
+
t.Cleanup(func() {
266
+
cancel()
267
+
})
268
+
269
+
consumer := sub.Consume(ctx)
270
+
require.NoError(t, err)
271
+
272
+
return consumer, cancel
273
+
}
+27
cmd/server/main.go
+27
cmd/server/main.go
···
1
+
package main
2
+
3
+
import (
4
+
"log"
5
+
"os"
6
+
"os/signal"
7
+
"syscall"
8
+
"time"
9
+
10
+
"github.com/willdot/messagebroker/internal/server"
11
+
)
12
+
13
+
func main() {
14
+
srv, err := server.New(":3000", time.Second, time.Second*2)
15
+
if err != nil {
16
+
log.Fatal(err)
17
+
}
18
+
19
+
defer func() {
20
+
_ = srv.Shutdown()
21
+
}()
22
+
23
+
signals := make(chan os.Signal, 1)
24
+
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
25
+
26
+
<-signals
27
+
}
+7
docker-compose.yaml
+7
docker-compose.yaml
+20
dockerfile.server
+20
dockerfile.server
···
1
+
FROM golang:latest as builder
2
+
3
+
WORKDIR /app
4
+
5
+
COPY go.mod go.sum ./
6
+
COPY cmd/server/ ./
7
+
RUN go mod download
8
+
9
+
COPY . .
10
+
11
+
RUN CGO_ENABLED=0 go build -o message-broker-server .
12
+
13
+
FROM alpine:latest
14
+
15
+
RUN apk --no-cache add ca-certificates
16
+
17
+
WORKDIR /root/
18
+
COPY --from=builder /app/message-broker-server .
19
+
20
+
CMD ["./message-broker-server"]
+98
example/main.go
+98
example/main.go
···
1
+
package main
2
+
3
+
import (
4
+
"context"
5
+
"flag"
6
+
"fmt"
7
+
"log/slog"
8
+
"time"
9
+
10
+
"github.com/willdot/messagebroker/client"
11
+
"github.com/willdot/messagebroker/internal/server"
12
+
)
13
+
14
+
// var publish *bool
15
+
// var consume *bool
16
+
var consumeFrom *int
17
+
var clientType *string
18
+
19
+
const (
20
+
topic = "topic-a"
21
+
)
22
+
23
+
func main() {
24
+
clientType = flag.String("client-type", "consume", "consume or publish (default consume)")
25
+
// publish = flag.Bool("publish", false, "will publish messages every 500ms until client is stopped")
26
+
// consume = flag.Bool("consume", false, "will consume messages until client is stopped")
27
+
consumeFrom = flag.Int("consume-from", -1, "index of message to start consuming from. If not set it will consume from the most recent")
28
+
flag.Parse()
29
+
30
+
switch *clientType {
31
+
case "consume":
32
+
consume()
33
+
case "publish":
34
+
sendMessages()
35
+
default:
36
+
fmt.Println("unknown client type")
37
+
}
38
+
}
39
+
40
+
func consume() {
41
+
sub, err := client.NewSubscriber(":3000")
42
+
if err != nil {
43
+
panic(err)
44
+
}
45
+
46
+
defer func() {
47
+
_ = sub.Close()
48
+
}()
49
+
startAt := 0
50
+
startAtType := server.Current
51
+
if *consumeFrom > -1 {
52
+
startAtType = server.From
53
+
startAt = *consumeFrom
54
+
}
55
+
56
+
err = sub.SubscribeToTopics([]string{topic}, startAtType, startAt)
57
+
if err != nil {
58
+
panic(err)
59
+
}
60
+
61
+
consumer := sub.Consume(context.Background())
62
+
if consumer.Err != nil {
63
+
panic(err)
64
+
}
65
+
66
+
for msg := range consumer.Messages() {
67
+
slog.Info("received message", "message", string(msg.Data))
68
+
msg.Ack(true)
69
+
}
70
+
}
71
+
72
+
func sendMessages() {
73
+
publisher, err := client.NewPublisher("localhost:3000")
74
+
if err != nil {
75
+
panic(err)
76
+
}
77
+
78
+
defer func() {
79
+
_ = publisher.Close()
80
+
}()
81
+
82
+
// send some messages
83
+
i := 0
84
+
for {
85
+
i++
86
+
msg := client.NewMessage(topic, []byte(fmt.Sprintf("message %d", i)))
87
+
88
+
err = publisher.PublishMessage(msg)
89
+
if err != nil {
90
+
slog.Error("failed to publish message", "error", err)
91
+
continue
92
+
}
93
+
94
+
slog.Info("message sent")
95
+
96
+
time.Sleep(time.Millisecond * 500)
97
+
}
98
+
}
+47
internal/messagestore/memory_store.go
+47
internal/messagestore/memory_store.go
···
1
+
package messagestore
2
+
3
+
import (
4
+
"sync"
5
+
6
+
"github.com/willdot/messagebroker/internal"
7
+
)
8
+
9
+
// MemoryStore allows messages to be stored in memory
10
+
type MemoryStore struct {
11
+
mu sync.Mutex
12
+
msgs map[int]internal.Message
13
+
nextOffset int
14
+
}
15
+
16
+
// NewMemoryStore initializes a new in memory store
17
+
func NewMemoryStore() *MemoryStore {
18
+
return &MemoryStore{
19
+
msgs: make(map[int]internal.Message),
20
+
}
21
+
}
22
+
23
+
// Write will write the provided message to the in memory store
24
+
func (m *MemoryStore) Write(msg internal.Message) error {
25
+
m.mu.Lock()
26
+
defer m.mu.Unlock()
27
+
28
+
m.msgs[m.nextOffset] = msg
29
+
30
+
m.nextOffset++
31
+
32
+
return nil
33
+
}
34
+
35
+
// ReadFrom will read messages from (and including) the provided offset and pass them to the provided handler
36
+
func (m *MemoryStore) ReadFrom(offset int, handleFunc func(msg internal.Message)) {
37
+
if offset < 0 || offset >= m.nextOffset {
38
+
return
39
+
}
40
+
41
+
m.mu.Lock()
42
+
defer m.mu.Unlock()
43
+
44
+
for i := offset; i < len(m.msgs); i++ {
45
+
handleFunc(m.msgs[i])
46
+
}
47
+
}
+12
internal/messge.go
+12
internal/messge.go
···
1
+
package internal
2
+
3
+
// Message represents a message that can be sent / received
4
+
type Message struct {
5
+
Data []byte
6
+
DeliveryCount int
7
+
}
8
+
9
+
// NewMessage intializes a new message
10
+
func NewMessage(data []byte) Message {
11
+
return Message{Data: data, DeliveryCount: 1}
12
+
}
+36
internal/server/peer.go
+36
internal/server/peer.go
···
1
+
package server
2
+
3
+
import (
4
+
"net"
5
+
"sync"
6
+
)
7
+
8
+
// Peer represents a remote connection to the server such as a publisher or subscriber.
9
+
type Peer struct {
10
+
conn net.Conn
11
+
connMu sync.Mutex
12
+
}
13
+
14
+
// New returns a new peer.
15
+
func NewPeer(conn net.Conn) *Peer {
16
+
return &Peer{
17
+
conn: conn,
18
+
}
19
+
}
20
+
21
+
// Addr returns the peers connections address.
22
+
func (p *Peer) Addr() net.Addr {
23
+
return p.conn.RemoteAddr()
24
+
}
25
+
26
+
// ConnOpp represents a set of actions on a connection that can be used synchrnously.
27
+
type ConnOpp func(conn net.Conn) error
28
+
29
+
// RunConnOperation will run the provided operation. It ensures that it is the only operation that is being
30
+
// run on the connection to ensure any other operations don't get mixed up.
31
+
func (p *Peer) RunConnOperation(op ConnOpp) error {
32
+
p.connMu.Lock()
33
+
defer p.connMu.Unlock()
34
+
35
+
return op(p.conn)
36
+
}
+513
internal/server/server.go
+513
internal/server/server.go
···
1
+
package server
2
+
3
+
import (
4
+
"encoding/binary"
5
+
"encoding/json"
6
+
"errors"
7
+
"fmt"
8
+
"io"
9
+
"log/slog"
10
+
"net"
11
+
"strings"
12
+
"sync"
13
+
"syscall"
14
+
"time"
15
+
16
+
"github.com/willdot/messagebroker/internal"
17
+
)
18
+
19
+
// Action represents the type of action that a peer requests to do
20
+
type Action uint16
21
+
22
+
const (
23
+
Subscribe Action = 1
24
+
Unsubscribe Action = 2
25
+
Publish Action = 3
26
+
Ack Action = 4
27
+
Nack Action = 5
28
+
)
29
+
30
+
// Status represents the status of a request
31
+
type Status uint16
32
+
33
+
const (
34
+
Subscribed Status = 1
35
+
Unsubscribed Status = 2
36
+
Error Status = 3
37
+
)
38
+
39
+
func (s Status) String() string {
40
+
switch s {
41
+
case Subscribed:
42
+
return "subscribed"
43
+
case Unsubscribed:
44
+
return "unsubscribed"
45
+
case Error:
46
+
return "error"
47
+
}
48
+
49
+
return ""
50
+
}
51
+
52
+
// StartAtType represents where the subcriber wishes to start subscribing to a topic from
53
+
type StartAtType uint16
54
+
55
+
const (
56
+
Beginning StartAtType = 0
57
+
Current StartAtType = 1
58
+
From StartAtType = 2
59
+
)
60
+
61
+
// Server accepts subscribe and publish connections and passes messages around
62
+
type Server struct {
63
+
Addr string
64
+
lis net.Listener
65
+
66
+
mu sync.Mutex
67
+
topics map[string]*topic
68
+
69
+
ackDelay time.Duration
70
+
ackTimeout time.Duration
71
+
}
72
+
73
+
// New creates and starts a new server
74
+
func New(Addr string, ackDelay, ackTimeout time.Duration) (*Server, error) {
75
+
lis, err := net.Listen("tcp", Addr)
76
+
if err != nil {
77
+
return nil, fmt.Errorf("failed to listen: %w", err)
78
+
}
79
+
80
+
srv := &Server{
81
+
lis: lis,
82
+
topics: map[string]*topic{},
83
+
ackDelay: ackDelay,
84
+
ackTimeout: ackTimeout,
85
+
}
86
+
87
+
go srv.start()
88
+
89
+
return srv, nil
90
+
}
91
+
92
+
// Shutdown will cleanly shutdown the server
93
+
func (s *Server) Shutdown() error {
94
+
return s.lis.Close()
95
+
}
96
+
97
+
func (s *Server) start() {
98
+
for {
99
+
conn, err := s.lis.Accept()
100
+
if err != nil {
101
+
if errors.Is(err, net.ErrClosed) {
102
+
slog.Info("listener closed")
103
+
return
104
+
}
105
+
slog.Error("listener failed to accept", "error", err)
106
+
continue
107
+
}
108
+
109
+
go s.handleConn(conn)
110
+
}
111
+
}
112
+
113
+
func (s *Server) handleConn(conn net.Conn) {
114
+
peer := NewPeer(conn)
115
+
116
+
slog.Info("handling connection", "peer", peer.Addr())
117
+
defer slog.Info("ending connection", "peer", peer.Addr())
118
+
119
+
action, err := readAction(peer, 0)
120
+
if err != nil {
121
+
if !errors.Is(err, io.EOF) {
122
+
slog.Error("failed to read action from peer", "error", err, "peer", peer.Addr())
123
+
}
124
+
return
125
+
}
126
+
127
+
switch action {
128
+
case Subscribe:
129
+
s.handleSubscribe(peer)
130
+
case Unsubscribe:
131
+
s.handleUnsubscribe(peer)
132
+
case Publish:
133
+
s.handlePublish(peer)
134
+
default:
135
+
slog.Error("unknown action", "action", action, "peer", peer.Addr())
136
+
writeInvalidAction(peer)
137
+
}
138
+
}
139
+
140
+
func (s *Server) handleSubscribe(peer *Peer) {
141
+
slog.Info("handling subscriber", "peer", peer.Addr())
142
+
// subscribe the peer to the topic
143
+
s.subscribePeerToTopic(peer)
144
+
145
+
s.waitForPeerAction(peer)
146
+
}
147
+
148
+
func (s *Server) waitForPeerAction(peer *Peer) {
149
+
// keep handling the peers connection, getting the action from the peer when it wishes to do something else.
150
+
// once the peers connection ends, it will be unsubscribed from all topics and returned
151
+
for {
152
+
action, err := readAction(peer, time.Millisecond*100)
153
+
if err != nil {
154
+
// if the error is a timeout, it means the peer hasn't sent an action indicating it wishes to do something so sleep
155
+
// for a little bit to allow for other actions to happen on the connection
156
+
var neterr net.Error
157
+
if errors.As(err, &neterr) && neterr.Timeout() {
158
+
time.Sleep(time.Millisecond * 500)
159
+
continue
160
+
}
161
+
162
+
if !errors.Is(err, io.EOF) {
163
+
slog.Error("failed to read action from subscriber", "error", err, "peer", peer.Addr())
164
+
}
165
+
166
+
s.unsubscribePeerFromAllTopics(peer)
167
+
168
+
return
169
+
}
170
+
171
+
switch action {
172
+
case Subscribe:
173
+
s.subscribePeerToTopic(peer)
174
+
case Unsubscribe:
175
+
s.handleUnsubscribe(peer)
176
+
default:
177
+
slog.Error("unknown action for subscriber", "action", action, "peer", peer.Addr())
178
+
writeInvalidAction(peer)
179
+
continue
180
+
}
181
+
}
182
+
}
183
+
184
+
func (s *Server) subscribePeerToTopic(peer *Peer) {
185
+
op := func(conn net.Conn) error {
186
+
// get the topics the peer wishes to subscribe to
187
+
dataLen, err := dataLengthUint32(conn)
188
+
if err != nil {
189
+
slog.Error(err.Error(), "peer", peer.Addr())
190
+
writeStatus(Error, "invalid data length of topics provided", conn)
191
+
return nil
192
+
}
193
+
if dataLen == 0 {
194
+
writeStatus(Error, "data length of topics is 0", conn)
195
+
return nil
196
+
}
197
+
198
+
buf := make([]byte, dataLen)
199
+
_, err = conn.Read(buf)
200
+
if err != nil {
201
+
slog.Error("failed to read subscibers topic data", "error", err, "peer", peer.Addr())
202
+
writeStatus(Error, "failed to read topic data", conn)
203
+
return nil
204
+
}
205
+
206
+
var topics []string
207
+
err = json.Unmarshal(buf, &topics)
208
+
if err != nil {
209
+
slog.Error("failed to unmarshal subscibers topic data", "error", err, "peer", peer.Addr())
210
+
writeStatus(Error, "invalid topic data provided", conn)
211
+
return nil
212
+
}
213
+
214
+
var startAtType StartAtType
215
+
err = binary.Read(conn, binary.BigEndian, &startAtType)
216
+
if err != nil {
217
+
slog.Error(err.Error(), "peer", peer.Addr())
218
+
writeStatus(Error, "invalid start at type provided", conn)
219
+
return nil
220
+
}
221
+
var startAt int
222
+
switch startAtType {
223
+
case From:
224
+
var s uint16
225
+
err = binary.Read(conn, binary.BigEndian, &s)
226
+
if err != nil {
227
+
slog.Error(err.Error(), "peer", peer.Addr())
228
+
writeStatus(Error, "invalid start at value provided", conn)
229
+
return nil
230
+
}
231
+
startAt = int(s)
232
+
case Beginning:
233
+
startAt = 0
234
+
case Current:
235
+
startAt = -1
236
+
default:
237
+
slog.Error("invalid start up type provided", "start up type", startAtType)
238
+
writeStatus(Error, "invalid start up type provided", conn)
239
+
return nil
240
+
}
241
+
242
+
s.subscribeToTopics(peer, topics, startAt)
243
+
writeStatus(Subscribed, "", conn)
244
+
245
+
return nil
246
+
}
247
+
248
+
_ = peer.RunConnOperation(op)
249
+
}
250
+
251
+
func (s *Server) handleUnsubscribe(peer *Peer) {
252
+
slog.Info("handling unsubscriber", "peer", peer.Addr())
253
+
op := func(conn net.Conn) error {
254
+
// get the topics the peer wishes to unsubscribe from
255
+
dataLen, err := dataLengthUint32(conn)
256
+
if err != nil {
257
+
slog.Error(err.Error(), "peer", peer.Addr())
258
+
writeStatus(Error, "invalid data length of topics provided", conn)
259
+
return nil
260
+
}
261
+
if dataLen == 0 {
262
+
writeStatus(Error, "data length of topics is 0", conn)
263
+
return nil
264
+
}
265
+
266
+
buf := make([]byte, dataLen)
267
+
_, err = conn.Read(buf)
268
+
if err != nil {
269
+
slog.Error("failed to read subscibers topic data", "error", err, "peer", peer.Addr())
270
+
writeStatus(Error, "failed to read topic data", conn)
271
+
return nil
272
+
}
273
+
274
+
var topics []string
275
+
err = json.Unmarshal(buf, &topics)
276
+
if err != nil {
277
+
slog.Error("failed to unmarshal subscibers topic data", "error", err, "peer", peer.Addr())
278
+
writeStatus(Error, "invalid topic data provided", conn)
279
+
return nil
280
+
}
281
+
282
+
s.unsubscribeToTopics(peer, topics)
283
+
writeStatus(Unsubscribed, "", conn)
284
+
285
+
return nil
286
+
}
287
+
288
+
_ = peer.RunConnOperation(op)
289
+
}
290
+
291
+
func (s *Server) handlePublish(peer *Peer) {
292
+
slog.Info("handling publisher", "peer", peer.Addr())
293
+
for {
294
+
op := func(conn net.Conn) error {
295
+
topicDataLen, err := dataLengthUint16(conn)
296
+
if err != nil {
297
+
if errors.Is(err, io.EOF) {
298
+
return nil
299
+
}
300
+
slog.Error("failed to read data length", "error", err, "peer", peer.Addr())
301
+
writeStatus(Error, "invalid data length of data provided", conn)
302
+
return nil
303
+
}
304
+
if topicDataLen == 0 {
305
+
return nil
306
+
}
307
+
topicBuf := make([]byte, topicDataLen)
308
+
_, err = conn.Read(topicBuf)
309
+
if err != nil {
310
+
slog.Error("failed to read topic from peer", "error", err, "peer", peer.Addr())
311
+
writeStatus(Error, "failed to read topic", conn)
312
+
return nil
313
+
}
314
+
315
+
topicStr := string(topicBuf)
316
+
if !strings.HasPrefix(topicStr, "topic:") {
317
+
slog.Error("topic data does not contain topic prefix", "peer", peer.Addr())
318
+
writeStatus(Error, "topic data does not contain 'topic:' prefix", conn)
319
+
return nil
320
+
}
321
+
topicStr = strings.TrimPrefix(topicStr, "topic:")
322
+
323
+
msgDataLen, err := dataLengthUint32(conn)
324
+
if err != nil {
325
+
slog.Error(err.Error(), "peer", peer.Addr())
326
+
writeStatus(Error, "invalid data length of data provided", conn)
327
+
return nil
328
+
}
329
+
if msgDataLen == 0 {
330
+
return nil
331
+
}
332
+
333
+
dataBuf := make([]byte, msgDataLen)
334
+
_, err = conn.Read(dataBuf)
335
+
if err != nil {
336
+
slog.Error("failed to read data from peer", "error", err, "peer", peer.Addr())
337
+
writeStatus(Error, "failed to read data", conn)
338
+
return nil
339
+
}
340
+
341
+
topic := s.getTopic(topicStr)
342
+
if topic == nil {
343
+
topic = newTopic(topicStr)
344
+
s.topics[topicStr] = topic
345
+
}
346
+
347
+
message := internal.NewMessage(dataBuf)
348
+
349
+
err = topic.sendMessageToSubscribers(message)
350
+
if err != nil {
351
+
slog.Error("failed to send message to subscribers", "error", err, "peer", peer.Addr())
352
+
writeStatus(Error, "failed to send message to subscribers", conn)
353
+
return nil
354
+
}
355
+
356
+
return nil
357
+
}
358
+
359
+
_ = peer.RunConnOperation(op)
360
+
}
361
+
}
362
+
363
+
func (s *Server) subscribeToTopics(peer *Peer, topics []string, startAt int) {
364
+
slog.Info("subscribing peer to topics", "topics", topics, "peer", peer.Addr())
365
+
for _, topic := range topics {
366
+
s.addSubsciberToTopic(topic, peer, startAt)
367
+
}
368
+
}
369
+
370
+
func (s *Server) addSubsciberToTopic(topicName string, peer *Peer, startAt int) {
371
+
s.mu.Lock()
372
+
defer s.mu.Unlock()
373
+
374
+
t, ok := s.topics[topicName]
375
+
if !ok {
376
+
t = newTopic(topicName)
377
+
}
378
+
379
+
t.mu.Lock()
380
+
t.subscriptions[peer.Addr()] = newSubscriber(peer, t, s.ackDelay, s.ackTimeout, startAt)
381
+
t.mu.Unlock()
382
+
383
+
s.topics[topicName] = t
384
+
}
385
+
386
+
func (s *Server) unsubscribeToTopics(peer *Peer, topics []string) {
387
+
slog.Info("unsubscribing peer from topics", "topics", topics, "peer", peer.Addr())
388
+
for _, topic := range topics {
389
+
s.removeSubsciberFromTopic(topic, peer)
390
+
}
391
+
}
392
+
393
+
func (s *Server) removeSubsciberFromTopic(topicName string, peer *Peer) {
394
+
s.mu.Lock()
395
+
defer s.mu.Unlock()
396
+
397
+
t, ok := s.topics[topicName]
398
+
if !ok {
399
+
return
400
+
}
401
+
402
+
sub := t.findSubscription(peer.Addr())
403
+
if sub == nil {
404
+
return
405
+
}
406
+
407
+
sub.unsubscribe()
408
+
t.removeSubscription(peer.Addr())
409
+
}
410
+
411
+
func (s *Server) unsubscribePeerFromAllTopics(peer *Peer) {
412
+
s.mu.Lock()
413
+
defer s.mu.Unlock()
414
+
415
+
for _, t := range s.topics {
416
+
sub := t.findSubscription(peer.Addr())
417
+
if sub == nil {
418
+
return
419
+
}
420
+
421
+
sub.unsubscribe()
422
+
t.removeSubscription(peer.Addr())
423
+
}
424
+
}
425
+
426
+
func (s *Server) getTopic(topicName string) *topic {
427
+
s.mu.Lock()
428
+
defer s.mu.Unlock()
429
+
430
+
if topic, ok := s.topics[topicName]; ok {
431
+
return topic
432
+
}
433
+
434
+
return nil
435
+
}
436
+
437
+
func readAction(peer *Peer, timeout time.Duration) (Action, error) {
438
+
var action Action
439
+
op := func(conn net.Conn) error {
440
+
if timeout > 0 {
441
+
if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
442
+
slog.Error("failed to set connection read deadline", "error", err, "peer", peer.Addr())
443
+
}
444
+
defer func() {
445
+
if err := conn.SetReadDeadline(time.Time{}); err != nil {
446
+
slog.Error("failed to reset connection read deadline", "error", err, "peer", peer.Addr())
447
+
}
448
+
}()
449
+
}
450
+
451
+
err := binary.Read(conn, binary.BigEndian, &action)
452
+
if err != nil {
453
+
return err
454
+
}
455
+
return nil
456
+
}
457
+
458
+
err := peer.RunConnOperation(op)
459
+
if err != nil {
460
+
return 0, fmt.Errorf("failed to read action from peer: %w", err)
461
+
}
462
+
463
+
return action, nil
464
+
}
465
+
466
+
func writeInvalidAction(peer *Peer) {
467
+
op := func(conn net.Conn) error {
468
+
writeStatus(Error, "unknown action", conn)
469
+
return nil
470
+
}
471
+
472
+
_ = peer.RunConnOperation(op)
473
+
}
474
+
475
+
func dataLengthUint32(conn net.Conn) (uint32, error) {
476
+
var dataLen uint32
477
+
err := binary.Read(conn, binary.BigEndian, &dataLen)
478
+
if err != nil {
479
+
return 0, err
480
+
}
481
+
return dataLen, nil
482
+
}
483
+
484
+
func dataLengthUint16(conn net.Conn) (uint16, error) {
485
+
var dataLen uint16
486
+
err := binary.Read(conn, binary.BigEndian, &dataLen)
487
+
if err != nil {
488
+
return 0, err
489
+
}
490
+
return dataLen, nil
491
+
}
492
+
493
+
func writeStatus(status Status, message string, conn net.Conn) {
494
+
statusB := make([]byte, 2)
495
+
binary.BigEndian.PutUint16(statusB, uint16(status))
496
+
497
+
headers := statusB
498
+
499
+
if len(message) > 0 {
500
+
sizeB := make([]byte, 2)
501
+
binary.BigEndian.PutUint16(sizeB, uint16(len(message)))
502
+
headers = append(headers, sizeB...)
503
+
}
504
+
505
+
msgBytes := []byte(message)
506
+
_, err := conn.Write(append(headers, msgBytes...))
507
+
if err != nil {
508
+
if !errors.Is(err, syscall.EPIPE) {
509
+
slog.Error("failed to write status to peers connection", "error", err, "peer", conn.RemoteAddr())
510
+
}
511
+
return
512
+
}
513
+
}
+632
internal/server/server_test.go
+632
internal/server/server_test.go
···
1
+
package server
2
+
3
+
import (
4
+
"encoding/binary"
5
+
"encoding/json"
6
+
"fmt"
7
+
"net"
8
+
"testing"
9
+
"time"
10
+
11
+
"github.com/stretchr/testify/assert"
12
+
"github.com/stretchr/testify/require"
13
+
"github.com/willdot/messagebroker/internal/messagestore"
14
+
)
15
+
16
+
const (
17
+
topicA = "topic a"
18
+
topicB = "topic b"
19
+
topicC = "topic c"
20
+
21
+
serverAddr = ":6666"
22
+
23
+
ackDelay = time.Millisecond * 100
24
+
ackTimeout = time.Millisecond * 100
25
+
)
26
+
27
+
func createServer(t *testing.T) *Server {
28
+
srv, err := New(serverAddr, ackDelay, ackTimeout)
29
+
require.NoError(t, err)
30
+
31
+
t.Cleanup(func() {
32
+
_ = srv.Shutdown()
33
+
})
34
+
35
+
return srv
36
+
}
37
+
38
+
func createServerWithExistingTopic(t *testing.T, topicName string) *Server {
39
+
srv := createServer(t)
40
+
srv.topics[topicName] = &topic{
41
+
name: topicName,
42
+
subscriptions: make(map[net.Addr]*subscriber),
43
+
messageStore: messagestore.NewMemoryStore(),
44
+
}
45
+
46
+
return srv
47
+
}
48
+
49
+
func createConnectionAndSubscribe(t *testing.T, topics []string, startAtType StartAtType, startAtIndex int) net.Conn {
50
+
conn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr))
51
+
require.NoError(t, err)
52
+
53
+
subscribeToTopics(t, conn, topics, startAtType, startAtIndex)
54
+
55
+
expectedRes := Subscribed
56
+
57
+
var resp Status
58
+
err = binary.Read(conn, binary.BigEndian, &resp)
59
+
require.NoError(t, err)
60
+
61
+
assert.Equal(t, expectedRes, resp)
62
+
63
+
return conn
64
+
}
65
+
66
+
func sendMessage(t *testing.T, conn net.Conn, topic string, message []byte) {
67
+
topicLenB := make([]byte, 4)
68
+
binary.BigEndian.PutUint32(topicLenB, uint32(len(topic)))
69
+
70
+
headers := topicLenB
71
+
headers = append(headers, []byte(topic)...)
72
+
73
+
messageLenB := make([]byte, 4)
74
+
binary.BigEndian.PutUint32(messageLenB, uint32(len(message)))
75
+
headers = append(headers, messageLenB...)
76
+
77
+
_, err := conn.Write(append(headers, message...))
78
+
require.NoError(t, err)
79
+
}
80
+
81
+
func subscribeToTopics(t *testing.T, conn net.Conn, topics []string, startAtType StartAtType, startAtIndex int) {
82
+
actionB := make([]byte, 2)
83
+
binary.BigEndian.PutUint16(actionB, uint16(Subscribe))
84
+
headers := actionB
85
+
86
+
b, err := json.Marshal(topics)
87
+
require.NoError(t, err)
88
+
89
+
topicNamesB := make([]byte, 4)
90
+
binary.BigEndian.PutUint32(topicNamesB, uint32(len(b)))
91
+
headers = append(headers, topicNamesB...)
92
+
headers = append(headers, b...)
93
+
94
+
startAtTypeB := make([]byte, 2)
95
+
binary.BigEndian.PutUint16(startAtTypeB, uint16(startAtType))
96
+
headers = append(headers, startAtTypeB...)
97
+
98
+
if startAtType == From {
99
+
fromB := make([]byte, 2)
100
+
binary.BigEndian.PutUint16(fromB, uint16(startAtIndex))
101
+
headers = append(headers, fromB...)
102
+
}
103
+
104
+
_, err = conn.Write(headers)
105
+
require.NoError(t, err)
106
+
}
107
+
108
+
func unsubscribetoTopics(t *testing.T, conn net.Conn, topics []string) {
109
+
actionB := make([]byte, 2)
110
+
binary.BigEndian.PutUint16(actionB, uint16(Unsubscribe))
111
+
headers := actionB
112
+
113
+
b, err := json.Marshal(topics)
114
+
require.NoError(t, err)
115
+
116
+
topicNamesB := make([]byte, 4)
117
+
binary.BigEndian.PutUint32(topicNamesB, uint32(len(b)))
118
+
headers = append(headers, topicNamesB...)
119
+
120
+
_, err = conn.Write(append(headers, b...))
121
+
require.NoError(t, err)
122
+
}
123
+
124
+
func TestSubscribeToTopics(t *testing.T) {
125
+
// create a server with an existing topic so we can test subscribing to a new and
126
+
// existing topic
127
+
srv := createServerWithExistingTopic(t, topicA)
128
+
129
+
_ = createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0)
130
+
131
+
srv.mu.Lock()
132
+
defer srv.mu.Unlock()
133
+
assert.Len(t, srv.topics, 2)
134
+
assert.Len(t, srv.topics[topicA].subscriptions, 1)
135
+
assert.Len(t, srv.topics[topicB].subscriptions, 1)
136
+
}
137
+
138
+
func TestUnsubscribesFromTopic(t *testing.T) {
139
+
srv := createServerWithExistingTopic(t, topicA)
140
+
141
+
conn := createConnectionAndSubscribe(t, []string{topicA, topicB, topicC}, Current, 0)
142
+
143
+
assert.Len(t, srv.topics, 3)
144
+
145
+
srv.mu.Lock()
146
+
assert.Len(t, srv.topics[topicA].subscriptions, 1)
147
+
assert.Len(t, srv.topics[topicB].subscriptions, 1)
148
+
assert.Len(t, srv.topics[topicC].subscriptions, 1)
149
+
srv.mu.Unlock()
150
+
151
+
topics := []string{topicA, topicB}
152
+
153
+
unsubscribetoTopics(t, conn, topics)
154
+
155
+
expectedRes := Unsubscribed
156
+
157
+
var resp Status
158
+
err := binary.Read(conn, binary.BigEndian, &resp)
159
+
require.NoError(t, err)
160
+
161
+
assert.Equal(t, expectedRes, resp)
162
+
163
+
assert.Len(t, srv.topics, 3)
164
+
165
+
srv.mu.Lock()
166
+
assert.Len(t, srv.topics[topicA].subscriptions, 0)
167
+
assert.Len(t, srv.topics[topicB].subscriptions, 0)
168
+
assert.Len(t, srv.topics[topicC].subscriptions, 1)
169
+
srv.mu.Unlock()
170
+
}
171
+
172
+
func TestSubscriberClosesWithoutUnsubscribing(t *testing.T) {
173
+
srv := createServer(t)
174
+
175
+
conn := createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0)
176
+
177
+
assert.Len(t, srv.topics, 2)
178
+
179
+
srv.mu.Lock()
180
+
assert.Len(t, srv.topics[topicA].subscriptions, 1)
181
+
assert.Len(t, srv.topics[topicB].subscriptions, 1)
182
+
srv.mu.Unlock()
183
+
184
+
// close the conn
185
+
err := conn.Close()
186
+
require.NoError(t, err)
187
+
188
+
publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr))
189
+
require.NoError(t, err)
190
+
191
+
err = binary.Write(publisherConn, binary.BigEndian, Publish)
192
+
require.NoError(t, err)
193
+
194
+
data := []byte("hello world")
195
+
196
+
sendMessage(t, publisherConn, topicA, data)
197
+
198
+
// the timeout for a connection is 100 milliseconds, so we should wait at least this long before checking the unsubscribe
199
+
// TODO: see if theres a better way, but without this, the test is flakey
200
+
time.Sleep(time.Millisecond * 100)
201
+
202
+
assert.Len(t, srv.topics, 2)
203
+
204
+
srv.mu.Lock()
205
+
assert.Len(t, srv.topics[topicA].subscriptions, 0)
206
+
assert.Len(t, srv.topics[topicB].subscriptions, 0)
207
+
srv.mu.Unlock()
208
+
}
209
+
210
+
func TestInvalidAction(t *testing.T) {
211
+
_ = createServer(t)
212
+
213
+
conn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr))
214
+
require.NoError(t, err)
215
+
216
+
err = binary.Write(conn, binary.BigEndian, uint16(99))
217
+
require.NoError(t, err)
218
+
219
+
expectedRes := Error
220
+
221
+
var resp Status
222
+
err = binary.Read(conn, binary.BigEndian, &resp)
223
+
require.NoError(t, err)
224
+
225
+
assert.Equal(t, expectedRes, resp)
226
+
227
+
expectedMessage := "unknown action"
228
+
229
+
var dataLen uint16
230
+
err = binary.Read(conn, binary.BigEndian, &dataLen)
231
+
require.NoError(t, err)
232
+
assert.Equal(t, len(expectedMessage), int(dataLen))
233
+
234
+
buf := make([]byte, dataLen)
235
+
_, err = conn.Read(buf)
236
+
require.NoError(t, err)
237
+
238
+
assert.Equal(t, expectedMessage, string(buf))
239
+
}
240
+
241
+
func TestInvalidTopicDataPublished(t *testing.T) {
242
+
_ = createServer(t)
243
+
244
+
publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr))
245
+
require.NoError(t, err)
246
+
247
+
err = binary.Write(publisherConn, binary.BigEndian, Publish)
248
+
require.NoError(t, err)
249
+
250
+
// send topic
251
+
topic := topicA
252
+
err = binary.Write(publisherConn, binary.BigEndian, uint32(len(topic)))
253
+
require.NoError(t, err)
254
+
_, err = publisherConn.Write([]byte(topic))
255
+
require.NoError(t, err)
256
+
257
+
expectedRes := Error
258
+
259
+
var resp Status
260
+
err = binary.Read(publisherConn, binary.BigEndian, &resp)
261
+
require.NoError(t, err)
262
+
263
+
assert.Equal(t, expectedRes, resp)
264
+
265
+
expectedMessage := "topic data does not contain 'topic:' prefix"
266
+
267
+
var dataLen uint16
268
+
err = binary.Read(publisherConn, binary.BigEndian, &dataLen)
269
+
require.NoError(t, err)
270
+
assert.Equal(t, len(expectedMessage), int(dataLen))
271
+
272
+
buf := make([]byte, dataLen)
273
+
_, err = publisherConn.Read(buf)
274
+
require.NoError(t, err)
275
+
276
+
assert.Equal(t, expectedMessage, string(buf))
277
+
}
278
+
279
+
func TestSendsDataToTopicSubscribers(t *testing.T) {
280
+
_ = createServer(t)
281
+
282
+
subscribers := make([]net.Conn, 0, 10)
283
+
for i := 0; i < 10; i++ {
284
+
subscriberConn := createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0)
285
+
286
+
subscribers = append(subscribers, subscriberConn)
287
+
}
288
+
289
+
publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr))
290
+
require.NoError(t, err)
291
+
292
+
err = binary.Write(publisherConn, binary.BigEndian, Publish)
293
+
require.NoError(t, err)
294
+
295
+
topic := fmt.Sprintf("topic:%s", topicA)
296
+
messageData := "hello world"
297
+
298
+
sendMessage(t, publisherConn, topic, []byte(messageData))
299
+
300
+
// check the subsribers got the data
301
+
for _, conn := range subscribers {
302
+
msg := readMessage(t, conn)
303
+
assert.Equal(t, messageData, string(msg))
304
+
}
305
+
}
306
+
307
+
func TestPublishMultipleTimes(t *testing.T) {
308
+
_ = createServer(t)
309
+
310
+
publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr))
311
+
require.NoError(t, err)
312
+
313
+
err = binary.Write(publisherConn, binary.BigEndian, Publish)
314
+
require.NoError(t, err)
315
+
316
+
messages := make([]string, 0, 10)
317
+
for i := 0; i < 10; i++ {
318
+
messages = append(messages, fmt.Sprintf("message %d", i))
319
+
}
320
+
321
+
subscribeFinCh := make(chan struct{})
322
+
// create a subscriber that will read messages
323
+
subscriberConn := createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0)
324
+
go func() {
325
+
// check subscriber got all messages
326
+
results := make([]string, 0, len(messages))
327
+
for i := 0; i < len(messages); i++ {
328
+
msg := readMessage(t, subscriberConn)
329
+
results = append(results, string(msg))
330
+
}
331
+
332
+
assert.ElementsMatch(t, results, messages)
333
+
334
+
subscribeFinCh <- struct{}{}
335
+
}()
336
+
337
+
topic := fmt.Sprintf("topic:%s", topicA)
338
+
339
+
// send multiple messages
340
+
for _, msg := range messages {
341
+
sendMessage(t, publisherConn, topic, []byte(msg))
342
+
}
343
+
344
+
select {
345
+
case <-subscribeFinCh:
346
+
break
347
+
case <-time.After(time.Second):
348
+
t.Fatal(fmt.Errorf("timed out waiting for subscriber to read messages"))
349
+
}
350
+
}
351
+
352
+
func TestSendsDataToTopicSubscriberNacksThenAcks(t *testing.T) {
353
+
_ = createServer(t)
354
+
355
+
subscriberConn := createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0)
356
+
357
+
publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr))
358
+
require.NoError(t, err)
359
+
360
+
err = binary.Write(publisherConn, binary.BigEndian, Publish)
361
+
require.NoError(t, err)
362
+
363
+
topic := fmt.Sprintf("topic:%s", topicA)
364
+
messageData := "hello world"
365
+
366
+
sendMessage(t, publisherConn, topic, []byte(messageData))
367
+
368
+
// check the subsribers got the data
369
+
readMessage := func(conn net.Conn, ack Action) {
370
+
var topicLen uint16
371
+
err = binary.Read(conn, binary.BigEndian, &topicLen)
372
+
require.NoError(t, err)
373
+
374
+
topicBuf := make([]byte, topicLen)
375
+
_, err = conn.Read(topicBuf)
376
+
require.NoError(t, err)
377
+
assert.Equal(t, topicA, string(topicBuf))
378
+
379
+
var dataLen uint64
380
+
err = binary.Read(conn, binary.BigEndian, &dataLen)
381
+
require.NoError(t, err)
382
+
383
+
buf := make([]byte, dataLen)
384
+
n, err := conn.Read(buf)
385
+
require.NoError(t, err)
386
+
387
+
require.Equal(t, int(dataLen), n)
388
+
389
+
assert.Equal(t, messageData, string(buf))
390
+
391
+
err = binary.Write(conn, binary.BigEndian, ack)
392
+
require.NoError(t, err)
393
+
}
394
+
395
+
// NACK the message and then ack it
396
+
readMessage(subscriberConn, Nack)
397
+
readMessage(subscriberConn, Ack)
398
+
// reading for another message should now timeout but give enough time for the ack delay to kick in
399
+
// should the second read of the message not have been ack'd properly
400
+
var topicLen uint16
401
+
_ = subscriberConn.SetReadDeadline(time.Now().Add(ackDelay + time.Millisecond*100))
402
+
err = binary.Read(subscriberConn, binary.BigEndian, &topicLen)
403
+
require.Error(t, err)
404
+
}
405
+
406
+
func TestSendsDataToTopicSubscriberDoesntAckMessage(t *testing.T) {
407
+
_ = createServer(t)
408
+
409
+
subscriberConn := createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0)
410
+
411
+
publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr))
412
+
require.NoError(t, err)
413
+
414
+
err = binary.Write(publisherConn, binary.BigEndian, Publish)
415
+
require.NoError(t, err)
416
+
417
+
topic := fmt.Sprintf("topic:%s", topicA)
418
+
messageData := "hello world"
419
+
420
+
sendMessage(t, publisherConn, topic, []byte(messageData))
421
+
422
+
// check the subsribers got the data
423
+
readMessage := func(conn net.Conn, ack bool) {
424
+
var topicLen uint16
425
+
err = binary.Read(conn, binary.BigEndian, &topicLen)
426
+
require.NoError(t, err)
427
+
428
+
topicBuf := make([]byte, topicLen)
429
+
_, err = conn.Read(topicBuf)
430
+
require.NoError(t, err)
431
+
assert.Equal(t, topicA, string(topicBuf))
432
+
433
+
var dataLen uint64
434
+
err = binary.Read(conn, binary.BigEndian, &dataLen)
435
+
require.NoError(t, err)
436
+
437
+
buf := make([]byte, dataLen)
438
+
n, err := conn.Read(buf)
439
+
require.NoError(t, err)
440
+
441
+
require.Equal(t, int(dataLen), n)
442
+
443
+
assert.Equal(t, messageData, string(buf))
444
+
445
+
if ack {
446
+
err = binary.Write(conn, binary.BigEndian, Ack)
447
+
require.NoError(t, err)
448
+
return
449
+
}
450
+
}
451
+
452
+
// don't send ack or nack and then ack on the second attempt
453
+
readMessage(subscriberConn, false)
454
+
readMessage(subscriberConn, true)
455
+
456
+
// reading for another message should now timeout but give enough time for the ack delay to kick in
457
+
// should the second read of the message not have been ack'd properly
458
+
var topicLen uint16
459
+
_ = subscriberConn.SetReadDeadline(time.Now().Add(ackDelay + time.Millisecond*100))
460
+
err = binary.Read(subscriberConn, binary.BigEndian, &topicLen)
461
+
require.Error(t, err)
462
+
}
463
+
464
+
func TestSendsDataToTopicSubscriberDeliveryCountTooHighWithNoAck(t *testing.T) {
465
+
_ = createServer(t)
466
+
467
+
subscriberConn := createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0)
468
+
469
+
publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr))
470
+
require.NoError(t, err)
471
+
472
+
err = binary.Write(publisherConn, binary.BigEndian, Publish)
473
+
require.NoError(t, err)
474
+
475
+
topic := fmt.Sprintf("topic:%s", topicA)
476
+
messageData := "hello world"
477
+
478
+
sendMessage(t, publisherConn, topic, []byte(messageData))
479
+
480
+
// check the subsribers got the data
481
+
readMessage := func(conn net.Conn, ack bool) {
482
+
var topicLen uint16
483
+
err = binary.Read(conn, binary.BigEndian, &topicLen)
484
+
require.NoError(t, err)
485
+
486
+
topicBuf := make([]byte, topicLen)
487
+
_, err = conn.Read(topicBuf)
488
+
require.NoError(t, err)
489
+
assert.Equal(t, topicA, string(topicBuf))
490
+
491
+
var dataLen uint64
492
+
err = binary.Read(conn, binary.BigEndian, &dataLen)
493
+
require.NoError(t, err)
494
+
495
+
buf := make([]byte, dataLen)
496
+
n, err := conn.Read(buf)
497
+
require.NoError(t, err)
498
+
499
+
require.Equal(t, int(dataLen), n)
500
+
501
+
assert.Equal(t, messageData, string(buf))
502
+
503
+
if ack {
504
+
err = binary.Write(conn, binary.BigEndian, Ack)
505
+
require.NoError(t, err)
506
+
return
507
+
}
508
+
}
509
+
510
+
// nack the message 5 times
511
+
readMessage(subscriberConn, false)
512
+
readMessage(subscriberConn, false)
513
+
readMessage(subscriberConn, false)
514
+
readMessage(subscriberConn, false)
515
+
readMessage(subscriberConn, false)
516
+
517
+
// reading for the message should now timeout as we have nack'd the message too many times
518
+
var topicLen uint16
519
+
_ = subscriberConn.SetReadDeadline(time.Now().Add(ackDelay + time.Millisecond*100))
520
+
err = binary.Read(subscriberConn, binary.BigEndian, &topicLen)
521
+
require.Error(t, err)
522
+
}
523
+
524
+
func TestSubscribeAndReplaysFromStart(t *testing.T) {
525
+
_ = createServer(t)
526
+
527
+
publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr))
528
+
require.NoError(t, err)
529
+
530
+
err = binary.Write(publisherConn, binary.BigEndian, Publish)
531
+
require.NoError(t, err)
532
+
533
+
messages := make([]string, 0, 10)
534
+
for i := 0; i < 10; i++ {
535
+
messages = append(messages, fmt.Sprintf("message %d", i))
536
+
}
537
+
538
+
topic := fmt.Sprintf("topic:%s", topicA)
539
+
540
+
for _, msg := range messages {
541
+
sendMessage(t, publisherConn, topic, []byte(msg))
542
+
}
543
+
544
+
// send some messages for topic B as well
545
+
sendMessage(t, publisherConn, fmt.Sprintf("topic:%s", topicB), []byte("topic b message 1"))
546
+
sendMessage(t, publisherConn, fmt.Sprintf("topic:%s", topicB), []byte("topic b message 2"))
547
+
sendMessage(t, publisherConn, fmt.Sprintf("topic:%s", topicB), []byte("topic b message 3"))
548
+
549
+
subscriberConn := createConnectionAndSubscribe(t, []string{topicA}, From, 0)
550
+
results := make([]string, 0, len(messages))
551
+
for i := 0; i < len(messages); i++ {
552
+
msg := readMessage(t, subscriberConn)
553
+
results = append(results, string(msg))
554
+
}
555
+
assert.ElementsMatch(t, results, messages)
556
+
}
557
+
558
+
func TestSubscribeAndReplaysFromIndex(t *testing.T) {
559
+
_ = createServer(t)
560
+
561
+
publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr))
562
+
require.NoError(t, err)
563
+
564
+
err = binary.Write(publisherConn, binary.BigEndian, Publish)
565
+
require.NoError(t, err)
566
+
567
+
messages := make([]string, 0, 10)
568
+
for i := 0; i < 10; i++ {
569
+
messages = append(messages, fmt.Sprintf("message %d", i))
570
+
}
571
+
572
+
topic := fmt.Sprintf("topic:%s", topicA)
573
+
574
+
// send multiple messages
575
+
for _, msg := range messages {
576
+
sendMessage(t, publisherConn, topic, []byte(msg))
577
+
}
578
+
579
+
// send some messages for topic B as well
580
+
sendMessage(t, publisherConn, fmt.Sprintf("topic:%s", topicB), []byte("topic b message 1"))
581
+
sendMessage(t, publisherConn, fmt.Sprintf("topic:%s", topicB), []byte("topic b message 2"))
582
+
sendMessage(t, publisherConn, fmt.Sprintf("topic:%s", topicB), []byte("topic b message 3"))
583
+
584
+
subscriberConn := createConnectionAndSubscribe(t, []string{topicA}, From, 3)
585
+
586
+
// now that the subscriber has subecribed send another message that should arrive after all the other messages were consumed
587
+
sendMessage(t, publisherConn, topic, []byte("hello there"))
588
+
589
+
results := make([]string, 0, len(messages))
590
+
for i := 0; i < len(messages)-3; i++ {
591
+
msg := readMessage(t, subscriberConn)
592
+
results = append(results, string(msg))
593
+
}
594
+
require.Len(t, results, 7)
595
+
expMessages := make([]string, 0, 7)
596
+
for i, msg := range messages {
597
+
if i < 3 {
598
+
continue
599
+
}
600
+
expMessages = append(expMessages, msg)
601
+
}
602
+
assert.Equal(t, expMessages, results)
603
+
604
+
// now check we can get the message that was sent after the subscription was created
605
+
msg := readMessage(t, subscriberConn)
606
+
assert.Equal(t, "hello there", string(msg))
607
+
}
608
+
609
+
func readMessage(t *testing.T, subscriberConn net.Conn) []byte {
610
+
var topicLen uint16
611
+
err := binary.Read(subscriberConn, binary.BigEndian, &topicLen)
612
+
require.NoError(t, err)
613
+
614
+
topicBuf := make([]byte, topicLen)
615
+
_, err = subscriberConn.Read(topicBuf)
616
+
require.NoError(t, err)
617
+
assert.Equal(t, topicA, string(topicBuf))
618
+
619
+
var dataLen uint64
620
+
err = binary.Read(subscriberConn, binary.BigEndian, &dataLen)
621
+
require.NoError(t, err)
622
+
623
+
buf := make([]byte, dataLen)
624
+
n, err := subscriberConn.Read(buf)
625
+
require.NoError(t, err)
626
+
require.Equal(t, int(dataLen), n)
627
+
628
+
err = binary.Write(subscriberConn, binary.BigEndian, Ack)
629
+
require.NoError(t, err)
630
+
631
+
return buf
632
+
}
+136
internal/server/subscriber.go
+136
internal/server/subscriber.go
···
1
+
package server
2
+
3
+
import (
4
+
"encoding/binary"
5
+
"fmt"
6
+
"log/slog"
7
+
"net"
8
+
"time"
9
+
10
+
"github.com/willdot/messagebroker/internal"
11
+
)
12
+
13
+
type subscriber struct {
14
+
peer *Peer
15
+
topic string
16
+
messages chan internal.Message
17
+
unsubscribeCh chan struct{}
18
+
19
+
ackDelay time.Duration
20
+
ackTimeout time.Duration
21
+
}
22
+
23
+
func newSubscriber(peer *Peer, topic *topic, ackDelay, ackTimeout time.Duration, startAt int) *subscriber {
24
+
s := &subscriber{
25
+
peer: peer,
26
+
topic: topic.name,
27
+
messages: make(chan internal.Message),
28
+
ackDelay: ackDelay,
29
+
ackTimeout: ackTimeout,
30
+
unsubscribeCh: make(chan struct{}, 1),
31
+
}
32
+
33
+
go s.sendMessages()
34
+
35
+
go func() {
36
+
topic.messageStore.ReadFrom(startAt, func(msg internal.Message) {
37
+
select {
38
+
case s.messages <- msg:
39
+
return
40
+
case <-s.unsubscribeCh:
41
+
return
42
+
}
43
+
})
44
+
}()
45
+
46
+
return s
47
+
}
48
+
49
+
func (s *subscriber) sendMessages() {
50
+
for {
51
+
select {
52
+
case <-s.unsubscribeCh:
53
+
return
54
+
case msg := <-s.messages:
55
+
ack, err := s.sendMessage(s.topic, msg)
56
+
if err != nil {
57
+
slog.Error("failed to send to message", "error", err, "peer", s.peer.Addr())
58
+
}
59
+
60
+
if ack {
61
+
continue
62
+
}
63
+
64
+
if msg.DeliveryCount >= 5 {
65
+
slog.Error("max delivery count for message. Dropping", "peer", s.peer.Addr())
66
+
continue
67
+
}
68
+
69
+
msg.DeliveryCount++
70
+
s.addMessage(msg, s.ackDelay)
71
+
}
72
+
}
73
+
}
74
+
75
+
func (s *subscriber) addMessage(msg internal.Message, delay time.Duration) {
76
+
go func() {
77
+
timer := time.NewTimer(delay)
78
+
defer timer.Stop()
79
+
80
+
select {
81
+
case <-s.unsubscribeCh:
82
+
return
83
+
case <-timer.C:
84
+
s.messages <- msg
85
+
}
86
+
}()
87
+
}
88
+
89
+
func (s *subscriber) sendMessage(topic string, msg internal.Message) (bool, error) {
90
+
var ack bool
91
+
op := func(conn net.Conn) error {
92
+
topicB := make([]byte, 2)
93
+
binary.BigEndian.PutUint16(topicB, uint16(len(topic)))
94
+
95
+
headers := topicB
96
+
headers = append(headers, []byte(topic)...)
97
+
98
+
// TODO: if message is empty, return error?
99
+
dataLenB := make([]byte, 8)
100
+
binary.BigEndian.PutUint64(dataLenB, uint64(len(msg.Data)))
101
+
headers = append(headers, dataLenB...)
102
+
103
+
_, err := conn.Write(append(headers, msg.Data...))
104
+
if err != nil {
105
+
return fmt.Errorf("failed to write to peer: %w", err)
106
+
}
107
+
108
+
if err := conn.SetReadDeadline(time.Now().Add(s.ackTimeout)); err != nil {
109
+
slog.Error("failed to set connection read deadline", "error", err, "peer", s.peer.Addr())
110
+
}
111
+
defer func() {
112
+
if err := conn.SetReadDeadline(time.Time{}); err != nil {
113
+
slog.Error("failed to reset connection read deadline", "error", err, "peer", s.peer.Addr())
114
+
}
115
+
}()
116
+
var ackRes Action
117
+
err = binary.Read(conn, binary.BigEndian, &ackRes)
118
+
if err != nil {
119
+
return fmt.Errorf("failed to read ack from peer: %w", err)
120
+
}
121
+
122
+
if ackRes == Ack {
123
+
ack = true
124
+
}
125
+
126
+
return nil
127
+
}
128
+
129
+
err := s.peer.RunConnOperation(op)
130
+
131
+
return ack, err
132
+
}
133
+
134
+
func (s *subscriber) unsubscribe() {
135
+
close(s.unsubscribeCh)
136
+
}
+62
internal/server/topic.go
+62
internal/server/topic.go
···
1
+
package server
2
+
3
+
import (
4
+
"fmt"
5
+
"net"
6
+
"sync"
7
+
8
+
"github.com/willdot/messagebroker/internal"
9
+
"github.com/willdot/messagebroker/internal/messagestore"
10
+
)
11
+
12
+
type Store interface {
13
+
Write(msg internal.Message) error
14
+
ReadFrom(offset int, handleFunc func(msg internal.Message))
15
+
}
16
+
17
+
type topic struct {
18
+
name string
19
+
subscriptions map[net.Addr]*subscriber
20
+
mu sync.Mutex
21
+
messageStore Store
22
+
}
23
+
24
+
func newTopic(name string) *topic {
25
+
messageStore := messagestore.NewMemoryStore()
26
+
return &topic{
27
+
name: name,
28
+
subscriptions: make(map[net.Addr]*subscriber),
29
+
messageStore: messageStore,
30
+
}
31
+
}
32
+
33
+
func (t *topic) sendMessageToSubscribers(msg internal.Message) error {
34
+
err := t.messageStore.Write(msg)
35
+
if err != nil {
36
+
return fmt.Errorf("failed to write message to store: %w", err)
37
+
}
38
+
39
+
t.mu.Lock()
40
+
subscribers := t.subscriptions
41
+
t.mu.Unlock()
42
+
43
+
for _, subscriber := range subscribers {
44
+
subscriber.addMessage(msg, 0)
45
+
}
46
+
47
+
return nil
48
+
}
49
+
50
+
func (t *topic) findSubscription(addr net.Addr) *subscriber {
51
+
t.mu.Lock()
52
+
defer t.mu.Unlock()
53
+
54
+
return t.subscriptions[addr]
55
+
}
56
+
57
+
func (t *topic) removeSubscription(addr net.Addr) {
58
+
t.mu.Lock()
59
+
defer t.mu.Unlock()
60
+
61
+
delete(t.subscriptions, addr)
62
+
}
-299
server.go
-299
server.go
···
1
-
package messagebroker
2
-
3
-
import (
4
-
"context"
5
-
"encoding/binary"
6
-
"encoding/json"
7
-
"fmt"
8
-
"log/slog"
9
-
"net"
10
-
"strings"
11
-
"sync"
12
-
)
13
-
14
-
// Action represents the type of action that a connection requests to do
15
-
type Action uint8
16
-
17
-
const (
18
-
Subscribe Action = 1
19
-
Unsubscribe Action = 2
20
-
Publish Action = 3
21
-
)
22
-
23
-
type Message struct {
24
-
Topic string `json:"topic"`
25
-
Data []byte `json:"data"`
26
-
}
27
-
28
-
type Server struct {
29
-
addr string
30
-
lis net.Listener
31
-
32
-
mu sync.Mutex
33
-
topics map[string]topic
34
-
}
35
-
36
-
func NewServer(ctx context.Context, addr string) (*Server, error) {
37
-
lis, err := net.Listen("tcp", addr)
38
-
if err != nil {
39
-
return nil, fmt.Errorf("failed to listen: %w", err)
40
-
}
41
-
42
-
srv := &Server{
43
-
lis: lis,
44
-
topics: map[string]topic{},
45
-
}
46
-
47
-
go srv.start(ctx)
48
-
49
-
return srv, nil
50
-
}
51
-
52
-
func (s *Server) Shutdown() error {
53
-
return s.lis.Close()
54
-
}
55
-
56
-
func (s *Server) start(ctx context.Context) {
57
-
for {
58
-
conn, err := s.lis.Accept()
59
-
if err != nil {
60
-
slog.Error("listener failed to accept", "error", err)
61
-
// TODO: see if there's a better way to check for this error
62
-
if strings.Contains(err.Error(), "use of closed network connection") {
63
-
return
64
-
}
65
-
}
66
-
67
-
go s.handleConn(conn)
68
-
}
69
-
}
70
-
71
-
func getActionFromConn(conn net.Conn) (Action, error) {
72
-
var action Action
73
-
err := binary.Read(conn, binary.BigEndian, &action)
74
-
if err != nil {
75
-
return 0, err
76
-
}
77
-
78
-
return action, nil
79
-
}
80
-
81
-
func getDataLengthFromConn(conn net.Conn) (uint32, error) {
82
-
var dataLen uint32
83
-
err := binary.Read(conn, binary.BigEndian, &dataLen)
84
-
if err != nil {
85
-
return 0, fmt.Errorf("failed to read data length from conn: %w", err)
86
-
}
87
-
88
-
return dataLen, nil
89
-
}
90
-
91
-
func (s *Server) handleConn(conn net.Conn) {
92
-
action, err := getActionFromConn(conn)
93
-
if err != nil {
94
-
slog.Error("failed to read action from conn", "error", err, "conn", conn.LocalAddr())
95
-
return
96
-
}
97
-
98
-
switch action {
99
-
case Subscribe:
100
-
s.handleSubscribingConn(conn)
101
-
case Unsubscribe:
102
-
s.handleUnsubscribingConn(conn)
103
-
case Publish:
104
-
s.handlePublisherConn(conn)
105
-
default:
106
-
slog.Error("unknown action", "action", action, "conn", conn.LocalAddr())
107
-
_, _ = conn.Write([]byte("unknown action"))
108
-
}
109
-
}
110
-
111
-
func (s *Server) handleSubscribingConn(conn net.Conn) {
112
-
// subscribe the connection to the topic
113
-
s.subscribeConnToTopic(conn)
114
-
115
-
// keep handling the connection, getting the action from the conection when it wishes to do something else.
116
-
// once the connection ends, it will be unsubscribed from all topics and returned
117
-
for {
118
-
action, err := getActionFromConn(conn)
119
-
if err != nil {
120
-
// TODO: see if there's a way to check if the connection has been ended etc
121
-
slog.Error("failed to read action from subscriber", "error", err, "conn", conn.LocalAddr())
122
-
123
-
s.unsubscribeConnectionFromAllTopics(conn.LocalAddr())
124
-
125
-
return
126
-
}
127
-
128
-
switch action {
129
-
case Subscribe:
130
-
s.subscribeConnToTopic(conn)
131
-
case Unsubscribe:
132
-
s.handleUnsubscribingConn(conn)
133
-
default:
134
-
slog.Error("unknown action for subscriber", "action", action, "conn", conn.LocalAddr())
135
-
continue
136
-
}
137
-
}
138
-
}
139
-
140
-
func (s *Server) subscribeConnToTopic(conn net.Conn) {
141
-
// get the topics the connection wishes to subscribe to
142
-
dataLen, err := getDataLengthFromConn(conn)
143
-
if err != nil {
144
-
slog.Error(err.Error(), "conn", conn.LocalAddr())
145
-
_, _ = conn.Write([]byte("invalid data length of topics provided"))
146
-
return
147
-
}
148
-
if dataLen == 0 {
149
-
_, _ = conn.Write([]byte("data length of topics is 0"))
150
-
return
151
-
}
152
-
153
-
buf := make([]byte, dataLen)
154
-
_, err = conn.Read(buf)
155
-
if err != nil {
156
-
slog.Error("failed to read subscibers topic data", "error", err, "conn", conn.LocalAddr())
157
-
_, _ = conn.Write([]byte("failed to read topic data"))
158
-
return
159
-
}
160
-
161
-
var topics []string
162
-
err = json.Unmarshal(buf, &topics)
163
-
if err != nil {
164
-
slog.Error("failed to unmarshal subscibers topic data", "error", err, "conn", conn.LocalAddr())
165
-
_, _ = conn.Write([]byte("invalid topic data provided"))
166
-
return
167
-
}
168
-
169
-
s.subscribeToTopics(conn, topics)
170
-
_, _ = conn.Write([]byte("subscribed"))
171
-
}
172
-
173
-
func (s *Server) handleUnsubscribingConn(conn net.Conn) {
174
-
// get the topics the connection wishes to unsubscribe from
175
-
dataLen, err := getDataLengthFromConn(conn)
176
-
if err != nil {
177
-
slog.Error(err.Error(), "conn", conn.LocalAddr())
178
-
_, _ = conn.Write([]byte("invalid data length of topics provided"))
179
-
return
180
-
}
181
-
if dataLen == 0 {
182
-
_, _ = conn.Write([]byte("data length of topics is 0"))
183
-
return
184
-
}
185
-
186
-
buf := make([]byte, dataLen)
187
-
_, err = conn.Read(buf)
188
-
if err != nil {
189
-
slog.Error("failed to read subscibers topic data", "error", err, "conn", conn.LocalAddr())
190
-
_, _ = conn.Write([]byte("failed to read topic data"))
191
-
return
192
-
}
193
-
194
-
var topics []string
195
-
err = json.Unmarshal(buf, &topics)
196
-
if err != nil {
197
-
slog.Error("failed to unmarshal subscibers topic data", "error", err, "conn", conn.LocalAddr())
198
-
_, _ = conn.Write([]byte("invalid topic data provided"))
199
-
return
200
-
}
201
-
202
-
s.unsubscribeToTopics(conn, topics)
203
-
204
-
_, _ = conn.Write([]byte("unsubscribed"))
205
-
}
206
-
207
-
func (s *Server) handlePublisherConn(conn net.Conn) {
208
-
dataLen, err := getDataLengthFromConn(conn)
209
-
if err != nil {
210
-
slog.Error(err.Error(), "conn", conn.LocalAddr())
211
-
_, _ = conn.Write([]byte("invalid data length of data provided"))
212
-
return
213
-
}
214
-
if dataLen == 0 {
215
-
return
216
-
}
217
-
218
-
buf := make([]byte, dataLen)
219
-
_, err = conn.Read(buf)
220
-
if err != nil {
221
-
_, _ = conn.Write([]byte("failed to read data"))
222
-
slog.Error("failed to read data from conn", "error", err, "conn", conn.LocalAddr())
223
-
return
224
-
}
225
-
226
-
var msg Message
227
-
err = json.Unmarshal(buf, &msg)
228
-
if err != nil {
229
-
_, _ = conn.Write([]byte("invalid message"))
230
-
slog.Error("failed to unmarshal data to message", "error", err, "conn", conn.LocalAddr())
231
-
return
232
-
}
233
-
234
-
topic := s.getTopic(msg.Topic)
235
-
if topic != nil {
236
-
topic.sendMessageToSubscribers(msg)
237
-
}
238
-
}
239
-
240
-
func (s *Server) subscribeToTopics(conn net.Conn, topics []string) {
241
-
for _, topic := range topics {
242
-
s.addSubsciberToTopic(topic, conn)
243
-
}
244
-
}
245
-
246
-
func (s *Server) addSubsciberToTopic(topicName string, conn net.Conn) {
247
-
s.mu.Lock()
248
-
defer s.mu.Unlock()
249
-
250
-
t, ok := s.topics[topicName]
251
-
if !ok {
252
-
t = newTopic(topicName)
253
-
}
254
-
255
-
t.subscriptions[conn.LocalAddr()] = Subscriber{
256
-
conn: conn,
257
-
currentOffset: 0,
258
-
}
259
-
260
-
s.topics[topicName] = t
261
-
}
262
-
263
-
func (s *Server) unsubscribeToTopics(conn net.Conn, topics []string) {
264
-
for _, topic := range topics {
265
-
s.removeSubsciberFromTopic(topic, conn)
266
-
}
267
-
}
268
-
269
-
func (s *Server) removeSubsciberFromTopic(topicName string, conn net.Conn) {
270
-
s.mu.Lock()
271
-
defer s.mu.Unlock()
272
-
273
-
t, ok := s.topics[topicName]
274
-
if !ok {
275
-
return
276
-
}
277
-
278
-
delete(t.subscriptions, conn.LocalAddr())
279
-
}
280
-
281
-
func (s *Server) unsubscribeConnectionFromAllTopics(addr net.Addr) {
282
-
s.mu.Lock()
283
-
defer s.mu.Unlock()
284
-
285
-
for _, topic := range s.topics {
286
-
delete(topic.subscriptions, addr)
287
-
}
288
-
}
289
-
290
-
func (s *Server) getTopic(topicName string) *topic {
291
-
s.mu.Lock()
292
-
defer s.mu.Unlock()
293
-
294
-
if topic, ok := s.topics[topicName]; ok {
295
-
return &topic
296
-
}
297
-
298
-
return nil
299
-
}
-231
server_test.go
-231
server_test.go
···
1
-
package messagebroker
2
-
3
-
import (
4
-
"context"
5
-
"encoding/binary"
6
-
"encoding/json"
7
-
"net"
8
-
"testing"
9
-
10
-
"github.com/stretchr/testify/assert"
11
-
"github.com/stretchr/testify/require"
12
-
)
13
-
14
-
func createServer(t *testing.T) *Server {
15
-
srv, err := NewServer(context.Background(), ":3000")
16
-
require.NoError(t, err)
17
-
18
-
t.Cleanup(func() {
19
-
srv.Shutdown()
20
-
})
21
-
22
-
return srv
23
-
}
24
-
25
-
func createServerWithExistingTopic(t *testing.T, topicName string) *Server {
26
-
srv := createServer(t)
27
-
srv.topics[topicName] = topic{
28
-
name: topicName,
29
-
subscriptions: make(map[net.Addr]Subscriber),
30
-
}
31
-
32
-
return srv
33
-
}
34
-
35
-
func createConnectionAndSubscribe(t *testing.T, topics []string) net.Conn {
36
-
conn, err := net.Dial("tcp", "localhost:3000")
37
-
require.NoError(t, err)
38
-
39
-
err = binary.Write(conn, binary.BigEndian, Subscribe)
40
-
require.NoError(t, err)
41
-
42
-
rawTopics, err := json.Marshal(topics)
43
-
require.NoError(t, err)
44
-
45
-
err = binary.Write(conn, binary.BigEndian, uint32(len(rawTopics)))
46
-
require.NoError(t, err)
47
-
48
-
_, err = conn.Write(rawTopics)
49
-
require.NoError(t, err)
50
-
51
-
expectedRes := "subscribed"
52
-
53
-
buf := make([]byte, len(expectedRes))
54
-
n, err := conn.Read(buf)
55
-
require.NoError(t, err)
56
-
require.Equal(t, len(expectedRes), n)
57
-
58
-
assert.Equal(t, expectedRes, string(buf))
59
-
60
-
return conn
61
-
}
62
-
63
-
func TestSubscribeToTopics(t *testing.T) {
64
-
// create a server with an existing topic so we can test subscribing to a new and
65
-
// existing topic
66
-
srv := createServerWithExistingTopic(t, "topic a")
67
-
68
-
_ = createConnectionAndSubscribe(t, []string{"topic a", "topic b"})
69
-
70
-
assert.Len(t, srv.topics, 2)
71
-
assert.Len(t, srv.topics["topic a"].subscriptions, 1)
72
-
assert.Len(t, srv.topics["topic b"].subscriptions, 1)
73
-
}
74
-
75
-
func TestUnsubscribesFromTopic(t *testing.T) {
76
-
srv := createServerWithExistingTopic(t, "topic a")
77
-
78
-
conn := createConnectionAndSubscribe(t, []string{"topic a", "topic b", "topic c"})
79
-
80
-
assert.Len(t, srv.topics, 3)
81
-
assert.Len(t, srv.topics["topic a"].subscriptions, 1)
82
-
assert.Len(t, srv.topics["topic b"].subscriptions, 1)
83
-
assert.Len(t, srv.topics["topic c"].subscriptions, 1)
84
-
85
-
err := binary.Write(conn, binary.BigEndian, Unsubscribe)
86
-
require.NoError(t, err)
87
-
88
-
topics := []string{"topic a", "topic b"}
89
-
rawTopics, err := json.Marshal(topics)
90
-
require.NoError(t, err)
91
-
92
-
err = binary.Write(conn, binary.BigEndian, uint32(len(rawTopics)))
93
-
require.NoError(t, err)
94
-
95
-
_, err = conn.Write(rawTopics)
96
-
require.NoError(t, err)
97
-
98
-
expectedRes := "unsubscribed"
99
-
100
-
buf := make([]byte, len(expectedRes))
101
-
n, err := conn.Read(buf)
102
-
require.NoError(t, err)
103
-
require.Equal(t, len(expectedRes), n)
104
-
105
-
assert.Equal(t, expectedRes, string(buf))
106
-
107
-
assert.Len(t, srv.topics, 3)
108
-
assert.Len(t, srv.topics["topic a"].subscriptions, 0)
109
-
assert.Len(t, srv.topics["topic b"].subscriptions, 0)
110
-
assert.Len(t, srv.topics["topic c"].subscriptions, 1)
111
-
}
112
-
113
-
func TestSubscriberClosesWithoutUnsubscribing(t *testing.T) {
114
-
srv := createServer(t)
115
-
116
-
conn := createConnectionAndSubscribe(t, []string{"topic a", "topic b"})
117
-
118
-
assert.Len(t, srv.topics, 2)
119
-
assert.Len(t, srv.topics["topic a"].subscriptions, 1)
120
-
assert.Len(t, srv.topics["topic b"].subscriptions, 1)
121
-
122
-
// close the conn
123
-
err := conn.Close()
124
-
require.NoError(t, err)
125
-
126
-
publisherConn, err := net.Dial("tcp", "localhost:3000")
127
-
require.NoError(t, err)
128
-
129
-
err = binary.Write(publisherConn, binary.BigEndian, Publish)
130
-
require.NoError(t, err)
131
-
132
-
data := []byte("hello world")
133
-
// send data length first
134
-
err = binary.Write(publisherConn, binary.BigEndian, uint32(len(data)))
135
-
require.NoError(t, err)
136
-
n, err := publisherConn.Write(data)
137
-
require.NoError(t, err)
138
-
require.Equal(t, len(data), n)
139
-
140
-
assert.Len(t, srv.topics, 2)
141
-
assert.Len(t, srv.topics["topic a"].subscriptions, 0)
142
-
assert.Len(t, srv.topics["topic b"].subscriptions, 0)
143
-
}
144
-
145
-
func TestInvalidAction(t *testing.T) {
146
-
_ = createServer(t)
147
-
148
-
conn, err := net.Dial("tcp", "localhost:3000")
149
-
require.NoError(t, err)
150
-
151
-
err = binary.Write(conn, binary.BigEndian, uint8(99))
152
-
require.NoError(t, err)
153
-
154
-
expectedRes := "unknown action"
155
-
156
-
buf := make([]byte, len(expectedRes))
157
-
n, err := conn.Read(buf)
158
-
require.NoError(t, err)
159
-
require.Equal(t, len(expectedRes), n)
160
-
161
-
assert.Equal(t, expectedRes, string(buf))
162
-
}
163
-
164
-
func TestInvalidMessagePublished(t *testing.T) {
165
-
_ = createServer(t)
166
-
167
-
publisherConn, err := net.Dial("tcp", "localhost:3000")
168
-
require.NoError(t, err)
169
-
170
-
err = binary.Write(publisherConn, binary.BigEndian, Publish)
171
-
require.NoError(t, err)
172
-
173
-
// send some data
174
-
data := []byte("this isn't wrapped in a message type")
175
-
176
-
// send data length first
177
-
err = binary.Write(publisherConn, binary.BigEndian, uint32(len(data)))
178
-
require.NoError(t, err)
179
-
n, err := publisherConn.Write(data)
180
-
require.NoError(t, err)
181
-
require.Equal(t, len(data), n)
182
-
183
-
buf := make([]byte, 15)
184
-
_, err = publisherConn.Read(buf)
185
-
require.NoError(t, err)
186
-
assert.Equal(t, "invalid message", string(buf))
187
-
}
188
-
189
-
func TestSendsDataToTopicSubscribers(t *testing.T) {
190
-
_ = createServer(t)
191
-
192
-
subscribers := make([]net.Conn, 0, 5)
193
-
for i := 0; i < 5; i++ {
194
-
subscriberConn := createConnectionAndSubscribe(t, []string{"topic a", "topic b"})
195
-
196
-
subscribers = append(subscribers, subscriberConn)
197
-
}
198
-
199
-
publisherConn, err := net.Dial("tcp", "localhost:3000")
200
-
require.NoError(t, err)
201
-
202
-
err = binary.Write(publisherConn, binary.BigEndian, Publish)
203
-
require.NoError(t, err)
204
-
205
-
// send some data
206
-
data := []byte("hello world")
207
-
msg := Message{
208
-
Topic: "topic a",
209
-
Data: data,
210
-
}
211
-
212
-
rawMsg, err := json.Marshal(msg)
213
-
require.NoError(t, err)
214
-
215
-
// send data length first
216
-
err = binary.Write(publisherConn, binary.BigEndian, uint32(len(rawMsg)))
217
-
require.NoError(t, err)
218
-
n, err := publisherConn.Write(rawMsg)
219
-
require.NoError(t, err)
220
-
require.Equal(t, len(rawMsg), n)
221
-
222
-
// check the subsribers got the data
223
-
for _, conn := range subscribers {
224
-
buf := make([]byte, len(data))
225
-
n, err := conn.Read(buf)
226
-
require.NoError(t, err)
227
-
require.Equal(t, len(data), n)
228
-
229
-
assert.Equal(t, data, buf)
230
-
}
231
-
}
-19
subscriber.go
-19
subscriber.go
···
1
-
package messagebroker
2
-
3
-
import (
4
-
"fmt"
5
-
"net"
6
-
)
7
-
8
-
type Subscriber struct {
9
-
conn net.Conn
10
-
currentOffset int
11
-
}
12
-
13
-
func (s *Subscriber) SendMessage(data []byte) error {
14
-
_, err := s.conn.Write(data)
15
-
if err != nil {
16
-
return fmt.Errorf("failed to write to connection: %w", err)
17
-
}
18
-
return nil
19
-
}
-50
topic.go
-50
topic.go
···
1
-
package messagebroker
2
-
3
-
import (
4
-
"log/slog"
5
-
"net"
6
-
"sync"
7
-
)
8
-
9
-
type topic struct {
10
-
name string
11
-
subscriptions map[net.Addr]Subscriber
12
-
mu sync.Mutex
13
-
}
14
-
15
-
func newTopic(name string) topic {
16
-
return topic{
17
-
name: name,
18
-
subscriptions: make(map[net.Addr]Subscriber),
19
-
}
20
-
}
21
-
22
-
func (t *topic) addSubscriber(conn net.Conn) {
23
-
t.mu.Lock()
24
-
defer t.mu.Unlock()
25
-
26
-
slog.Info("adding subscriber", "conn", conn.LocalAddr())
27
-
t.subscriptions[conn.LocalAddr()] = Subscriber{conn: conn}
28
-
}
29
-
30
-
func (t *topic) removeSubscriber(addr net.Addr) {
31
-
t.mu.Lock()
32
-
defer t.mu.Unlock()
33
-
34
-
slog.Info("removing subscriber", "conn", addr)
35
-
delete(t.subscriptions, addr)
36
-
}
37
-
38
-
func (t *topic) sendMessageToSubscribers(msg Message) {
39
-
t.mu.Lock()
40
-
subscribers := t.subscriptions
41
-
t.mu.Unlock()
42
-
43
-
for addr, subscriber := range subscribers {
44
-
err := subscriber.SendMessage(msg.Data)
45
-
if err != nil {
46
-
slog.Error("failed to send to message", "error", err, "conn", addr)
47
-
continue
48
-
}
49
-
}
50
-
}