+51
peer.go
+51
peer.go
···
1
+
package messagebroker
2
+
3
+
import (
4
+
"encoding/binary"
5
+
"fmt"
6
+
"net"
7
+
)
8
+
9
+
type peer struct {
10
+
conn net.Conn
11
+
}
12
+
13
+
func newPeer(conn net.Conn) peer {
14
+
return peer{
15
+
conn: conn,
16
+
}
17
+
}
18
+
19
+
// Read wraps the peers underlying connections Read function to satisfy io.Reader
20
+
func (p *peer) Read(b []byte) (n int, err error) {
21
+
return p.conn.Read(b)
22
+
}
23
+
24
+
// Write wraps the peers underlying connections Write function to satisfy io.Writer
25
+
func (p *peer) Write(b []byte) (n int, err error) {
26
+
return p.conn.Write(b)
27
+
}
28
+
29
+
func (p *peer) addr() net.Addr {
30
+
return p.conn.LocalAddr()
31
+
}
32
+
33
+
func (p *peer) readAction() (Action, error) {
34
+
var action Action
35
+
err := binary.Read(p.conn, binary.BigEndian, &action)
36
+
if err != nil {
37
+
return 0, fmt.Errorf("failed to read action from peer: %w", err)
38
+
}
39
+
40
+
return action, nil
41
+
}
42
+
43
+
func (p *peer) readDataLength() (uint32, error) {
44
+
var dataLen uint32
45
+
err := binary.Read(p.conn, binary.BigEndian, &dataLen)
46
+
if err != nil {
47
+
return 0, fmt.Errorf("failed to read data length from peer: %w", err)
48
+
}
49
+
50
+
return dataLen, nil
51
+
}
+87
-105
server.go
+87
-105
server.go
···
2
2
3
3
import (
4
4
"context"
5
-
"encoding/binary"
6
5
"encoding/json"
7
6
"fmt"
8
7
"log/slog"
···
11
10
"sync"
12
11
)
13
12
14
-
// Action represents the type of action that a connection requests to do
13
+
// Action represents the type of action that a peer requests to do
15
14
type Action uint8
16
15
17
16
const (
···
68
67
}
69
68
}
70
69
71
-
func getActionFromConn(conn net.Conn) (Action, error) {
72
-
var action Action
73
-
err := binary.Read(conn, binary.BigEndian, &action)
74
-
if err != nil {
75
-
return 0, err
76
-
}
77
-
78
-
return action, nil
79
-
}
80
-
81
-
func getDataLengthFromConn(conn net.Conn) (uint32, error) {
82
-
var dataLen uint32
83
-
err := binary.Read(conn, binary.BigEndian, &dataLen)
84
-
if err != nil {
85
-
return 0, fmt.Errorf("failed to read data length from conn: %w", err)
86
-
}
87
-
88
-
return dataLen, nil
89
-
}
90
-
91
70
func (s *Server) handleConn(conn net.Conn) {
92
-
action, err := getActionFromConn(conn)
71
+
peer := newPeer(conn)
72
+
action, err := peer.readAction()
93
73
if err != nil {
94
-
slog.Error("failed to read action from conn", "error", err, "conn", conn.LocalAddr())
74
+
slog.Error("failed to read action from peer", "error", err, "peer", peer.addr())
95
75
return
96
76
}
97
77
98
78
switch action {
99
79
case Subscribe:
100
-
s.handleSubscribingConn(conn)
80
+
s.handleSubscribe(peer)
101
81
case Unsubscribe:
102
-
s.handleUnsubscribingConn(conn)
82
+
s.handleUnsubscribe(peer)
103
83
case Publish:
104
-
s.handlePublisherConn(conn)
84
+
s.handlePublish(peer)
105
85
default:
106
-
slog.Error("unknown action", "action", action, "conn", conn.LocalAddr())
107
-
_, _ = conn.Write([]byte("unknown action"))
86
+
slog.Error("unknown action", "action", action, "peer", peer.addr())
87
+
_, _ = peer.Write([]byte("unknown action"))
108
88
}
109
89
}
110
90
111
-
func (s *Server) handleSubscribingConn(conn net.Conn) {
112
-
// subscribe the connection to the topic
113
-
s.subscribeConnToTopic(conn)
91
+
func (s *Server) handleSubscribe(peer peer) {
92
+
// subscribe the peer to the topic
93
+
s.subscribePeerToTopic(peer)
114
94
115
-
// keep handling the connection, getting the action from the conection when it wishes to do something else.
116
-
// once the connection ends, it will be unsubscribed from all topics and returned
95
+
// keep handling the peers connection, getting the action from the peer when it wishes to do something else.
96
+
// once the peers connection ends, it will be unsubscribed from all topics and returned
117
97
for {
118
-
action, err := getActionFromConn(conn)
98
+
action, err := peer.readAction()
119
99
if err != nil {
120
-
// TODO: see if there's a way to check if the connection has been ended etc
121
-
slog.Error("failed to read action from subscriber", "error", err, "conn", conn.LocalAddr())
100
+
// TODO: see if there's a way to check if the peers connection has been ended etc
101
+
slog.Error("failed to read action from subscriber", "error", err, "peer", peer.addr())
122
102
123
-
s.unsubscribeConnectionFromAllTopics(conn.LocalAddr())
103
+
s.unsubscribePeerFromAllTopics(peer)
124
104
125
105
return
126
106
}
127
107
128
108
switch action {
129
109
case Subscribe:
130
-
s.subscribeConnToTopic(conn)
110
+
s.subscribePeerToTopic(peer)
131
111
case Unsubscribe:
132
-
s.handleUnsubscribingConn(conn)
112
+
s.handleUnsubscribe(peer)
133
113
default:
134
-
slog.Error("unknown action for subscriber", "action", action, "conn", conn.LocalAddr())
114
+
slog.Error("unknown action for subscriber", "action", action, "peer", peer.addr())
135
115
continue
136
116
}
137
117
}
138
118
}
139
119
140
-
func (s *Server) subscribeConnToTopic(conn net.Conn) {
141
-
// get the topics the connection wishes to subscribe to
142
-
dataLen, err := getDataLengthFromConn(conn)
120
+
func (s *Server) subscribePeerToTopic(peer peer) {
121
+
// get the topics the peer wishes to subscribe to
122
+
dataLen, err := peer.readDataLength()
143
123
if err != nil {
144
-
slog.Error(err.Error(), "conn", conn.LocalAddr())
145
-
_, _ = conn.Write([]byte("invalid data length of topics provided"))
124
+
slog.Error(err.Error(), "peer", peer.addr())
125
+
_, _ = peer.Write([]byte("invalid data length of topics provided"))
146
126
return
147
127
}
148
128
if dataLen == 0 {
149
-
_, _ = conn.Write([]byte("data length of topics is 0"))
129
+
_, _ = peer.Write([]byte("data length of topics is 0"))
150
130
return
151
131
}
152
132
153
133
buf := make([]byte, dataLen)
154
-
_, err = conn.Read(buf)
134
+
_, err = peer.Read(buf)
155
135
if err != nil {
156
-
slog.Error("failed to read subscibers topic data", "error", err, "conn", conn.LocalAddr())
157
-
_, _ = conn.Write([]byte("failed to read topic data"))
136
+
slog.Error("failed to read subscibers topic data", "error", err, "peer", peer.addr())
137
+
_, _ = peer.Write([]byte("failed to read topic data"))
158
138
return
159
139
}
160
140
161
141
var topics []string
162
142
err = json.Unmarshal(buf, &topics)
163
143
if err != nil {
164
-
slog.Error("failed to unmarshal subscibers topic data", "error", err, "conn", conn.LocalAddr())
165
-
_, _ = conn.Write([]byte("invalid topic data provided"))
144
+
slog.Error("failed to unmarshal subscibers topic data", "error", err, "peer", peer.addr())
145
+
_, _ = peer.Write([]byte("invalid topic data provided"))
166
146
return
167
147
}
168
148
169
-
s.subscribeToTopics(conn, topics)
170
-
_, _ = conn.Write([]byte("subscribed"))
149
+
s.subscribeToTopics(peer, topics)
150
+
_, _ = peer.Write([]byte("subscribed"))
171
151
}
172
152
173
-
func (s *Server) handleUnsubscribingConn(conn net.Conn) {
174
-
// get the topics the connection wishes to unsubscribe from
175
-
dataLen, err := getDataLengthFromConn(conn)
153
+
func (s *Server) handleUnsubscribe(peer peer) {
154
+
// get the topics the peer wishes to unsubscribe from
155
+
dataLen, err := peer.readDataLength()
176
156
if err != nil {
177
-
slog.Error(err.Error(), "conn", conn.LocalAddr())
178
-
_, _ = conn.Write([]byte("invalid data length of topics provided"))
157
+
slog.Error(err.Error(), "peer", peer.addr())
158
+
_, _ = peer.Write([]byte("invalid data length of topics provided"))
179
159
return
180
160
}
181
161
if dataLen == 0 {
182
-
_, _ = conn.Write([]byte("data length of topics is 0"))
162
+
_, _ = peer.Write([]byte("data length of topics is 0"))
183
163
return
184
164
}
185
165
186
166
buf := make([]byte, dataLen)
187
-
_, err = conn.Read(buf)
167
+
_, err = peer.Read(buf)
188
168
if err != nil {
189
-
slog.Error("failed to read subscibers topic data", "error", err, "conn", conn.LocalAddr())
190
-
_, _ = conn.Write([]byte("failed to read topic data"))
169
+
slog.Error("failed to read subscibers topic data", "error", err, "peer", peer.addr())
170
+
_, _ = peer.Write([]byte("failed to read topic data"))
191
171
return
192
172
}
193
173
194
174
var topics []string
195
175
err = json.Unmarshal(buf, &topics)
196
176
if err != nil {
197
-
slog.Error("failed to unmarshal subscibers topic data", "error", err, "conn", conn.LocalAddr())
198
-
_, _ = conn.Write([]byte("invalid topic data provided"))
177
+
slog.Error("failed to unmarshal subscibers topic data", "error", err, "peer", peer.addr())
178
+
_, _ = peer.Write([]byte("invalid topic data provided"))
199
179
return
200
180
}
201
181
202
-
s.unsubscribeToTopics(conn, topics)
182
+
s.unsubscribeToTopics(peer, topics)
203
183
204
-
_, _ = conn.Write([]byte("unsubscribed"))
184
+
_, _ = peer.Write([]byte("unsubscribed"))
205
185
}
206
186
207
-
func (s *Server) handlePublisherConn(conn net.Conn) {
208
-
dataLen, err := getDataLengthFromConn(conn)
209
-
if err != nil {
210
-
slog.Error(err.Error(), "conn", conn.LocalAddr())
211
-
_, _ = conn.Write([]byte("invalid data length of data provided"))
212
-
return
213
-
}
214
-
if dataLen == 0 {
215
-
return
216
-
}
187
+
func (s *Server) handlePublish(peer peer) {
188
+
for {
189
+
dataLen, err := peer.readDataLength()
190
+
if err != nil {
191
+
slog.Error(err.Error(), "peer", peer.addr())
192
+
_, _ = peer.Write([]byte("invalid data length of data provided"))
193
+
return
194
+
}
195
+
if dataLen == 0 {
196
+
continue
197
+
}
217
198
218
-
buf := make([]byte, dataLen)
219
-
_, err = conn.Read(buf)
220
-
if err != nil {
221
-
_, _ = conn.Write([]byte("failed to read data"))
222
-
slog.Error("failed to read data from conn", "error", err, "conn", conn.LocalAddr())
223
-
return
224
-
}
199
+
buf := make([]byte, dataLen)
200
+
_, err = peer.Read(buf)
201
+
if err != nil {
202
+
_, _ = peer.Write([]byte("failed to read data"))
203
+
slog.Error("failed to read data from peer", "error", err, "peer", peer.addr())
204
+
return
205
+
}
225
206
226
-
var msg Message
227
-
err = json.Unmarshal(buf, &msg)
228
-
if err != nil {
229
-
_, _ = conn.Write([]byte("invalid message"))
230
-
slog.Error("failed to unmarshal data to message", "error", err, "conn", conn.LocalAddr())
231
-
return
232
-
}
207
+
var msg Message
208
+
err = json.Unmarshal(buf, &msg)
209
+
if err != nil {
210
+
_, _ = peer.Write([]byte("invalid message"))
211
+
slog.Error("failed to unmarshal data to message", "error", err, "peer", peer.addr())
212
+
continue
213
+
}
233
214
234
-
topic := s.getTopic(msg.Topic)
235
-
if topic != nil {
236
-
topic.sendMessageToSubscribers(msg)
215
+
topic := s.getTopic(msg.Topic)
216
+
if topic != nil {
217
+
topic.sendMessageToSubscribers(msg)
218
+
}
237
219
}
238
220
}
239
221
240
-
func (s *Server) subscribeToTopics(conn net.Conn, topics []string) {
222
+
func (s *Server) subscribeToTopics(peer peer, topics []string) {
241
223
for _, topic := range topics {
242
-
s.addSubsciberToTopic(topic, conn)
224
+
s.addSubsciberToTopic(topic, peer)
243
225
}
244
226
}
245
227
246
-
func (s *Server) addSubsciberToTopic(topicName string, conn net.Conn) {
228
+
func (s *Server) addSubsciberToTopic(topicName string, peer peer) {
247
229
s.mu.Lock()
248
230
defer s.mu.Unlock()
249
231
···
252
234
t = newTopic(topicName)
253
235
}
254
236
255
-
t.subscriptions[conn.LocalAddr()] = Subscriber{
256
-
conn: conn,
237
+
t.subscriptions[peer.addr()] = Subscriber{
238
+
peer: peer,
257
239
currentOffset: 0,
258
240
}
259
241
260
242
s.topics[topicName] = t
261
243
}
262
244
263
-
func (s *Server) unsubscribeToTopics(conn net.Conn, topics []string) {
245
+
func (s *Server) unsubscribeToTopics(peer peer, topics []string) {
264
246
for _, topic := range topics {
265
-
s.removeSubsciberFromTopic(topic, conn)
247
+
s.removeSubsciberFromTopic(topic, peer)
266
248
}
267
249
}
268
250
269
-
func (s *Server) removeSubsciberFromTopic(topicName string, conn net.Conn) {
251
+
func (s *Server) removeSubsciberFromTopic(topicName string, peer peer) {
270
252
s.mu.Lock()
271
253
defer s.mu.Unlock()
272
254
···
275
257
return
276
258
}
277
259
278
-
delete(t.subscriptions, conn.LocalAddr())
260
+
delete(t.subscriptions, peer.addr())
279
261
}
280
262
281
-
func (s *Server) unsubscribeConnectionFromAllTopics(addr net.Addr) {
263
+
func (s *Server) unsubscribePeerFromAllTopics(peer peer) {
282
264
s.mu.Lock()
283
265
defer s.mu.Unlock()
284
266
285
267
for _, topic := range s.topics {
286
-
delete(topic.subscriptions, addr)
268
+
delete(topic.subscriptions, peer.addr())
287
269
}
288
270
}
289
271
+73
-6
server_test.go
+73
-6
server_test.go
···
4
4
"context"
5
5
"encoding/binary"
6
6
"encoding/json"
7
+
"fmt"
7
8
"net"
8
9
"testing"
10
+
"time"
9
11
10
12
"github.com/stretchr/testify/assert"
11
13
"github.com/stretchr/testify/require"
···
202
204
err = binary.Write(publisherConn, binary.BigEndian, Publish)
203
205
require.NoError(t, err)
204
206
205
-
// send some data
206
-
data := []byte("hello world")
207
+
// send a message
207
208
msg := Message{
208
209
Topic: "topic a",
209
-
Data: data,
210
+
Data: []byte("hello world"),
210
211
}
211
212
212
213
rawMsg, err := json.Marshal(msg)
···
221
222
222
223
// check the subsribers got the data
223
224
for _, conn := range subscribers {
224
-
buf := make([]byte, len(data))
225
+
226
+
var dataLen uint64
227
+
err = binary.Read(conn, binary.BigEndian, &dataLen)
228
+
require.NoError(t, err)
229
+
230
+
buf := make([]byte, dataLen)
225
231
n, err := conn.Read(buf)
226
232
require.NoError(t, err)
227
-
require.Equal(t, len(data), n)
233
+
require.Equal(t, int(dataLen), n)
234
+
235
+
assert.Equal(t, rawMsg, buf)
236
+
}
237
+
}
238
+
239
+
func TestPublishMultipleTimes(t *testing.T) {
240
+
_ = createServer(t)
241
+
242
+
publisherConn, err := net.Dial("tcp", "localhost:3000")
243
+
require.NoError(t, err)
244
+
245
+
err = binary.Write(publisherConn, binary.BigEndian, Publish)
246
+
require.NoError(t, err)
247
+
248
+
messages := make([][]byte, 0, 10)
249
+
for i := 0; i < 10; i++ {
250
+
msg := Message{
251
+
Topic: "topic a",
252
+
Data: []byte(fmt.Sprintf("message %d", i)),
253
+
}
254
+
255
+
rawMsg, err := json.Marshal(msg)
256
+
require.NoError(t, err)
257
+
258
+
messages = append(messages, rawMsg)
259
+
}
260
+
261
+
subscribeFinCh := make(chan struct{})
262
+
// create a subscriber that will read messages
263
+
subscriberConn := createConnectionAndSubscribe(t, []string{"topic a", "topic b"})
264
+
go func() {
265
+
// check subscriber got all messages
266
+
for _, msg := range messages {
267
+
var dataLen uint64
268
+
err = binary.Read(subscriberConn, binary.BigEndian, &dataLen)
269
+
require.NoError(t, err)
270
+
271
+
buf := make([]byte, dataLen)
272
+
n, err := subscriberConn.Read(buf)
273
+
require.NoError(t, err)
274
+
require.Equal(t, int(dataLen), n)
228
275
229
-
assert.Equal(t, data, buf)
276
+
assert.Equal(t, msg, buf)
277
+
}
278
+
279
+
subscribeFinCh <- struct{}{}
280
+
}()
281
+
282
+
// send multiple messages
283
+
for _, msg := range messages {
284
+
// send data length first
285
+
err = binary.Write(publisherConn, binary.BigEndian, uint32(len(msg)))
286
+
require.NoError(t, err)
287
+
n, err := publisherConn.Write(msg)
288
+
require.NoError(t, err)
289
+
require.Equal(t, len(msg), n)
290
+
}
291
+
292
+
select {
293
+
case <-subscribeFinCh:
294
+
break
295
+
case <-time.After(time.Second):
296
+
t.Fatal(fmt.Errorf("timed out waiting for subscriber to read messages"))
230
297
}
231
298
}
+12
-5
subscriber.go
+12
-5
subscriber.go
···
1
1
package messagebroker
2
2
3
3
import (
4
+
"encoding/binary"
4
5
"fmt"
5
-
"net"
6
6
)
7
7
8
8
type Subscriber struct {
9
-
conn net.Conn
9
+
peer peer
10
10
currentOffset int
11
11
}
12
12
13
-
func (s *Subscriber) SendMessage(data []byte) error {
14
-
_, err := s.conn.Write(data)
13
+
func (s *Subscriber) SendMessage(msg []byte) error {
14
+
dataLen := uint64(len(msg))
15
+
16
+
err := binary.Write(&s.peer, binary.BigEndian, dataLen)
17
+
if err != nil {
18
+
return fmt.Errorf("failed to send data length: %w", err)
19
+
}
20
+
21
+
_, err = s.peer.Write(msg)
15
22
if err != nil {
16
-
return fmt.Errorf("failed to write to connection: %w", err)
23
+
return fmt.Errorf("failed to write to peer: %w", err)
17
24
}
18
25
return nil
19
26
}
+9
-11
topic.go
+9
-11
topic.go
···
1
1
package messagebroker
2
2
3
3
import (
4
+
"encoding/json"
4
5
"log/slog"
5
6
"net"
6
7
"sync"
···
19
20
}
20
21
}
21
22
22
-
func (t *topic) addSubscriber(conn net.Conn) {
23
-
t.mu.Lock()
24
-
defer t.mu.Unlock()
25
-
26
-
slog.Info("adding subscriber", "conn", conn.LocalAddr())
27
-
t.subscriptions[conn.LocalAddr()] = Subscriber{conn: conn}
28
-
}
29
-
30
23
func (t *topic) removeSubscriber(addr net.Addr) {
31
24
t.mu.Lock()
32
25
defer t.mu.Unlock()
33
26
34
-
slog.Info("removing subscriber", "conn", addr)
27
+
slog.Info("removing subscriber", "peer", addr)
35
28
delete(t.subscriptions, addr)
36
29
}
37
30
···
40
33
subscribers := t.subscriptions
41
34
t.mu.Unlock()
42
35
36
+
msgData, err := json.Marshal(msg)
37
+
if err != nil {
38
+
slog.Error("failed to marshal message for subscribers", "error", err)
39
+
}
40
+
43
41
for addr, subscriber := range subscribers {
44
-
err := subscriber.SendMessage(msg.Data)
42
+
err := subscriber.SendMessage(msgData)
45
43
if err != nil {
46
-
slog.Error("failed to send to message", "error", err, "conn", addr)
44
+
slog.Error("failed to send to message", "error", err, "peer", addr)
47
45
continue
48
46
}
49
47
}