+13
-8
internal/server/server.go
+13
-8
internal/server/server.go
···
376
376
t = newTopic(topicName)
377
377
}
378
378
379
+
t.mu.Lock()
379
380
t.subscriptions[peer.Addr()] = newSubscriber(peer, t, s.ackDelay, s.ackTimeout, startAt)
381
+
t.mu.Unlock()
380
382
381
383
s.topics[topicName] = t
382
384
}
···
396
398
if !ok {
397
399
return
398
400
}
399
-
sub, ok := t.subscriptions[peer.Addr()]
400
-
if !ok {
401
+
402
+
sub := t.findSubscription(peer.Addr())
403
+
if sub == nil {
401
404
return
402
405
}
406
+
403
407
sub.unsubscribe()
404
-
delete(t.subscriptions, peer.Addr())
408
+
t.removeSubscription(peer.Addr())
405
409
}
406
410
407
411
func (s *Server) unsubscribePeerFromAllTopics(peer *Peer) {
408
412
s.mu.Lock()
409
413
defer s.mu.Unlock()
410
414
411
-
for _, topic := range s.topics {
412
-
sub, ok := topic.subscriptions[peer.Addr()]
413
-
if !ok {
414
-
continue
415
+
for _, t := range s.topics {
416
+
sub := t.findSubscription(peer.Addr())
417
+
if sub == nil {
418
+
return
415
419
}
420
+
416
421
sub.unsubscribe()
417
-
delete(topic.subscriptions, peer.Addr())
422
+
t.removeSubscription(peer.Addr())
418
423
}
419
424
}
420
425
+14
internal/server/server_test.go
+14
internal/server/server_test.go
···
128
128
129
129
_ = createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0)
130
130
131
+
srv.mu.Lock()
132
+
defer srv.mu.Unlock()
131
133
assert.Len(t, srv.topics, 2)
132
134
assert.Len(t, srv.topics[topicA].subscriptions, 1)
133
135
assert.Len(t, srv.topics[topicB].subscriptions, 1)
···
139
141
conn := createConnectionAndSubscribe(t, []string{topicA, topicB, topicC}, Current, 0)
140
142
141
143
assert.Len(t, srv.topics, 3)
144
+
145
+
srv.mu.Lock()
142
146
assert.Len(t, srv.topics[topicA].subscriptions, 1)
143
147
assert.Len(t, srv.topics[topicB].subscriptions, 1)
144
148
assert.Len(t, srv.topics[topicC].subscriptions, 1)
149
+
srv.mu.Unlock()
145
150
146
151
topics := []string{topicA, topicB}
147
152
···
156
161
assert.Equal(t, expectedRes, resp)
157
162
158
163
assert.Len(t, srv.topics, 3)
164
+
165
+
srv.mu.Lock()
159
166
assert.Len(t, srv.topics[topicA].subscriptions, 0)
160
167
assert.Len(t, srv.topics[topicB].subscriptions, 0)
161
168
assert.Len(t, srv.topics[topicC].subscriptions, 1)
169
+
srv.mu.Unlock()
162
170
}
163
171
164
172
func TestSubscriberClosesWithoutUnsubscribing(t *testing.T) {
···
167
175
conn := createConnectionAndSubscribe(t, []string{topicA, topicB}, Current, 0)
168
176
169
177
assert.Len(t, srv.topics, 2)
178
+
179
+
srv.mu.Lock()
170
180
assert.Len(t, srv.topics[topicA].subscriptions, 1)
171
181
assert.Len(t, srv.topics[topicB].subscriptions, 1)
182
+
srv.mu.Unlock()
172
183
173
184
// close the conn
174
185
err := conn.Close()
···
189
200
time.Sleep(time.Millisecond * 100)
190
201
191
202
assert.Len(t, srv.topics, 2)
203
+
204
+
srv.mu.Lock()
192
205
assert.Len(t, srv.topics[topicA].subscriptions, 0)
193
206
assert.Len(t, srv.topics[topicB].subscriptions, 0)
207
+
srv.mu.Unlock()
194
208
}
195
209
196
210
func TestInvalidAction(t *testing.T) {
+14
internal/server/topic.go
+14
internal/server/topic.go
···
46
46
47
47
return nil
48
48
}
49
+
50
+
func (t *topic) findSubscription(addr net.Addr) *subscriber {
51
+
t.mu.Lock()
52
+
defer t.mu.Unlock()
53
+
54
+
return t.subscriptions[addr]
55
+
}
56
+
57
+
func (t *topic) removeSubscription(addr net.Addr) {
58
+
t.mu.Lock()
59
+
defer t.mu.Unlock()
60
+
61
+
delete(t.subscriptions, addr)
62
+
}