diff --git a/sei-db/common/metrics/buckets.go b/sei-db/common/metrics/buckets.go new file mode 100644 index 0000000000..5bb4004416 --- /dev/null +++ b/sei-db/common/metrics/buckets.go @@ -0,0 +1,25 @@ +package metrics + +import "github.com/sei-protocol/sei-chain/sei-db/common/unit" + +// Shared histogram bucket boundaries for use across the codebase. +// The OTel defaults are too coarse for meaningful percentile queries in Grafana. + +// LatencyBuckets covers 10μs to 5 minutes — wide enough for both fast key +// lookups and slow compactions/flushes without needing per-metric tuning. +var LatencyBuckets = []float64{ + 0.00001, 0.000025, 0.00005, 0.0001, 0.00025, 0.0005, // 10μs–500μs + 0.001, 0.0025, 0.005, 0.01, 0.025, 0.05, // 1ms–50ms + 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30, 60, 120, 300, // 100ms–5min +} + +// ByteSizeBuckets covers 256B to 1GB for data size histograms. +var ByteSizeBuckets = []float64{ + 256, unit.KB, 4 * unit.KB, 16 * unit.KB, 64 * unit.KB, 256 * unit.KB, + unit.MB, 4 * unit.MB, 16 * unit.MB, 64 * unit.MB, 256 * unit.MB, unit.GB, +} + +// CountBuckets covers 1 to 1M for per-operation step/iteration counts. +var CountBuckets = []float64{ + 1, 5, 10, 50, 100, 500, 1000, 5000, 10000, 100000, 1000000, +} diff --git a/sei-db/common/threading/pool.go b/sei-db/common/threading/pool.go index b86f85b9c3..4af9cebfa1 100644 --- a/sei-db/common/threading/pool.go +++ b/sei-db/common/threading/pool.go @@ -9,5 +9,12 @@ type Pool interface { // If Submit is called concurrently with or after shutdown (i.e. when ctx is done/cancelled), the task may // be silently dropped. Callers that need a guarantee of execution must // ensure Submit happens-before shutdown. + // + // This method is permitted to return an error only under the following conditions: + // - the pool is shutting down (i.e. its context is done/cancelled) + // - the provided ctx parameter is done/cancelled before this method returns + // - invalid input (e.g. the task is nil) + // + // If this method returns an error, the task may or may not have been executed. Submit(ctx context.Context, task func()) error } diff --git a/sei-db/db_engine/dbcache/cache.go b/sei-db/db_engine/dbcache/cache.go index da65017fdf..ccbaf6464c 100644 --- a/sei-db/db_engine/dbcache/cache.go +++ b/sei-db/db_engine/dbcache/cache.go @@ -1,6 +1,11 @@ package dbcache import ( + "context" + "fmt" + "time" + + "github.com/sei-protocol/sei-chain/sei-db/common/threading" "github.com/sei-protocol/sei-chain/sei-db/db_engine/types" ) @@ -22,6 +27,9 @@ type Reader func(key []byte) (value []byte, found bool, err error) // - the Reader method returns an error (for methods that accpet a Reader) // - the cache is shutting down // - the cache's work pools are shutting down +// +// Cache errors are are generally not recoverable, and it should be assumed that a cache that has returned an error +// is in a corrupted state, and should be discarded. type Cache interface { // Get returns the value for the given key, or (nil, false, nil) if not found. @@ -64,6 +72,14 @@ type Cache interface { BatchSet(updates []CacheUpdate) error } +// DefaultEstimatedOverheadPerEntry is a rough estimate of the fixed heap overhead per cache entry +// on a 64-bit architecture (amd64/arm64). It accounts for the shardEntry struct (48 B), +// list.Element (48 B), lruQueueEntry (32 B), two map-entry costs (~64 B), string allocation +// rounding (~16 B), and a margin for the duplicate key copy stored in the LRU. Derived from +// static analysis of Go size classes and map bucket layout; validate experimentally for your +// target platform. +const DefaultEstimatedOverheadPerEntry uint64 = 250 + // CacheUpdate describes a single key-value mutation to apply to the cache. type CacheUpdate struct { // The key to update. @@ -76,3 +92,35 @@ type CacheUpdate struct { func (u *CacheUpdate) IsDelete() bool { return u.Value == nil } + +// BuildCache creates a new Cache. +func BuildCache( + ctx context.Context, + shardCount uint64, + maxSize uint64, + readPool threading.Pool, + miscPool threading.Pool, + estimatedOverheadPerEntry uint64, + cacheName string, + metricsScrapeInterval time.Duration, +) (Cache, error) { + + if maxSize == 0 { + return NewNoOpCache(), nil + } + + cache, err := NewStandardCache( + ctx, + shardCount, + maxSize, + readPool, + miscPool, + estimatedOverheadPerEntry, + cacheName, + metricsScrapeInterval, + ) + if err != nil { + return nil, fmt.Errorf("failed to create cache: %w", err) + } + return cache, nil +} diff --git a/sei-db/db_engine/dbcache/cache_impl.go b/sei-db/db_engine/dbcache/cache_impl.go new file mode 100644 index 0000000000..1292f74caf --- /dev/null +++ b/sei-db/db_engine/dbcache/cache_impl.go @@ -0,0 +1,194 @@ +package dbcache + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/sei-protocol/sei-chain/sei-db/common/threading" + "github.com/sei-protocol/sei-chain/sei-db/db_engine/types" +) + +var _ Cache = (*cache)(nil) + +// A standard implementation of a flatcache. +type cache struct { + ctx context.Context + + // A utility for assigning keys to shard indices. + shardManager *shardManager + + // The shards in the cache. + shards []*shard + + // A pool for asynchronous reads. + readPool threading.Pool + + // A pool for miscellaneous operations that are neither computationally intensive nor IO bound. + miscPool threading.Pool +} + +// Creates a new Cache. If cacheName is non-empty, OTel metrics are enabled and the +// background size scrape runs every metricsScrapeInterval. +func NewStandardCache( + ctx context.Context, + // The number of shards in the cache. Must be a power of two and greater than 0. + shardCount uint64, + // The maximum size of the cache, in bytes. + maxSize uint64, + // A work pool for reading from the DB. + readPool threading.Pool, + // A work pool for miscellaneous operations that are neither computationally intensive nor IO bound. + miscPool threading.Pool, + // The estimated overhead per entry, in bytes. This is used to calculate the maximum size of the cache. + // This value should be derived experimentally, and may differ between different builds and architectures. + estimatedOverheadPerEntry uint64, + // Name used as the "cache" attribute on metrics. Empty string disables metrics. + cacheName string, + // How often to scrape cache size for metrics. Ignored if cacheName is empty. + metricsScrapeInterval time.Duration, +) (Cache, error) { + if shardCount == 0 || (shardCount&(shardCount-1)) != 0 { + return nil, ErrNumShardsNotPowerOfTwo + } + if maxSize == 0 { + return nil, fmt.Errorf("maxSize must be greater than 0") + } + + shardManager, err := newShardManager(shardCount) + if err != nil { + return nil, fmt.Errorf("failed to create shard manager: %w", err) + } + sizePerShard := maxSize / shardCount + if sizePerShard == 0 { + return nil, fmt.Errorf("maxSize must be greater than shardCount") + } + + shards := make([]*shard, shardCount) + for i := uint64(0); i < shardCount; i++ { + shards[i], err = NewShard(ctx, readPool, sizePerShard, estimatedOverheadPerEntry) + if err != nil { + return nil, fmt.Errorf("failed to create shard: %w", err) + } + } + + c := &cache{ + ctx: ctx, + shardManager: shardManager, + shards: shards, + readPool: readPool, + miscPool: miscPool, + } + + if cacheName != "" { + metrics := newCacheMetrics(ctx, cacheName, metricsScrapeInterval, c.getCacheSizeInfo) + for _, s := range c.shards { + s.metrics = metrics + } + } + + return c, nil +} + +func (c *cache) getCacheSizeInfo() (bytes uint64, entries uint64) { + for _, s := range c.shards { + b, e := s.getSizeInfo() + bytes += b + entries += e + } + return bytes, entries +} + +func (c *cache) BatchSet(updates []CacheUpdate) error { + // Sort entries by shard index so each shard is locked only once. + shardMap := make(map[uint64][]CacheUpdate) + for i := range updates { + idx := c.shardManager.Shard(updates[i].Key) + shardMap[idx] = append(shardMap[idx], updates[i]) + } + + var wg sync.WaitGroup + for shardIndex, shardEntries := range shardMap { + wg.Add(1) + err := c.miscPool.Submit(c.ctx, func() { + defer wg.Done() + c.shards[shardIndex].BatchSet(shardEntries) + }) + if err != nil { + return fmt.Errorf("failed to submit batch set: %w", err) + } + } + wg.Wait() + + return nil +} + +func (c *cache) BatchGet(read Reader, keys map[string]types.BatchGetResult) error { + work := make(map[uint64]map[string]types.BatchGetResult) + for key := range keys { + idx := c.shardManager.Shard([]byte(key)) + if work[idx] == nil { + work[idx] = make(map[string]types.BatchGetResult) + } + work[idx][key] = types.BatchGetResult{} + } + + var wg sync.WaitGroup + for shardIndex, subMap := range work { + wg.Add(1) + + err := c.miscPool.Submit(c.ctx, func() { + defer wg.Done() + err := c.shards[shardIndex].BatchGet(read, subMap) + if err != nil { + for key := range subMap { + subMap[key] = types.BatchGetResult{Error: err} + } + } + }) + if err != nil { + return fmt.Errorf("failed to submit batch get: %w", err) + } + } + wg.Wait() + + for _, subMap := range work { + for key, result := range subMap { + keys[key] = result + } + } + + return nil +} + +func (c *cache) Delete(key []byte) { + shardIndex := c.shardManager.Shard(key) + shard := c.shards[shardIndex] + shard.Delete(key) +} + +func (c *cache) Get(read Reader, key []byte, updateLru bool) ([]byte, bool, error) { + shardIndex := c.shardManager.Shard(key) + shard := c.shards[shardIndex] + + value, ok, err := shard.Get(read, key, updateLru) + if err != nil { + return nil, false, fmt.Errorf("failed to get value from shard: %w", err) + } + if !ok { + return nil, false, nil + } + return value, ok, nil +} + +func (c *cache) Set(key []byte, value []byte) { + shardIndex := c.shardManager.Shard(key) + shard := c.shards[shardIndex] + + if value == nil { + shard.Delete(key) + } else { + shard.Set(key, value) + } +} diff --git a/sei-db/db_engine/dbcache/cache_impl_test.go b/sei-db/db_engine/dbcache/cache_impl_test.go new file mode 100644 index 0000000000..5433019c93 --- /dev/null +++ b/sei-db/db_engine/dbcache/cache_impl_test.go @@ -0,0 +1,738 @@ +package dbcache + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/sei-protocol/sei-chain/sei-db/common/threading" + "github.com/sei-protocol/sei-chain/sei-db/db_engine/types" +) + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +func noopRead(key []byte) ([]byte, bool, error) { return nil, false, nil } + +func newTestCache(t *testing.T, store map[string][]byte, shardCount, maxSize uint64) (Cache, Reader) { + t.Helper() + read := func(key []byte) ([]byte, bool, error) { + v, ok := store[string(key)] + if !ok { + return nil, false, nil + } + return v, true, nil + } + pool := threading.NewAdHocPool() + c, err := NewStandardCache(context.Background(), shardCount, maxSize, pool, pool, 16, "", 0) + require.NoError(t, err) + return c, read +} + +// --------------------------------------------------------------------------- +// NewStandardCache — validation +// --------------------------------------------------------------------------- + +func TestNewStandardCacheValid(t *testing.T) { + pool := threading.NewAdHocPool() + c, err := NewStandardCache(context.Background(), 4, 1024, pool, pool, 16, "", 0) + require.NoError(t, err) + require.NotNil(t, c) +} + +func TestNewStandardCacheSingleShard(t *testing.T) { + pool := threading.NewAdHocPool() + c, err := NewStandardCache(context.Background(), 1, 1024, pool, pool, 16, "", 0) + require.NoError(t, err) + require.NotNil(t, c) +} + +func TestNewStandardCacheShardCountZero(t *testing.T) { + pool := threading.NewAdHocPool() + _, err := NewStandardCache(context.Background(), 0, 1024, pool, pool, 16, "", 0) + require.Error(t, err) +} + +func TestNewStandardCacheShardCountNotPowerOfTwo(t *testing.T) { + pool := threading.NewAdHocPool() + for _, n := range []uint64{3, 5, 6, 7, 9, 10} { + _, err := NewStandardCache(context.Background(), n, 1024, pool, pool, 16, "", 0) + require.Error(t, err, "shardCount=%d", n) + } +} + +func TestNewStandardCacheMaxSizeZero(t *testing.T) { + pool := threading.NewAdHocPool() + _, err := NewStandardCache(context.Background(), 4, 0, pool, pool, 16, "", 0) + require.Error(t, err) +} + +func TestNewStandardCacheMaxSizeLessThanShardCount(t *testing.T) { + pool := threading.NewAdHocPool() + // shardCount=4, maxSize=3 → sizePerShard=0 + _, err := NewStandardCache(context.Background(), 4, 3, pool, pool, 16, "", 0) + require.Error(t, err) +} + +func TestNewStandardCacheWithMetrics(t *testing.T) { + pool := threading.NewAdHocPool() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + c, err := NewStandardCache(ctx, 2, 1024, pool, pool, 0, "test-cache", time.Hour) + require.NoError(t, err) + require.NotNil(t, c) +} + +// --------------------------------------------------------------------------- +// Get +// --------------------------------------------------------------------------- + +func TestCacheGetFromDB(t *testing.T) { + store := map[string][]byte{"foo": []byte("bar")} + c, read := newTestCache(t, store, 4, 4096) + + val, found, err := c.Get(read, []byte("foo"), true) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, "bar", string(val)) +} + +func TestCacheGetNotFound(t *testing.T) { + c, read := newTestCache(t, map[string][]byte{}, 4, 4096) + + val, found, err := c.Get(read, []byte("missing"), true) + require.NoError(t, err) + require.False(t, found) + require.Nil(t, val) +} + +func TestCacheGetAfterSet(t *testing.T) { + c, read := newTestCache(t, map[string][]byte{}, 4, 4096) + + c.Set([]byte("k"), []byte("v")) + + val, found, err := c.Get(read, []byte("k"), true) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, "v", string(val)) +} + +func TestCacheGetAfterDelete(t *testing.T) { + store := map[string][]byte{"k": []byte("v")} + c, read := newTestCache(t, store, 4, 4096) + + // Warm the cache so the key is present before deleting. + _, _, err := c.Get(read, []byte("k"), true) + require.NoError(t, err) + + c.Delete([]byte("k")) + + val, found, err := c.Get(read, []byte("k"), true) + require.NoError(t, err) + require.False(t, found) + require.Nil(t, val) +} + +func TestCacheGetDBError(t *testing.T) { + dbErr := errors.New("db fail") + readFunc := func(key []byte) ([]byte, bool, error) { return nil, false, dbErr } + pool := threading.NewAdHocPool() + c, _ := NewStandardCache(context.Background(), 1, 4096, pool, pool, 0, "", 0) + + _, _, err := c.Get(readFunc, []byte("k"), true) + require.Error(t, err) + require.ErrorIs(t, err, dbErr) +} + +func TestCacheGetSameKeyConsistentShard(t *testing.T) { + var readCalls atomic.Int64 + readFunc := func(key []byte) ([]byte, bool, error) { + readCalls.Add(1) + return []byte("val"), true, nil + } + pool := threading.NewAdHocPool() + c, _ := NewStandardCache(context.Background(), 4, 4096, pool, pool, 0, "", 0) + + val1, _, _ := c.Get(readFunc, []byte("key"), true) + val2, _, _ := c.Get(readFunc, []byte("key"), true) + + require.Equal(t, string(val1), string(val2)) + require.Equal(t, int64(1), readCalls.Load(), "second Get should hit cache") +} + +// --------------------------------------------------------------------------- +// Set +// --------------------------------------------------------------------------- + +func TestCacheSetNewKey(t *testing.T) { + c, read := newTestCache(t, map[string][]byte{}, 4, 4096) + + c.Set([]byte("a"), []byte("1")) + + val, found, err := c.Get(read, []byte("a"), false) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, "1", string(val)) +} + +func TestCacheSetOverwrite(t *testing.T) { + c, read := newTestCache(t, map[string][]byte{}, 4, 4096) + + c.Set([]byte("a"), []byte("old")) + c.Set([]byte("a"), []byte("new")) + + val, found, err := c.Get(read, []byte("a"), false) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, "new", string(val)) +} + +func TestCacheSetNilValue(t *testing.T) { + c, read := newTestCache(t, map[string][]byte{}, 4, 4096) + + c.Set([]byte("k"), nil) + + val, found, err := c.Get(read, []byte("k"), false) + require.NoError(t, err) + require.False(t, found, "Set(key, nil) should be treated as a deletion") + require.Nil(t, val) +} + +func TestCacheSetNilConsistentWithBatchSet(t *testing.T) { + store := map[string][]byte{"a": []byte("orig-a"), "b": []byte("orig-b")} + + cSet, readSet := newTestCache(t, store, 1, 4096) + cBatch, readBatch := newTestCache(t, store, 1, 4096) + + // Warm both caches so the backing store value is loaded. + _, _, err := cSet.Get(readSet, []byte("a"), true) + require.NoError(t, err) + _, _, err = cBatch.Get(readBatch, []byte("b"), true) + require.NoError(t, err) + + // Delete via Set(key, nil) in one cache and BatchSet({key, nil}) in the other. + cSet.Set([]byte("a"), nil) + require.NoError(t, cBatch.BatchSet([]CacheUpdate{ + {Key: []byte("b"), Value: nil}, + })) + + valA, foundA, err := cSet.Get(readSet, []byte("a"), false) + require.NoError(t, err) + valB, foundB, err := cBatch.Get(readBatch, []byte("b"), false) + require.NoError(t, err) + + require.Equal(t, foundA, foundB, "Set(key, nil) and BatchSet with nil value should agree on found") + require.Equal(t, valA, valB, "Set(key, nil) and BatchSet with nil value should agree on value") + require.False(t, foundA, "nil value should be treated as a deletion") +} + +// --------------------------------------------------------------------------- +// Delete +// --------------------------------------------------------------------------- + +func TestCacheDeleteExistingKey(t *testing.T) { + c, read := newTestCache(t, map[string][]byte{}, 4, 4096) + + c.Set([]byte("k"), []byte("v")) + c.Delete([]byte("k")) + + _, found, err := c.Get(read, []byte("k"), false) + require.NoError(t, err) + require.False(t, found) +} + +func TestCacheDeleteNonexistent(t *testing.T) { + c, read := newTestCache(t, map[string][]byte{}, 4, 4096) + + c.Delete([]byte("ghost")) + + _, found, err := c.Get(read, []byte("ghost"), false) + require.NoError(t, err) + require.False(t, found) +} + +func TestCacheDeleteThenSet(t *testing.T) { + c, read := newTestCache(t, map[string][]byte{}, 4, 4096) + + c.Set([]byte("k"), []byte("v1")) + c.Delete([]byte("k")) + c.Set([]byte("k"), []byte("v2")) + + val, found, err := c.Get(read, []byte("k"), false) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, "v2", string(val)) +} + +// --------------------------------------------------------------------------- +// BatchSet +// --------------------------------------------------------------------------- + +func TestCacheBatchSetMultipleKeys(t *testing.T) { + c, read := newTestCache(t, map[string][]byte{}, 4, 4096) + + err := c.BatchSet([]CacheUpdate{ + {Key: []byte("a"), Value: []byte("1")}, + {Key: []byte("b"), Value: []byte("2")}, + {Key: []byte("c"), Value: []byte("3")}, + }) + require.NoError(t, err) + + for _, tc := range []struct{ key, want string }{{"a", "1"}, {"b", "2"}, {"c", "3"}} { + val, found, err := c.Get(read, []byte(tc.key), false) + require.NoError(t, err, "key=%q", tc.key) + require.True(t, found, "key=%q", tc.key) + require.Equal(t, tc.want, string(val), "key=%q", tc.key) + } +} + +func TestCacheBatchSetMixedSetAndDelete(t *testing.T) { + c, read := newTestCache(t, map[string][]byte{}, 4, 4096) + + c.Set([]byte("keep"), []byte("v")) + c.Set([]byte("remove"), []byte("v")) + + err := c.BatchSet([]CacheUpdate{ + {Key: []byte("keep"), Value: []byte("updated")}, + {Key: []byte("remove"), Value: nil}, + {Key: []byte("new"), Value: []byte("fresh")}, + }) + require.NoError(t, err) + + val, found, _ := c.Get(read, []byte("keep"), false) + require.True(t, found) + require.Equal(t, "updated", string(val)) + + _, found, _ = c.Get(read, []byte("remove"), false) + require.False(t, found) + + val, found, _ = c.Get(read, []byte("new"), false) + require.True(t, found) + require.Equal(t, "fresh", string(val)) +} + +func TestCacheBatchSetEmpty(t *testing.T) { + c, _ := newTestCache(t, map[string][]byte{}, 4, 4096) + + require.NoError(t, c.BatchSet(nil)) + require.NoError(t, c.BatchSet([]CacheUpdate{})) +} + +func TestCacheBatchSetPoolFailure(t *testing.T) { + readPool := threading.NewAdHocPool() + c, _ := NewStandardCache(context.Background(), 1, 4096, readPool, &failPool{}, 0, "", 0) + + err := c.BatchSet([]CacheUpdate{ + {Key: []byte("k"), Value: []byte("v")}, + }) + require.Error(t, err) +} + +// --------------------------------------------------------------------------- +// BatchGet +// --------------------------------------------------------------------------- + +func TestCacheBatchGetAllCached(t *testing.T) { + c, read := newTestCache(t, map[string][]byte{}, 4, 4096) + + c.Set([]byte("a"), []byte("1")) + c.Set([]byte("b"), []byte("2")) + + keys := map[string]types.BatchGetResult{"a": {}, "b": {}} + require.NoError(t, c.BatchGet(read, keys)) + + require.True(t, keys["a"].IsFound()) + require.Equal(t, "1", string(keys["a"].Value)) + require.True(t, keys["b"].IsFound()) + require.Equal(t, "2", string(keys["b"].Value)) +} + +func TestCacheBatchGetAllFromDB(t *testing.T) { + store := map[string][]byte{"x": []byte("10"), "y": []byte("20")} + c, read := newTestCache(t, store, 4, 4096) + + keys := map[string]types.BatchGetResult{"x": {}, "y": {}} + require.NoError(t, c.BatchGet(read, keys)) + + require.True(t, keys["x"].IsFound()) + require.Equal(t, "10", string(keys["x"].Value)) + require.True(t, keys["y"].IsFound()) + require.Equal(t, "20", string(keys["y"].Value)) +} + +func TestCacheBatchGetMixedCachedAndDB(t *testing.T) { + store := map[string][]byte{"db-key": []byte("from-db")} + c, read := newTestCache(t, store, 4, 4096) + + c.Set([]byte("cached"), []byte("from-cache")) + + keys := map[string]types.BatchGetResult{"cached": {}, "db-key": {}} + require.NoError(t, c.BatchGet(read, keys)) + + require.True(t, keys["cached"].IsFound()) + require.Equal(t, "from-cache", string(keys["cached"].Value)) + require.True(t, keys["db-key"].IsFound()) + require.Equal(t, "from-db", string(keys["db-key"].Value)) +} + +func TestCacheBatchGetNotFoundKeys(t *testing.T) { + c, read := newTestCache(t, map[string][]byte{}, 4, 4096) + + keys := map[string]types.BatchGetResult{"nope": {}} + require.NoError(t, c.BatchGet(read, keys)) + require.False(t, keys["nope"].IsFound()) +} + +func TestCacheBatchGetDeletedKey(t *testing.T) { + c, read := newTestCache(t, map[string][]byte{}, 4, 4096) + + c.Set([]byte("k"), []byte("v")) + c.Delete([]byte("k")) + + keys := map[string]types.BatchGetResult{"k": {}} + require.NoError(t, c.BatchGet(read, keys)) + require.False(t, keys["k"].IsFound()) +} + +func TestCacheBatchGetDBError(t *testing.T) { + dbErr := errors.New("broken") + readFunc := func(key []byte) ([]byte, bool, error) { return nil, false, dbErr } + pool := threading.NewAdHocPool() + c, _ := NewStandardCache(context.Background(), 1, 4096, pool, pool, 0, "", 0) + + keys := map[string]types.BatchGetResult{"fail": {}} + require.NoError(t, c.BatchGet(readFunc, keys), "BatchGet itself should not fail") + require.Error(t, keys["fail"].Error) +} + +func TestCacheBatchGetEmpty(t *testing.T) { + c, read := newTestCache(t, map[string][]byte{}, 4, 4096) + keys := map[string]types.BatchGetResult{} + require.NoError(t, c.BatchGet(read, keys)) +} + +func TestCacheBatchGetPoolFailure(t *testing.T) { + readPool := threading.NewAdHocPool() + c, _ := NewStandardCache(context.Background(), 1, 4096, readPool, &failPool{}, 0, "", 0) + + keys := map[string]types.BatchGetResult{"k": {}} + err := c.BatchGet(noopRead, keys) + require.Error(t, err) +} + +func TestCacheBatchGetShardReadPoolFailure(t *testing.T) { + miscPool := threading.NewAdHocPool() + c, _ := NewStandardCache(context.Background(), 1, 4096, &failPool{}, miscPool, 0, "", 0) + + keys := map[string]types.BatchGetResult{"a": {}, "b": {}} + require.NoError(t, c.BatchGet(noopRead, keys)) + + for k, r := range keys { + require.Error(t, r.Error, "key=%q should have per-key error", k) + } +} + +// --------------------------------------------------------------------------- +// Cross-shard distribution +// --------------------------------------------------------------------------- + +func TestCacheDistributesAcrossShards(t *testing.T) { + c, _ := newTestCache(t, map[string][]byte{}, 4, 4096) + impl := c.(*cache) + + for i := 0; i < 100; i++ { + c.Set([]byte(fmt.Sprintf("key-%d", i)), []byte("v")) + } + + nonEmpty := 0 + for _, s := range impl.shards { + _, entries := s.getSizeInfo() + if entries > 0 { + nonEmpty++ + } + } + require.GreaterOrEqual(t, nonEmpty, 2, "keys should distribute across multiple shards") +} + +func TestCacheGetRoutesToSameShard(t *testing.T) { + c, _ := newTestCache(t, map[string][]byte{}, 4, 4096) + impl := c.(*cache) + + c.Set([]byte("key"), []byte("val")) + + idx := impl.shardManager.Shard([]byte("key")) + _, entries := impl.shards[idx].getSizeInfo() + require.Equal(t, uint64(1), entries, "key should be in the shard determined by shardManager") +} + +// --------------------------------------------------------------------------- +// getCacheSizeInfo +// --------------------------------------------------------------------------- + +func TestCacheGetCacheSizeInfoEmpty(t *testing.T) { + c, _ := newTestCache(t, map[string][]byte{}, 4, 4096) + impl := c.(*cache) + + bytes, entries := impl.getCacheSizeInfo() + require.Equal(t, uint64(0), bytes) + require.Equal(t, uint64(0), entries) +} + +func TestCacheGetCacheSizeInfoAggregatesShards(t *testing.T) { + c, _ := newTestCache(t, map[string][]byte{}, 4, 4096) + impl := c.(*cache) + + for i := 0; i < 20; i++ { + c.Set([]byte(fmt.Sprintf("k%d", i)), []byte(fmt.Sprintf("v%d", i))) + } + + bytes, entries := impl.getCacheSizeInfo() + require.Equal(t, uint64(20), entries) + require.Greater(t, bytes, uint64(0)) +} + +// --------------------------------------------------------------------------- +// estimatedOverheadPerEntry +// --------------------------------------------------------------------------- + +func TestCacheSizeInfoIncludesOverhead(t *testing.T) { + const overhead = 200 + pool := threading.NewAdHocPool() + c, err := NewStandardCache(context.Background(), 1, 100_000, pool, pool, overhead, "", 0) + require.NoError(t, err) + impl := c.(*cache) + + c.Set([]byte("ab"), []byte("cd")) + c.Set([]byte("efg"), []byte("hi")) + + bytes, entries := impl.getCacheSizeInfo() + require.Equal(t, uint64(2), entries) + // (2+2+200) + (3+2+200) = 409 + require.Equal(t, uint64(409), bytes) +} + +func TestCacheOverheadCausesEarlierEviction(t *testing.T) { + const overhead = 200 + pool := threading.NewAdHocPool() + // Single shard, maxSize=500. Each 10-byte value entry costs 1+10+200=211 bytes. + // Two entries = 422 < 500. Three entries = 633 > 500, so one must be evicted. + c, err := NewStandardCache(context.Background(), 1, 500, pool, pool, overhead, "", 0) + require.NoError(t, err) + impl := c.(*cache) + + c.Set([]byte("a"), []byte("0123456789")) + c.Set([]byte("b"), []byte("0123456789")) + + _, entries := impl.getCacheSizeInfo() + require.Equal(t, uint64(2), entries, "two entries should fit") + + c.Set([]byte("c"), []byte("0123456789")) + + bytes, entries := impl.getCacheSizeInfo() + require.Equal(t, uint64(2), entries, "third entry should trigger eviction") + require.LessOrEqual(t, bytes, uint64(500)) +} + +// --------------------------------------------------------------------------- +// Many keys — BatchGet/BatchSet spanning all shards +// --------------------------------------------------------------------------- + +func TestCacheBatchSetThenBatchGetManyKeys(t *testing.T) { + c, read := newTestCache(t, map[string][]byte{}, 4, 100_000) + + updates := make([]CacheUpdate, 200) + for i := range updates { + updates[i] = CacheUpdate{ + Key: []byte(fmt.Sprintf("key-%03d", i)), + Value: []byte(fmt.Sprintf("val-%03d", i)), + } + } + require.NoError(t, c.BatchSet(updates)) + + keys := make(map[string]types.BatchGetResult, 200) + for i := 0; i < 200; i++ { + keys[fmt.Sprintf("key-%03d", i)] = types.BatchGetResult{} + } + require.NoError(t, c.BatchGet(read, keys)) + + for i := 0; i < 200; i++ { + k := fmt.Sprintf("key-%03d", i) + want := fmt.Sprintf("val-%03d", i) + require.True(t, keys[k].IsFound(), "key=%q", k) + require.Equal(t, want, string(keys[k].Value), "key=%q", k) + require.NoError(t, keys[k].Error, "key=%q", k) + } +} + +// --------------------------------------------------------------------------- +// Concurrency +// --------------------------------------------------------------------------- + +func TestCacheConcurrentGetSet(t *testing.T) { + store := map[string][]byte{} + for i := 0; i < 50; i++ { + store[fmt.Sprintf("db-%d", i)] = []byte(fmt.Sprintf("v-%d", i)) + } + c, read := newTestCache(t, store, 4, 100_000) + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(2) + key := []byte(fmt.Sprintf("key-%d", i)) + val := []byte(fmt.Sprintf("val-%d", i)) + + go func() { + defer wg.Done() + c.Set(key, val) + }() + go func() { + defer wg.Done() + c.Get(read, key, true) + }() + } + wg.Wait() +} + +func TestCacheConcurrentBatchSetAndBatchGet(t *testing.T) { + store := map[string][]byte{} + for i := 0; i < 50; i++ { + store[fmt.Sprintf("db-%d", i)] = []byte(fmt.Sprintf("v-%d", i)) + } + c, read := newTestCache(t, store, 4, 100_000) + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + updates := make([]CacheUpdate, 50) + for i := range updates { + updates[i] = CacheUpdate{ + Key: []byte(fmt.Sprintf("set-%d", i)), + Value: []byte(fmt.Sprintf("sv-%d", i)), + } + } + c.BatchSet(updates) + }() + + wg.Add(1) + go func() { + defer wg.Done() + keys := make(map[string]types.BatchGetResult) + for i := 0; i < 50; i++ { + keys[fmt.Sprintf("db-%d", i)] = types.BatchGetResult{} + } + c.BatchGet(read, keys) + }() + + wg.Wait() +} + +func TestCacheConcurrentDeleteAndGet(t *testing.T) { + c, read := newTestCache(t, map[string][]byte{}, 4, 100_000) + + for i := 0; i < 100; i++ { + c.Set([]byte(fmt.Sprintf("k-%d", i)), []byte("v")) + } + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(2) + key := []byte(fmt.Sprintf("k-%d", i)) + go func() { + defer wg.Done() + c.Delete(key) + }() + go func() { + defer wg.Done() + c.Get(read, key, true) + }() + } + wg.Wait() +} + +// --------------------------------------------------------------------------- +// Eviction through the cache layer +// --------------------------------------------------------------------------- + +func TestCacheEvictsPerShard(t *testing.T) { + c, _ := newTestCache(t, map[string][]byte{}, 1, 20) + impl := c.(*cache) + + c.Set([]byte("a"), []byte("11111111")) + c.Set([]byte("b"), []byte("22222222")) + + c.Set([]byte("c"), []byte("33333333")) + + bytes, _ := impl.shards[0].getSizeInfo() + require.LessOrEqual(t, bytes, uint64(20)) +} + +// --------------------------------------------------------------------------- +// Edge: BatchSet with keys all routed to the same shard +// --------------------------------------------------------------------------- + +func TestCacheBatchSetSameShard(t *testing.T) { + c, read := newTestCache(t, map[string][]byte{}, 1, 4096) + + err := c.BatchSet([]CacheUpdate{ + {Key: []byte("x"), Value: []byte("1")}, + {Key: []byte("y"), Value: []byte("2")}, + {Key: []byte("z"), Value: []byte("3")}, + }) + require.NoError(t, err) + + for _, tc := range []struct{ key, want string }{{"x", "1"}, {"y", "2"}, {"z", "3"}} { + val, found, err := c.Get(read, []byte(tc.key), false) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, tc.want, string(val)) + } +} + +// --------------------------------------------------------------------------- +// Edge: BatchGet after BatchSet with deletes +// --------------------------------------------------------------------------- + +func TestCacheBatchGetAfterBatchSetWithDeletes(t *testing.T) { + c, read := newTestCache(t, map[string][]byte{}, 4, 4096) + + c.Set([]byte("a"), []byte("1")) + c.Set([]byte("b"), []byte("2")) + c.Set([]byte("c"), []byte("3")) + + err := c.BatchSet([]CacheUpdate{ + {Key: []byte("a"), Value: []byte("updated")}, + {Key: []byte("b"), Value: nil}, + }) + require.NoError(t, err) + + keys := map[string]types.BatchGetResult{"a": {}, "b": {}, "c": {}} + require.NoError(t, c.BatchGet(read, keys)) + + require.True(t, keys["a"].IsFound()) + require.Equal(t, "updated", string(keys["a"].Value)) + require.False(t, keys["b"].IsFound()) + require.True(t, keys["c"].IsFound()) + require.Equal(t, "3", string(keys["c"].Value)) +} + +// --------------------------------------------------------------------------- +// Power-of-two shard counts +// --------------------------------------------------------------------------- + +func TestNewStandardCachePowerOfTwoShardCounts(t *testing.T) { + pool := threading.NewAdHocPool() + for _, n := range []uint64{1, 2, 4, 8, 16, 32, 64} { + c, err := NewStandardCache(context.Background(), n, n*100, pool, pool, 0, "", 0) + require.NoError(t, err, "shardCount=%d", n) + require.NotNil(t, c, "shardCount=%d", n) + } +} diff --git a/sei-db/db_engine/dbcache/cache_metrics.go b/sei-db/db_engine/dbcache/cache_metrics.go new file mode 100644 index 0000000000..a6344bf08f --- /dev/null +++ b/sei-db/db_engine/dbcache/cache_metrics.go @@ -0,0 +1,136 @@ +package dbcache + +import ( + "context" + "time" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + + "github.com/sei-protocol/sei-chain/sei-db/common/metrics" +) + +const cacheMeterName = "seidb_pebblecache" + +// CacheMetrics records OTel metrics for a pebblecache instance. +// All report methods are nil-safe: if the receiver is nil, they are no-ops, +// allowing the cache to call them unconditionally regardless of whether metrics +// are enabled. +// +// The cacheName is used as the "cache" attribute on all recorded metrics, +// enabling multiple cache instances to be distinguished in dashboards. +type CacheMetrics struct { + // Pre-computed attribute option reused on every recording to avoid + // per-call allocations on the hot path. + attrs metric.MeasurementOption + + sizeBytes metric.Int64Gauge + sizeEntries metric.Int64Gauge + hits metric.Int64Counter + misses metric.Int64Counter + missLatency metric.Float64Histogram +} + +// newCacheMetrics creates a CacheMetrics that records cache statistics via OTel. +// A background goroutine scrapes cache size every scrapeInterval until ctx is +// cancelled. The cacheName is attached as the "cache" attribute to all recorded +// metrics, enabling multiple cache instances to be distinguished in dashboards. +// +// Multiple instances are safe: OTel instrument registration is idempotent, so each +// call receives references to the same underlying instruments. The "cache" attribute +// distinguishes series (e.g. pebblecache_hits{cache="state"}). +func newCacheMetrics( + ctx context.Context, + cacheName string, + scrapeInterval time.Duration, + getSize func() (bytes uint64, entries uint64), +) *CacheMetrics { + meter := otel.Meter(cacheMeterName) + + sizeBytes, _ := meter.Int64Gauge( + "pebblecache_size_bytes", + metric.WithDescription("Current cache size in bytes"), + metric.WithUnit("By"), + ) + sizeEntries, _ := meter.Int64Gauge( + "pebblecache_size_entries", + metric.WithDescription("Current number of entries in the cache"), + metric.WithUnit("{count}"), + ) + hits, _ := meter.Int64Counter( + "pebblecache_hits", + metric.WithDescription("Total number of cache hits"), + metric.WithUnit("{count}"), + ) + misses, _ := meter.Int64Counter( + "pebblecache_misses", + metric.WithDescription("Total number of cache misses"), + metric.WithUnit("{count}"), + ) + missLatency, _ := meter.Float64Histogram( + "pebblecache_miss_latency", + metric.WithDescription("Time taken to resolve a cache miss from the backing store"), + metric.WithUnit("s"), + metric.WithExplicitBucketBoundaries(metrics.LatencyBuckets...), + ) + + cm := &CacheMetrics{ + attrs: metric.WithAttributes(attribute.String("cache", cacheName)), + sizeBytes: sizeBytes, + sizeEntries: sizeEntries, + hits: hits, + misses: misses, + missLatency: missLatency, + } + + go cm.collectLoop(ctx, scrapeInterval, getSize) + + return cm +} + +func (cm *CacheMetrics) reportCacheHits(count int64) { + if cm == nil { + return + } + cm.hits.Add(context.Background(), count, cm.attrs) +} + +func (cm *CacheMetrics) reportCacheMisses(count int64) { + if cm == nil { + return + } + cm.misses.Add(context.Background(), count, cm.attrs) +} + +func (cm *CacheMetrics) reportCacheMissLatency(latency time.Duration) { + if cm == nil { + return + } + cm.missLatency.Record(context.Background(), latency.Seconds(), cm.attrs) +} + +// collectLoop periodically scrapes cache size from the provided function +// and records it as gauge values. It exits when ctx is cancelled. +func (cm *CacheMetrics) collectLoop( + ctx context.Context, + interval time.Duration, + getSize func() (bytes uint64, entries uint64), +) { + + if cm == nil { + return + } + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + bytes, entries := getSize() + cm.sizeBytes.Record(ctx, int64(bytes), cm.attrs) //nolint:gosec // G115: safe, cache size fits int64 + cm.sizeEntries.Record(ctx, int64(entries), cm.attrs) //nolint:gosec // G115: safe, entry count fits int64 + } + } +} diff --git a/sei-db/db_engine/dbcache/shard.go b/sei-db/db_engine/dbcache/shard.go new file mode 100644 index 0000000000..6a71105add --- /dev/null +++ b/sei-db/db_engine/dbcache/shard.go @@ -0,0 +1,437 @@ +package dbcache + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/sei-protocol/sei-chain/sei-db/common/threading" + "github.com/sei-protocol/sei-chain/sei-db/db_engine/types" +) + +// A single shard of a Cache. +type shard struct { + ctx context.Context + + // A lock to protect the shard's data. + lock sync.Mutex + + // The data in the shard. + data map[string]*shardEntry + + // Organizes data for garbage collection. + gcQueue *lruQueue + + // A pool for asynchronous reads. + readPool threading.Pool + + // The maximum size of this cache, in bytes. + maxSize uint64 + + // The estimated overhead per entry, in bytes. This is used to calculate the maximum size of the cache. + // This value should be derived experimentally, and may differ between different builds and architectures. + estimatedOverheadPerEntry uint64 + + // Cache-level metrics. Nil-safe; if nil, no metrics are recorded. + metrics *CacheMetrics +} + +// The result of a read from the underlying database. +type readResult struct { + value []byte + err error +} + +// The status of a value in the cache. +type valueStatus int + +const ( + // The value is not known and we are not currently attempting to find it. + statusUnknown valueStatus = iota + // We've scheduled a read of the value but haven't yet finished the read. + statusScheduled + // The data is available. + statusAvailable + // We are aware that the value is deleted (special case of data being available). + statusDeleted +) + +// A single shardEntry in a shard. Records data for a single key. +type shardEntry struct { + // The parent shard that contains this entry. + shard *shard + + // The current status of this entry. + status valueStatus + + // The value, if known. + value []byte + + // If the value is not available when we request it, + // it will be written to this channel when it is available. + valueChan chan readResult +} + +/* +This implementation currently uses a single exlusive lock, as opposed to a RW lock. This is a lot simpler than +using a RW lock, but it comes at higher risk of contention under certain workloads. If this contention ever +becomes a problem, we might consider switching to a RW lock. Below is a potential implementation strategy +for converting to a RW lock: + +- Create a background goroutine that is responsible for garbage collection and updating the LRU. +- The GC goroutine should periodically wake up, grab the lock, and do garbage collection. +- When Get() is called, the calling goroutine should grab a read lock and attempt to read the value. + - If the value is present, send a message to the GC goroutine over a channel (so it can update the LRU) + and return the value. In this way, many readers can read from this shard concurrently. + - If the value is missing, drop the read lock and acquire a write lock. Then, handle the read + like we currently handle in the current implementation. +*/ + +// Creates a new Shard. +func NewShard( + ctx context.Context, + // A work pool for asynchronous reads. + readPool threading.Pool, + // The maximum size of this shard, in bytes. + maxSize uint64, + // The estimated overhead per entry, in bytes. This is used to calculate the maximum size of the cache. + // This value should be derived experimentally, and may differ between different builds and architectures. + estimatedOverheadPerEntry uint64, +) (*shard, error) { + + if maxSize == 0 { + return nil, fmt.Errorf("maxSize must be greater than 0") + } + + return &shard{ + ctx: ctx, + readPool: readPool, + lock: sync.Mutex{}, + data: make(map[string]*shardEntry), + gcQueue: newLRUQueue(), + estimatedOverheadPerEntry: estimatedOverheadPerEntry, + maxSize: maxSize, + }, nil +} + +// Get returns the value for the given key, or (nil, false, nil) if not found. +func (s *shard) Get(read Reader, key []byte, updateLru bool) ([]byte, bool, error) { + s.lock.Lock() + + entry := s.getEntry(key, true) + + switch entry.status { + case statusAvailable: + return s.getAvailable(entry, key, updateLru) + case statusDeleted: + return s.getDeleted(key, updateLru) + case statusScheduled: + return s.getScheduled(entry) + case statusUnknown: + return s.getUnknown(read, entry, key) + default: + s.lock.Unlock() + panic(fmt.Sprintf("unexpected status: %#v", entry.status)) + } +} + +// Handles Get for a key whose value is already cached. Lock must be held; releases it. +func (s *shard) getAvailable(entry *shardEntry, key []byte, updateLru bool) ([]byte, bool, error) { + value := entry.value + if updateLru { + s.gcQueue.Touch(key) + } + s.lock.Unlock() + s.metrics.reportCacheHits(1) + return value, true, nil +} + +// Handles Get for a key known to be deleted. Lock must be held; releases it. +func (s *shard) getDeleted(key []byte, updateLru bool) ([]byte, bool, error) { + if updateLru { + s.gcQueue.Touch(key) + } + s.lock.Unlock() + s.metrics.reportCacheHits(1) + return nil, false, nil +} + +// Handles Get for a key with an in-flight read from another goroutine. Lock must be held; releases it. +func (s *shard) getScheduled(entry *shardEntry) ([]byte, bool, error) { + valueChan := entry.valueChan + s.lock.Unlock() + s.metrics.reportCacheMisses(1) + startTime := time.Now() + result, err := threading.InterruptiblePull(s.ctx, valueChan) + s.metrics.reportCacheMissLatency(time.Since(startTime)) + if err != nil { + return nil, false, fmt.Errorf("failed to pull value from channel: %w", err) + } + valueChan <- result // reload the channel in case there are other listeners + if result.err != nil { + return nil, false, fmt.Errorf("failed to read value from database: %w", result.err) + } + return result.value, result.value != nil, nil +} + +// Handles Get for a key not yet read. Schedules the read and waits. Lock must be held; releases it. +func (s *shard) getUnknown(read Reader, entry *shardEntry, key []byte) ([]byte, bool, error) { + entry.status = statusScheduled + valueChan := make(chan readResult, 1) + entry.valueChan = valueChan + s.lock.Unlock() + s.metrics.reportCacheMisses(1) + startTime := time.Now() + err := s.readPool.Submit(s.ctx, func() { + value, _, readErr := read(key) + entry.injectValue(key, readResult{value: value, err: readErr}) + }) + if err != nil { + return nil, false, fmt.Errorf("failed to schedule read: %w", err) + } + result, err := threading.InterruptiblePull(s.ctx, valueChan) + s.metrics.reportCacheMissLatency(time.Since(startTime)) + if err != nil { + return nil, false, fmt.Errorf("failed to pull value from channel: %w", err) + } + valueChan <- result // reload the channel in case there are other listeners + if result.err != nil { + return nil, false, result.err + } + return result.value, result.value != nil, nil +} + +// This method is called by the read scheduler when a value becomes available. +func (se *shardEntry) injectValue(key []byte, result readResult) { + se.shard.lock.Lock() + + if se.status == statusScheduled { + if result.err != nil { + // Don't cache errors — reset so the next caller retries. + delete(se.shard.data, string(key)) + } else if result.value == nil { + se.status = statusDeleted + se.value = nil + size := uint64(len(key)) + se.shard.estimatedOverheadPerEntry + se.shard.gcQueue.Push(key, size) + se.shard.evictUnlocked() + } else { + se.status = statusAvailable + se.value = result.value + size := uint64(len(key)) + uint64(len(result.value)) + se.shard.estimatedOverheadPerEntry + se.shard.gcQueue.Push(key, size) + se.shard.evictUnlocked() + } + } + + se.shard.lock.Unlock() + + se.valueChan <- result +} + +// Get a shard entry for a given key. Caller is responsible for holding the shard's lock +// when this method is called. +func (s *shard) getEntry(key []byte, createIfMissing bool) *shardEntry { + if entry, ok := s.data[string(key)]; ok { + return entry + } + if !createIfMissing { + return nil + } + entry := &shardEntry{ + shard: s, + status: statusUnknown, + } + keyStr := string(key) + s.data[keyStr] = entry + return entry +} + +// Tracks a key whose value is not yet available and must be waited on. +type pendingRead struct { + key string + entry *shardEntry + valueChan chan readResult + needsSchedule bool + // Populated after the read completes, used by bulkInjectValues. + result readResult +} + +// BatchGet reads a batch of keys from the shard. Results are written into the provided map. +func (s *shard) BatchGet(read Reader, keys map[string]types.BatchGetResult) error { + pending := make([]pendingRead, 0, len(keys)) + var hits int64 + + s.lock.Lock() + for key := range keys { + entry := s.getEntry([]byte(key), true) + + switch entry.status { + case statusAvailable, statusDeleted: + keys[key] = types.BatchGetResult{Value: entry.value} + hits++ + case statusScheduled: + pending = append(pending, pendingRead{ + key: key, + entry: entry, + valueChan: entry.valueChan, + }) + case statusUnknown: + entry.status = statusScheduled + valueChan := make(chan readResult, 1) + entry.valueChan = valueChan + pending = append(pending, pendingRead{ + key: key, + entry: entry, + valueChan: valueChan, + needsSchedule: true, + }) + default: + s.lock.Unlock() + panic(fmt.Sprintf("unexpected status: %#v", entry.status)) + } + } + s.lock.Unlock() + + if hits > 0 { + s.metrics.reportCacheHits(hits) + } + if len(pending) == 0 { + return nil + } + + s.metrics.reportCacheMisses(int64(len(pending))) + startTime := time.Now() + + for i := range pending { + if pending[i].needsSchedule { + p := &pending[i] + err := s.readPool.Submit(s.ctx, func() { + value, _, readErr := read([]byte(p.key)) + p.entry.valueChan <- readResult{value: value, err: readErr} + }) + if err != nil { + return fmt.Errorf("failed to schedule read: %w", err) + } + } + } + + for i := range pending { + result, err := threading.InterruptiblePull(s.ctx, pending[i].valueChan) + if err != nil { + return fmt.Errorf("failed to pull value from channel: %w", err) + } + pending[i].valueChan <- result + pending[i].result = result + + if result.err != nil { + keys[pending[i].key] = types.BatchGetResult{Error: result.err} + } else { + keys[pending[i].key] = types.BatchGetResult{Value: result.value} + } + } + + s.metrics.reportCacheMissLatency(time.Since(startTime)) + go s.bulkInjectValues(pending) + + return nil +} + +// Applies deferred cache updates for a batch of reads under a single lock acquisition. +func (s *shard) bulkInjectValues(reads []pendingRead) { + s.lock.Lock() + for i := range reads { + entry := reads[i].entry + if entry.status != statusScheduled { + continue + } + result := reads[i].result + if result.err != nil { + // Don't cache errors — reset so the next caller retries. + delete(s.data, reads[i].key) + } else if result.value == nil { + entry.status = statusDeleted + entry.value = nil + size := uint64(len(reads[i].key)) + s.estimatedOverheadPerEntry + s.gcQueue.Push([]byte(reads[i].key), size) + } else { + entry.status = statusAvailable + entry.value = result.value + size := uint64(len(reads[i].key)) + uint64(len(result.value)) + s.estimatedOverheadPerEntry + s.gcQueue.Push([]byte(reads[i].key), size) + } + } + s.evictUnlocked() + s.lock.Unlock() +} + +// Evicts least recently used entries until the cache is within its size budget. +// Caller is required to hold the lock. +func (s *shard) evictUnlocked() { + for s.gcQueue.GetTotalSize() > s.maxSize { + next := s.gcQueue.PopLeastRecentlyUsed() + delete(s.data, next) + } +} + +// getSizeInfo returns the current size (bytes) and entry count under the shard lock. +func (s *shard) getSizeInfo() (bytes uint64, entries uint64) { + s.lock.Lock() + defer s.lock.Unlock() + return s.gcQueue.GetTotalSize(), s.gcQueue.GetCount() +} + +// Set sets the value for the given key. +func (s *shard) Set(key []byte, value []byte) { + s.lock.Lock() + s.setUnlocked(key, value) + s.evictUnlocked() + s.lock.Unlock() +} + +// Set a value. Caller is required to hold the lock. +func (s *shard) setUnlocked(key []byte, value []byte) { + entry := s.getEntry(key, true) + entry.status = statusAvailable + entry.value = value + + size := uint64(len(key)) + uint64(len(value)) + s.estimatedOverheadPerEntry + s.gcQueue.Push(key, size) +} + +// BatchSet sets the values for a batch of keys. +func (s *shard) BatchSet(entries []CacheUpdate) { + s.lock.Lock() + for i := range entries { + if entries[i].IsDelete() { + s.deleteUnlocked(entries[i].Key) + } else { + s.setUnlocked(entries[i].Key, entries[i].Value) + } + } + s.evictUnlocked() + s.lock.Unlock() +} + +// Delete deletes the value for the given key. +func (s *shard) Delete(key []byte) { + s.lock.Lock() + s.deleteUnlocked(key) + s.evictUnlocked() + s.lock.Unlock() +} + +// Delete a value. Caller is required to hold the lock. +func (s *shard) deleteUnlocked(key []byte) { + entry := s.getEntry(key, false) + if entry == nil { + // Key is not in the cache, so nothing to do. + return + } + entry.status = statusDeleted + entry.value = nil + + size := uint64(len(key)) + s.estimatedOverheadPerEntry + s.gcQueue.Push(key, size) +} diff --git a/sei-db/db_engine/dbcache/shard_test.go b/sei-db/db_engine/dbcache/shard_test.go new file mode 100644 index 0000000000..5e438dd72b --- /dev/null +++ b/sei-db/db_engine/dbcache/shard_test.go @@ -0,0 +1,930 @@ +package dbcache + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/sei-protocol/sei-chain/sei-db/common/threading" + "github.com/sei-protocol/sei-chain/sei-db/db_engine/types" +) + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +func newTestShard(t *testing.T, maxSize uint64, store map[string][]byte) (*shard, Reader) { + t.Helper() + read := Reader(func(key []byte) ([]byte, bool, error) { + v, ok := store[string(key)] + if !ok { + return nil, false, nil + } + return v, true, nil + }) + s, err := NewShard(context.Background(), threading.NewAdHocPool(), maxSize, 0) + require.NoError(t, err) + return s, read +} + +// --------------------------------------------------------------------------- +// NewShard +// --------------------------------------------------------------------------- + +func TestNewShardValid(t *testing.T) { + s, err := NewShard(context.Background(), threading.NewAdHocPool(), 1024, 0) + require.NoError(t, err) + require.NotNil(t, s) +} + +func TestNewShardZeroMaxSize(t *testing.T) { + _, err := NewShard(context.Background(), threading.NewAdHocPool(), 0, 0) + require.Error(t, err) +} + +// --------------------------------------------------------------------------- +// Get — cache miss flows +// --------------------------------------------------------------------------- + +func TestGetCacheMissFoundInDB(t *testing.T) { + store := map[string][]byte{"hello": []byte("world")} + s, read := newTestShard(t, 4096, store) + + val, found, err := s.Get(read, []byte("hello"), true) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, "world", string(val)) +} + +func TestGetCacheMissNotFoundInDB(t *testing.T) { + s, read := newTestShard(t, 4096, map[string][]byte{}) + + val, found, err := s.Get(read, []byte("missing"), true) + require.NoError(t, err) + require.False(t, found) + require.Nil(t, val) +} + +func TestGetCacheMissDBError(t *testing.T) { + dbErr := errors.New("disk on fire") + readFunc := Reader(func(key []byte) ([]byte, bool, error) { return nil, false, dbErr }) + s, _ := NewShard(context.Background(), threading.NewAdHocPool(), 4096, 0) + + _, _, err := s.Get(readFunc, []byte("boom"), true) + require.Error(t, err) + require.ErrorIs(t, err, dbErr) +} + +func TestGetDBErrorDoesNotCacheResult(t *testing.T) { + var calls atomic.Int64 + readFunc := Reader(func(key []byte) ([]byte, bool, error) { + n := calls.Add(1) + if n == 1 { + return nil, false, errors.New("transient") + } + return []byte("recovered"), true, nil + }) + s, _ := NewShard(context.Background(), threading.NewAdHocPool(), 4096, 0) + + _, _, err := s.Get(readFunc, []byte("key"), true) + require.Error(t, err, "first call should fail") + + val, found, err := s.Get(readFunc, []byte("key"), true) + require.NoError(t, err, "second call should succeed") + require.True(t, found) + require.Equal(t, "recovered", string(val)) + require.Equal(t, int64(2), calls.Load(), "error should not be cached") +} + +// --------------------------------------------------------------------------- +// Get — cache hit flows +// --------------------------------------------------------------------------- + +func TestGetCacheHitAvailable(t *testing.T) { + s, read := newTestShard(t, 4096, map[string][]byte{"k": []byte("v")}) + + s.Get(read, []byte("k"), true) + + val, found, err := s.Get(read, []byte("k"), true) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, "v", string(val)) +} + +func TestGetCacheHitDeleted(t *testing.T) { + s, read := newTestShard(t, 4096, map[string][]byte{}) + + s.Get(read, []byte("gone"), true) + + val, found, err := s.Get(read, []byte("gone"), true) + require.NoError(t, err) + require.False(t, found) + require.Nil(t, val) +} + +func TestGetAfterSet(t *testing.T) { + var readCalls atomic.Int64 + readFunc := Reader(func(key []byte) ([]byte, bool, error) { + readCalls.Add(1) + return nil, false, nil + }) + s, _ := NewShard(context.Background(), threading.NewAdHocPool(), 4096, 0) + + s.Set([]byte("k"), []byte("from-set")) + + val, found, err := s.Get(readFunc, []byte("k"), true) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, "from-set", string(val)) + require.Equal(t, int64(0), readCalls.Load(), "readFunc should not be called for Set-populated entry") +} + +func TestGetAfterDelete(t *testing.T) { + store := map[string][]byte{"k": []byte("v")} + s, read := newTestShard(t, 4096, store) + + // Warm the cache so the key is present before deleting. + _, _, err := s.Get(read, []byte("k"), true) + require.NoError(t, err) + + s.Delete([]byte("k")) + + val, found, err := s.Get(read, []byte("k"), true) + require.NoError(t, err) + require.False(t, found) + require.Nil(t, val) +} + +// --------------------------------------------------------------------------- +// Get — concurrent reads on the same key +// --------------------------------------------------------------------------- + +func TestGetConcurrentSameKey(t *testing.T) { + var readCalls atomic.Int64 + gate := make(chan struct{}) + + readFunc := Reader(func(key []byte) ([]byte, bool, error) { + readCalls.Add(1) + <-gate + return []byte("value"), true, nil + }) + s, _ := NewShard(context.Background(), threading.NewAdHocPool(), 4096, 0) + + const n = 10 + var wg sync.WaitGroup + errs := make([]error, n) + vals := make([]string, n) + founds := make([]bool, n) + + for i := 0; i < n; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + v, f, e := s.Get(readFunc, []byte("shared"), true) + vals[idx] = string(v) + founds[idx] = f + errs[idx] = e + }(i) + } + + time.Sleep(50 * time.Millisecond) + close(gate) + wg.Wait() + + for i := 0; i < n; i++ { + require.NoError(t, errs[i], "goroutine %d", i) + require.True(t, founds[i], "goroutine %d", i) + require.Equal(t, "value", vals[i], "goroutine %d", i) + } + + require.Equal(t, int64(1), readCalls.Load(), "readFunc should be called exactly once") +} + +// --------------------------------------------------------------------------- +// Get — context cancellation +// --------------------------------------------------------------------------- + +func TestGetContextCancelled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + readFunc := Reader(func(key []byte) ([]byte, bool, error) { + time.Sleep(time.Second) + return []byte("late"), true, nil + }) + s, _ := NewShard(ctx, threading.NewAdHocPool(), 4096, 0) + + cancel() + + _, _, err := s.Get(readFunc, []byte("k"), true) + require.Error(t, err) +} + +// --------------------------------------------------------------------------- +// Get — updateLru flag +// --------------------------------------------------------------------------- + +func TestGetUpdateLruTrue(t *testing.T) { + store := map[string][]byte{ + "a": []byte("1"), + "b": []byte("2"), + } + s, read := newTestShard(t, 4096, store) + + s.Get(read, []byte("a"), true) + s.Get(read, []byte("b"), true) + + s.Get(read, []byte("a"), true) + + s.lock.Lock() + lru := s.gcQueue.PopLeastRecentlyUsed() + s.lock.Unlock() + + require.Equal(t, "b", lru) +} + +func TestGetUpdateLruFalse(t *testing.T) { + store := map[string][]byte{ + "a": []byte("1"), + "b": []byte("2"), + } + s, read := newTestShard(t, 4096, store) + + s.Get(read, []byte("a"), true) + s.Get(read, []byte("b"), true) + + s.Get(read, []byte("a"), false) + + s.lock.Lock() + lru := s.gcQueue.PopLeastRecentlyUsed() + s.lock.Unlock() + + require.Equal(t, "a", lru, "updateLru=false should not move entry") +} + +// --------------------------------------------------------------------------- +// Set +// --------------------------------------------------------------------------- + +func TestSetNewKey(t *testing.T) { + s, read := newTestShard(t, 4096, map[string][]byte{}) + + s.Set([]byte("k"), []byte("v")) + + val, found, err := s.Get(read, []byte("k"), false) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, "v", string(val)) +} + +func TestSetOverwritesExistingKey(t *testing.T) { + s, read := newTestShard(t, 4096, map[string][]byte{}) + + s.Set([]byte("k"), []byte("old")) + s.Set([]byte("k"), []byte("new")) + + val, found, err := s.Get(read, []byte("k"), false) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, "new", string(val)) +} + +func TestSetOverwritesDeletedKey(t *testing.T) { + s, read := newTestShard(t, 4096, map[string][]byte{}) + + s.Delete([]byte("k")) + s.Set([]byte("k"), []byte("revived")) + + val, found, err := s.Get(read, []byte("k"), false) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, "revived", string(val)) +} + +func TestSetNilValue(t *testing.T) { + s, read := newTestShard(t, 4096, map[string][]byte{}) + + s.Set([]byte("k"), nil) + + val, found, err := s.Get(read, []byte("k"), false) + require.NoError(t, err) + require.True(t, found) + require.Nil(t, val) +} + +func TestSetEmptyKey(t *testing.T) { + s, read := newTestShard(t, 4096, map[string][]byte{}) + + s.Set([]byte(""), []byte("empty-key-val")) + + val, found, err := s.Get(read, []byte(""), false) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, "empty-key-val", string(val)) +} + +// --------------------------------------------------------------------------- +// Delete +// --------------------------------------------------------------------------- + +func TestDeleteExistingKey(t *testing.T) { + s, read := newTestShard(t, 4096, map[string][]byte{}) + + s.Set([]byte("k"), []byte("v")) + s.Delete([]byte("k")) + + val, found, err := s.Get(read, []byte("k"), false) + require.NoError(t, err) + require.False(t, found) + require.Nil(t, val) +} + +func TestDeleteNonexistentKey(t *testing.T) { + s, read := newTestShard(t, 4096, map[string][]byte{}) + + s.Delete([]byte("ghost")) + + val, found, err := s.Get(read, []byte("ghost"), false) + require.NoError(t, err) + require.False(t, found) + require.Nil(t, val) +} + +func TestDeleteThenSetThenGet(t *testing.T) { + s, read := newTestShard(t, 4096, map[string][]byte{}) + + s.Set([]byte("k"), []byte("v1")) + s.Delete([]byte("k")) + s.Set([]byte("k"), []byte("v2")) + + val, found, err := s.Get(read, []byte("k"), false) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, "v2", string(val)) +} + +// --------------------------------------------------------------------------- +// BatchSet +// --------------------------------------------------------------------------- + +func TestBatchSetSetsMultiple(t *testing.T) { + s, read := newTestShard(t, 4096, map[string][]byte{}) + + s.BatchSet([]CacheUpdate{ + {Key: []byte("a"), Value: []byte("1")}, + {Key: []byte("b"), Value: []byte("2")}, + {Key: []byte("c"), Value: []byte("3")}, + }) + + for _, tc := range []struct { + key, want string + }{{"a", "1"}, {"b", "2"}, {"c", "3"}} { + val, found, err := s.Get(read, []byte(tc.key), false) + require.NoError(t, err, "Get(%q)", tc.key) + require.True(t, found, "Get(%q)", tc.key) + require.Equal(t, tc.want, string(val), "Get(%q)", tc.key) + } +} + +func TestBatchSetMixedSetAndDelete(t *testing.T) { + s, read := newTestShard(t, 4096, map[string][]byte{}) + + s.Set([]byte("keep"), []byte("v")) + s.Set([]byte("remove"), []byte("v")) + + s.BatchSet([]CacheUpdate{ + {Key: []byte("keep"), Value: []byte("updated")}, + {Key: []byte("remove"), Value: nil}, + {Key: []byte("new"), Value: []byte("fresh")}, + }) + + val, found, _ := s.Get(read, []byte("keep"), false) + require.True(t, found) + require.Equal(t, "updated", string(val)) + + _, found, _ = s.Get(read, []byte("remove"), false) + require.False(t, found, "expected remove to be deleted") + + val, found, _ = s.Get(read, []byte("new"), false) + require.True(t, found) + require.Equal(t, "fresh", string(val)) +} + +func TestBatchSetEmpty(t *testing.T) { + s, _ := newTestShard(t, 4096, map[string][]byte{}) + s.BatchSet(nil) + s.BatchSet([]CacheUpdate{}) + + bytes, entries := s.getSizeInfo() + require.Equal(t, uint64(0), bytes) + require.Equal(t, uint64(0), entries) +} + +// --------------------------------------------------------------------------- +// BatchGet +// --------------------------------------------------------------------------- + +func TestBatchGetAllCached(t *testing.T) { + s, read := newTestShard(t, 4096, map[string][]byte{}) + + s.Set([]byte("a"), []byte("1")) + s.Set([]byte("b"), []byte("2")) + + keys := map[string]types.BatchGetResult{ + "a": {}, + "b": {}, + } + require.NoError(t, s.BatchGet(read, keys)) + + for k, want := range map[string]string{"a": "1", "b": "2"} { + r := keys[k] + require.True(t, r.IsFound(), "key=%q", k) + require.Equal(t, want, string(r.Value), "key=%q", k) + } +} + +func TestBatchGetAllFromDB(t *testing.T) { + store := map[string][]byte{"x": []byte("10"), "y": []byte("20")} + s, read := newTestShard(t, 4096, store) + + keys := map[string]types.BatchGetResult{ + "x": {}, + "y": {}, + } + require.NoError(t, s.BatchGet(read, keys)) + + for k, want := range map[string]string{"x": "10", "y": "20"} { + r := keys[k] + require.True(t, r.IsFound(), "key=%q", k) + require.Equal(t, want, string(r.Value), "key=%q", k) + } +} + +func TestBatchGetMixedCachedAndDB(t *testing.T) { + store := map[string][]byte{"db-key": []byte("from-db")} + s, read := newTestShard(t, 4096, store) + + s.Set([]byte("cached"), []byte("from-cache")) + + keys := map[string]types.BatchGetResult{ + "cached": {}, + "db-key": {}, + } + require.NoError(t, s.BatchGet(read, keys)) + + require.True(t, keys["cached"].IsFound()) + require.Equal(t, "from-cache", string(keys["cached"].Value)) + require.True(t, keys["db-key"].IsFound()) + require.Equal(t, "from-db", string(keys["db-key"].Value)) +} + +func TestBatchGetNotFoundKeys(t *testing.T) { + s, read := newTestShard(t, 4096, map[string][]byte{}) + + keys := map[string]types.BatchGetResult{ + "nope": {}, + } + require.NoError(t, s.BatchGet(read, keys)) + require.False(t, keys["nope"].IsFound()) +} + +func TestBatchGetDeletedKeys(t *testing.T) { + s, read := newTestShard(t, 4096, map[string][]byte{}) + + s.Set([]byte("del"), []byte("v")) + s.Delete([]byte("del")) + + keys := map[string]types.BatchGetResult{ + "del": {}, + } + require.NoError(t, s.BatchGet(read, keys)) + require.False(t, keys["del"].IsFound()) +} + +func TestBatchGetDBError(t *testing.T) { + dbErr := errors.New("broken") + readFunc := Reader(func(key []byte) ([]byte, bool, error) { return nil, false, dbErr }) + s, _ := NewShard(context.Background(), threading.NewAdHocPool(), 4096, 0) + + keys := map[string]types.BatchGetResult{ + "fail": {}, + } + require.NoError(t, s.BatchGet(readFunc, keys), "BatchGet itself should not fail") + require.Error(t, keys["fail"].Error, "expected per-key error") +} + +func TestBatchGetEmpty(t *testing.T) { + s, read := newTestShard(t, 4096, map[string][]byte{}) + + keys := map[string]types.BatchGetResult{} + require.NoError(t, s.BatchGet(read, keys)) +} + +func TestBatchGetCachesResults(t *testing.T) { + var readCalls atomic.Int64 + store := map[string][]byte{"k": []byte("v")} + readFunc := Reader(func(key []byte) ([]byte, bool, error) { + readCalls.Add(1) + v, ok := store[string(key)] + return v, ok, nil + }) + s, _ := NewShard(context.Background(), threading.NewAdHocPool(), 4096, 0) + + keys := map[string]types.BatchGetResult{"k": {}} + s.BatchGet(readFunc, keys) + + time.Sleep(50 * time.Millisecond) + + val, found, err := s.Get(readFunc, []byte("k"), false) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, "v", string(val)) + require.Equal(t, int64(1), readCalls.Load(), "result should be cached") +} + +// --------------------------------------------------------------------------- +// Eviction +// --------------------------------------------------------------------------- + +func TestEvictionRespectMaxSize(t *testing.T) { + s, _ := newTestShard(t, 30, map[string][]byte{}) + + s.Set([]byte("a"), []byte("aaaaaaaaaa")) + s.Set([]byte("b"), []byte("bbbbbbbbbb")) + + _, entries := s.getSizeInfo() + require.Equal(t, uint64(2), entries) + + s.Set([]byte("c"), []byte("cccccccccc")) + + bytes, entries := s.getSizeInfo() + require.LessOrEqual(t, bytes, uint64(30), "shard size should not exceed maxSize") + require.Equal(t, uint64(2), entries) +} + +func TestEvictionOrderIsLRU(t *testing.T) { + s, read := newTestShard(t, 15, map[string][]byte{}) + + s.Set([]byte("a"), []byte("1111")) + s.Set([]byte("b"), []byte("2222")) + s.Set([]byte("c"), []byte("3333")) + + s.Get(read, []byte("a"), true) + + s.Set([]byte("d"), []byte("4444")) + + s.lock.Lock() + _, bExists := s.data["b"] + _, aExists := s.data["a"] + s.lock.Unlock() + + require.False(t, bExists, "expected 'b' to be evicted (it was LRU)") + require.True(t, aExists, "expected 'a' to survive (it was recently touched)") +} + +func TestEvictionOnDelete(t *testing.T) { + s, _ := newTestShard(t, 10, map[string][]byte{}) + + s.Set([]byte("a"), []byte("val")) + s.Delete([]byte("longkey1")) + + bytes, _ := s.getSizeInfo() + require.LessOrEqual(t, bytes, uint64(10), "size should not exceed maxSize") +} + +func TestEvictionOnGetFromDB(t *testing.T) { + store := map[string][]byte{ + "x": []byte("12345678901234567890"), + } + s, read := newTestShard(t, 25, store) + + s.Set([]byte("a"), []byte("small")) + + s.Get(read, []byte("x"), true) + + time.Sleep(50 * time.Millisecond) + + bytes, _ := s.getSizeInfo() + require.LessOrEqual(t, bytes, uint64(25), "size should not exceed maxSize after DB read") +} + +// --------------------------------------------------------------------------- +// getSizeInfo +// --------------------------------------------------------------------------- + +func TestGetSizeInfoEmpty(t *testing.T) { + s, _ := newTestShard(t, 4096, map[string][]byte{}) + bytes, entries := s.getSizeInfo() + require.Equal(t, uint64(0), bytes) + require.Equal(t, uint64(0), entries) +} + +func TestGetSizeInfoAfterSets(t *testing.T) { + s, _ := newTestShard(t, 4096, map[string][]byte{}) + + s.Set([]byte("ab"), []byte("cd")) + s.Set([]byte("efg"), []byte("hi")) + + bytes, entries := s.getSizeInfo() + require.Equal(t, uint64(2), entries) + require.Equal(t, uint64(9), bytes) +} + +// --------------------------------------------------------------------------- +// estimatedOverheadPerEntry +// --------------------------------------------------------------------------- + +func TestOverheadIncludedInSizeAfterSet(t *testing.T) { + const overhead = 100 + s, _ := NewShard(context.Background(), threading.NewAdHocPool(), 100_000, overhead) + + s.Set([]byte("ab"), []byte("cd")) + s.Set([]byte("efg"), []byte("hi")) + + bytes, entries := s.getSizeInfo() + require.Equal(t, uint64(2), entries) + // (2+2+100) + (3+2+100) = 209 + require.Equal(t, uint64(209), bytes) +} + +func TestOverheadIncludedInSizeAfterDelete(t *testing.T) { + const overhead = 100 + store := map[string][]byte{"abc": []byte("val")} + read := Reader(func(key []byte) ([]byte, bool, error) { + v, ok := store[string(key)] + return v, ok, nil + }) + s, _ := NewShard(context.Background(), threading.NewAdHocPool(), 100_000, overhead) + + // Warm the cache so the key is present before deleting. + _, _, err := s.Get(read, []byte("abc"), true) + require.NoError(t, err) + + s.Delete([]byte("abc")) + + bytes, entries := s.getSizeInfo() + require.Equal(t, uint64(1), entries) + // 3 + 100 = 103 + require.Equal(t, uint64(103), bytes) +} + +func TestOverheadIncludedInSizeAfterDBRead(t *testing.T) { + const overhead = 100 + store := map[string][]byte{"key": []byte("value")} + read := Reader(func(key []byte) ([]byte, bool, error) { + v, ok := store[string(key)] + return v, ok, nil + }) + s, _ := NewShard(context.Background(), threading.NewAdHocPool(), 100_000, overhead) + + val, found, err := s.Get(read, []byte("key"), true) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, "value", string(val)) + + bytes, entries := s.getSizeInfo() + require.Equal(t, uint64(1), entries) + // 3 + 5 + 100 = 108 + require.Equal(t, uint64(108), bytes) +} + +func TestOverheadIncludedInSizeAfterDBReadNotFound(t *testing.T) { + const overhead = 100 + read := Reader(func(key []byte) ([]byte, bool, error) { return nil, false, nil }) + s, _ := NewShard(context.Background(), threading.NewAdHocPool(), 100_000, overhead) + + _, found, err := s.Get(read, []byte("key"), true) + require.NoError(t, err) + require.False(t, found) + + bytes, entries := s.getSizeInfo() + require.Equal(t, uint64(1), entries) + // 3 + 100 = 103 + require.Equal(t, uint64(103), bytes) +} + +func TestOverheadTriggersEarlierEviction(t *testing.T) { + const overhead = 50 + s, _ := NewShard(context.Background(), threading.NewAdHocPool(), 100, overhead) + + // "a" + "1234" + 50 = 55 bytes + s.Set([]byte("a"), []byte("1234")) + _, entries := s.getSizeInfo() + require.Equal(t, uint64(1), entries) + + // "b" + "5678" + 50 = 55 bytes, total = 110 > 100 → evict "a" + s.Set([]byte("b"), []byte("5678")) + bytes, entries := s.getSizeInfo() + require.Equal(t, uint64(1), entries, "overhead should cause eviction to keep only one entry") + require.LessOrEqual(t, bytes, uint64(100)) +} + +func TestOverheadIncludedInBatchGetFromDB(t *testing.T) { + const overhead = 100 + store := map[string][]byte{"x": []byte("10"), "y": []byte("20")} + read := Reader(func(key []byte) ([]byte, bool, error) { + v, ok := store[string(key)] + return v, ok, nil + }) + s, _ := NewShard(context.Background(), threading.NewAdHocPool(), 100_000, overhead) + + keys := map[string]types.BatchGetResult{"x": {}, "y": {}} + require.NoError(t, s.BatchGet(read, keys)) + + time.Sleep(50 * time.Millisecond) + + bytes, entries := s.getSizeInfo() + require.Equal(t, uint64(2), entries) + // (1+2+100) + (1+2+100) = 206 + require.Equal(t, uint64(206), bytes) +} + +func TestOverheadSizeUpdatedOnOverwrite(t *testing.T) { + const overhead = 100 + s, _ := NewShard(context.Background(), threading.NewAdHocPool(), 100_000, overhead) + + s.Set([]byte("k"), []byte("short")) + b1, _ := s.getSizeInfo() + // 1 + 5 + 100 = 106 + require.Equal(t, uint64(106), b1) + + s.Set([]byte("k"), []byte("a-longer-value")) + b2, entries := s.getSizeInfo() + require.Equal(t, uint64(1), entries) + // 1 + 14 + 100 = 115 + require.Equal(t, uint64(115), b2) +} + +// --------------------------------------------------------------------------- +// injectValue — edge cases +// --------------------------------------------------------------------------- + +func TestInjectValueNotFound(t *testing.T) { + s, read := newTestShard(t, 4096, map[string][]byte{}) + + val, found, err := s.Get(read, []byte("missing"), true) + require.NoError(t, err) + require.False(t, found) + require.Nil(t, val) + + s.lock.Lock() + entry, ok := s.data["missing"] + s.lock.Unlock() + require.True(t, ok, "entry should exist in map") + require.Equal(t, statusDeleted, entry.status) +} + +// --------------------------------------------------------------------------- +// Concurrent Set and Get +// --------------------------------------------------------------------------- + +func TestConcurrentSetAndGet(t *testing.T) { + s, read := newTestShard(t, 4096, map[string][]byte{}) + + const n = 100 + var wg sync.WaitGroup + + for i := 0; i < n; i++ { + wg.Add(2) + key := []byte(fmt.Sprintf("key-%d", i)) + val := []byte(fmt.Sprintf("val-%d", i)) + + go func() { + defer wg.Done() + s.Set(key, val) + }() + go func() { + defer wg.Done() + s.Get(read, key, true) + }() + } + + wg.Wait() +} + +func TestConcurrentBatchSetAndBatchGet(t *testing.T) { + store := map[string][]byte{} + for i := 0; i < 50; i++ { + store[fmt.Sprintf("db-%d", i)] = []byte(fmt.Sprintf("v-%d", i)) + } + s, read := newTestShard(t, 100_000, store) + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + updates := make([]CacheUpdate, 20) + for i := 0; i < 20; i++ { + updates[i] = CacheUpdate{ + Key: []byte(fmt.Sprintf("set-%d", i)), + Value: []byte(fmt.Sprintf("sv-%d", i)), + } + } + s.BatchSet(updates) + }() + + wg.Add(1) + go func() { + defer wg.Done() + keys := make(map[string]types.BatchGetResult) + for i := 0; i < 50; i++ { + keys[fmt.Sprintf("db-%d", i)] = types.BatchGetResult{} + } + s.BatchGet(read, keys) + }() + + wg.Wait() +} + +// --------------------------------------------------------------------------- +// Pool submission failure +// --------------------------------------------------------------------------- + +type failPool struct{} + +func (fp *failPool) Submit(_ context.Context, _ func()) error { + return errors.New("pool exhausted") +} + +func TestGetPoolSubmitFailure(t *testing.T) { + readFunc := Reader(func(key []byte) ([]byte, bool, error) { return []byte("v"), true, nil }) + s, _ := NewShard(context.Background(), &failPool{}, 4096, 0) + + _, _, err := s.Get(readFunc, []byte("k"), true) + require.Error(t, err) +} + +func TestBatchGetPoolSubmitFailure(t *testing.T) { + readFunc := Reader(func(key []byte) ([]byte, bool, error) { return []byte("v"), true, nil }) + s, _ := NewShard(context.Background(), &failPool{}, 4096, 0) + + keys := map[string]types.BatchGetResult{"k": {}} + err := s.BatchGet(readFunc, keys) + require.Error(t, err) +} + +// --------------------------------------------------------------------------- +// Large values +// --------------------------------------------------------------------------- + +func TestSetLargeValueExceedingMaxSizeEvictsOldEntries(t *testing.T) { + s, _ := newTestShard(t, 100, map[string][]byte{}) + + s.Set([]byte("a"), []byte("small")) + + bigVal := make([]byte, 95) + for i := range bigVal { + bigVal[i] = 'X' + } + s.Set([]byte("b"), bigVal) + + bytes, _ := s.getSizeInfo() + require.LessOrEqual(t, bytes, uint64(100), "size should not exceed maxSize after large set") +} + +// --------------------------------------------------------------------------- +// bulkInjectValues — error entries are not cached +// --------------------------------------------------------------------------- + +func TestBatchGetDBErrorNotCached(t *testing.T) { + var calls atomic.Int64 + readFunc := Reader(func(key []byte) ([]byte, bool, error) { + n := calls.Add(1) + if n == 1 { + return nil, false, errors.New("transient db error") + } + return []byte("ok"), true, nil + }) + s, _ := NewShard(context.Background(), threading.NewAdHocPool(), 4096, 0) + + keys := map[string]types.BatchGetResult{"k": {}} + s.BatchGet(readFunc, keys) + + time.Sleep(50 * time.Millisecond) + + val, found, err := s.Get(readFunc, []byte("k"), true) + require.NoError(t, err, "retry should succeed") + require.True(t, found) + require.Equal(t, "ok", string(val)) +} + +// --------------------------------------------------------------------------- +// Edge: Set then Delete then BatchGet +// --------------------------------------------------------------------------- + +func TestSetDeleteThenBatchGet(t *testing.T) { + s, read := newTestShard(t, 4096, map[string][]byte{}) + + s.Set([]byte("k"), []byte("v")) + s.Delete([]byte("k")) + + keys := map[string]types.BatchGetResult{"k": {}} + require.NoError(t, s.BatchGet(read, keys)) + require.False(t, keys["k"].IsFound()) +}