+168
set.go
+168
set.go
···
···
1
+
package set
2
+
3
+
import (
4
+
"iter"
5
+
"maps"
6
+
)
7
+
8
+
type Set[T comparable] struct {
9
+
data map[T]struct{}
10
+
}
11
+
12
+
func New[T comparable]() Set[T] {
13
+
return Set[T]{
14
+
data: make(map[T]struct{}),
15
+
}
16
+
}
17
+
18
+
func (s *Set[T]) Insert(item T) bool {
19
+
_, exists := s.data[item]
20
+
s.data[item] = struct{}{}
21
+
return !exists
22
+
}
23
+
24
+
func (s *Set[T]) Remove(item T) bool {
25
+
_, exists := s.data[item]
26
+
if exists {
27
+
delete(s.data, item)
28
+
}
29
+
return exists
30
+
}
31
+
32
+
func (s Set[T]) Contains(item T) bool {
33
+
_, exists := s.data[item]
34
+
return exists
35
+
}
36
+
37
+
func (s Set[T]) Len() int {
38
+
return len(s.data)
39
+
}
40
+
41
+
func (s Set[T]) IsEmpty() bool {
42
+
return len(s.data) == 0
43
+
}
44
+
45
+
func (s *Set[T]) Clear() {
46
+
s.data = make(map[T]struct{})
47
+
}
48
+
49
+
func (s Set[T]) All() iter.Seq[T] {
50
+
return func(yield func(T) bool) {
51
+
for item := range s.data {
52
+
if !yield(item) {
53
+
return
54
+
}
55
+
}
56
+
}
57
+
}
58
+
59
+
func (s Set[T]) Clone() Set[T] {
60
+
return Set[T]{
61
+
data: maps.Clone(s.data),
62
+
}
63
+
}
64
+
65
+
func (s Set[T]) Union(other Set[T]) iter.Seq[T] {
66
+
if s.Len() >= other.Len() {
67
+
return chain(s.All(), other.Difference(s))
68
+
} else {
69
+
return chain(other.All(), s.Difference(other))
70
+
}
71
+
}
72
+
73
+
func chain[T any](seqs ...iter.Seq[T]) iter.Seq[T] {
74
+
return func(yield func(T) bool) {
75
+
for _, seq := range seqs {
76
+
for item := range seq {
77
+
if !yield(item) {
78
+
return
79
+
}
80
+
}
81
+
}
82
+
}
83
+
}
84
+
85
+
func (s Set[T]) Intersection(other Set[T]) iter.Seq[T] {
86
+
return func(yield func(T) bool) {
87
+
for item := range s.data {
88
+
if other.Contains(item) {
89
+
if !yield(item) {
90
+
return
91
+
}
92
+
}
93
+
}
94
+
}
95
+
}
96
+
97
+
func (s Set[T]) Difference(other Set[T]) iter.Seq[T] {
98
+
return func(yield func(T) bool) {
99
+
for item := range s.data {
100
+
if !other.Contains(item) {
101
+
if !yield(item) {
102
+
return
103
+
}
104
+
}
105
+
}
106
+
}
107
+
}
108
+
109
+
func (s Set[T]) SymmetricDifference(other Set[T]) iter.Seq[T] {
110
+
return func(yield func(T) bool) {
111
+
for item := range s.data {
112
+
if !other.Contains(item) {
113
+
if !yield(item) {
114
+
return
115
+
}
116
+
}
117
+
}
118
+
for item := range other.data {
119
+
if !s.Contains(item) {
120
+
if !yield(item) {
121
+
return
122
+
}
123
+
}
124
+
}
125
+
}
126
+
}
127
+
128
+
func (s Set[T]) IsSubset(other Set[T]) bool {
129
+
for item := range s.data {
130
+
if !other.Contains(item) {
131
+
return false
132
+
}
133
+
}
134
+
return true
135
+
}
136
+
137
+
func (s Set[T]) IsSuperset(other Set[T]) bool {
138
+
return other.IsSubset(s)
139
+
}
140
+
141
+
func (s Set[T]) IsDisjoint(other Set[T]) bool {
142
+
for item := range s.data {
143
+
if other.Contains(item) {
144
+
return false
145
+
}
146
+
}
147
+
return true
148
+
}
149
+
150
+
func (s Set[T]) Equal(other Set[T]) bool {
151
+
if s.Len() != other.Len() {
152
+
return false
153
+
}
154
+
for item := range s.data {
155
+
if !other.Contains(item) {
156
+
return false
157
+
}
158
+
}
159
+
return true
160
+
}
161
+
162
+
func Collect[T comparable](seq iter.Seq[T]) Set[T] {
163
+
result := New[T]()
164
+
for item := range seq {
165
+
result.Insert(item)
166
+
}
167
+
return result
168
+
}
+240
set_test.go
+240
set_test.go
···
···
1
+
package set
2
+
3
+
import (
4
+
"slices"
5
+
"testing"
6
+
)
7
+
8
+
func TestNew(t *testing.T) {
9
+
s := New[int]()
10
+
if s.Len() != 0 {
11
+
t.Errorf("New set should be empty, got length %d", s.Len())
12
+
}
13
+
if !s.IsEmpty() {
14
+
t.Error("New set should be empty")
15
+
}
16
+
}
17
+
18
+
func TestFromSlice(t *testing.T) {
19
+
s := Collect(slices.Values([]int{1, 2, 3, 2, 1}))
20
+
if s.Len() != 3 {
21
+
t.Errorf("Expected length 3, got %d", s.Len())
22
+
}
23
+
if !s.Contains(1) || !s.Contains(2) || !s.Contains(3) {
24
+
t.Error("Set should contain all unique elements from slice")
25
+
}
26
+
}
27
+
28
+
func TestInsert(t *testing.T) {
29
+
s := New[string]()
30
+
31
+
if !s.Insert("hello") {
32
+
t.Error("First insert should return true")
33
+
}
34
+
if s.Insert("hello") {
35
+
t.Error("Duplicate insert should return false")
36
+
}
37
+
if s.Len() != 1 {
38
+
t.Errorf("Expected length 1, got %d", s.Len())
39
+
}
40
+
}
41
+
42
+
func TestRemove(t *testing.T) {
43
+
s := Collect(slices.Values([]int{1, 2, 3}))
44
+
45
+
if !s.Remove(2) {
46
+
t.Error("Remove existing element should return true")
47
+
}
48
+
if s.Remove(2) {
49
+
t.Error("Remove non-existing element should return false")
50
+
}
51
+
if s.Contains(2) {
52
+
t.Error("Element should be removed")
53
+
}
54
+
if s.Len() != 2 {
55
+
t.Errorf("Expected length 2, got %d", s.Len())
56
+
}
57
+
}
58
+
59
+
func TestContains(t *testing.T) {
60
+
s := Collect(slices.Values([]int{1, 2, 3}))
61
+
62
+
if !s.Contains(1) {
63
+
t.Error("Should contain 1")
64
+
}
65
+
if s.Contains(4) {
66
+
t.Error("Should not contain 4")
67
+
}
68
+
}
69
+
70
+
func TestClear(t *testing.T) {
71
+
s := Collect(slices.Values([]int{1, 2, 3}))
72
+
s.Clear()
73
+
74
+
if !s.IsEmpty() {
75
+
t.Error("Set should be empty after clear")
76
+
}
77
+
if s.Len() != 0 {
78
+
t.Errorf("Expected length 0, got %d", s.Len())
79
+
}
80
+
}
81
+
82
+
func TestIterator(t *testing.T) {
83
+
s := Collect(slices.Values([]int{1, 2, 3}))
84
+
var items []int
85
+
86
+
for item := range s.All() {
87
+
items = append(items, item)
88
+
}
89
+
90
+
slices.Sort(items)
91
+
expected := []int{1, 2, 3}
92
+
if !slices.Equal(items, expected) {
93
+
t.Errorf("Expected %v, got %v", expected, items)
94
+
}
95
+
}
96
+
97
+
func TestClone(t *testing.T) {
98
+
s1 := Collect(slices.Values([]int{1, 2, 3}))
99
+
s2 := s1.Clone()
100
+
101
+
if !s1.Equal(s2) {
102
+
t.Error("Cloned set should be equal to original")
103
+
}
104
+
105
+
s2.Insert(4)
106
+
if s1.Contains(4) {
107
+
t.Error("Modifying clone should not affect original")
108
+
}
109
+
}
110
+
111
+
func TestUnion(t *testing.T) {
112
+
s1 := Collect(slices.Values([]int{1, 2}))
113
+
s2 := Collect(slices.Values([]int{2, 3}))
114
+
115
+
result := Collect(s1.Union(s2))
116
+
expected := Collect(slices.Values([]int{1, 2, 3}))
117
+
118
+
if !result.Equal(expected) {
119
+
t.Errorf("Expected %v, got %v", expected, result)
120
+
}
121
+
}
122
+
123
+
func TestIntersection(t *testing.T) {
124
+
s1 := Collect(slices.Values([]int{1, 2, 3}))
125
+
s2 := Collect(slices.Values([]int{2, 3, 4}))
126
+
127
+
expected := Collect(slices.Values([]int{2, 3}))
128
+
result := Collect(s1.Intersection(s2))
129
+
130
+
if !result.Equal(expected) {
131
+
t.Errorf("Expected %v, got %v", expected, result)
132
+
}
133
+
}
134
+
135
+
func TestDifference(t *testing.T) {
136
+
s1 := Collect(slices.Values([]int{1, 2, 3}))
137
+
s2 := Collect(slices.Values([]int{2, 3, 4}))
138
+
139
+
expected := Collect(slices.Values([]int{1}))
140
+
result := Collect(s1.Difference(s2))
141
+
142
+
if !result.Equal(expected) {
143
+
t.Errorf("Expected %v, got %v", expected, result)
144
+
}
145
+
}
146
+
147
+
func TestSymmetricDifference(t *testing.T) {
148
+
s1 := Collect(slices.Values([]int{1, 2, 3}))
149
+
s2 := Collect(slices.Values([]int{2, 3, 4}))
150
+
151
+
expected := Collect(slices.Values([]int{1, 4}))
152
+
result := Collect(s1.SymmetricDifference(s2))
153
+
154
+
if !result.Equal(expected) {
155
+
t.Errorf("Expected %v, got %v", expected, result)
156
+
}
157
+
}
158
+
159
+
func TestSymmetricDifferenceCommutativeProperty(t *testing.T) {
160
+
s1 := Collect(slices.Values([]int{1, 2, 3}))
161
+
s2 := Collect(slices.Values([]int{2, 3, 4}))
162
+
163
+
result1 := Collect(s1.SymmetricDifference(s2))
164
+
result2 := Collect(s2.SymmetricDifference(s1))
165
+
166
+
if !result1.Equal(result2) {
167
+
t.Errorf("Expected %v, got %v", result1, result2)
168
+
}
169
+
}
170
+
171
+
func TestIsSubset(t *testing.T) {
172
+
s1 := Collect(slices.Values([]int{1, 2}))
173
+
s2 := Collect(slices.Values([]int{1, 2, 3}))
174
+
175
+
if !s1.IsSubset(s2) {
176
+
t.Error("s1 should be subset of s2")
177
+
}
178
+
if s2.IsSubset(s1) {
179
+
t.Error("s2 should not be subset of s1")
180
+
}
181
+
}
182
+
183
+
func TestIsSuperset(t *testing.T) {
184
+
s1 := Collect(slices.Values([]int{1, 2, 3}))
185
+
s2 := Collect(slices.Values([]int{1, 2}))
186
+
187
+
if !s1.IsSuperset(s2) {
188
+
t.Error("s1 should be superset of s2")
189
+
}
190
+
if s2.IsSuperset(s1) {
191
+
t.Error("s2 should not be superset of s1")
192
+
}
193
+
}
194
+
195
+
func TestIsDisjoint(t *testing.T) {
196
+
s1 := Collect(slices.Values([]int{1, 2}))
197
+
s2 := Collect(slices.Values([]int{3, 4}))
198
+
s3 := Collect(slices.Values([]int{2, 3}))
199
+
200
+
if !s1.IsDisjoint(s2) {
201
+
t.Error("s1 and s2 should be disjoint")
202
+
}
203
+
if s1.IsDisjoint(s3) {
204
+
t.Error("s1 and s3 should not be disjoint")
205
+
}
206
+
}
207
+
208
+
func TestEqual(t *testing.T) {
209
+
s1 := Collect(slices.Values([]int{1, 2, 3}))
210
+
s2 := Collect(slices.Values([]int{3, 2, 1}))
211
+
s3 := Collect(slices.Values([]int{1, 2}))
212
+
213
+
if !s1.Equal(s2) {
214
+
t.Error("s1 and s2 should be equal")
215
+
}
216
+
if s1.Equal(s3) {
217
+
t.Error("s1 and s3 should not be equal")
218
+
}
219
+
}
220
+
221
+
func TestCollect(t *testing.T) {
222
+
s1 := Collect(slices.Values([]int{1, 2}))
223
+
s2 := Collect(slices.Values([]int{2, 3}))
224
+
225
+
unionSet := Collect(s1.Union(s2))
226
+
if unionSet.Len() != 3 {
227
+
t.Errorf("Expected union set length 3, got %d", unionSet.Len())
228
+
}
229
+
if !unionSet.Contains(1) || !unionSet.Contains(2) || !unionSet.Contains(3) {
230
+
t.Error("Union set should contain 1, 2, and 3")
231
+
}
232
+
233
+
diffSet := Collect(s1.Difference(s2))
234
+
if diffSet.Len() != 1 {
235
+
t.Errorf("Expected difference set length 1, got %d", diffSet.Len())
236
+
}
237
+
if !diffSet.Contains(1) {
238
+
t.Error("Difference set should contain 1")
239
+
}
240
+
}