An experimental pub/sub client and server project.

initial commit with basic pub/sub

+1 -21
.gitignore
··· 1 - # If you prefer the allow list template instead of the deny list, see community template: 2 - # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 3 - # 4 - # Binaries for programs and plugins 5 - *.exe 6 - *.exe~ 7 - *.dll 8 - *.so 9 - *.dylib 10 - 11 - # Test binary, built with `go test -c` 12 - *.test 13 - 14 - # Output of the go coverage tool, specifically when used with LiteIDE 15 - *.out 16 - 17 - # Dependency directories (remove the comment below to include it) 18 - # vendor/ 19 - 20 - # Go workspace file 21 - go.work 1 + .DS_STORE
+11
go.mod
··· 1 + module github.com/willdot/message-broker 2 + 3 + go 1.21.0 4 + 5 + require github.com/stretchr/testify v1.8.4 6 + 7 + require ( 8 + github.com/davecgh/go-spew v1.1.1 // indirect 9 + github.com/pmezard/go-difflib v1.0.0 // indirect 10 + gopkg.in/yaml.v3 v3.0.1 // indirect 11 + )
+10
go.sum
··· 1 + github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 + github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 + github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 4 + github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 5 + github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 6 + github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 7 + gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 8 + gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 9 + gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 10 + gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+299
server.go
··· 1 + package main 2 + 3 + import ( 4 + "context" 5 + "encoding/binary" 6 + "encoding/json" 7 + "fmt" 8 + "log/slog" 9 + "net" 10 + "strings" 11 + "sync" 12 + ) 13 + 14 + // Action represents the type of action that a connection requests to do 15 + type Action uint8 16 + 17 + const ( 18 + Subscribe Action = 1 19 + Unsubscribe Action = 2 20 + Publish Action = 3 21 + ) 22 + 23 + type message struct { 24 + Topic string `json:"topic"` 25 + Data []byte `json:"data"` 26 + } 27 + 28 + type Server struct { 29 + addr string 30 + lis net.Listener 31 + 32 + mu sync.Mutex 33 + topics map[string]topic 34 + } 35 + 36 + func NewServer(ctx context.Context, addr string) (*Server, error) { 37 + lis, err := net.Listen("tcp", addr) 38 + if err != nil { 39 + return nil, fmt.Errorf("failed to listen: %w", err) 40 + } 41 + 42 + srv := &Server{ 43 + lis: lis, 44 + topics: map[string]topic{}, 45 + } 46 + 47 + go srv.start(ctx) 48 + 49 + return srv, nil 50 + } 51 + 52 + func (s *Server) Shutdown() error { 53 + return s.lis.Close() 54 + } 55 + 56 + func (s *Server) start(ctx context.Context) { 57 + for { 58 + conn, err := s.lis.Accept() 59 + if err != nil { 60 + slog.Error("listener failed to accept", "error", err) 61 + // TODO: see if there's a better way to check for this error 62 + if strings.Contains(err.Error(), "use of closed network connection") { 63 + return 64 + } 65 + } 66 + 67 + go s.handleConn(conn) 68 + } 69 + } 70 + 71 + func getActionFromConn(conn net.Conn) (Action, error) { 72 + var action Action 73 + err := binary.Read(conn, binary.BigEndian, &action) 74 + if err != nil { 75 + return 0, err 76 + } 77 + 78 + return action, nil 79 + } 80 + 81 + func getDataLengthFromConn(conn net.Conn) (uint32, error) { 82 + var dataLen uint32 83 + err := binary.Read(conn, binary.BigEndian, &dataLen) 84 + if err != nil { 85 + return 0, fmt.Errorf("failed to read data length from conn: %w", err) 86 + } 87 + 88 + return dataLen, nil 89 + } 90 + 91 + func (s *Server) handleConn(conn net.Conn) { 92 + action, err := getActionFromConn(conn) 93 + if err != nil { 94 + slog.Error("failed to read action from conn", "error", err, "conn", conn.LocalAddr()) 95 + return 96 + } 97 + 98 + switch action { 99 + case Subscribe: 100 + s.handleSubscribingConn(conn) 101 + case Unsubscribe: 102 + s.handleUnsubscribingConn(conn) 103 + case Publish: 104 + s.handlePublisherConn(conn) 105 + default: 106 + slog.Error("unknown action", "action", action, "conn", conn.LocalAddr()) 107 + _, _ = conn.Write([]byte("unknown action")) 108 + } 109 + } 110 + 111 + func (s *Server) handleSubscribingConn(conn net.Conn) { 112 + // subscribe the connection to the topic 113 + s.subscribeConnToTopic(conn) 114 + 115 + // keep handling the connection, getting the action from the conection when it wishes to do something else. 116 + // once the connection ends, it will be unsubscribed from all topics and returned 117 + for { 118 + action, err := getActionFromConn(conn) 119 + if err != nil { 120 + // TODO: see if there's a way to check if the connection has been ended etc 121 + slog.Error("failed to read action from subscriber", "error", err, "conn", conn.LocalAddr()) 122 + 123 + s.unsubscribeConnectionFromAllTopics(conn.LocalAddr()) 124 + 125 + return 126 + } 127 + 128 + switch action { 129 + case Subscribe: 130 + s.subscribeConnToTopic(conn) 131 + case Unsubscribe: 132 + s.handleUnsubscribingConn(conn) 133 + default: 134 + slog.Error("unknown action for subscriber", "action", action, "conn", conn.LocalAddr()) 135 + continue 136 + } 137 + } 138 + } 139 + 140 + func (s *Server) subscribeConnToTopic(conn net.Conn) { 141 + // get the topics the connection wishes to subscribe to 142 + dataLen, err := getDataLengthFromConn(conn) 143 + if err != nil { 144 + slog.Error(err.Error(), "conn", conn.LocalAddr()) 145 + _, _ = conn.Write([]byte("invalid data length of topics provided")) 146 + return 147 + } 148 + if dataLen == 0 { 149 + _, _ = conn.Write([]byte("data length of topics is 0")) 150 + return 151 + } 152 + 153 + buf := make([]byte, dataLen) 154 + _, err = conn.Read(buf) 155 + if err != nil { 156 + slog.Error("failed to read subscibers topic data", "error", err, "conn", conn.LocalAddr()) 157 + _, _ = conn.Write([]byte("failed to read topic data")) 158 + return 159 + } 160 + 161 + var topics []string 162 + err = json.Unmarshal(buf, &topics) 163 + if err != nil { 164 + slog.Error("failed to unmarshal subscibers topic data", "error", err, "conn", conn.LocalAddr()) 165 + _, _ = conn.Write([]byte("invalid topic data provided")) 166 + return 167 + } 168 + 169 + s.subscribeToTopics(conn, topics) 170 + _, _ = conn.Write([]byte("subscribed")) 171 + } 172 + 173 + func (s *Server) handleUnsubscribingConn(conn net.Conn) { 174 + // get the topics the connection wishes to unsubscribe from 175 + dataLen, err := getDataLengthFromConn(conn) 176 + if err != nil { 177 + slog.Error(err.Error(), "conn", conn.LocalAddr()) 178 + _, _ = conn.Write([]byte("invalid data length of topics provided")) 179 + return 180 + } 181 + if dataLen == 0 { 182 + _, _ = conn.Write([]byte("data length of topics is 0")) 183 + return 184 + } 185 + 186 + buf := make([]byte, dataLen) 187 + _, err = conn.Read(buf) 188 + if err != nil { 189 + slog.Error("failed to read subscibers topic data", "error", err, "conn", conn.LocalAddr()) 190 + _, _ = conn.Write([]byte("failed to read topic data")) 191 + return 192 + } 193 + 194 + var topics []string 195 + err = json.Unmarshal(buf, &topics) 196 + if err != nil { 197 + slog.Error("failed to unmarshal subscibers topic data", "error", err, "conn", conn.LocalAddr()) 198 + _, _ = conn.Write([]byte("invalid topic data provided")) 199 + return 200 + } 201 + 202 + s.unsubscribeToTopics(conn, topics) 203 + 204 + _, _ = conn.Write([]byte("unsubscribed")) 205 + } 206 + 207 + func (s *Server) handlePublisherConn(conn net.Conn) { 208 + dataLen, err := getDataLengthFromConn(conn) 209 + if err != nil { 210 + slog.Error(err.Error(), "conn", conn.LocalAddr()) 211 + _, _ = conn.Write([]byte("invalid data length of data provided")) 212 + return 213 + } 214 + if dataLen == 0 { 215 + return 216 + } 217 + 218 + buf := make([]byte, dataLen) 219 + _, err = conn.Read(buf) 220 + if err != nil { 221 + _, _ = conn.Write([]byte("failed to read data")) 222 + slog.Error("failed to read data from conn", "error", err, "conn", conn.LocalAddr()) 223 + return 224 + } 225 + 226 + var msg message 227 + err = json.Unmarshal(buf, &msg) 228 + if err != nil { 229 + _, _ = conn.Write([]byte("invalid message")) 230 + slog.Error("failed to unmarshal data to message", "error", err, "conn", conn.LocalAddr()) 231 + return 232 + } 233 + 234 + topic := s.getTopic(msg.Topic) 235 + if topic != nil { 236 + topic.sendMessageToSubscribers(msg) 237 + } 238 + } 239 + 240 + func (s *Server) subscribeToTopics(conn net.Conn, topics []string) { 241 + for _, topic := range topics { 242 + s.addSubsciberToTopic(topic, conn) 243 + } 244 + } 245 + 246 + func (s *Server) addSubsciberToTopic(topicName string, conn net.Conn) { 247 + s.mu.Lock() 248 + defer s.mu.Unlock() 249 + 250 + t, ok := s.topics[topicName] 251 + if !ok { 252 + t = newTopic(topicName) 253 + } 254 + 255 + t.subscriptions[conn.LocalAddr()] = Subscriber{ 256 + conn: conn, 257 + currentOffset: 0, 258 + } 259 + 260 + s.topics[topicName] = t 261 + } 262 + 263 + func (s *Server) unsubscribeToTopics(conn net.Conn, topics []string) { 264 + for _, topic := range topics { 265 + s.removeSubsciberFromTopic(topic, conn) 266 + } 267 + } 268 + 269 + func (s *Server) removeSubsciberFromTopic(topicName string, conn net.Conn) { 270 + s.mu.Lock() 271 + defer s.mu.Unlock() 272 + 273 + t, ok := s.topics[topicName] 274 + if !ok { 275 + return 276 + } 277 + 278 + delete(t.subscriptions, conn.LocalAddr()) 279 + } 280 + 281 + func (s *Server) unsubscribeConnectionFromAllTopics(addr net.Addr) { 282 + s.mu.Lock() 283 + defer s.mu.Unlock() 284 + 285 + for _, topic := range s.topics { 286 + delete(topic.subscriptions, addr) 287 + } 288 + } 289 + 290 + func (s *Server) getTopic(topicName string) *topic { 291 + s.mu.Lock() 292 + defer s.mu.Unlock() 293 + 294 + if topic, ok := s.topics[topicName]; ok { 295 + return &topic 296 + } 297 + 298 + return nil 299 + }
+231
server_test.go
··· 1 + package main 2 + 3 + import ( 4 + "context" 5 + "encoding/binary" 6 + "encoding/json" 7 + "net" 8 + "testing" 9 + 10 + "github.com/stretchr/testify/assert" 11 + "github.com/stretchr/testify/require" 12 + ) 13 + 14 + func createServer(t *testing.T) *Server { 15 + srv, err := NewServer(context.Background(), ":3000") 16 + require.NoError(t, err) 17 + 18 + t.Cleanup(func() { 19 + srv.Shutdown() 20 + }) 21 + 22 + return srv 23 + } 24 + 25 + func createServerWithExistingTopic(t *testing.T, topicName string) *Server { 26 + srv := createServer(t) 27 + srv.topics[topicName] = topic{ 28 + name: topicName, 29 + subscriptions: make(map[net.Addr]Subscriber), 30 + } 31 + 32 + return srv 33 + } 34 + 35 + func createConnectionAndSubscribe(t *testing.T, topics []string) net.Conn { 36 + conn, err := net.Dial("tcp", "localhost:3000") 37 + require.NoError(t, err) 38 + 39 + err = binary.Write(conn, binary.BigEndian, Subscribe) 40 + require.NoError(t, err) 41 + 42 + rawTopics, err := json.Marshal(topics) 43 + require.NoError(t, err) 44 + 45 + err = binary.Write(conn, binary.BigEndian, uint32(len(rawTopics))) 46 + require.NoError(t, err) 47 + 48 + _, err = conn.Write(rawTopics) 49 + require.NoError(t, err) 50 + 51 + expectedRes := "subscribed" 52 + 53 + buf := make([]byte, len(expectedRes)) 54 + n, err := conn.Read(buf) 55 + require.NoError(t, err) 56 + require.Equal(t, len(expectedRes), n) 57 + 58 + assert.Equal(t, expectedRes, string(buf)) 59 + 60 + return conn 61 + } 62 + 63 + func TestSubscribeToTopics(t *testing.T) { 64 + // create a server with an existing topic so we can test subscribing to a new and 65 + // existing topic 66 + srv := createServerWithExistingTopic(t, "topic a") 67 + 68 + _ = createConnectionAndSubscribe(t, []string{"topic a", "topic b"}) 69 + 70 + assert.Len(t, srv.topics, 2) 71 + assert.Len(t, srv.topics["topic a"].subscriptions, 1) 72 + assert.Len(t, srv.topics["topic b"].subscriptions, 1) 73 + } 74 + 75 + func TestUnsubscribesFromTopic(t *testing.T) { 76 + srv := createServerWithExistingTopic(t, "topic a") 77 + 78 + conn := createConnectionAndSubscribe(t, []string{"topic a", "topic b", "topic c"}) 79 + 80 + assert.Len(t, srv.topics, 3) 81 + assert.Len(t, srv.topics["topic a"].subscriptions, 1) 82 + assert.Len(t, srv.topics["topic b"].subscriptions, 1) 83 + assert.Len(t, srv.topics["topic c"].subscriptions, 1) 84 + 85 + err := binary.Write(conn, binary.BigEndian, Unsubscribe) 86 + require.NoError(t, err) 87 + 88 + topics := []string{"topic a", "topic b"} 89 + rawTopics, err := json.Marshal(topics) 90 + require.NoError(t, err) 91 + 92 + err = binary.Write(conn, binary.BigEndian, uint32(len(rawTopics))) 93 + require.NoError(t, err) 94 + 95 + _, err = conn.Write(rawTopics) 96 + require.NoError(t, err) 97 + 98 + expectedRes := "unsubscribed" 99 + 100 + buf := make([]byte, len(expectedRes)) 101 + n, err := conn.Read(buf) 102 + require.NoError(t, err) 103 + require.Equal(t, len(expectedRes), n) 104 + 105 + assert.Equal(t, expectedRes, string(buf)) 106 + 107 + assert.Len(t, srv.topics, 3) 108 + assert.Len(t, srv.topics["topic a"].subscriptions, 0) 109 + assert.Len(t, srv.topics["topic b"].subscriptions, 0) 110 + assert.Len(t, srv.topics["topic c"].subscriptions, 1) 111 + } 112 + 113 + func TestSubscriberClosesWithoutUnsubscribing(t *testing.T) { 114 + srv := createServer(t) 115 + 116 + conn := createConnectionAndSubscribe(t, []string{"topic a", "topic b"}) 117 + 118 + assert.Len(t, srv.topics, 2) 119 + assert.Len(t, srv.topics["topic a"].subscriptions, 1) 120 + assert.Len(t, srv.topics["topic b"].subscriptions, 1) 121 + 122 + // close the conn 123 + err := conn.Close() 124 + require.NoError(t, err) 125 + 126 + publisherConn, err := net.Dial("tcp", "localhost:3000") 127 + require.NoError(t, err) 128 + 129 + err = binary.Write(publisherConn, binary.BigEndian, Publish) 130 + require.NoError(t, err) 131 + 132 + data := []byte("hello world") 133 + // send data length first 134 + err = binary.Write(publisherConn, binary.BigEndian, uint32(len(data))) 135 + require.NoError(t, err) 136 + n, err := publisherConn.Write(data) 137 + require.NoError(t, err) 138 + require.Equal(t, len(data), n) 139 + 140 + assert.Len(t, srv.topics, 2) 141 + assert.Len(t, srv.topics["topic a"].subscriptions, 0) 142 + assert.Len(t, srv.topics["topic b"].subscriptions, 0) 143 + } 144 + 145 + func TestInvalidAction(t *testing.T) { 146 + _ = createServer(t) 147 + 148 + conn, err := net.Dial("tcp", "localhost:3000") 149 + require.NoError(t, err) 150 + 151 + err = binary.Write(conn, binary.BigEndian, uint8(99)) 152 + require.NoError(t, err) 153 + 154 + expectedRes := "unknown action" 155 + 156 + buf := make([]byte, len(expectedRes)) 157 + n, err := conn.Read(buf) 158 + require.NoError(t, err) 159 + require.Equal(t, len(expectedRes), n) 160 + 161 + assert.Equal(t, expectedRes, string(buf)) 162 + } 163 + 164 + func TestInvalidMessagePublished(t *testing.T) { 165 + _ = createServer(t) 166 + 167 + publisherConn, err := net.Dial("tcp", "localhost:3000") 168 + require.NoError(t, err) 169 + 170 + err = binary.Write(publisherConn, binary.BigEndian, Publish) 171 + require.NoError(t, err) 172 + 173 + // send some data 174 + data := []byte("this isn't wrapped in a message type") 175 + 176 + // send data length first 177 + err = binary.Write(publisherConn, binary.BigEndian, uint32(len(data))) 178 + require.NoError(t, err) 179 + n, err := publisherConn.Write(data) 180 + require.NoError(t, err) 181 + require.Equal(t, len(data), n) 182 + 183 + buf := make([]byte, 15) 184 + _, err = publisherConn.Read(buf) 185 + require.NoError(t, err) 186 + assert.Equal(t, "invalid message", string(buf)) 187 + } 188 + 189 + func TestSendsDataToTopicSubscribers(t *testing.T) { 190 + _ = createServer(t) 191 + 192 + subscribers := make([]net.Conn, 0, 5) 193 + for i := 0; i < 5; i++ { 194 + subscriberConn := createConnectionAndSubscribe(t, []string{"topic a", "topic b"}) 195 + 196 + subscribers = append(subscribers, subscriberConn) 197 + } 198 + 199 + publisherConn, err := net.Dial("tcp", "localhost:3000") 200 + require.NoError(t, err) 201 + 202 + err = binary.Write(publisherConn, binary.BigEndian, Publish) 203 + require.NoError(t, err) 204 + 205 + // send some data 206 + data := []byte("hello world") 207 + msg := message{ 208 + Topic: "topic a", 209 + Data: data, 210 + } 211 + 212 + rawMsg, err := json.Marshal(msg) 213 + require.NoError(t, err) 214 + 215 + // send data length first 216 + err = binary.Write(publisherConn, binary.BigEndian, uint32(len(rawMsg))) 217 + require.NoError(t, err) 218 + n, err := publisherConn.Write(rawMsg) 219 + require.NoError(t, err) 220 + require.Equal(t, len(rawMsg), n) 221 + 222 + // check the subsribers got the data 223 + for _, conn := range subscribers { 224 + buf := make([]byte, len(data)) 225 + n, err := conn.Read(buf) 226 + require.NoError(t, err) 227 + require.Equal(t, len(data), n) 228 + 229 + assert.Equal(t, data, buf) 230 + } 231 + }
+19
subscriber.go
··· 1 + package main 2 + 3 + import ( 4 + "fmt" 5 + "net" 6 + ) 7 + 8 + type Subscriber struct { 9 + conn net.Conn 10 + currentOffset int 11 + } 12 + 13 + func (s *Subscriber) SendMessage(data []byte) error { 14 + _, err := s.conn.Write(data) 15 + if err != nil { 16 + return fmt.Errorf("failed to write to connection: %w", err) 17 + } 18 + return nil 19 + }
+50
topic.go
··· 1 + package main 2 + 3 + import ( 4 + "log/slog" 5 + "net" 6 + "sync" 7 + ) 8 + 9 + type topic struct { 10 + name string 11 + subscriptions map[net.Addr]Subscriber 12 + mu sync.Mutex 13 + } 14 + 15 + func newTopic(name string) topic { 16 + return topic{ 17 + name: name, 18 + subscriptions: make(map[net.Addr]Subscriber), 19 + } 20 + } 21 + 22 + func (t *topic) addSubscriber(conn net.Conn) { 23 + t.mu.Lock() 24 + defer t.mu.Unlock() 25 + 26 + slog.Info("adding subscriber", "conn", conn.LocalAddr()) 27 + t.subscriptions[conn.LocalAddr()] = Subscriber{conn: conn} 28 + } 29 + 30 + func (t *topic) removeSubscriber(addr net.Addr) { 31 + t.mu.Lock() 32 + defer t.mu.Unlock() 33 + 34 + slog.Info("removing subscriber", "conn", addr) 35 + delete(t.subscriptions, addr) 36 + } 37 + 38 + func (t *topic) sendMessageToSubscribers(msg message) { 39 + t.mu.Lock() 40 + subscribers := t.subscriptions 41 + t.mu.Unlock() 42 + 43 + for addr, subscriber := range subscribers { 44 + err := subscriber.SendMessage(msg.Data) 45 + if err != nil { 46 + slog.Error("failed to send to message", "error", err, "conn", addr) 47 + continue 48 + } 49 + } 50 + }