+13
-8
internal/server/server.go
+13
-8
internal/server/server.go
···
376
t = newTopic(topicName)
377
}
378
379
t.subscriptions[peer.Addr()] = newSubscriber(peer, t, s.ackDelay, s.ackTimeout, startAt)
380
381
s.topics[topicName] = t
382
}
···
396
if !ok {
397
return
398
}
399
-
sub, ok := t.subscriptions[peer.Addr()]
400
-
if !ok {
401
return
402
}
403
sub.unsubscribe()
404
-
delete(t.subscriptions, peer.Addr())
405
}
406
407
func (s *Server) unsubscribePeerFromAllTopics(peer *Peer) {
408
s.mu.Lock()
409
defer s.mu.Unlock()
410
411
-
for _, topic := range s.topics {
412
-
sub, ok := topic.subscriptions[peer.Addr()]
413
-
if !ok {
414
-
continue
415
}
416
sub.unsubscribe()
417
-
delete(topic.subscriptions, peer.Addr())
418
}
419
}
420
···
376
t = newTopic(topicName)
377
}
378
379
+
t.mu.Lock()
380
t.subscriptions[peer.Addr()] = newSubscriber(peer, t, s.ackDelay, s.ackTimeout, startAt)
381
+
t.mu.Unlock()
382
383
s.topics[topicName] = t
384
}
···
398
if !ok {
399
return
400
}
401
+
402
+
sub := t.findSubscription(peer.Addr())
403
+
if sub == nil {
404
return
405
}
406
+
407
sub.unsubscribe()
408
+
t.removeSubscription(peer.Addr())
409
}
410
411
func (s *Server) unsubscribePeerFromAllTopics(peer *Peer) {
412
s.mu.Lock()
413
defer s.mu.Unlock()
414
415
+
for _, t := range s.topics {
416
+
sub := t.findSubscription(peer.Addr())
417
+
if sub == nil {
418
+
return
419
}
420
+
421
sub.unsubscribe()
422
+
t.removeSubscription(peer.Addr())
423
}
424
}
425
+14
internal/server/server_test.go
+14
internal/server/server_test.go
···
128
129
_ = createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0)
130
131
assert.Len(t, srv.topics, 2)
132
assert.Len(t, srv.topics[topicA].subscriptions, 1)
133
assert.Len(t, srv.topics[topicB].subscriptions, 1)
···
139
conn := createConnectionAndSubscribe(t, []string{topicA, topicB, topicC}, Current, 0)
140
141
assert.Len(t, srv.topics, 3)
142
assert.Len(t, srv.topics[topicA].subscriptions, 1)
143
assert.Len(t, srv.topics[topicB].subscriptions, 1)
144
assert.Len(t, srv.topics[topicC].subscriptions, 1)
145
146
topics := []string{topicA, topicB}
147
···
156
assert.Equal(t, expectedRes, resp)
157
158
assert.Len(t, srv.topics, 3)
159
assert.Len(t, srv.topics[topicA].subscriptions, 0)
160
assert.Len(t, srv.topics[topicB].subscriptions, 0)
161
assert.Len(t, srv.topics[topicC].subscriptions, 1)
162
}
163
164
func TestSubscriberClosesWithoutUnsubscribing(t *testing.T) {
···
167
conn := createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0)
168
169
assert.Len(t, srv.topics, 2)
170
assert.Len(t, srv.topics[topicA].subscriptions, 1)
171
assert.Len(t, srv.topics[topicB].subscriptions, 1)
172
173
// close the conn
174
err := conn.Close()
···
189
time.Sleep(time.Millisecond * 100)
190
191
assert.Len(t, srv.topics, 2)
192
assert.Len(t, srv.topics[topicA].subscriptions, 0)
193
assert.Len(t, srv.topics[topicB].subscriptions, 0)
194
}
195
196
func TestInvalidAction(t *testing.T) {
···
128
129
_ = createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0)
130
131
+
srv.mu.Lock()
132
+
defer srv.mu.Unlock()
133
assert.Len(t, srv.topics, 2)
134
assert.Len(t, srv.topics[topicA].subscriptions, 1)
135
assert.Len(t, srv.topics[topicB].subscriptions, 1)
···
141
conn := createConnectionAndSubscribe(t, []string{topicA, topicB, topicC}, Current, 0)
142
143
assert.Len(t, srv.topics, 3)
144
+
145
+
srv.mu.Lock()
146
assert.Len(t, srv.topics[topicA].subscriptions, 1)
147
assert.Len(t, srv.topics[topicB].subscriptions, 1)
148
assert.Len(t, srv.topics[topicC].subscriptions, 1)
149
+
srv.mu.Unlock()
150
151
topics := []string{topicA, topicB}
152
···
161
assert.Equal(t, expectedRes, resp)
162
163
assert.Len(t, srv.topics, 3)
164
+
165
+
srv.mu.Lock()
166
assert.Len(t, srv.topics[topicA].subscriptions, 0)
167
assert.Len(t, srv.topics[topicB].subscriptions, 0)
168
assert.Len(t, srv.topics[topicC].subscriptions, 1)
169
+
srv.mu.Unlock()
170
}
171
172
func TestSubscriberClosesWithoutUnsubscribing(t *testing.T) {
···
175
conn := createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0)
176
177
assert.Len(t, srv.topics, 2)
178
+
179
+
srv.mu.Lock()
180
assert.Len(t, srv.topics[topicA].subscriptions, 1)
181
assert.Len(t, srv.topics[topicB].subscriptions, 1)
182
+
srv.mu.Unlock()
183
184
// close the conn
185
err := conn.Close()
···
200
time.Sleep(time.Millisecond * 100)
201
202
assert.Len(t, srv.topics, 2)
203
+
204
+
srv.mu.Lock()
205
assert.Len(t, srv.topics[topicA].subscriptions, 0)
206
assert.Len(t, srv.topics[topicB].subscriptions, 0)
207
+
srv.mu.Unlock()
208
}
209
210
func TestInvalidAction(t *testing.T) {
+14
internal/server/topic.go
+14
internal/server/topic.go
···
46
47
return nil
48
}
49
+
50
+
func (t *topic) findSubscription(addr net.Addr) *subscriber {
51
+
t.mu.Lock()
52
+
defer t.mu.Unlock()
53
+
54
+
return t.subscriptions[addr]
55
+
}
56
+
57
+
func (t *topic) removeSubscription(addr net.Addr) {
58
+
t.mu.Lock()
59
+
defer t.mu.Unlock()
60
+
61
+
delete(t.subscriptions, addr)
62
+
}