An experimental pub/sub client and server project.
1package client
2
3import (
4 "context"
5 "encoding/binary"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "log/slog"
11 "net"
12 "sync"
13 "syscall"
14 "time"
15
16 "github.com/willdot/messagebroker/internal/server"
17)
18
19type connOpp func(conn net.Conn) error
20
21// Subscriber allows subscriptions to a server and the consumption of messages
22type Subscriber struct {
23 conn net.Conn
24 connMu sync.Mutex
25 subscribedTopics []string
26 addr string
27}
28
29// NewSubscriber will connect to the server at the given address
30func NewSubscriber(addr string) (*Subscriber, error) {
31 conn, err := net.Dial("tcp", addr)
32 if err != nil {
33 return nil, fmt.Errorf("failed to dial: %w", err)
34 }
35
36 return &Subscriber{
37 conn: conn,
38 addr: addr,
39 }, nil
40}
41
42func (s *Subscriber) reconnect() error {
43 conn, err := net.Dial("tcp", s.addr)
44 if err != nil {
45 return fmt.Errorf("failed to dial: %w", err)
46 }
47
48 s.conn = conn
49 return nil
50}
51
52// Close cleanly shuts down the subscriber
53func (s *Subscriber) Close() error {
54 return s.conn.Close()
55}
56
57// SubscribeToTopics will subscribe to the provided topics
58func (s *Subscriber) SubscribeToTopics(topicNames []string, startAtType server.StartAtType, startAtIndex int) error {
59 op := func(conn net.Conn) error {
60 return subscribeToTopics(conn, topicNames, startAtType, startAtIndex)
61 }
62
63 err := s.connOperation(op)
64 if err != nil {
65 return fmt.Errorf("failed to subscribe to topics: %w", err)
66 }
67
68 s.addToSubscribedTopics(topicNames)
69
70 return nil
71}
72
73func (s *Subscriber) addToSubscribedTopics(topics []string) {
74 existingSubs := make(map[string]struct{})
75 for _, topic := range s.subscribedTopics {
76 existingSubs[topic] = struct{}{}
77 }
78
79 for _, topic := range topics {
80 existingSubs[topic] = struct{}{}
81 }
82
83 subs := make([]string, 0, len(existingSubs))
84 for topic := range existingSubs {
85 subs = append(subs, topic)
86 }
87
88 s.subscribedTopics = subs
89}
90
91func (s *Subscriber) removeTopicsFromSubscription(topics []string) {
92 existingSubs := make(map[string]struct{})
93 for _, topic := range s.subscribedTopics {
94 existingSubs[topic] = struct{}{}
95 }
96
97 for _, topic := range topics {
98 delete(existingSubs, topic)
99 }
100
101 subs := make([]string, 0, len(existingSubs))
102 for topic := range existingSubs {
103 subs = append(subs, topic)
104 }
105
106 s.subscribedTopics = subs
107}
108
109// UnsubscribeToTopics will unsubscribe to the provided topics
110func (s *Subscriber) UnsubscribeToTopics(topicNames []string) error {
111 op := func(conn net.Conn) error {
112 return unsubscribeToTopics(conn, topicNames)
113 }
114
115 err := s.connOperation(op)
116 if err != nil {
117 return fmt.Errorf("failed to unsubscribe to topics: %w", err)
118 }
119
120 s.removeTopicsFromSubscription(topicNames)
121
122 return nil
123}
124
125func subscribeToTopics(conn net.Conn, topicNames []string, startAtType server.StartAtType, startAtIndex int) error {
126 actionB := make([]byte, 2)
127 binary.BigEndian.PutUint16(actionB, uint16(server.Subscribe))
128 headers := actionB
129
130 b, err := json.Marshal(topicNames)
131 if err != nil {
132 return fmt.Errorf("failed to marshal topic names: %w", err)
133 }
134
135 topicNamesB := make([]byte, 4)
136 binary.BigEndian.PutUint32(topicNamesB, uint32(len(b)))
137 headers = append(headers, topicNamesB...)
138 headers = append(headers, b...)
139
140 startAtTypeB := make([]byte, 2)
141 binary.BigEndian.PutUint16(startAtTypeB, uint16(startAtType))
142 headers = append(headers, startAtTypeB...)
143
144 if startAtType == server.From {
145 fromB := make([]byte, 2)
146 binary.BigEndian.PutUint16(fromB, uint16(startAtIndex))
147 headers = append(headers, fromB...)
148 }
149
150 _, err = conn.Write(headers)
151 if err != nil {
152 return fmt.Errorf("failed to subscribe to topics: %w", err)
153 }
154
155 var resp server.Status
156 err = binary.Read(conn, binary.BigEndian, &resp)
157 if err != nil {
158 return fmt.Errorf("failed to read confirmation of subscribe: %w", err)
159 }
160
161 if resp == server.Subscribed {
162 return nil
163 }
164
165 var dataLen uint16
166 err = binary.Read(conn, binary.BigEndian, &dataLen)
167 if err != nil {
168 return fmt.Errorf("received status %s:", resp)
169 }
170
171 buf := make([]byte, dataLen)
172 _, err = conn.Read(buf)
173 if err != nil {
174 return fmt.Errorf("received status %s:", resp)
175 }
176
177 return fmt.Errorf("received status %s - %s", resp, buf)
178}
179
180func unsubscribeToTopics(conn net.Conn, topicNames []string) error {
181 actionB := make([]byte, 2)
182 binary.BigEndian.PutUint16(actionB, uint16(server.Unsubscribe))
183 headers := actionB
184
185 b, err := json.Marshal(topicNames)
186 if err != nil {
187 return fmt.Errorf("failed to marshal topic names: %w", err)
188 }
189
190 topicNamesB := make([]byte, 4)
191 binary.BigEndian.PutUint32(topicNamesB, uint32(len(b)))
192 headers = append(headers, topicNamesB...)
193
194 _, err = conn.Write(append(headers, b...))
195 if err != nil {
196 return fmt.Errorf("failed to unsubscribe to topics: %w", err)
197 }
198
199 var resp server.Status
200 err = binary.Read(conn, binary.BigEndian, &resp)
201 if err != nil {
202 return fmt.Errorf("failed to read confirmation of unsubscribe: %w", err)
203 }
204
205 if resp == server.Unsubscribed {
206 return nil
207 }
208
209 var dataLen uint16
210 err = binary.Read(conn, binary.BigEndian, &dataLen)
211 if err != nil {
212 return fmt.Errorf("received status %s:", resp)
213 }
214
215 buf := make([]byte, dataLen)
216 _, err = conn.Read(buf)
217 if err != nil {
218 return fmt.Errorf("received status %s:", resp)
219 }
220
221 return fmt.Errorf("received status %s - %s", resp, buf)
222}
223
224// Consumer allows the consumption of messages. If during the consumer receiving messages from the
225// server an error occurs, it will be stored in Err
226type Consumer struct {
227 msgs chan *Message
228 // TODO: better error handling? Maybe a channel of errors?
229 Err error
230}
231
232// Messages returns a channel in which this consumer will put messages onto. It is safe to range over the channel since it will be closed once
233// the consumer has finished either due to an error or from being cancelled.
234func (c *Consumer) Messages() <-chan *Message {
235 return c.msgs
236}
237
238// Consume will create a consumer and start it running in a go routine. You can then use the Msgs channel of the consumer
239// to read the messages
240func (s *Subscriber) Consume(ctx context.Context) *Consumer {
241 consumer := &Consumer{
242 msgs: make(chan *Message),
243 }
244
245 go s.consume(ctx, consumer)
246
247 return consumer
248}
249
250func (s *Subscriber) consume(ctx context.Context, consumer *Consumer) {
251 defer close(consumer.msgs)
252 for {
253 if ctx.Err() != nil {
254 return
255 }
256
257 err := s.readMessage(ctx, consumer.msgs)
258 if err == nil {
259 continue
260 }
261
262 // if we couldn't connect to the server, attempt to reconnect
263 if !errors.Is(err, syscall.EPIPE) && !errors.Is(err, io.EOF) {
264 slog.Error("failed to read message", "error", err)
265 consumer.Err = err
266 return
267 }
268
269 slog.Info("attempting to reconnect")
270
271 for i := 0; i < 5; i++ {
272 time.Sleep(time.Millisecond * 500)
273 err = s.reconnect()
274 if err == nil {
275 break
276 }
277
278 slog.Error("Failed to reconnect", "error", err, "attempt", i)
279 }
280
281 slog.Info("attempting to resubscribe")
282
283 err = s.SubscribeToTopics(s.subscribedTopics, server.Current, 0)
284 if err != nil {
285 consumer.Err = fmt.Errorf("failed to subscribe to topics after reconnecting: %w", err)
286 return
287 }
288
289 }
290}
291
292func (s *Subscriber) readMessage(ctx context.Context, msgChan chan *Message) error {
293 op := func(conn net.Conn) error {
294 err := s.conn.SetReadDeadline(time.Now().Add(time.Millisecond * 300))
295 if err != nil {
296 return err
297 }
298
299 var topicLen uint16
300 err = binary.Read(s.conn, binary.BigEndian, &topicLen)
301 if err != nil {
302 // TODO: check if this is needed elsewhere. I'm not sure where the read deadline resets....
303 if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
304 return nil
305 }
306 return err
307 }
308
309 topicBuf := make([]byte, topicLen)
310 _, err = s.conn.Read(topicBuf)
311 if err != nil {
312 return err
313 }
314
315 var dataLen uint64
316 err = binary.Read(s.conn, binary.BigEndian, &dataLen)
317 if err != nil {
318 return err
319 }
320
321 if dataLen <= 0 {
322 return nil
323 }
324
325 dataBuf := make([]byte, dataLen)
326 _, err = s.conn.Read(dataBuf)
327 if err != nil {
328 return err
329 }
330
331 msg := NewMessage(string(topicBuf), dataBuf)
332
333 msgChan <- msg
334
335 var ack bool
336 select {
337 case <-ctx.Done():
338 return ctx.Err()
339 case ack = <-msg.ack:
340 }
341 ackMessage := server.Nack
342 if ack {
343 ackMessage = server.Ack
344 }
345
346 err = binary.Write(s.conn, binary.BigEndian, ackMessage)
347 if err != nil {
348 return fmt.Errorf("failed to ack/nack message: %w", err)
349 }
350
351 return nil
352 }
353
354 err := s.connOperation(op)
355 if err != nil {
356 var neterr net.Error
357 if errors.As(err, &neterr) && neterr.Timeout() {
358 return nil
359 }
360 return err
361 }
362
363 return err
364}
365
366func (s *Subscriber) connOperation(op connOpp) error {
367 s.connMu.Lock()
368 defer s.connMu.Unlock()
369
370 return op(s.conn)
371}