+1
-21
.gitignore
+1
-21
.gitignore
···
1
-
# If you prefer the allow list template instead of the deny list, see community template:
2
-
# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
3
-
#
4
-
# Binaries for programs and plugins
5
-
*.exe
6
-
*.exe~
7
-
*.dll
8
-
*.so
9
-
*.dylib
10
-
11
-
# Test binary, built with `go test -c`
12
-
*.test
13
-
14
-
# Output of the go coverage tool, specifically when used with LiteIDE
15
-
*.out
16
-
17
-
# Dependency directories (remove the comment below to include it)
18
-
# vendor/
19
-
20
-
# Go workspace file
21
-
go.work
1
+
.DS_STORE
+11
go.mod
+11
go.mod
+10
go.sum
+10
go.sum
···
1
+
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
2
+
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
3
+
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
4
+
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
5
+
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
6
+
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
7
+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
8
+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
9
+
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
10
+
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+299
server.go
+299
server.go
···
1
+
package main
2
+
3
+
import (
4
+
"context"
5
+
"encoding/binary"
6
+
"encoding/json"
7
+
"fmt"
8
+
"log/slog"
9
+
"net"
10
+
"strings"
11
+
"sync"
12
+
)
13
+
14
+
// Action represents the type of action that a connection requests to do
15
+
type Action uint8
16
+
17
+
const (
18
+
Subscribe Action = 1
19
+
Unsubscribe Action = 2
20
+
Publish Action = 3
21
+
)
22
+
23
+
type message struct {
24
+
Topic string `json:"topic"`
25
+
Data []byte `json:"data"`
26
+
}
27
+
28
+
type Server struct {
29
+
addr string
30
+
lis net.Listener
31
+
32
+
mu sync.Mutex
33
+
topics map[string]topic
34
+
}
35
+
36
+
func NewServer(ctx context.Context, addr string) (*Server, error) {
37
+
lis, err := net.Listen("tcp", addr)
38
+
if err != nil {
39
+
return nil, fmt.Errorf("failed to listen: %w", err)
40
+
}
41
+
42
+
srv := &Server{
43
+
lis: lis,
44
+
topics: map[string]topic{},
45
+
}
46
+
47
+
go srv.start(ctx)
48
+
49
+
return srv, nil
50
+
}
51
+
52
+
func (s *Server) Shutdown() error {
53
+
return s.lis.Close()
54
+
}
55
+
56
+
func (s *Server) start(ctx context.Context) {
57
+
for {
58
+
conn, err := s.lis.Accept()
59
+
if err != nil {
60
+
slog.Error("listener failed to accept", "error", err)
61
+
// TODO: see if there's a better way to check for this error
62
+
if strings.Contains(err.Error(), "use of closed network connection") {
63
+
return
64
+
}
65
+
}
66
+
67
+
go s.handleConn(conn)
68
+
}
69
+
}
70
+
71
+
func getActionFromConn(conn net.Conn) (Action, error) {
72
+
var action Action
73
+
err := binary.Read(conn, binary.BigEndian, &action)
74
+
if err != nil {
75
+
return 0, err
76
+
}
77
+
78
+
return action, nil
79
+
}
80
+
81
+
func getDataLengthFromConn(conn net.Conn) (uint32, error) {
82
+
var dataLen uint32
83
+
err := binary.Read(conn, binary.BigEndian, &dataLen)
84
+
if err != nil {
85
+
return 0, fmt.Errorf("failed to read data length from conn: %w", err)
86
+
}
87
+
88
+
return dataLen, nil
89
+
}
90
+
91
+
func (s *Server) handleConn(conn net.Conn) {
92
+
action, err := getActionFromConn(conn)
93
+
if err != nil {
94
+
slog.Error("failed to read action from conn", "error", err, "conn", conn.LocalAddr())
95
+
return
96
+
}
97
+
98
+
switch action {
99
+
case Subscribe:
100
+
s.handleSubscribingConn(conn)
101
+
case Unsubscribe:
102
+
s.handleUnsubscribingConn(conn)
103
+
case Publish:
104
+
s.handlePublisherConn(conn)
105
+
default:
106
+
slog.Error("unknown action", "action", action, "conn", conn.LocalAddr())
107
+
_, _ = conn.Write([]byte("unknown action"))
108
+
}
109
+
}
110
+
111
+
func (s *Server) handleSubscribingConn(conn net.Conn) {
112
+
// subscribe the connection to the topic
113
+
s.subscribeConnToTopic(conn)
114
+
115
+
// keep handling the connection, getting the action from the conection when it wishes to do something else.
116
+
// once the connection ends, it will be unsubscribed from all topics and returned
117
+
for {
118
+
action, err := getActionFromConn(conn)
119
+
if err != nil {
120
+
// TODO: see if there's a way to check if the connection has been ended etc
121
+
slog.Error("failed to read action from subscriber", "error", err, "conn", conn.LocalAddr())
122
+
123
+
s.unsubscribeConnectionFromAllTopics(conn.LocalAddr())
124
+
125
+
return
126
+
}
127
+
128
+
switch action {
129
+
case Subscribe:
130
+
s.subscribeConnToTopic(conn)
131
+
case Unsubscribe:
132
+
s.handleUnsubscribingConn(conn)
133
+
default:
134
+
slog.Error("unknown action for subscriber", "action", action, "conn", conn.LocalAddr())
135
+
continue
136
+
}
137
+
}
138
+
}
139
+
140
+
func (s *Server) subscribeConnToTopic(conn net.Conn) {
141
+
// get the topics the connection wishes to subscribe to
142
+
dataLen, err := getDataLengthFromConn(conn)
143
+
if err != nil {
144
+
slog.Error(err.Error(), "conn", conn.LocalAddr())
145
+
_, _ = conn.Write([]byte("invalid data length of topics provided"))
146
+
return
147
+
}
148
+
if dataLen == 0 {
149
+
_, _ = conn.Write([]byte("data length of topics is 0"))
150
+
return
151
+
}
152
+
153
+
buf := make([]byte, dataLen)
154
+
_, err = conn.Read(buf)
155
+
if err != nil {
156
+
slog.Error("failed to read subscibers topic data", "error", err, "conn", conn.LocalAddr())
157
+
_, _ = conn.Write([]byte("failed to read topic data"))
158
+
return
159
+
}
160
+
161
+
var topics []string
162
+
err = json.Unmarshal(buf, &topics)
163
+
if err != nil {
164
+
slog.Error("failed to unmarshal subscibers topic data", "error", err, "conn", conn.LocalAddr())
165
+
_, _ = conn.Write([]byte("invalid topic data provided"))
166
+
return
167
+
}
168
+
169
+
s.subscribeToTopics(conn, topics)
170
+
_, _ = conn.Write([]byte("subscribed"))
171
+
}
172
+
173
+
func (s *Server) handleUnsubscribingConn(conn net.Conn) {
174
+
// get the topics the connection wishes to unsubscribe from
175
+
dataLen, err := getDataLengthFromConn(conn)
176
+
if err != nil {
177
+
slog.Error(err.Error(), "conn", conn.LocalAddr())
178
+
_, _ = conn.Write([]byte("invalid data length of topics provided"))
179
+
return
180
+
}
181
+
if dataLen == 0 {
182
+
_, _ = conn.Write([]byte("data length of topics is 0"))
183
+
return
184
+
}
185
+
186
+
buf := make([]byte, dataLen)
187
+
_, err = conn.Read(buf)
188
+
if err != nil {
189
+
slog.Error("failed to read subscibers topic data", "error", err, "conn", conn.LocalAddr())
190
+
_, _ = conn.Write([]byte("failed to read topic data"))
191
+
return
192
+
}
193
+
194
+
var topics []string
195
+
err = json.Unmarshal(buf, &topics)
196
+
if err != nil {
197
+
slog.Error("failed to unmarshal subscibers topic data", "error", err, "conn", conn.LocalAddr())
198
+
_, _ = conn.Write([]byte("invalid topic data provided"))
199
+
return
200
+
}
201
+
202
+
s.unsubscribeToTopics(conn, topics)
203
+
204
+
_, _ = conn.Write([]byte("unsubscribed"))
205
+
}
206
+
207
+
func (s *Server) handlePublisherConn(conn net.Conn) {
208
+
dataLen, err := getDataLengthFromConn(conn)
209
+
if err != nil {
210
+
slog.Error(err.Error(), "conn", conn.LocalAddr())
211
+
_, _ = conn.Write([]byte("invalid data length of data provided"))
212
+
return
213
+
}
214
+
if dataLen == 0 {
215
+
return
216
+
}
217
+
218
+
buf := make([]byte, dataLen)
219
+
_, err = conn.Read(buf)
220
+
if err != nil {
221
+
_, _ = conn.Write([]byte("failed to read data"))
222
+
slog.Error("failed to read data from conn", "error", err, "conn", conn.LocalAddr())
223
+
return
224
+
}
225
+
226
+
var msg message
227
+
err = json.Unmarshal(buf, &msg)
228
+
if err != nil {
229
+
_, _ = conn.Write([]byte("invalid message"))
230
+
slog.Error("failed to unmarshal data to message", "error", err, "conn", conn.LocalAddr())
231
+
return
232
+
}
233
+
234
+
topic := s.getTopic(msg.Topic)
235
+
if topic != nil {
236
+
topic.sendMessageToSubscribers(msg)
237
+
}
238
+
}
239
+
240
+
func (s *Server) subscribeToTopics(conn net.Conn, topics []string) {
241
+
for _, topic := range topics {
242
+
s.addSubsciberToTopic(topic, conn)
243
+
}
244
+
}
245
+
246
+
func (s *Server) addSubsciberToTopic(topicName string, conn net.Conn) {
247
+
s.mu.Lock()
248
+
defer s.mu.Unlock()
249
+
250
+
t, ok := s.topics[topicName]
251
+
if !ok {
252
+
t = newTopic(topicName)
253
+
}
254
+
255
+
t.subscriptions[conn.LocalAddr()] = Subscriber{
256
+
conn: conn,
257
+
currentOffset: 0,
258
+
}
259
+
260
+
s.topics[topicName] = t
261
+
}
262
+
263
+
func (s *Server) unsubscribeToTopics(conn net.Conn, topics []string) {
264
+
for _, topic := range topics {
265
+
s.removeSubsciberFromTopic(topic, conn)
266
+
}
267
+
}
268
+
269
+
func (s *Server) removeSubsciberFromTopic(topicName string, conn net.Conn) {
270
+
s.mu.Lock()
271
+
defer s.mu.Unlock()
272
+
273
+
t, ok := s.topics[topicName]
274
+
if !ok {
275
+
return
276
+
}
277
+
278
+
delete(t.subscriptions, conn.LocalAddr())
279
+
}
280
+
281
+
func (s *Server) unsubscribeConnectionFromAllTopics(addr net.Addr) {
282
+
s.mu.Lock()
283
+
defer s.mu.Unlock()
284
+
285
+
for _, topic := range s.topics {
286
+
delete(topic.subscriptions, addr)
287
+
}
288
+
}
289
+
290
+
func (s *Server) getTopic(topicName string) *topic {
291
+
s.mu.Lock()
292
+
defer s.mu.Unlock()
293
+
294
+
if topic, ok := s.topics[topicName]; ok {
295
+
return &topic
296
+
}
297
+
298
+
return nil
299
+
}
+231
server_test.go
+231
server_test.go
···
1
+
package main
2
+
3
+
import (
4
+
"context"
5
+
"encoding/binary"
6
+
"encoding/json"
7
+
"net"
8
+
"testing"
9
+
10
+
"github.com/stretchr/testify/assert"
11
+
"github.com/stretchr/testify/require"
12
+
)
13
+
14
+
func createServer(t *testing.T) *Server {
15
+
srv, err := NewServer(context.Background(), ":3000")
16
+
require.NoError(t, err)
17
+
18
+
t.Cleanup(func() {
19
+
srv.Shutdown()
20
+
})
21
+
22
+
return srv
23
+
}
24
+
25
+
func createServerWithExistingTopic(t *testing.T, topicName string) *Server {
26
+
srv := createServer(t)
27
+
srv.topics[topicName] = topic{
28
+
name: topicName,
29
+
subscriptions: make(map[net.Addr]Subscriber),
30
+
}
31
+
32
+
return srv
33
+
}
34
+
35
+
func createConnectionAndSubscribe(t *testing.T, topics []string) net.Conn {
36
+
conn, err := net.Dial("tcp", "localhost:3000")
37
+
require.NoError(t, err)
38
+
39
+
err = binary.Write(conn, binary.BigEndian, Subscribe)
40
+
require.NoError(t, err)
41
+
42
+
rawTopics, err := json.Marshal(topics)
43
+
require.NoError(t, err)
44
+
45
+
err = binary.Write(conn, binary.BigEndian, uint32(len(rawTopics)))
46
+
require.NoError(t, err)
47
+
48
+
_, err = conn.Write(rawTopics)
49
+
require.NoError(t, err)
50
+
51
+
expectedRes := "subscribed"
52
+
53
+
buf := make([]byte, len(expectedRes))
54
+
n, err := conn.Read(buf)
55
+
require.NoError(t, err)
56
+
require.Equal(t, len(expectedRes), n)
57
+
58
+
assert.Equal(t, expectedRes, string(buf))
59
+
60
+
return conn
61
+
}
62
+
63
+
func TestSubscribeToTopics(t *testing.T) {
64
+
// create a server with an existing topic so we can test subscribing to a new and
65
+
// existing topic
66
+
srv := createServerWithExistingTopic(t, "topic a")
67
+
68
+
_ = createConnectionAndSubscribe(t, []string{"topic a", "topic b"})
69
+
70
+
assert.Len(t, srv.topics, 2)
71
+
assert.Len(t, srv.topics["topic a"].subscriptions, 1)
72
+
assert.Len(t, srv.topics["topic b"].subscriptions, 1)
73
+
}
74
+
75
+
func TestUnsubscribesFromTopic(t *testing.T) {
76
+
srv := createServerWithExistingTopic(t, "topic a")
77
+
78
+
conn := createConnectionAndSubscribe(t, []string{"topic a", "topic b", "topic c"})
79
+
80
+
assert.Len(t, srv.topics, 3)
81
+
assert.Len(t, srv.topics["topic a"].subscriptions, 1)
82
+
assert.Len(t, srv.topics["topic b"].subscriptions, 1)
83
+
assert.Len(t, srv.topics["topic c"].subscriptions, 1)
84
+
85
+
err := binary.Write(conn, binary.BigEndian, Unsubscribe)
86
+
require.NoError(t, err)
87
+
88
+
topics := []string{"topic a", "topic b"}
89
+
rawTopics, err := json.Marshal(topics)
90
+
require.NoError(t, err)
91
+
92
+
err = binary.Write(conn, binary.BigEndian, uint32(len(rawTopics)))
93
+
require.NoError(t, err)
94
+
95
+
_, err = conn.Write(rawTopics)
96
+
require.NoError(t, err)
97
+
98
+
expectedRes := "unsubscribed"
99
+
100
+
buf := make([]byte, len(expectedRes))
101
+
n, err := conn.Read(buf)
102
+
require.NoError(t, err)
103
+
require.Equal(t, len(expectedRes), n)
104
+
105
+
assert.Equal(t, expectedRes, string(buf))
106
+
107
+
assert.Len(t, srv.topics, 3)
108
+
assert.Len(t, srv.topics["topic a"].subscriptions, 0)
109
+
assert.Len(t, srv.topics["topic b"].subscriptions, 0)
110
+
assert.Len(t, srv.topics["topic c"].subscriptions, 1)
111
+
}
112
+
113
+
func TestSubscriberClosesWithoutUnsubscribing(t *testing.T) {
114
+
srv := createServer(t)
115
+
116
+
conn := createConnectionAndSubscribe(t, []string{"topic a", "topic b"})
117
+
118
+
assert.Len(t, srv.topics, 2)
119
+
assert.Len(t, srv.topics["topic a"].subscriptions, 1)
120
+
assert.Len(t, srv.topics["topic b"].subscriptions, 1)
121
+
122
+
// close the conn
123
+
err := conn.Close()
124
+
require.NoError(t, err)
125
+
126
+
publisherConn, err := net.Dial("tcp", "localhost:3000")
127
+
require.NoError(t, err)
128
+
129
+
err = binary.Write(publisherConn, binary.BigEndian, Publish)
130
+
require.NoError(t, err)
131
+
132
+
data := []byte("hello world")
133
+
// send data length first
134
+
err = binary.Write(publisherConn, binary.BigEndian, uint32(len(data)))
135
+
require.NoError(t, err)
136
+
n, err := publisherConn.Write(data)
137
+
require.NoError(t, err)
138
+
require.Equal(t, len(data), n)
139
+
140
+
assert.Len(t, srv.topics, 2)
141
+
assert.Len(t, srv.topics["topic a"].subscriptions, 0)
142
+
assert.Len(t, srv.topics["topic b"].subscriptions, 0)
143
+
}
144
+
145
+
func TestInvalidAction(t *testing.T) {
146
+
_ = createServer(t)
147
+
148
+
conn, err := net.Dial("tcp", "localhost:3000")
149
+
require.NoError(t, err)
150
+
151
+
err = binary.Write(conn, binary.BigEndian, uint8(99))
152
+
require.NoError(t, err)
153
+
154
+
expectedRes := "unknown action"
155
+
156
+
buf := make([]byte, len(expectedRes))
157
+
n, err := conn.Read(buf)
158
+
require.NoError(t, err)
159
+
require.Equal(t, len(expectedRes), n)
160
+
161
+
assert.Equal(t, expectedRes, string(buf))
162
+
}
163
+
164
+
func TestInvalidMessagePublished(t *testing.T) {
165
+
_ = createServer(t)
166
+
167
+
publisherConn, err := net.Dial("tcp", "localhost:3000")
168
+
require.NoError(t, err)
169
+
170
+
err = binary.Write(publisherConn, binary.BigEndian, Publish)
171
+
require.NoError(t, err)
172
+
173
+
// send some data
174
+
data := []byte("this isn't wrapped in a message type")
175
+
176
+
// send data length first
177
+
err = binary.Write(publisherConn, binary.BigEndian, uint32(len(data)))
178
+
require.NoError(t, err)
179
+
n, err := publisherConn.Write(data)
180
+
require.NoError(t, err)
181
+
require.Equal(t, len(data), n)
182
+
183
+
buf := make([]byte, 15)
184
+
_, err = publisherConn.Read(buf)
185
+
require.NoError(t, err)
186
+
assert.Equal(t, "invalid message", string(buf))
187
+
}
188
+
189
+
func TestSendsDataToTopicSubscribers(t *testing.T) {
190
+
_ = createServer(t)
191
+
192
+
subscribers := make([]net.Conn, 0, 5)
193
+
for i := 0; i < 5; i++ {
194
+
subscriberConn := createConnectionAndSubscribe(t, []string{"topic a", "topic b"})
195
+
196
+
subscribers = append(subscribers, subscriberConn)
197
+
}
198
+
199
+
publisherConn, err := net.Dial("tcp", "localhost:3000")
200
+
require.NoError(t, err)
201
+
202
+
err = binary.Write(publisherConn, binary.BigEndian, Publish)
203
+
require.NoError(t, err)
204
+
205
+
// send some data
206
+
data := []byte("hello world")
207
+
msg := message{
208
+
Topic: "topic a",
209
+
Data: data,
210
+
}
211
+
212
+
rawMsg, err := json.Marshal(msg)
213
+
require.NoError(t, err)
214
+
215
+
// send data length first
216
+
err = binary.Write(publisherConn, binary.BigEndian, uint32(len(rawMsg)))
217
+
require.NoError(t, err)
218
+
n, err := publisherConn.Write(rawMsg)
219
+
require.NoError(t, err)
220
+
require.Equal(t, len(rawMsg), n)
221
+
222
+
// check the subsribers got the data
223
+
for _, conn := range subscribers {
224
+
buf := make([]byte, len(data))
225
+
n, err := conn.Read(buf)
226
+
require.NoError(t, err)
227
+
require.Equal(t, len(data), n)
228
+
229
+
assert.Equal(t, data, buf)
230
+
}
231
+
}
+19
subscriber.go
+19
subscriber.go
···
1
+
package main
2
+
3
+
import (
4
+
"fmt"
5
+
"net"
6
+
)
7
+
8
+
type Subscriber struct {
9
+
conn net.Conn
10
+
currentOffset int
11
+
}
12
+
13
+
func (s *Subscriber) SendMessage(data []byte) error {
14
+
_, err := s.conn.Write(data)
15
+
if err != nil {
16
+
return fmt.Errorf("failed to write to connection: %w", err)
17
+
}
18
+
return nil
19
+
}
+50
topic.go
+50
topic.go
···
1
+
package main
2
+
3
+
import (
4
+
"log/slog"
5
+
"net"
6
+
"sync"
7
+
)
8
+
9
+
type topic struct {
10
+
name string
11
+
subscriptions map[net.Addr]Subscriber
12
+
mu sync.Mutex
13
+
}
14
+
15
+
func newTopic(name string) topic {
16
+
return topic{
17
+
name: name,
18
+
subscriptions: make(map[net.Addr]Subscriber),
19
+
}
20
+
}
21
+
22
+
func (t *topic) addSubscriber(conn net.Conn) {
23
+
t.mu.Lock()
24
+
defer t.mu.Unlock()
25
+
26
+
slog.Info("adding subscriber", "conn", conn.LocalAddr())
27
+
t.subscriptions[conn.LocalAddr()] = Subscriber{conn: conn}
28
+
}
29
+
30
+
func (t *topic) removeSubscriber(addr net.Addr) {
31
+
t.mu.Lock()
32
+
defer t.mu.Unlock()
33
+
34
+
slog.Info("removing subscriber", "conn", addr)
35
+
delete(t.subscriptions, addr)
36
+
}
37
+
38
+
func (t *topic) sendMessageToSubscribers(msg message) {
39
+
t.mu.Lock()
40
+
subscribers := t.subscriptions
41
+
t.mu.Unlock()
42
+
43
+
for addr, subscriber := range subscribers {
44
+
err := subscriber.SendMessage(msg.Data)
45
+
if err != nil {
46
+
slog.Error("failed to send to message", "error", err, "conn", addr)
47
+
continue
48
+
}
49
+
}
50
+
}