+6
message.go
+6
message.go
+1
-1
peer.go
server/peer.go
+1
-1
peer.go
server/peer.go
+5
-8
server.go
server/server.go
+5
-8
server.go
server/server.go
···
1
-
package messagebroker
1
+
package server
2
2
3
3
import (
4
4
"context"
···
8
8
"net"
9
9
"strings"
10
10
"sync"
11
+
12
+
"github.com/willdot/messagebroker"
11
13
)
12
14
13
15
// Action represents the type of action that a peer requests to do
···
19
21
Publish Action = 3
20
22
)
21
23
22
-
type Message struct {
23
-
Topic string `json:"topic"`
24
-
Data []byte `json:"data"`
25
-
}
26
-
27
24
type Server struct {
28
25
addr string
29
26
lis net.Listener
···
32
29
topics map[string]topic
33
30
}
34
31
35
-
func NewServer(ctx context.Context, addr string) (*Server, error) {
32
+
func New(ctx context.Context, addr string) (*Server, error) {
36
33
lis, err := net.Listen("tcp", addr)
37
34
if err != nil {
38
35
return nil, fmt.Errorf("failed to listen: %w", err)
···
204
201
return
205
202
}
206
203
207
-
var msg Message
204
+
var msg messagebroker.Message
208
205
err = json.Unmarshal(buf, &msg)
209
206
if err != nil {
210
207
_, _ = peer.Write([]byte("invalid message"))
+5
-4
server_test.go
server/server_test.go
+5
-4
server_test.go
server/server_test.go
···
1
-
package messagebroker
1
+
package server
2
2
3
3
import (
4
4
"context"
···
11
11
12
12
"github.com/stretchr/testify/assert"
13
13
"github.com/stretchr/testify/require"
14
+
"github.com/willdot/messagebroker"
14
15
)
15
16
16
17
func createServer(t *testing.T) *Server {
17
-
srv, err := NewServer(context.Background(), ":3000")
18
+
srv, err := New(context.Background(), ":3000")
18
19
require.NoError(t, err)
19
20
20
21
t.Cleanup(func() {
···
205
206
require.NoError(t, err)
206
207
207
208
// send a message
208
-
msg := Message{
209
+
msg := messagebroker.Message{
209
210
Topic: "topic a",
210
211
Data: []byte("hello world"),
211
212
}
···
247
248
248
249
messages := make([][]byte, 0, 10)
249
250
for i := 0; i < 10; i++ {
250
-
msg := Message{
251
+
msg := messagebroker.Message{
251
252
Topic: "topic a",
252
253
Data: []byte(fmt.Sprintf("message %d", i)),
253
254
}
+1
-1
subscriber.go
server/subscriber.go
+1
-1
subscriber.go
server/subscriber.go
+137
subscriber/subscriber.go
+137
subscriber/subscriber.go
···
1
+
package subscriber
2
+
3
+
import (
4
+
"context"
5
+
"encoding/binary"
6
+
"encoding/json"
7
+
"fmt"
8
+
"log/slog"
9
+
"net"
10
+
"time"
11
+
12
+
"github.com/willdot/messagebroker"
13
+
"github.com/willdot/messagebroker/server"
14
+
)
15
+
16
+
type Subscriber struct {
17
+
conn net.Conn
18
+
}
19
+
20
+
func New(addr string) (*Subscriber, error) {
21
+
conn, err := net.Dial("tcp", addr)
22
+
if err != nil {
23
+
return nil, fmt.Errorf("failed to dial: %w", err)
24
+
}
25
+
26
+
return &Subscriber{
27
+
conn: conn,
28
+
}, nil
29
+
}
30
+
31
+
func (s *Subscriber) Close() error {
32
+
return s.conn.Close()
33
+
}
34
+
35
+
func (s *Subscriber) SubscribeToTopics(topicNames []string) error {
36
+
err := binary.Write(s.conn, binary.BigEndian, server.Subscribe)
37
+
if err != nil {
38
+
return fmt.Errorf("failed to subscribe: %w", err)
39
+
}
40
+
41
+
b, err := json.Marshal(topicNames)
42
+
if err != nil {
43
+
return fmt.Errorf("failed to marshal topic names: %w", err)
44
+
}
45
+
46
+
err = binary.Write(s.conn, binary.BigEndian, uint32(len(b)))
47
+
if err != nil {
48
+
return fmt.Errorf("failed to write topic data length: %w", err)
49
+
}
50
+
51
+
_, err = s.conn.Write(b)
52
+
if err != nil {
53
+
return fmt.Errorf("failed to subscribe to topics: %w", err)
54
+
}
55
+
buf := make([]byte, 512)
56
+
_, err = s.conn.Read(buf)
57
+
if err != nil {
58
+
return fmt.Errorf("failed to read confirmation of subscription: %w", err)
59
+
}
60
+
61
+
// TODO: this is soooo hacky - need to have some sort of response code
62
+
if string(buf[:10]) != "subscribed" {
63
+
return fmt.Errorf("failed to subscribe: '%s'", string(buf))
64
+
}
65
+
66
+
return nil
67
+
}
68
+
69
+
type Consumer struct {
70
+
Msgs chan messagebroker.Message
71
+
Err error
72
+
}
73
+
74
+
// TODO: maybe buffer the message channel up?
75
+
func (s *Subscriber) Consume(ctx context.Context) *Consumer {
76
+
consumer := &Consumer{
77
+
Msgs: make(chan messagebroker.Message),
78
+
}
79
+
80
+
go s.consume(ctx, consumer)
81
+
82
+
return consumer
83
+
}
84
+
85
+
func (s *Subscriber) consume(ctx context.Context, consumer *Consumer) {
86
+
defer close(consumer.Msgs)
87
+
for {
88
+
if ctx.Err() != nil {
89
+
return
90
+
}
91
+
92
+
msg, err := s.readMessage()
93
+
if err != nil {
94
+
consumer.Err = err
95
+
return
96
+
}
97
+
98
+
if msg != nil {
99
+
consumer.Msgs <- *msg
100
+
}
101
+
}
102
+
}
103
+
104
+
func (s *Subscriber) readMessage() (*messagebroker.Message, error) {
105
+
err := s.conn.SetReadDeadline(time.Now().Add(time.Second))
106
+
if err != nil {
107
+
return nil, err
108
+
}
109
+
110
+
var dataLen uint64
111
+
err = binary.Read(s.conn, binary.BigEndian, &dataLen)
112
+
if err != nil {
113
+
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
114
+
return nil, nil
115
+
}
116
+
return nil, err
117
+
}
118
+
119
+
if dataLen <= 0 {
120
+
return nil, nil
121
+
}
122
+
123
+
buf := make([]byte, dataLen)
124
+
_, err = s.conn.Read(buf)
125
+
if err != nil {
126
+
return nil, err
127
+
}
128
+
129
+
var msg messagebroker.Message
130
+
err = json.Unmarshal(buf, &msg)
131
+
if err != nil {
132
+
slog.Error("failed to unmarshal message", "error", err)
133
+
return nil, nil
134
+
}
135
+
136
+
return &msg, nil
137
+
}
+139
subscriber/subscriber_test.go
+139
subscriber/subscriber_test.go
···
1
+
package subscriber_test
2
+
3
+
import (
4
+
"context"
5
+
"encoding/binary"
6
+
"encoding/json"
7
+
"fmt"
8
+
"net"
9
+
"testing"
10
+
"time"
11
+
12
+
"github.com/stretchr/testify/assert"
13
+
"github.com/stretchr/testify/require"
14
+
"github.com/willdot/messagebroker"
15
+
"github.com/willdot/messagebroker/server"
16
+
"github.com/willdot/messagebroker/subscriber"
17
+
)
18
+
19
+
const (
20
+
serverAddr = ":3000"
21
+
)
22
+
23
+
func createServer(t *testing.T) {
24
+
server, err := server.New(context.Background(), serverAddr)
25
+
require.NoError(t, err)
26
+
27
+
t.Cleanup(func() {
28
+
server.Shutdown()
29
+
})
30
+
}
31
+
32
+
func TestNew(t *testing.T) {
33
+
createServer(t)
34
+
35
+
sub, err := subscriber.New(serverAddr)
36
+
require.NoError(t, err)
37
+
38
+
t.Cleanup(func() {
39
+
sub.Close()
40
+
})
41
+
}
42
+
43
+
func TestNewInvalidServerAddr(t *testing.T) {
44
+
createServer(t)
45
+
46
+
_, err := subscriber.New(":123456")
47
+
require.Error(t, err)
48
+
}
49
+
50
+
func TestSubscribeToTopics(t *testing.T) {
51
+
createServer(t)
52
+
53
+
sub, err := subscriber.New(serverAddr)
54
+
require.NoError(t, err)
55
+
56
+
t.Cleanup(func() {
57
+
sub.Close()
58
+
})
59
+
60
+
topics := []string{"topic a", "topic b"}
61
+
62
+
err = sub.SubscribeToTopics(topics)
63
+
require.NoError(t, err)
64
+
}
65
+
66
+
func TestSubscribeConsumeFromSubscription(t *testing.T) {
67
+
createServer(t)
68
+
69
+
sub, err := subscriber.New(serverAddr)
70
+
require.NoError(t, err)
71
+
72
+
t.Cleanup(func() {
73
+
sub.Close()
74
+
})
75
+
76
+
topics := []string{"topic a", "topic b"}
77
+
78
+
err = sub.SubscribeToTopics(topics)
79
+
require.NoError(t, err)
80
+
81
+
ctx, cancel := context.WithCancel(context.Background())
82
+
t.Cleanup(func() {
83
+
cancel()
84
+
})
85
+
86
+
consumer := sub.Consume(ctx)
87
+
require.NoError(t, err)
88
+
89
+
var receivedMessages []messagebroker.Message
90
+
91
+
consumerFinCh := make(chan struct{})
92
+
go func() {
93
+
for msg := range consumer.Msgs {
94
+
receivedMessages = append(receivedMessages, msg)
95
+
}
96
+
97
+
require.NoError(t, err)
98
+
consumerFinCh <- struct{}{}
99
+
}()
100
+
101
+
publisherConn, err := net.Dial("tcp", "localhost:3000")
102
+
require.NoError(t, err)
103
+
104
+
err = binary.Write(publisherConn, binary.BigEndian, server.Publish)
105
+
require.NoError(t, err)
106
+
107
+
// send some messages
108
+
sentMessages := make([]messagebroker.Message, 0, 10)
109
+
for i := 0; i < 10; i++ {
110
+
msg := messagebroker.Message{
111
+
Topic: "topic a",
112
+
Data: []byte(fmt.Sprintf("message %d", i)),
113
+
}
114
+
115
+
sentMessages = append(sentMessages, msg)
116
+
117
+
b, err := json.Marshal(msg)
118
+
require.NoError(t, err)
119
+
120
+
err = binary.Write(publisherConn, binary.BigEndian, uint32(len(b)))
121
+
require.NoError(t, err)
122
+
n, err := publisherConn.Write(b)
123
+
require.NoError(t, err)
124
+
require.Equal(t, len(b), n)
125
+
}
126
+
127
+
// give the consumer some time to read the messages -- TODO: make better!
128
+
time.Sleep(time.Millisecond * 500)
129
+
cancel()
130
+
131
+
select {
132
+
case <-consumerFinCh:
133
+
break
134
+
case <-time.After(time.Second):
135
+
t.Fatal("timed out waiting for consumer to read messages")
136
+
}
137
+
138
+
assert.ElementsMatch(t, receivedMessages, sentMessages)
139
+
}
+4
-2
topic.go
server/topic.go
+4
-2
topic.go
server/topic.go
···
1
-
package messagebroker
1
+
package server
2
2
3
3
import (
4
4
"encoding/json"
5
5
"log/slog"
6
6
"net"
7
7
"sync"
8
+
9
+
"github.com/willdot/messagebroker"
8
10
)
9
11
10
12
type topic struct {
···
28
30
delete(t.subscriptions, addr)
29
31
}
30
32
31
-
func (t *topic) sendMessageToSubscribers(msg Message) {
33
+
func (t *topic) sendMessageToSubscribers(msg messagebroker.Message) {
32
34
t.mu.Lock()
33
35
subscribers := t.subscriptions
34
36
t.mu.Unlock()