a rusty set datastructure for go

set: init

Signed-off-by: oppiliappan <me@oppi.li>

oppi.li 4710e037

Changed files
+411
+3
go.mod
···
··· 1 + module tangled.org/oppi.li/set 2 + 3 + go 1.25.0
+168
set.go
···
··· 1 + package set 2 + 3 + import ( 4 + "iter" 5 + "maps" 6 + ) 7 + 8 + type Set[T comparable] struct { 9 + data map[T]struct{} 10 + } 11 + 12 + func New[T comparable]() Set[T] { 13 + return Set[T]{ 14 + data: make(map[T]struct{}), 15 + } 16 + } 17 + 18 + func (s *Set[T]) Insert(item T) bool { 19 + _, exists := s.data[item] 20 + s.data[item] = struct{}{} 21 + return !exists 22 + } 23 + 24 + func (s *Set[T]) Remove(item T) bool { 25 + _, exists := s.data[item] 26 + if exists { 27 + delete(s.data, item) 28 + } 29 + return exists 30 + } 31 + 32 + func (s Set[T]) Contains(item T) bool { 33 + _, exists := s.data[item] 34 + return exists 35 + } 36 + 37 + func (s Set[T]) Len() int { 38 + return len(s.data) 39 + } 40 + 41 + func (s Set[T]) IsEmpty() bool { 42 + return len(s.data) == 0 43 + } 44 + 45 + func (s *Set[T]) Clear() { 46 + s.data = make(map[T]struct{}) 47 + } 48 + 49 + func (s Set[T]) All() iter.Seq[T] { 50 + return func(yield func(T) bool) { 51 + for item := range s.data { 52 + if !yield(item) { 53 + return 54 + } 55 + } 56 + } 57 + } 58 + 59 + func (s Set[T]) Clone() Set[T] { 60 + return Set[T]{ 61 + data: maps.Clone(s.data), 62 + } 63 + } 64 + 65 + func (s Set[T]) Union(other Set[T]) iter.Seq[T] { 66 + if s.Len() >= other.Len() { 67 + return chain(s.All(), other.Difference(s)) 68 + } else { 69 + return chain(other.All(), s.Difference(other)) 70 + } 71 + } 72 + 73 + func chain[T any](seqs ...iter.Seq[T]) iter.Seq[T] { 74 + return func(yield func(T) bool) { 75 + for _, seq := range seqs { 76 + for item := range seq { 77 + if !yield(item) { 78 + return 79 + } 80 + } 81 + } 82 + } 83 + } 84 + 85 + func (s Set[T]) Intersection(other Set[T]) iter.Seq[T] { 86 + return func(yield func(T) bool) { 87 + for item := range s.data { 88 + if other.Contains(item) { 89 + if !yield(item) { 90 + return 91 + } 92 + } 93 + } 94 + } 95 + } 96 + 97 + func (s Set[T]) Difference(other Set[T]) iter.Seq[T] { 98 + return func(yield func(T) bool) { 99 + for item := range s.data { 100 + if !other.Contains(item) { 101 + if !yield(item) { 102 + return 103 + } 104 + } 105 + } 106 + } 107 + } 108 + 109 + func (s Set[T]) SymmetricDifference(other Set[T]) iter.Seq[T] { 110 + return func(yield func(T) bool) { 111 + for item := range s.data { 112 + if !other.Contains(item) { 113 + if !yield(item) { 114 + return 115 + } 116 + } 117 + } 118 + for item := range other.data { 119 + if !s.Contains(item) { 120 + if !yield(item) { 121 + return 122 + } 123 + } 124 + } 125 + } 126 + } 127 + 128 + func (s Set[T]) IsSubset(other Set[T]) bool { 129 + for item := range s.data { 130 + if !other.Contains(item) { 131 + return false 132 + } 133 + } 134 + return true 135 + } 136 + 137 + func (s Set[T]) IsSuperset(other Set[T]) bool { 138 + return other.IsSubset(s) 139 + } 140 + 141 + func (s Set[T]) IsDisjoint(other Set[T]) bool { 142 + for item := range s.data { 143 + if other.Contains(item) { 144 + return false 145 + } 146 + } 147 + return true 148 + } 149 + 150 + func (s Set[T]) Equal(other Set[T]) bool { 151 + if s.Len() != other.Len() { 152 + return false 153 + } 154 + for item := range s.data { 155 + if !other.Contains(item) { 156 + return false 157 + } 158 + } 159 + return true 160 + } 161 + 162 + func Collect[T comparable](seq iter.Seq[T]) Set[T] { 163 + result := New[T]() 164 + for item := range seq { 165 + result.Insert(item) 166 + } 167 + return result 168 + }
+240
set_test.go
···
··· 1 + package set 2 + 3 + import ( 4 + "slices" 5 + "testing" 6 + ) 7 + 8 + func TestNew(t *testing.T) { 9 + s := New[int]() 10 + if s.Len() != 0 { 11 + t.Errorf("New set should be empty, got length %d", s.Len()) 12 + } 13 + if !s.IsEmpty() { 14 + t.Error("New set should be empty") 15 + } 16 + } 17 + 18 + func TestFromSlice(t *testing.T) { 19 + s := Collect(slices.Values([]int{1, 2, 3, 2, 1})) 20 + if s.Len() != 3 { 21 + t.Errorf("Expected length 3, got %d", s.Len()) 22 + } 23 + if !s.Contains(1) || !s.Contains(2) || !s.Contains(3) { 24 + t.Error("Set should contain all unique elements from slice") 25 + } 26 + } 27 + 28 + func TestInsert(t *testing.T) { 29 + s := New[string]() 30 + 31 + if !s.Insert("hello") { 32 + t.Error("First insert should return true") 33 + } 34 + if s.Insert("hello") { 35 + t.Error("Duplicate insert should return false") 36 + } 37 + if s.Len() != 1 { 38 + t.Errorf("Expected length 1, got %d", s.Len()) 39 + } 40 + } 41 + 42 + func TestRemove(t *testing.T) { 43 + s := Collect(slices.Values([]int{1, 2, 3})) 44 + 45 + if !s.Remove(2) { 46 + t.Error("Remove existing element should return true") 47 + } 48 + if s.Remove(2) { 49 + t.Error("Remove non-existing element should return false") 50 + } 51 + if s.Contains(2) { 52 + t.Error("Element should be removed") 53 + } 54 + if s.Len() != 2 { 55 + t.Errorf("Expected length 2, got %d", s.Len()) 56 + } 57 + } 58 + 59 + func TestContains(t *testing.T) { 60 + s := Collect(slices.Values([]int{1, 2, 3})) 61 + 62 + if !s.Contains(1) { 63 + t.Error("Should contain 1") 64 + } 65 + if s.Contains(4) { 66 + t.Error("Should not contain 4") 67 + } 68 + } 69 + 70 + func TestClear(t *testing.T) { 71 + s := Collect(slices.Values([]int{1, 2, 3})) 72 + s.Clear() 73 + 74 + if !s.IsEmpty() { 75 + t.Error("Set should be empty after clear") 76 + } 77 + if s.Len() != 0 { 78 + t.Errorf("Expected length 0, got %d", s.Len()) 79 + } 80 + } 81 + 82 + func TestIterator(t *testing.T) { 83 + s := Collect(slices.Values([]int{1, 2, 3})) 84 + var items []int 85 + 86 + for item := range s.All() { 87 + items = append(items, item) 88 + } 89 + 90 + slices.Sort(items) 91 + expected := []int{1, 2, 3} 92 + if !slices.Equal(items, expected) { 93 + t.Errorf("Expected %v, got %v", expected, items) 94 + } 95 + } 96 + 97 + func TestClone(t *testing.T) { 98 + s1 := Collect(slices.Values([]int{1, 2, 3})) 99 + s2 := s1.Clone() 100 + 101 + if !s1.Equal(s2) { 102 + t.Error("Cloned set should be equal to original") 103 + } 104 + 105 + s2.Insert(4) 106 + if s1.Contains(4) { 107 + t.Error("Modifying clone should not affect original") 108 + } 109 + } 110 + 111 + func TestUnion(t *testing.T) { 112 + s1 := Collect(slices.Values([]int{1, 2})) 113 + s2 := Collect(slices.Values([]int{2, 3})) 114 + 115 + result := Collect(s1.Union(s2)) 116 + expected := Collect(slices.Values([]int{1, 2, 3})) 117 + 118 + if !result.Equal(expected) { 119 + t.Errorf("Expected %v, got %v", expected, result) 120 + } 121 + } 122 + 123 + func TestIntersection(t *testing.T) { 124 + s1 := Collect(slices.Values([]int{1, 2, 3})) 125 + s2 := Collect(slices.Values([]int{2, 3, 4})) 126 + 127 + expected := Collect(slices.Values([]int{2, 3})) 128 + result := Collect(s1.Intersection(s2)) 129 + 130 + if !result.Equal(expected) { 131 + t.Errorf("Expected %v, got %v", expected, result) 132 + } 133 + } 134 + 135 + func TestDifference(t *testing.T) { 136 + s1 := Collect(slices.Values([]int{1, 2, 3})) 137 + s2 := Collect(slices.Values([]int{2, 3, 4})) 138 + 139 + expected := Collect(slices.Values([]int{1})) 140 + result := Collect(s1.Difference(s2)) 141 + 142 + if !result.Equal(expected) { 143 + t.Errorf("Expected %v, got %v", expected, result) 144 + } 145 + } 146 + 147 + func TestSymmetricDifference(t *testing.T) { 148 + s1 := Collect(slices.Values([]int{1, 2, 3})) 149 + s2 := Collect(slices.Values([]int{2, 3, 4})) 150 + 151 + expected := Collect(slices.Values([]int{1, 4})) 152 + result := Collect(s1.SymmetricDifference(s2)) 153 + 154 + if !result.Equal(expected) { 155 + t.Errorf("Expected %v, got %v", expected, result) 156 + } 157 + } 158 + 159 + func TestSymmetricDifferenceCommutativeProperty(t *testing.T) { 160 + s1 := Collect(slices.Values([]int{1, 2, 3})) 161 + s2 := Collect(slices.Values([]int{2, 3, 4})) 162 + 163 + result1 := Collect(s1.SymmetricDifference(s2)) 164 + result2 := Collect(s2.SymmetricDifference(s1)) 165 + 166 + if !result1.Equal(result2) { 167 + t.Errorf("Expected %v, got %v", result1, result2) 168 + } 169 + } 170 + 171 + func TestIsSubset(t *testing.T) { 172 + s1 := Collect(slices.Values([]int{1, 2})) 173 + s2 := Collect(slices.Values([]int{1, 2, 3})) 174 + 175 + if !s1.IsSubset(s2) { 176 + t.Error("s1 should be subset of s2") 177 + } 178 + if s2.IsSubset(s1) { 179 + t.Error("s2 should not be subset of s1") 180 + } 181 + } 182 + 183 + func TestIsSuperset(t *testing.T) { 184 + s1 := Collect(slices.Values([]int{1, 2, 3})) 185 + s2 := Collect(slices.Values([]int{1, 2})) 186 + 187 + if !s1.IsSuperset(s2) { 188 + t.Error("s1 should be superset of s2") 189 + } 190 + if s2.IsSuperset(s1) { 191 + t.Error("s2 should not be superset of s1") 192 + } 193 + } 194 + 195 + func TestIsDisjoint(t *testing.T) { 196 + s1 := Collect(slices.Values([]int{1, 2})) 197 + s2 := Collect(slices.Values([]int{3, 4})) 198 + s3 := Collect(slices.Values([]int{2, 3})) 199 + 200 + if !s1.IsDisjoint(s2) { 201 + t.Error("s1 and s2 should be disjoint") 202 + } 203 + if s1.IsDisjoint(s3) { 204 + t.Error("s1 and s3 should not be disjoint") 205 + } 206 + } 207 + 208 + func TestEqual(t *testing.T) { 209 + s1 := Collect(slices.Values([]int{1, 2, 3})) 210 + s2 := Collect(slices.Values([]int{3, 2, 1})) 211 + s3 := Collect(slices.Values([]int{1, 2})) 212 + 213 + if !s1.Equal(s2) { 214 + t.Error("s1 and s2 should be equal") 215 + } 216 + if s1.Equal(s3) { 217 + t.Error("s1 and s3 should not be equal") 218 + } 219 + } 220 + 221 + func TestCollect(t *testing.T) { 222 + s1 := Collect(slices.Values([]int{1, 2})) 223 + s2 := Collect(slices.Values([]int{2, 3})) 224 + 225 + unionSet := Collect(s1.Union(s2)) 226 + if unionSet.Len() != 3 { 227 + t.Errorf("Expected union set length 3, got %d", unionSet.Len()) 228 + } 229 + if !unionSet.Contains(1) || !unionSet.Contains(2) || !unionSet.Contains(3) { 230 + t.Error("Union set should contain 1, 2, and 3") 231 + } 232 + 233 + diffSet := Collect(s1.Difference(s2)) 234 + if diffSet.Len() != 1 { 235 + t.Errorf("Expected difference set length 1, got %d", diffSet.Len()) 236 + } 237 + if !diffSet.Contains(1) { 238 + t.Error("Difference set should contain 1") 239 + } 240 + }