From 43de73155fdc3970252d235da1f5d444a8cbb4ab Mon Sep 17 00:00:00 2001 From: spekary Date: Fri, 24 Jan 2025 18:06:00 -0800 Subject: [PATCH] Handling nil receiver for Set types --- safe_map_test.go | 2 +- set.go | 56 ++++++++++++--- set_ordered.go | 172 ++++++++++++++++++++++++++++++++++++++++++++ set_ordered_test.go | 162 +++++++++++++++++++++++++++++++++++++++++ set_test.go | 49 +++++++++++-- seti.go | 2 +- seti_test.go | 65 ++++++++++++----- 7 files changed, 475 insertions(+), 33 deletions(-) create mode 100644 set_ordered.go create mode 100644 set_ordered_test.go diff --git a/safe_map_test.go b/safe_map_test.go index 7890853..ac39fb7 100644 --- a/safe_map_test.go +++ b/safe_map_test.go @@ -16,7 +16,7 @@ func init() { gob.Register(new(SafeMap[string, int])) } -func TestNil(t *testing.T) { +func TestSafeMap_Nil(t *testing.T) { var m SafeMap[string, int] assert.False(t, m.Has("z")) diff --git a/set.go b/set.go index 09742bf..3a9a576 100644 --- a/set.go +++ b/set.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "iter" - "slices" ) // Set is a collection that keeps track of membership. @@ -33,11 +32,17 @@ func NewSet[K comparable](values ...K) *Set[K] { // Clear resets the set to an empty set func (m *Set[K]) Clear() { + if m == nil { + return + } m.items = nil } // Len returns the number of items in the set func (m *Set[K]) Len() int { + if m == nil || m.items == nil { + return 0 + } return m.items.Len() } @@ -46,7 +51,7 @@ func (m *Set[K]) Len() int { // While its safe to call methods of the set from within the Range function, its discouraged. // If you ever switch to one of the SafeSet sets, it will cause a deadlock. func (m *Set[K]) Range(f func(k K) bool) { - if m == nil || m.items == nil { + if m.Len() == 0 { return } for k := range m.items { @@ -58,22 +63,34 @@ func (m *Set[K]) Range(f func(k K) bool) { // Has returns true if the value exists in the set. func (m *Set[K]) Has(k K) bool { + if m.Len() == 0 { + return false + } return m.items.Has(k) } // Delete removes the value from the set. If the value does not exist, nothing happens. func (m *Set[K]) Delete(k K) { + if m.Len() == 0 { + return + } m.items.Delete(k) } // Values returns a new slice containing the values of the set. func (m *Set[K]) Values() []K { + if m.Len() == 0 { + return nil + } return m.items.Keys() } // Add adds the value to the set. // If the value already exists, nothing changes. func (m *Set[K]) Add(k ...K) SetI[K] { + if m == nil { + panic("cannot add values to a nil Set") + } if m.items == nil { m.items = make(map[K]struct{}) } @@ -91,6 +108,9 @@ func (m *Set[K]) Merge(in SetI[K]) { // Copy adds the values from in to the set. func (m *Set[K]) Copy(in SetI[K]) { + if m == nil { + panic("cannot copy to a nil Set") + } if in == nil || in.Len() == 0 { return } @@ -105,6 +125,9 @@ func (m *Set[K]) Copy(in SetI[K]) { // Equal returns true if the two sets are the same length and contain the same values. func (m *Set[K]) Equal(m2 SetI[K]) bool { + if m == nil { + return m2.Len() == 0 + } if m.Len() != m2.Len() { return false } @@ -148,6 +171,9 @@ func (m *Set[K]) UnmarshalBinary(data []byte) (err error) { // MarshalJSON implements the json.Marshaler interface to convert the map into a JSON object. func (m *Set[K]) MarshalJSON() (out []byte, err error) { + if m.Len() == 0 { + return []byte("[]"), nil + } return json.Marshal(m.Values()) } @@ -165,12 +191,13 @@ func (m *Set[K]) UnmarshalJSON(in []byte) (err error) { // String returns the set as a string in a predictable way. func (m *Set[K]) String() string { - vals := slices.Clone(m.Values()) ret := "{" - for i, v := range vals { - ret += fmt.Sprintf("%#v", v) - if i < m.Len()-1 { - ret += "," + if m.Len() != 0 { + for i, v := range m.Values() { + ret += fmt.Sprintf("%#v", v) + if i < m.Len()-1 { + ret += "," + } } } ret += "}" @@ -179,12 +206,20 @@ func (m *Set[K]) String() string { // All returns an iterator over all the items in the set. Order is not determinate. func (m *Set[K]) All() iter.Seq[K] { + if m.Len() == 0 { + return func(yield func(K) bool) { + return + } + } return m.items.KeysIter() } // Insert adds the values from seq to the map. // Duplicates are overridden. func (m *Set[K]) Insert(seq iter.Seq[K]) { + if m == nil { + panic("cannot insert into a nil Set") + } if m.items == nil { m.items = NewStdMap[K, struct{}]() } @@ -206,12 +241,17 @@ func CollectSet[K comparable](seq iter.Seq[K]) *Set[K] { // the new keys and values are set using ordinary assignment. func (m *Set[K]) Clone() *Set[K] { m1 := NewSet[K]() - m1.items = m.items.Clone() + if m.Len() != 0 { + m1.items = m.items.Clone() + } return m1 } // DeleteFunc deletes any values for which del returns true. func (m *Set[K]) DeleteFunc(del func(K) bool) { + if m.Len() == 0 { + return + } del2 := func(k K, s struct{}) bool { return del(k) } diff --git a/set_ordered.go b/set_ordered.go new file mode 100644 index 0000000..2902952 --- /dev/null +++ b/set_ordered.go @@ -0,0 +1,172 @@ +package maps + +import ( + "cmp" + "encoding/json" + "fmt" + "iter" + "slices" +) + +// OrderedSet implements a set of values that will be returned sorted. +// +// Ordered sets are useful when in general you don't care about ordering, but +// you would still like the same values to be presented in the same order when +// they are asked for. Examples include test code, iterators, values stored in a database, +// or values that will be presented to a user. +type OrderedSet[K cmp.Ordered] struct { + Set[K] +} + +func NewOrderedSet[K cmp.Ordered](values ...K) *OrderedSet[K] { + s := new(OrderedSet[K]) + for _, k := range values { + s.Add(k) + } + return s +} + +// Clear resets the set to an empty set +func (m *OrderedSet[K]) Clear() { + if m == nil { + return + } + m.Set.Clear() +} + +// Len returns the number of items in the set +func (m *OrderedSet[K]) Len() int { + if m == nil || m.items == nil { + return 0 + } + return m.Set.Len() +} + +// Range will range over the values in order. +func (m *OrderedSet[K]) Range(f func(k K) bool) { + if m.Len() == 0 { + return + } + values := m.Values() + for _, k := range values { + if !f(k) { + break + } + } +} + +// Has returns true if the value exists in the set. +func (m *OrderedSet[K]) Has(k K) bool { + if m.Len() == 0 { + return false + } + return m.Set.Has(k) +} + +// Delete removes the value from the set. If the value does not exist, nothing happens. +func (m *OrderedSet[K]) Delete(k K) { + if m.Len() == 0 { + return + } + m.Set.Delete(k) +} + +// Equal returns true if the two sets are the same length and contain the same values. +func (m *OrderedSet[K]) Equal(m2 SetI[K]) bool { + if m == nil { + return m2.Len() == 0 + } + return m.Set.Equal(m2) +} + +// Values returns a new slice containing the values of the set. +func (m *OrderedSet[K]) Values() []K { + if m.Len() == 0 { + return nil + } + v := m.items.Keys() + slices.Sort(v) + return v +} + +// Add adds the value to the set. +// If the value already exists, nothing changes. +func (m *OrderedSet[K]) Add(k ...K) SetI[K] { + if m == nil { + panic("cannot add values to a nil Set") + } + m.Set.Add(k...) + return m +} + +// Copy adds the values from in to the set. +func (m *OrderedSet[K]) Copy(in SetI[K]) { + if m == nil { + panic("cannot copy to a nil Set") + } + m.Set.Copy(in) +} + +// MarshalJSON implements the json.Marshaler interface to convert the map into a JSON object. +func (m *OrderedSet[K]) MarshalJSON() (out []byte, err error) { + if m.Len() == 0 { + return []byte("[]"), nil + } + return json.Marshal(m.Values()) +} + +// All returns an iterator over all the items in the set. Order is determinate. +func (m *OrderedSet[K]) All() iter.Seq[K] { + if m.Len() == 0 { + return func(yield func(K) bool) { + return + } + } + v := m.Values() + return slices.Values(v) +} + +// Insert adds the values from seq to the map. +// Duplicates are overridden. +func (m *OrderedSet[K]) Insert(seq iter.Seq[K]) { + if m == nil { + panic("cannot insert into a nil Set") + } + m.Set.Insert(seq) +} + +// Clone returns a copy of the Set. This is a shallow clone: +// the new keys and values are set using ordinary assignment. +func (m *OrderedSet[K]) Clone() *OrderedSet[K] { + m1 := NewOrderedSet[K]() + if m != nil { + m1.items = m.items.Clone() + } + return m1 +} + +// DeleteFunc deletes any values for which del returns true. +func (m *OrderedSet[K]) DeleteFunc(del func(K) bool) { + if m.Len() == 0 { + return + } + m.Set.DeleteFunc(del) +} + +// String returns the set as a string. +func (m *OrderedSet[K]) String() string { + if m == nil { + return "{}" + } + ret := "{" + if m.Len() != 0 { + for i, v := range m.Values() { + ret += fmt.Sprintf("%#v", v) + if i < m.Len()-1 { + ret += "," + } + } + } + ret += "}" + return ret +} diff --git a/set_ordered_test.go b/set_ordered_test.go new file mode 100644 index 0000000..fab1235 --- /dev/null +++ b/set_ordered_test.go @@ -0,0 +1,162 @@ +package maps + +import ( + "cmp" + "encoding/gob" + "fmt" + "github.com/stretchr/testify/assert" + "slices" + "testing" +) + +type orderedSetT = OrderedSet[string] +type orderedSetTI = SetI[string] + +func TestOrderedSet_SetI(t *testing.T) { + runSetITests[orderedSetT](t, makeSetI[orderedSetT]) +} + +func init() { + gob.Register(new(orderedSetT)) +} + +func TestOrderedSet_Values(t *testing.T) { + type testCase[K cmp.Ordered] struct { + name string + m *OrderedSet[K] + want []K + } + tests := []testCase[int]{ + {"none", NewOrderedSet[int](), []int(nil)}, + {"one", NewOrderedSet[int](1), []int{1}}, + {"three", NewOrderedSet[int](1, 2, 3), []int{1, 2, 3}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, tt.m.Values(), "Values()") + }) + } +} + +func TestOrderedSet_MarshalJSON(t *testing.T) { + type testCase[K cmp.Ordered] struct { + name string + m *OrderedSet[K] + wantOut string + wantErr bool + } + tests := []testCase[string]{ + {"zero", NewOrderedSet[string](), `[]`, false}, + {"one", NewOrderedSet("a"), `["a"]`, false}, + {"three", NewOrderedSet("a", "c", "b"), `["a","b","c"]`, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b, err := tt.m.MarshalJSON() + gotOut := string(b) + assert.Equal(t, tt.wantErr, err != nil) + assert.Equalf(t, tt.wantOut, gotOut, "MarshalJSON()") + }) + } +} + +func TestOrderedSet_All(t *testing.T) { + set := NewOrderedSet[int]() + set.Add(5) + set.Add(3) + set.Add(8) + set.Add(1) + + iterator := set.All() + var result []int + + for v := range iterator { + result = append(result, v) + } + + expected := []int{1, 3, 5, 8} + assert.Equal(t, expected, result) +} + +func TestOrderedSet_Range(t *testing.T) { + type testCase[K cmp.Ordered] struct { + name string + m *OrderedSet[K] + expected []int + } + tests := []testCase[int]{ + {"none", NewOrderedSet[int](), []int(nil)}, + {"one", NewOrderedSet[int](1), []int{1}}, + {"three", NewOrderedSet[int](1, 2, 3), []int{1, 2, 3}}, + {"four", NewOrderedSet[int](4, 3, 2, 1), []int{1, 2, 3}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var values []int + var count int + tt.m.Range(func(i int) bool { + values = append(values, i) + count++ + return count < 3 + }) + assert.Equal(t, tt.expected, values) + }) + } +} + +func TestOrderedSet_Clone(t *testing.T) { + m1 := NewOrderedSet("a", "b", "c") + m2 := m1.Clone() + assert.True(t, m1.Equal(m2)) + + var m3 *OrderedSet[string] + m4 := m3.Clone() + m3.Equal(m4) + assert.True(t, m3.Equal(m4)) + + m2.Add("d") + assert.False(t, m1.Equal(m2)) +} + +func TestOrderedSet_Nil(t *testing.T) { + t.Run("Nil", func(t *testing.T) { + var m1, m2 *OrderedSet[string] + + assert.Equal(t, 0, m1.Len()) + m1.Clear() + assert.True(t, m1.Equal(m2)) + m3 := m2.Clone() + assert.True(t, m1.Equal(m3)) + m3.Add("a") + assert.False(t, m1.Equal(m3)) + m1.Range(func(k string) bool { + assert.Fail(t, "no range should happen") + return false + }) + assert.False(t, m1.Has("b")) + m1.Delete("a") + assert.Empty(t, m1.Values()) + assert.Equal(t, "{}", m1.String()) + m1.DeleteFunc(func(k string) bool { + return false + }) + for _ = range m1.All() { + assert.Fail(t, "no range should happen") + } + assert.Panics(t, func() { + m1.Insert(slices.Values([]string{"a"})) + }) + assert.Panics(t, func() { + m1.Add("a") + }) + assert.Panics(t, func() { + m1.Copy(m2) + }) + }) +} + +func ExampleOrderedSet_String() { + m := NewOrderedSet("a", "c", "a", "b") + fmt.Print(m.String()) + // Output: {"a","b","c"} +} diff --git a/set_test.go b/set_test.go index a787c1a..4113b97 100644 --- a/set_test.go +++ b/set_test.go @@ -4,7 +4,7 @@ import ( "encoding/gob" "fmt" "github.com/stretchr/testify/assert" - "sort" + "slices" "testing" ) @@ -22,12 +22,8 @@ func init() { func ExampleSet_String() { m := new(Set[string]) m.Add("a") - m.Add("b") - m.Add("a") - v := m.Values() - sort.Strings(v) - fmt.Print(v) - // Output: [a b] + fmt.Print(m.String()) + // Output: {"a"} } func TestCollectSet(t *testing.T) { @@ -43,3 +39,42 @@ func TestSet_Clone(t *testing.T) { m3 := m2.Clone() assert.True(t, m1.Equal(m3)) } + +func TestSet_Nil(t *testing.T) { + t.Run("Nil", func(t *testing.T) { + var m1, m2 *Set[string] + + assert.Equal(t, 0, m1.Len()) + m1.Clear() + assert.True(t, m1.Equal(m2)) + m3 := m2.Clone() + assert.True(t, m1.Equal(m3)) + m3.Add("a") + assert.False(t, m1.Equal(m3)) + m1.Range(func(k string) bool { + assert.Fail(t, "no range should happen") + return false + }) + assert.False(t, m1.Has("b")) + m1.Delete("a") + assert.Empty(t, m1.Values()) + assert.Equal(t, "{}", m1.String()) + m1.DeleteFunc(func(k string) bool { + return false + }) + + for _ = range m1.All() { + assert.Fail(t, "no range should happen") + } + assert.Panics(t, func() { + m1.Insert(slices.Values([]string{"a"})) + }) + assert.Panics(t, func() { + m1.Add("a") + }) + assert.Panics(t, func() { + m1.Copy(m2) + }) + + }) +} diff --git a/seti.go b/seti.go index f5db475..15547df 100644 --- a/seti.go +++ b/seti.go @@ -7,6 +7,7 @@ type SetI[K comparable] interface { Add(k ...K) SetI[K] Clear() Len() int + Copy(in SetI[K]) Range(func(k K) bool) Has(k K) bool Values() []K @@ -15,6 +16,5 @@ type SetI[K comparable] interface { Delete(k K) All() iter.Seq[K] Insert(seq iter.Seq[K]) - Clone() *Set[K] DeleteFunc(del func(K) bool) } diff --git a/seti_test.go b/seti_test.go index 9395144..75469e3 100644 --- a/seti_test.go +++ b/seti_test.go @@ -37,6 +37,7 @@ func runSetITests[M any](t *testing.T, f makeSetF) { testSetAll(t, f) testSetInsert(t, f) testSetDeleteFunc(t, f) + testSetCopy(t, f) } func testSetClear(t *testing.T, f makeSetF) { @@ -103,28 +104,25 @@ func testSetHas(t *testing.T, f makeSetF) { func testSetRange(t *testing.T, f makeSetF) { tests := []struct { - name string - m setTI - expected int + name string + m setTI + expectedLen int }{ - {"0", f(), 0}, - {"1", f("a"), 1}, - {"2", f("a", "b"), 2}, - {"3", f("a", "b", "c"), 2}, + {"none", f(), 0}, + {"one", f("a"), 1}, + {"three", f("b", "a", "c"), 3}, + {"four", f("d", "a", "c", "b"), 3}, } for _, tt := range tests { t.Run("Range "+tt.name, func(t *testing.T) { - count := 0 - tt.m.Range(func(k string) bool { + var values []string + var count int + tt.m.Range(func(i string) bool { + values = append(values, i) count++ - if count > 1 { - return false - } - return true + return count < 3 }) - if count != tt.expected { - t.Errorf("Expected %d, got %d", tt.expected, count) - } + assert.Equal(t, tt.expectedLen, len(values)) }) } } @@ -199,6 +197,12 @@ func testSetMarshalJSON(t *testing.T, f makeSetF) { assert.NoError(t, err) // Note: The below output is what is produced, but isn't guaranteed. go seems to currently be sorting keys assert.Contains(t, string(s), `"a"`) + + m = f() + s, err = json.Marshal(m) + assert.NoError(t, err) + // Note: The below output is what is produced, but isn't guaranteed. go seems to currently be sorting keys + assert.Equal(t, "[]", string(s)) }) } @@ -212,6 +216,26 @@ func testSetUnmarshalJSON[M any](t *testing.T, f makeSetF) { m2 := i.(SetI[string]) assert.True(t, m2.Has("c")) + + b = []byte(`[]`) + + var m3 M + + json.Unmarshal(b, &m3) + i = &m3 + m4 := i.(SetI[string]) + + assert.Equal(t, 0, m4.Len()) + + b = []byte(`["d"]`) + + // Unmarshalling into an existing set should add values + json.Unmarshal(b, &m) + i = &m + m5 := i.(SetI[string]) + + assert.Equal(t, 4, m5.Len()) + } func testSetDelete(t *testing.T, f makeSetF) { @@ -262,3 +286,12 @@ func testSetDeleteFunc(t *testing.T, f makeSetF) { assert.Equal(t, 1, m1.Len()) }) } + +func testSetCopy(t *testing.T, f makeSetF) { + t.Run("DeleteFunc", func(t *testing.T) { + m1 := f("a", "b", "c") + m2 := f() + m2.Copy(m1) + assert.True(t, m1.Equal(m2)) + }) +}