diff --git a/.gitignore b/.gitignore index aca8bb31..55892921 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,6 @@ go.work.sum .DS_Store .idea/ + +# tmp folder +/tmp \ No newline at end of file diff --git a/fuzz_network_test.go b/fuzz_network_test.go new file mode 100644 index 00000000..3e14d1cb --- /dev/null +++ b/fuzz_network_test.go @@ -0,0 +1,20 @@ +package simplex_test + +import ( + "testing" + "time" + + "github.com/ava-labs/simplex/testutil/random_network" +) + +func TestNetworkSimpleFuzz(t *testing.T) { + for i := 0; i < 10; i++ { + t.Run("", func(t *testing.T) { + config := random_network.DefaultFuzzConfig() + config.RandomSeed = time.Now().UnixMilli() + network := random_network.NewNetwork(config, t) + network.Run() + network.PrintStatus() + }) + } +} diff --git a/testutil/logger.go b/testutil/logger.go index 7e945083..30d586bb 100644 --- a/testutil/logger.go +++ b/testutil/logger.go @@ -5,6 +5,7 @@ package testutil import ( "fmt" + "io" "os" "strings" "testing" @@ -14,6 +15,11 @@ import ( "go.uber.org/zap/zapcore" ) +const ( + LOG_LEVEL = "log_level" + INFO_LOG_LEVEL = "info" +) + type TestLogger struct { *zap.Logger t *testing.T @@ -94,13 +100,16 @@ func (tl *TestLogger) Error(msg string, fields ...zap.Field) { } func MakeLogger(t *testing.T, node ...int) *TestLogger { - return MakeLoggerWithFile(t, nil, node...) + // Preserve existing behavior: logs to stdout by default. + return MakeLoggerWithFile(t, nil, true, node...) } -// MakeLoggerWithFile creates a TestLogger that optionally writes to a file in addition to stdout. -// If fileWriter is nil, logs only to stdout (same as MakeLogger). -// If fileWriter is provided, logs to both stdout and the file. -func MakeLoggerWithFile(t *testing.T, fileWriter zapcore.WriteSyncer, node ...int) *TestLogger { +// MakeLoggerWithFile creates a TestLogger that can write to a file and optionally to stdout. +// - If writeStdout is true, logs may be written to stdout. +// - If fileWriter is non-nil, logs may be written to that fileWriter. +// - If both are enabled, logs go to both. +// - If neither is enabled, logs are discarded. +func MakeLoggerWithFile(t *testing.T, fileWriter zapcore.WriteSyncer, writeStdout bool, node ...int) *TestLogger { defaultEncoderConfig := zapcore.EncoderConfig{ TimeKey: "timestamp", LevelKey: "level", @@ -118,28 +127,48 @@ func MakeLoggerWithFile(t *testing.T, fileWriter zapcore.WriteSyncer, node ...in config.EncodeTime = zapcore.TimeEncoderOfLayout("[01-02|15:04:05.000]") config.ConsoleSeparator = " " - // Create stdout encoder - stdoutEncoder := zapcore.NewConsoleEncoder(config) - if strings.ToLower(os.Getenv("LOG_LEVEL")) == "info" { - stdoutEncoder = &DebugSwallowingEncoder{consoleEncoder: stdoutEncoder, ObjectEncoder: stdoutEncoder, pool: buffer.NewPool()} - } - atomicLevel := zap.NewAtomicLevelAt(zapcore.DebugLevel) - // Create stdout core - stdoutCore := zapcore.NewCore(stdoutEncoder, zapcore.AddSync(os.Stdout), atomicLevel) + var cores []zapcore.Core - // If file writer is provided, create a tee core with both stdout and file - var core zapcore.Core + // Stdout core only if explicitly enabled + if writeStdout { + stdoutEncoder := zapcore.NewConsoleEncoder(config) + if strings.ToLower(os.Getenv("LOG_LEVEL")) == "info" { + stdoutEncoder = &DebugSwallowingEncoder{ + consoleEncoder: stdoutEncoder, + ObjectEncoder: stdoutEncoder, + pool: buffer.NewPool(), + } + } + stdoutCore := zapcore.NewCore(stdoutEncoder, zapcore.AddSync(os.Stdout), atomicLevel) + cores = append(cores, stdoutCore) + } + + // File core only if provided if fileWriter != nil { fileEncoder := zapcore.NewConsoleEncoder(config) - if strings.ToLower(os.Getenv("LOG_LEVEL")) == "info" { - fileEncoder = &DebugSwallowingEncoder{consoleEncoder: fileEncoder, ObjectEncoder: fileEncoder, pool: buffer.NewPool()} + if strings.ToLower(os.Getenv(LOG_LEVEL)) == INFO_LOG_LEVEL { + fileEncoder = &DebugSwallowingEncoder{ + consoleEncoder: fileEncoder, + ObjectEncoder: fileEncoder, + pool: buffer.NewPool(), + } } fileCore := zapcore.NewCore(fileEncoder, fileWriter, atomicLevel) - core = zapcore.NewTee(stdoutCore, fileCore) - } else { - core = stdoutCore + cores = append(cores, fileCore) + } + + // If neither stdout nor file enabled, discard logs. + var core zapcore.Core + switch len(cores) { + case 0: + discardEncoder := zapcore.NewConsoleEncoder(config) + core = zapcore.NewCore(discardEncoder, zapcore.AddSync(io.Discard), atomicLevel) + case 1: + core = cores[0] + default: + core = zapcore.NewTee(cores...) } logger := zap.New(core, zap.AddCaller()) @@ -150,16 +179,13 @@ func MakeLoggerWithFile(t *testing.T, fileWriter zapcore.WriteSyncer, node ...in traceVerboseLogger := zap.New(core, zap.AddCaller(), zap.AddCallerSkip(1)) traceVerboseLogger = traceVerboseLogger.With(zap.String("test", t.Name())) - if len(node) > 0 { traceVerboseLogger = traceVerboseLogger.With(zap.Int("myNodeID", node[0])) } - l := &TestLogger{t: t, Logger: logger, traceVerboseLogger: traceVerboseLogger, + return &TestLogger{t: t, Logger: logger, traceVerboseLogger: traceVerboseLogger, atomicLevel: atomicLevel, } - - return l } type DebugSwallowingEncoder struct { diff --git a/testutil/random_network/block.go b/testutil/random_network/block.go new file mode 100644 index 00000000..59e3f3f9 --- /dev/null +++ b/testutil/random_network/block.go @@ -0,0 +1,140 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package random_network + +import ( + "context" + "crypto/sha256" + "encoding/asn1" + "fmt" + + "github.com/ava-labs/simplex" +) + +var _ simplex.Block = (*Block)(nil) + +type Block struct { + blacklist simplex.Blacklist + + // contents + txs []*TX + + // protocol metadata + metadata simplex.ProtocolMetadata + digest simplex.Digest + + // mempool access + mempool *Mempool +} + +func NewBlock(metadata simplex.ProtocolMetadata, blacklist simplex.Blacklist, mempool *Mempool, txs []*TX) *Block { + b := &Block{ + mempool: mempool, + txs: txs, + metadata: metadata, + blacklist: blacklist, + } + + b.ComputeAndSetDigest() + return b +} + +func (b *Block) Verify(ctx context.Context) (simplex.VerifiedBlock, error) { + return b, b.mempool.VerifyBlock(ctx, b) +} + +func (b *Block) Blacklist() simplex.Blacklist { + return b.blacklist +} + +func (b *Block) BlockHeader() simplex.BlockHeader { + return simplex.BlockHeader{ + ProtocolMetadata: b.metadata, + Digest: b.digest, + } +} + +type encodedBlock struct { + ProtocolMetadata []byte + TXs []asn1TX + Blacklist []byte +} + +func (b *Block) Bytes() ([]byte, error) { + mdBytes := b.metadata.Bytes() + + var asn1TXs []asn1TX + for _, tx := range b.txs { + asn1TXs = append(asn1TXs, asn1TX{ID: tx.ID[:], ShouldFailVerification: tx.shouldFailVerification}) + } + + blacklistBytes := b.blacklist.Bytes() + + encodedB := encodedBlock{ + ProtocolMetadata: mdBytes, + TXs: asn1TXs, + Blacklist: blacklistBytes, + } + + return asn1.Marshal(encodedB) +} + +func (b *Block) containsTX(txID txID) bool { + for _, tx := range b.txs { + if tx.ID == txID { + return true + } + } + return false +} + +func (b *Block) ComputeAndSetDigest() { + tbBytes, err := b.Bytes() + if err != nil { + panic(fmt.Sprintf("failed to serialize test block: %v", err)) + } + + b.digest = sha256.Sum256(tbBytes) +} + +type BlockDeserializer struct { + mempool *Mempool +} + +var _ simplex.BlockDeserializer = (*BlockDeserializer)(nil) + +func (bd *BlockDeserializer) DeserializeBlock(ctx context.Context, buff []byte) (simplex.Block, error) { + var encodedBlock encodedBlock + _, err := asn1.Unmarshal(buff, &encodedBlock) + if err != nil { + return nil, err + } + + md, err := simplex.ProtocolMetadataFromBytes(encodedBlock.ProtocolMetadata) + if err != nil { + return nil, err + } + + var blacklist simplex.Blacklist + if err := blacklist.FromBytes(encodedBlock.Blacklist); err != nil { + return nil, err + } + + txs := make([]*TX, len(encodedBlock.TXs)) + for i, asn1Tx := range encodedBlock.TXs { + tx := asn1Tx.toTX() + txs[i] = tx + } + + b := &Block{ + metadata: *md, + txs: txs, + blacklist: blacklist, + mempool: bd.mempool, + } + + b.ComputeAndSetDigest() + + return b, nil +} diff --git a/testutil/random_network/config.go b/testutil/random_network/config.go new file mode 100644 index 00000000..d60512db --- /dev/null +++ b/testutil/random_network/config.go @@ -0,0 +1,59 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package random_network + +import ( + "time" + + "github.com/ava-labs/simplex" +) + +type FuzzConfig struct { + // The minimum and maximum number of nodes in the network. + MinNodes int // Default is 3. + MaxNodes int // Default is 10. + + // The minimum and maximum number of transactions to be issued at a block. Default is between 5 and 20. + MinTxsPerIssue int + MaxTxsPerIssue int + + // Number of transactions per block. Default is 15. + TxsPerBlock int + + // The number of blocks that must be finalized before ending the fuzz test. Default is 100. + NumFinalizedBlocks int + + RandomSeed int64 + + // Probability that a node will be randomly crashed. Default is .1 (10%). + NodeCrashProbability float64 + + // Probability that a crashed node will be restarted. Default is .5 (50%). + NodeRecoverProbability float64 + + // Amount to advance the time by. Default is simplex.DefaultMaxProposalWaitTime / 5. + AdvanceTimeTickAmount time.Duration + + // Creates main.log for network logs and {nodeID-short}.log for each node. + // NodeID is represented as a 16-character hex string (first 8 bytes). + // Default directory is "tmp". + // If empty, logging to files is disabled and logs will only be printed to console. + LogDirectory string +} + +func DefaultFuzzConfig() *FuzzConfig { + return &FuzzConfig{ + MinNodes: 3, + MaxNodes: 10, + MinTxsPerIssue: 5, + MaxTxsPerIssue: 20, + TxsPerBlock: 15, + NumFinalizedBlocks: 100, + RandomSeed: time.Now().UnixMilli(), + NodeCrashProbability: 0.1, + NodeRecoverProbability: 0.5, + AdvanceTimeTickAmount: simplex.DefaultMaxProposalWaitTime / 5, + LogDirectory: "tmp", + } +} diff --git a/testutil/random_network/logging.go b/testutil/random_network/logging.go new file mode 100644 index 00000000..92e27592 --- /dev/null +++ b/testutil/random_network/logging.go @@ -0,0 +1,83 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package random_network + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/ava-labs/simplex" + "github.com/ava-labs/simplex/testutil" + "go.uber.org/zap/zapcore" +) + +// CreateNetworkLogger creates a logger for the network that writes to both console and main.log +func CreateNetworkLogger(t *testing.T, config *FuzzConfig) *testutil.TestLogger { + if config.LogDirectory == "" { + return testutil.MakeLogger(t, 0) + } + + // Clear the log directory before creating new logs + if err := clearLogDirectory(config.LogDirectory); err != nil { + t.Fatalf("Failed to clear log directory: %v", err) + } + + // Create file writer for main.log + fileWriter, err := setupFileOutput(t, config.LogDirectory, "main.log") + if err != nil { + t.Fatalf("Failed to setup file output for main.log: %v", err) + } + + return testutil.MakeLoggerWithFile(t, fileWriter, true) +} + +// CreateNodeLogger creates a logger for a node that writes to both console and {nodeID}.log +func CreateNodeLogger(t *testing.T, config *FuzzConfig, nodeID simplex.NodeID) *testutil.TestLogger { + if config.LogDirectory == "" { + return testutil.MakeLogger(t, int(nodeID[0])) + } + + filename := fmt.Sprintf("%s.log", nodeID.String()) + + // Create file writer for node-specific log + fileWriter, err := setupFileOutput(t, config.LogDirectory, filename) + if err != nil { + t.Fatalf("Failed to setup file output for %s: %v", filename, err) + } + + return testutil.MakeLoggerWithFile(t, fileWriter, false, int(nodeID[0])) +} + +// setupFileOutput creates a file for logging and returns a WriteSyncer. +// The file is closed automatically when the test ends. +func setupFileOutput(t *testing.T, logDir, filename string) (zapcore.WriteSyncer, error) { + // Create log directory if it doesn't exist + if err := os.MkdirAll(logDir, 0755); err != nil { + return nil, fmt.Errorf("failed to create log directory %s: %w", logDir, err) + } + + // Create full path + logPath := filepath.Join(logDir, filename) + + // Open file for appending (create if doesn't exist) + file, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return nil, fmt.Errorf("failed to open log file %s: %w", logPath, err) + } + + t.Cleanup(func() { file.Close() }) + + // Wrap with AddSync to make it safe for concurrent writes + return zapcore.AddSync(file), nil +} + +// clearLogDirectory removes the contents of the log directory +func clearLogDirectory(logDir string) error { + if err := os.RemoveAll(logDir); err != nil { + return fmt.Errorf("failed to remove log directory %s: %w", logDir, err) + } + return nil +} diff --git a/testutil/random_network/logging_test.go b/testutil/random_network/logging_test.go new file mode 100644 index 00000000..dc8e7f5e --- /dev/null +++ b/testutil/random_network/logging_test.go @@ -0,0 +1,102 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package random_network + +import ( + "os" + "path/filepath" + "testing" + + "github.com/ava-labs/simplex/testutil" + "github.com/stretchr/testify/require" +) + +func TestFileLogging(t *testing.T) { + // Create temp directory for test logs + tempDir := t.TempDir() + + // Create config with file logging enabled + config := DefaultFuzzConfig() + config.LogDirectory = tempDir + + // Create network logger + networkLogger := CreateNetworkLogger(t, config) + networkLogger.Info("Test network log message") + + // Create node logger + nodeID := testutil.GenerateNodeID(t) + nodeLogger := CreateNodeLogger(t, config, nodeID) + nodeLogger.Info("Test node log message") + + // Verify main.log was created and contains the message + mainLogPath := filepath.Join(tempDir, "main.log") + require.FileExists(t, mainLogPath) + + mainLogContent, err := os.ReadFile(mainLogPath) + require.NoError(t, err) + require.Contains(t, string(mainLogContent), "Test network log message") + + // Verify node log was created with hex filename + nodeLogPattern := filepath.Join(tempDir, "*.log") + matches, err := filepath.Glob(nodeLogPattern) + require.NoError(t, err) + require.GreaterOrEqual(t, len(matches), 2) // Should have main.log and at least one node log + + // Find the node log file (not main.log) + var nodeLogPath string + for _, match := range matches { + if filepath.Base(match) != "main.log" { + nodeLogPath = match + break + } + } + require.NotEmpty(t, nodeLogPath) + + nodeLogContent, err := os.ReadFile(nodeLogPath) + require.NoError(t, err) + require.Contains(t, string(nodeLogContent), "Test node log message") +} + +func TestLogDirectoryClearing(t *testing.T) { + // Create temp directory for test logs + tempDir := t.TempDir() + + // Create an old log file + oldLogPath := filepath.Join(tempDir, "old.log") + err := os.WriteFile(oldLogPath, []byte("old log content"), 0644) + require.NoError(t, err) + require.FileExists(t, oldLogPath) + + // Create config with file logging enabled + config := DefaultFuzzConfig() + config.LogDirectory = tempDir + + // Create network logger - this should clear the directory + networkLogger := CreateNetworkLogger(t, config) + networkLogger.Info("New log message") + + // Verify old log was removed + require.NoFileExists(t, oldLogPath) + + // Verify new main.log was created + mainLogPath := filepath.Join(tempDir, "main.log") + require.FileExists(t, mainLogPath) +} + +func TestConsoleOnlyLogging(t *testing.T) { + // Create config with file logging disabled (empty directory) + config := DefaultFuzzConfig() + config.LogDirectory = "" + + // Create network logger - should not panic + networkLogger := CreateNetworkLogger(t, config) + networkLogger.Info("Console only network log") + + // Create node logger - should not panic + nodeID := testutil.GenerateNodeID(t) + nodeLogger := CreateNodeLogger(t, config, nodeID) + nodeLogger.Info("Console only node log") + + // No files should be created (test passes if no panic) +} diff --git a/testutil/random_network/mempool.go b/testutil/random_network/mempool.go new file mode 100644 index 00000000..47087efa --- /dev/null +++ b/testutil/random_network/mempool.go @@ -0,0 +1,332 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package random_network + +import ( + "bytes" + "context" + "errors" + "fmt" + "slices" + "sync" + + "github.com/ava-labs/simplex" + "go.uber.org/zap" +) + +var emptyDigest = simplex.Digest{} + +var ( + errAlreadyAccepted = errors.New("tx already accepted") + errAlreadyInChain = errors.New("tx already in chain") + errDuplicateTxInBlock = errors.New("duplicate tx in block") + errDoubleBlockVerification = errors.New("block has already been verified") + errParentNotFound = errors.New("parent block not accepted or verified") +) + +type Mempool struct { + lock *sync.Mutex + config *FuzzConfig + txsReady chan struct{} + logger simplex.Logger + + // txID -> TX + unacceptedTxs map[txID]*TX + + // blocks that have been verified but not accepted + verifiedButNotAcceptedBlocks map[simplex.Digest]*Block + + // all the blocks that have been accepted + acceptedBlocks map[simplex.Digest]*Block + + // fast lookup of accepted txs + acceptedTXs map[txID]struct{} +} + +func NewMempool(l simplex.Logger, config *FuzzConfig) *Mempool { + return &Mempool{ + unacceptedTxs: make(map[txID]*TX), + verifiedButNotAcceptedBlocks: make(map[simplex.Digest]*Block), + acceptedTXs: make(map[txID]struct{}), + acceptedBlocks: make(map[simplex.Digest]*Block), + lock: &sync.Mutex{}, + txsReady: make(chan struct{}, 1), + logger: l, + config: config, + } +} + +func (m *Mempool) AddPendingTXs(txs ...*TX) { + m.lock.Lock() + defer m.lock.Unlock() + + for _, tx := range txs { + m.unacceptedTxs[tx.ID] = tx + } +} + +// NotifyTxsReady signals that there are pending transactions in the mempool. +func (m *Mempool) NotifyTxsReady() { + select { + case m.txsReady <- struct{}{}: + default: + } +} + +// waitForPendingTxs waits until there are pending transactions in the mempool or the context is canceled +func (m *Mempool) waitForPendingTxs(ctx context.Context) { + for { + // Check if txs are available + m.lock.Lock() + m.logger.Debug("Checking for pending txs in mempool", zap.Int("unacceptedTxs", len(m.unacceptedTxs))) + hasTxs := len(m.unacceptedTxs) > 0 + m.lock.Unlock() + + if hasTxs { + return + } + + m.logger.Debug("No pending txs in mempool, waiting for txs to be added") + + // No txs available, wait for notification or cancellation + select { + case <-m.txsReady: + m.logger.Debug("Received notification of pending txs in mempool") + case <-ctx.Done(): + return + } + } +} + +// VerifyBlock verifies the block and its transactions. Errors if any tx is invalid or if there are duplicate txs in the block. +func (m *Mempool) VerifyBlock(ctx context.Context, b *Block) error { + m.lock.Lock() + defer m.lock.Unlock() + + // Ensure the block has not already been verified or accepted + if _, exists := m.verifiedButNotAcceptedBlocks[b.digest]; exists { + m.logger.Error("Block has already been verified", zap.Error(errDoubleBlockVerification), zap.Stringer("Digest", b.digest)) + return fmt.Errorf("%w: %s", errDoubleBlockVerification, b.digest) + } + + if _, exists := m.acceptedBlocks[b.digest]; exists { + m.logger.Error("Block has already been accepted", zap.Error(errDoubleBlockVerification), zap.Stringer("Digest", b.digest)) + return fmt.Errorf("%w: %s", errDoubleBlockVerification, b.digest) + } + + // Ensure the parent block is accepted or verified + if parentInChain := m.isParentAcceptedOrVerified(b); !parentInChain { + m.logger.Error("Parent has not been accepted or verified", zap.Error(errParentNotFound), zap.Stringer("Digest", b.digest)) + return fmt.Errorf("%w: parent digest %s, block digest %s", errParentNotFound, b.metadata.Prev, b.digest) + } + + // Assert there are no duplicate txs in the block + txIDSet := make(map[txID]struct{}) + for _, tx := range b.txs { + if _, exists := txIDSet[tx.ID]; exists { + return errDuplicateTxInBlock + } + txIDSet[tx.ID] = struct{}{} + } + + // Verify each transaction + for _, tx := range b.txs { + if err := m.verifyTx(ctx, tx, b.metadata.Prev); err != nil { + return err + } + } + + // Update state - don't delete from unverifiedTXs yet, as multiple nodes may build blocks with the same txs + // txs will be deleted when the block is accepted + m.verifiedButNotAcceptedBlocks[b.digest] = b + + return nil +} + +func (m *Mempool) isParentAcceptedOrVerified(block *Block) bool { + // Genesis block case + if block.metadata.Prev == emptyDigest { + return true + } + + _, exists := m.acceptedBlocks[block.metadata.Prev] + if exists { + return true + } + + _, exists = m.verifiedButNotAcceptedBlocks[block.metadata.Prev] + if exists { + return true + } + + return false +} + +// verifyTx verifies a single transaction against the mempool state and the block's chain. +func (m *Mempool) verifyTx(ctx context.Context, tx *TX, blockParent simplex.Digest) error { + if _, exists := m.acceptedTXs[tx.ID]; exists { + return fmt.Errorf("%w: %s", errAlreadyAccepted, tx.ID) + } + + if m.isTxInChain(tx.ID, blockParent) { + return errAlreadyInChain + } + + if err := tx.Verify(ctx); err != nil { + return err + } + return nil +} + +// recursively check if the tx has already been included in any ancestor block to prevent double spends +func (m *Mempool) isTxInChain(txID txID, parentDigest simplex.Digest) bool { + block, exists := m.verifiedButNotAcceptedBlocks[parentDigest] + if !exists { + return false + } + + if block.containsTX(txID) { + return true + } + + return m.isTxInChain(txID, block.metadata.Prev) +} + +// AcceptBlock accepts the block and updates the mempool +// state to clean up transactions, remove sibling/uncle blocks, +// and move any non-conflicting transactions from purged sibling/uncle blocks back to unaccepted +func (m *Mempool) AcceptBlock(b *Block) { + m.lock.Lock() + defer m.lock.Unlock() + + m.acceptedBlocks[b.digest] = b + + for _, tx := range b.txs { + m.acceptedTXs[tx.ID] = struct{}{} + delete(m.unacceptedTxs, tx.ID) + } + + // delete any verified but not accepted blocks that are siblings or uncles and move not conflicting txs back to unaccepted + delete(m.verifiedButNotAcceptedBlocks, b.digest) + + for _, verifiedBlock := range m.verifiedButNotAcceptedBlocks { + // any block that shares a parent(excluding our block) should be purged + if verifiedBlock.metadata.Prev == b.metadata.Prev { + delete(m.verifiedButNotAcceptedBlocks, verifiedBlock.digest) + m.purgeBlockAndChildren(verifiedBlock) + } + } + + if len(m.unacceptedTxs) > 0 { + m.logger.Debug("After accepting block, moved txs back to unaccepted due to sibling/uncle blocks being purged", zap.Int("num unaccepted txs", len(m.unacceptedTxs))) + m.NotifyTxsReady() + } +} + +// purgeBlockAndChildren goes through any blocks that build off of this one and move their txs +// back to unaccepted. It also moves this blocks transactions to unaccepted. +func (m *Mempool) purgeBlockAndChildren(block *Block) { + m.moveTxsToUnaccepted(block) + + for digest, verifiedBlock := range m.verifiedButNotAcceptedBlocks { + if verifiedBlock.metadata.Prev == block.digest { + delete(m.verifiedButNotAcceptedBlocks, digest) + m.purgeBlockAndChildren(verifiedBlock) + } + } +} + +func (m *Mempool) moveTxsToUnaccepted(block *Block) { + for _, tx := range block.txs { + if _, exists := m.acceptedTXs[tx.ID]; !exists { + m.unacceptedTxs[tx.ID] = tx + } + } +} + +func (m *Mempool) BuildBlock(ctx context.Context, md simplex.ProtocolMetadata, bl simplex.Blacklist) (simplex.VerifiedBlock, bool) { + m.waitForPendingTxs(ctx) + + // Pack the block once we have pending txs + txs := m.packBlock(ctx, m.config.TxsPerBlock, md.Prev) + if ctx.Err() != nil { + return nil, false + } + + // sort transactions + slices.SortFunc(txs, func(a *TX, b *TX) int { + return bytes.Compare(a.ID[:], b.ID[:]) + }) + + block := NewBlock(md, bl, m, txs) + m.logger.Debug("Built block with txs", zap.String("block digest", block.digest.String()), zap.Int("num txs", len(block.txs)), zap.Uint64("round", md.Round), zap.Uint64("seq", md.Seq)) + // in the future we can create a malicious block but we need to ensure the number of crashed nodes in under the threshold f(since we cant tolerate more than f malicious nodes) + err := m.VerifyBlock(ctx, block) + if err != nil { + m.logger.Error("Failed to verify built block", zap.String("block digest", block.digest.String()), zap.Error(err)) + return nil, false + } + + return block, true +} + +func (m *Mempool) packBlock(ctx context.Context, maxTxs int, parentDigest simplex.Digest) []*TX { + m.lock.Lock() + defer m.lock.Unlock() + + txs := make([]*TX, 0, maxTxs) + for _, tx := range m.unacceptedTxs { + if err := m.verifyTx(ctx, tx, parentDigest); err != nil { + m.logger.Debug("Skipping tx during block packing due to failed verification", zap.Stringer("txID", tx), zap.Error(err)) + continue + } + txs = append(txs, tx) + if len(txs) >= maxTxs { + break + } + } + + return txs +} + +func (m *Mempool) WaitForPendingBlock(ctx context.Context) { + m.waitForPendingTxs(ctx) +} + +// IsTxAccepted returns true if the transaction has been accepted +func (m *Mempool) IsTxAccepted(txID txID) bool { + m.lock.Lock() + defer m.lock.Unlock() + + _, accepted := m.acceptedTXs[txID] + return accepted +} + +// IsTxPending returns true if the transaction is still pending (unaccepted) +func (m *Mempool) IsTxPending(txID txID) bool { + m.lock.Lock() + defer m.lock.Unlock() + + _, pending := m.unacceptedTxs[txID] + return pending +} + +// Clear resets the mempool state to simulate a node restart. +// We do not remove accepted & unaccepted transactions/blocks from the mempool(since we don't have tx gossip) +// but we do clear verified blocks since we are expected to re-verify after a restart. +func (m *Mempool) Clear() { + m.lock.Lock() + defer m.lock.Unlock() + + // move all the transactions from verified to unaccepted, since we are clearing the mempool but the transactions are still valid and can be re-included in future blocks + for _, block := range m.verifiedButNotAcceptedBlocks { + for _, tx := range block.txs { + if _, accepted := m.acceptedTXs[tx.ID]; !accepted { + m.unacceptedTxs[tx.ID] = tx + } + } + } + + m.verifiedButNotAcceptedBlocks = make(map[simplex.Digest]*Block) +} diff --git a/testutil/random_network/mempool_test.go b/testutil/random_network/mempool_test.go new file mode 100644 index 00000000..0119a434 --- /dev/null +++ b/testutil/random_network/mempool_test.go @@ -0,0 +1,181 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package random_network + +import ( + "context" + "testing" + + "github.com/ava-labs/simplex" + "github.com/ava-labs/simplex/testutil" + "github.com/stretchr/testify/require" +) + +var emptyBlacklist = simplex.Blacklist{ + NodeCount: 4, + SuspectedNodes: simplex.SuspectedNodes{}, + Updates: []simplex.BlacklistUpdate{}, +} + +func TestMempoolVerifiesTx(t *testing.T) { + logger := testutil.MakeLogger(t, 1) + logger.Silence() // so we dont log errors/warns + + ctx := context.Background() + require := require.New(t) + round0MD := NewProtocolMetadata(0, 0, simplex.Digest{}) + config := DefaultFuzzConfig() + + tests := []struct { + name string + expectErr error + setup func() (*Mempool, *Block, error) + }{ + { + name: "ValidTx", + expectErr: nil, + setup: func() (*Mempool, *Block, error) { + mempool := NewMempool(logger, config) + tx := CreateNewTX() + mempool.AddPendingTXs(tx) + + block := NewBlock(round0MD, emptyBlacklist, mempool, []*TX{tx}) + + return mempool, block, nil + }, + }, + { + name: "Duplicate Tx In Block", + expectErr: errDuplicateTxInBlock, + setup: func() (*Mempool, *Block, error) { + mempool := NewMempool(logger, config) + tx := CreateNewTX() + mempool.AddPendingTXs(tx) + + block := NewBlock(round0MD, emptyBlacklist, mempool, []*TX{tx, tx}) + return mempool, block, nil + }, + }, + { + name: "Already Accepted", + expectErr: errDoubleBlockVerification, + setup: func() (*Mempool, *Block, error) { + mempool := NewMempool(logger, config) + tx := CreateNewTX() + mempool.AddPendingTXs(tx) + + block := NewBlock(round0MD, emptyBlacklist, mempool, []*TX{tx}) + mempool.AcceptBlock(block) + + mempool.AddPendingTXs(tx) + return mempool, block, nil + }, + }, + { + name: "Already In Chain", + expectErr: errAlreadyInChain, + setup: func() (*Mempool, *Block, error) { + mempool := NewMempool(logger, config) + tx := CreateNewTX() + mempool.AddPendingTXs(tx) + + parentBlock := NewBlock(round0MD, emptyBlacklist, mempool, []*TX{tx}) + if err := mempool.VerifyBlock(ctx, parentBlock); err != nil { + return nil, nil, err + } + + mempool.AddPendingTXs(tx) + md := NewProtocolMetadata(1, 1, parentBlock.digest) + block := NewBlock(md, emptyBlacklist, mempool, []*TX{tx}) + return mempool, block, nil + }, + }, + { + name: "Tx Verification Fails", + expectErr: errTxVerification, + setup: func() (*Mempool, *Block, error) { + mempool := NewMempool(logger, config) + tx := CreateNewTX() + tx.SetShouldFailVerification() + mempool.AddPendingTXs(tx) + + block := NewBlock(round0MD, emptyBlacklist, mempool, []*TX{tx}) + + return mempool, block, nil + }, + }, + { + name: "Correctly verifies transaction not in chain", + expectErr: nil, + setup: func() (*Mempool, *Block, error) { + mempool := NewMempool(logger, config) + tx1 := CreateNewTX() + mempool.AddPendingTXs(tx1) + + blockWithSameTxButNotParent := NewBlock(round0MD, emptyBlacklist, mempool, []*TX{tx1}) + err := mempool.VerifyBlock(ctx, blockWithSameTxButNotParent) + + mempool.AddPendingTXs(tx1) + block := NewBlock(NewProtocolMetadata(1, 1, simplex.Digest{}), emptyBlacklist, mempool, []*TX{tx1}) + return mempool, block, err + }, + }, + { + name: "Double Block Verification", + expectErr: errDoubleBlockVerification, + setup: func() (*Mempool, *Block, error) { + mempool := NewMempool(logger, config) + tx := CreateNewTX() + mempool.AddPendingTXs(tx) + + block := NewBlock(round0MD, emptyBlacklist, mempool, []*TX{tx}) + err := mempool.VerifyBlock(ctx, block) + if err != nil { + return nil, nil, err + } + + mempool.AddPendingTXs(tx) + return mempool, block, nil + }, + }, + { + name: "Parent Previously Verified But Was Pruned", + expectErr: errParentNotFound, + setup: func() (*Mempool, *Block, error) { + mempool := NewMempool(logger, config) + tx1 := CreateNewTX() + tx2 := CreateNewTX() + childTx := CreateNewTX() + mempool.AddPendingTXs(tx1) + mempool.AddPendingTXs(tx2) + + // create & verify two siblings + brother := NewBlock(round0MD, emptyBlacklist, mempool, []*TX{tx1}) + if err := mempool.VerifyBlock(ctx, brother); err != nil { + return nil, nil, err + } + + sister := NewBlock(round0MD, emptyBlacklist, mempool, []*TX{tx2}) + if err := mempool.VerifyBlock(ctx, sister); err != nil { + return nil, nil, err + } + + // accept the sister, so the brother should be pruned + mempool.AcceptBlock(sister) + + childBlock := NewBlock(NewProtocolMetadata(1, 1, brother.digest), emptyBlacklist, mempool, []*TX{childTx}) + return mempool, childBlock, nil + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mempool, block, err := tt.setup() + require.NoError(err) + err = mempool.VerifyBlock(ctx, block) + require.ErrorIs(err, tt.expectErr) + }) + } +} diff --git a/testutil/random_network/network.go b/testutil/random_network/network.go new file mode 100644 index 00000000..a2f6d247 --- /dev/null +++ b/testutil/random_network/network.go @@ -0,0 +1,327 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package random_network + +import ( + "math/rand" + "sync" + "testing" + "time" + + "github.com/ava-labs/simplex" + "github.com/ava-labs/simplex/testutil" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type Network struct { + *testutil.BasicInMemoryNetwork + logger *testutil.TestLogger + t *testing.T + + lock sync.Mutex + nodes []*Node + numNodes uint64 + randomness *rand.Rand + config *FuzzConfig + + // tx stats + allIssuedTxs int +} + +func NewNetwork(config *FuzzConfig, t *testing.T) *Network { + logger := CreateNetworkLogger(t, config) + logger.Info("Creating new network with random seed", zap.Int64("seed", config.RandomSeed)) + r := rand.New(rand.NewSource(config.RandomSeed)) + + numNodes := r.Intn(config.MaxNodes-config.MinNodes+1) + config.MinNodes + nodeIds := make([]simplex.NodeID, numNodes) + for i := range numNodes { + nodeIds[i] = []byte{byte(i)} + } + + nodes := make([]*Node, numNodes) + + logger.Info("Initiating network with nodes", zap.Int("num nodes", numNodes)) + basicNetwork := testutil.NewBasicInMemoryNetwork(t, nodeIds) + + for i := range numNodes { + node := NewNode(t, nodeIds[i], basicNetwork, config, randomNodeConfig{}) + basicNetwork.AddNode(node.BasicNode) + nodes[i] = node + } + + return &Network{ + BasicInMemoryNetwork: basicNetwork, + nodes: nodes, + randomness: r, + logger: logger, + t: t, + config: config, + numNodes: uint64(numNodes), + } +} + +func (n *Network) StartInstances() { + panic("Call Run() Instead") +} + +func (n *Network) Run() { + n.BasicInMemoryNetwork.StartInstances() + defer n.BasicInMemoryNetwork.StopInstances() + + targetHeight := uint64(n.config.NumFinalizedBlocks) + + for { + n.crashAndRecoverNodes() + txs := n.issueTxs() + + maxHeight := n.getMaxHeight() + minHeight := n.getMinHeight() + n.logger.Info("Issued Transactions", zap.Int("count", len(txs)), zap.Uint64("min height", minHeight), zap.Uint64("max height", maxHeight)) + n.waitForTxAcceptance(txs) + + maxHeight = n.getMaxHeight() + minHeight = n.getMinHeight() + n.logger.Info("All issued transactions accepted", zap.Int("count", len(txs)), zap.Uint64("min height", minHeight), zap.Uint64("max height", maxHeight)) + // get the max height and ensure all nodes recover to that height + n.recoverToHeight(n.getMaxHeight()) + + if minHeight >= targetHeight { + n.logger.Info("Reached target height", zap.Uint64("targetHeight", targetHeight), zap.Uint64("minHeight", minHeight)) + break + } + } + + // if we have gotten this far, the test has succeeded so we can clear the log directory + clearLogDirectory(n.config.LogDirectory) +} + +func (n *Network) getMinHeightNodeID() simplex.NodeID { + minHeight := n.nodes[0].storage.NumBlocks() + minHeightNodeID := n.nodes[0].E.ID + + for _, node := range n.nodes[1:] { + height := node.storage.NumBlocks() + + if height < minHeight { + minHeight = height + minHeightNodeID = node.E.ID + } + } + + return minHeightNodeID +} + +func (n *Network) recoverToHeight(height uint64) { + for n.getMinHeight() < height { + n.logger.Debug("Advancing network time", zap.Uint64("num crashed nodes", n.numCrashedNodes()), + zap.Uint64("min height", n.getMinHeight()), + zap.Uint64("max height", n.getMaxHeight()), + zap.Stringer("Smallest node ID", n.getMinHeightNodeID()), + zap.Uint64("target height", height), + ) + for i, node := range n.nodes { + isCrashed := node.isCrashed.Load() + if isCrashed { + // randomly decide to recover based on NodeRecoverPercentage + if n.randomness.Float64() < n.config.NodeRecoverProbability { + n.logger.Debug("Recovering node", zap.Stringer("nodeID", node.E.ID)) + n.startNode(i) + } + } + } + + n.lock.Lock() + n.BasicInMemoryNetwork.AdvanceTime(n.config.AdvanceTimeTickAmount) + n.lock.Unlock() + } + +} + +func (n *Network) issueTxs() []*TX { + n.lock.Lock() + defer n.lock.Unlock() + + numTxs := n.randomness.Intn(n.config.MaxTxsPerIssue-n.config.MinTxsPerIssue+1) + n.config.MinTxsPerIssue // randomize between min and max inclusive + txs := make([]*TX, 0, numTxs) + + for range numTxs { + tx := CreateNewTX() + txs = append(txs, tx) + n.allIssuedTxs++ + } + + // first add to all mempools + for _, node := range n.nodes { + node.mempool.AddPendingTXs(txs...) + } + + // then notify all mempools that new txs are ready + for _, node := range n.nodes { + node.mempool.NotifyTxsReady() + } + + return txs +} + +func (n *Network) waitForTxAcceptance(txs []*TX) { + for { + allAccepted := true + for _, node := range n.nodes { + if node.isCrashed.Load() { + continue + } + if accepted := node.areTxsAccepted(txs); !accepted { + node.mempool.lock.Lock() + n.logger.Debug("Not all txs accepted yet by node", zap.Stringer("nodeID", node.E.ID), zap.Int("unaccepted txs in mempool", len(node.mempool.unacceptedTxs))) + node.mempool.lock.Unlock() + allAccepted = false + } + } + + if allAccepted { + return + } + + n.lock.Lock() + n.logger.Debug("Advancing network time to wait for tx acceptance", zap.Uint64("num crashed nodes", n.numCrashedNodes())) + n.BasicInMemoryNetwork.AdvanceTime(n.config.AdvanceTimeTickAmount) + n.lock.Unlock() + + time.Sleep(20 * time.Millisecond) + } +} + +func (n *Network) numCrashedNodes() uint64 { + numCrashed := 0 + for _, node := range n.nodes { + if node.isCrashed.Load() { + numCrashed++ + } + } + return uint64(numCrashed) +} + +func (n *Network) crashAndRecoverNodes() { + if n.config.NodeCrashProbability == 0 { + return + } + + f := (int(n.numNodes) - 1) / 3 + + if f == 0 { + n.logger.Info("Not enough nodes for crash testing", zap.Uint64("numNodes", n.numNodes)) + return + } + + crashedNodes := []simplex.NodeID{} + recoveredNodes := []simplex.NodeID{} + maxLeftToCrash := f - int(n.numCrashedNodes()) + // go through each node, randomly decide to crash based on NodeCrashPercentage + for i, node := range n.nodes { + isCrashed := node.isCrashed.Load() + if isCrashed { + // randomly decide to recover based on NodeRecoverPercentage + if n.randomness.Float64() < n.config.NodeRecoverProbability { + n.startNode(i) + recoveredNodes = append(recoveredNodes, node.E.ID) + maxLeftToCrash++ + } + continue + } + + // check if we can still crash more nodes + if maxLeftToCrash <= 0 { + continue + } + + // randomly decide to crash the node + if n.randomness.Float64() < n.config.NodeCrashProbability { + maxLeftToCrash-- + n.crashNode(i) + crashedNodes = append(crashedNodes, node.E.ID) + } + } + + if len(recoveredNodes)+len(crashedNodes) > 0 { + n.logger.Info("Recovered and crashed nodes", zap.Stringers("crashed", crashedNodes), zap.Stringers("recovered", recoveredNodes), zap.Uint64("num crashed", n.numCrashedNodes())) + } +} + +func (n *Network) getMinHeight() uint64 { + minHeight := n.nodes[0].storage.NumBlocks() + for _, node := range n.nodes[1:] { + height := node.storage.NumBlocks() + + if height < minHeight { + minHeight = height + } + } + return minHeight +} + +func (n *Network) getMaxHeight() uint64 { + maxHeight := n.nodes[0].storage.NumBlocks() + for _, node := range n.nodes[1:] { + height := node.storage.NumBlocks() + + if height > maxHeight { + maxHeight = height + } + } + return maxHeight +} + +func (n *Network) SetInfoLog() { + n.lock.Lock() + defer n.lock.Unlock() + + for _, node := range n.nodes { + node.logger.SetLevel(zapcore.InfoLevel) + } +} + +func (n *Network) PrintStatus() { + n.lock.Lock() + defer n.lock.Unlock() + + // prints the number of nodes + n.logger.Info("Network Status", zap.Int("num nodes", len(n.nodes)), zap.Int64("Seed", n.config.RandomSeed)) + + // prints the number of txs in each node's mempool + for _, node := range n.nodes { + n.logger.Info("Node Status", zap.Stringer("nodeID", node.E.ID), zap.Int("Short", int(node.E.ID[0])), zap.Uint64("Round", node.E.Metadata().Round), zap.Uint64("Height", node.storage.NumBlocks())) + node.PrintMessageTypesSent() + } + + // prints total issued txs and failed txs + n.logger.Info("Transaction Stats", zap.Int("total issued txs", n.allIssuedTxs)) +} + +func (n *Network) crashNode(idx int) { + n.logger.Debug("Crashing node", zap.Stringer("nodeID", n.nodes[idx].E.ID)) + n.nodes[idx].isCrashed.Store(true) + n.nodes[idx].Stop() +} + +func (n *Network) startNode(idx int) { + instance := n.nodes[idx] + nodeID := instance.E.ID + mempool := instance.mempool + clonedWal := instance.wal.Clone() + clonedStorage := instance.storage.Clone() + mempool.Clear() + + newNode := NewNode(n.t, nodeID, n.BasicInMemoryNetwork, n.config, randomNodeConfig{ + mempool: mempool, + wal: clonedWal, + storage: clonedStorage, + logger: instance.logger, + }) + n.nodes[idx] = newNode + n.BasicInMemoryNetwork.ReplaceNode(newNode.BasicNode) + + newNode.Start() +} diff --git a/testutil/random_network/node.go b/testutil/random_network/node.go new file mode 100644 index 00000000..44620749 --- /dev/null +++ b/testutil/random_network/node.go @@ -0,0 +1,161 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package random_network + +import ( + "sync/atomic" + "testing" + + "github.com/ava-labs/simplex" + "github.com/ava-labs/simplex/testutil" + "github.com/stretchr/testify/require" +) + +type randomNodeConfig struct { + mempool *Mempool + storage *Storage + wal *testutil.TestWAL + logger *testutil.TestLogger +} + +type Node struct { + *testutil.BasicNode + + storage *Storage + wal *testutil.TestWAL + mempool *Mempool + + logger *testutil.TestLogger + isCrashed atomic.Bool +} + +func NewNode(t *testing.T, nodeID simplex.NodeID, net *testutil.BasicInMemoryNetwork, config *FuzzConfig, nodeConfig randomNodeConfig) *Node { + var l *testutil.TestLogger + if nodeConfig.logger != nil { + l = nodeConfig.logger + } else { + l = CreateNodeLogger(t, config, nodeID) + } + + var mempool *Mempool + + if nodeConfig.mempool != nil { + mempool = nodeConfig.mempool + } else { + mempool = NewMempool(l, config) + } + + comm := testutil.NewTestComm(nodeID, net, testutil.AllowAllMessages) + epochConfig, wal, _ := testutil.DefaultTestNodeEpochConfig(t, nodeID, comm, mempool) + epochConfig.Logger = l + epochConfig.MaxRoundWindow = 100 + epochConfig.ReplicationEnabled = true + + // storage + var storage *Storage + if nodeConfig.storage != nil { + storage = nodeConfig.storage + } else { + storage = NewStorage(mempool) + } + epochConfig.Storage = storage + + // wal + if nodeConfig.wal != nil { + wal = nodeConfig.wal + } + epochConfig.WAL = wal + + epochConfig.BlockDeserializer = &BlockDeserializer{ + mempool: mempool, + } + + e, err := simplex.NewEpoch(epochConfig) + require.NoError(t, err) + + n := &Node{ + BasicNode: testutil.NewBasicNode(t, e, l), + storage: storage, + mempool: mempool, + logger: l, + wal: wal, + isCrashed: atomic.Bool{}, + } + + n.BasicNode.CustomHandler = n.HandleMessage + + return n +} + +func (n *Node) HandleMessage(msg *simplex.Message, from simplex.NodeID) error { + msgCopy := n.copyMessage(msg) + return n.BasicNode.HandleMessage(&msgCopy, from) +} + +// copyBlock creates a copy of a simplex.Block with the node's mempool. +func (n *Node) copyBlock(b simplex.Block) simplex.Block { + blockCopy := *b.(*Block) + blockCopy.mempool = n.mempool + return &blockCopy +} + +// copyMessage creates a copy of the message and its relevant fields to avoid mutating shared state in the in-memory network +// this is important because blocks are not serialized/deserialized in our current comm implementation, so sending blocks +// also sends relevant state associated with the node that is sending the message which can cause unintended side effects. +func (n *Node) copyMessage(msg *simplex.Message) simplex.Message { + msgCopy := *msg + + switch { + case msgCopy.BlockMessage != nil: + blockMsgCopy := *msgCopy.BlockMessage + blockMsgCopy.Block = n.copyBlock(msgCopy.BlockMessage.Block) + msgCopy.BlockMessage = &blockMsgCopy + + case msgCopy.ReplicationResponse != nil: + rrCopy := *msgCopy.ReplicationResponse + + // Also copy the Data slice to avoid mutating shared slice + rrCopy.Data = make([]simplex.QuorumRound, len(msgCopy.ReplicationResponse.Data)) + copy(rrCopy.Data, msgCopy.ReplicationResponse.Data) + msgCopy.ReplicationResponse = &rrCopy + + // convert quorum rounds to our type + for i, qr := range msgCopy.ReplicationResponse.Data { + if qr.Block != nil { + msgCopy.ReplicationResponse.Data[i].Block = n.copyBlock(qr.Block) + } + } + + if msgCopy.ReplicationResponse.LatestRound != nil { + latestRoundCopy := *msgCopy.ReplicationResponse.LatestRound + msgCopy.ReplicationResponse.LatestRound = &latestRoundCopy + if latestRoundCopy.Block != nil { + msgCopy.ReplicationResponse.LatestRound.Block = n.copyBlock(latestRoundCopy.Block) + } + } + + if msgCopy.ReplicationResponse.LatestSeq != nil { + latestSeqCopy := *msgCopy.ReplicationResponse.LatestSeq + msgCopy.ReplicationResponse.LatestSeq = &latestSeqCopy + if latestSeqCopy.Block != nil { + msgCopy.ReplicationResponse.LatestSeq.Block = n.copyBlock(latestSeqCopy.Block) + } + } + default: + // no-op + } + return msgCopy +} + +func (n *Node) areTxsAccepted(txs []*TX) bool { + n.mempool.lock.Lock() + defer n.mempool.lock.Unlock() + + for _, tx := range txs { + if _, exists := n.mempool.acceptedTXs[tx.ID]; !exists { + return false + } + } + return true +} diff --git a/testutil/random_network/storage.go b/testutil/random_network/storage.go new file mode 100644 index 00000000..effc597c --- /dev/null +++ b/testutil/random_network/storage.go @@ -0,0 +1,35 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package random_network + +import ( + "context" + + "github.com/ava-labs/simplex" + "github.com/ava-labs/simplex/testutil" +) + +type Storage struct { + *testutil.InMemStorage + mempool *Mempool +} + +func NewStorage(mempool *Mempool) *Storage { + return &Storage{ + InMemStorage: testutil.NewInMemStorage(), + mempool: mempool, + } +} + +func (s *Storage) Index(ctx context.Context, block simplex.VerifiedBlock, certificate simplex.Finalization) error { + s.mempool.AcceptBlock(block.(*Block)) + return s.InMemStorage.Index(ctx, block, certificate) +} + +func (s *Storage) Clone() *Storage { + return &Storage{ + InMemStorage: s.InMemStorage.Clone(), + mempool: s.mempool, // Share the same mempool + } +} diff --git a/testutil/random_network/tx.go b/testutil/random_network/tx.go new file mode 100644 index 00000000..418ccaf7 --- /dev/null +++ b/testutil/random_network/tx.go @@ -0,0 +1,76 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package random_network + +import ( + "context" + "crypto/rand" + "encoding/asn1" + "errors" + "fmt" +) + +var errTxVerification = errors.New("tx verification failed") + +type txID [32]byte + +type TX struct { + ID txID + shouldFailVerification bool +} + +func (t *TX) String() string { + return fmt.Sprintf("%x", t.ID[:]) +} + +type asn1TX struct { + ID []byte + ShouldFailVerification bool +} + +func (aTX asn1TX) toTX() *TX { + var idArr txID + copy(idArr[:], aTX.ID) + return &TX{ID: idArr, shouldFailVerification: aTX.ShouldFailVerification} +} + +func CreateNewTX() *TX { + id := make([]byte, 32) + _, err := rand.Read(id) + if err != nil { + panic(err) + } + + var idArr txID + copy(idArr[:], id) + + return &TX{ID: idArr} +} + +func (t *TX) SetShouldFailVerification() { + t.shouldFailVerification = true +} + +func (t *TX) Bytes() ([]byte, error) { + return asn1.Marshal(asn1TX{ID: t.ID[:], ShouldFailVerification: t.shouldFailVerification}) +} + +func TxFromBytes(b []byte) (*TX, error) { + var asn1TX asn1TX + _, err := asn1.Unmarshal(b, &asn1TX) + if err != nil { + return nil, err + } + + return asn1TX.toTX(), nil +} + +func (t *TX) Verify(ctx context.Context) error { + // TBD + // Can set artificial failure here for testing or longer verification times + if t.shouldFailVerification { + return errTxVerification + } + return nil +} diff --git a/testutil/random_network/tx_test.go b/testutil/random_network/tx_test.go new file mode 100644 index 00000000..aacbc78c --- /dev/null +++ b/testutil/random_network/tx_test.go @@ -0,0 +1,49 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package random_network + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCreateNewTX(t *testing.T) { + tx := CreateNewTX() + require.NotNil(t, tx) + + var zero txID + require.NotEqual(t, zero, tx.ID) +} + +func TestTXSerialize(t *testing.T) { + tx := CreateNewTX() + + b, err := tx.Bytes() + require.NoError(t, err) + require.NotEmpty(t, b) + + tx2, err := TxFromBytes(b) + require.NoError(t, err) + require.Equal(t, tx, tx2) +} + +func TestTXSerializeWithShouldFailVerification(t *testing.T) { + tx := CreateNewTX() + tx.SetShouldFailVerification() + + b, err := tx.Bytes() + require.NoError(t, err) + require.NotEmpty(t, b) + + tx2, err := TxFromBytes(b) + require.NoError(t, err) + require.True(t, tx2.shouldFailVerification) + require.Equal(t, tx, tx2) +} + +func TestTxFromBytesInvalid(t *testing.T) { + _, err := TxFromBytes([]byte{0x01, 0x02, 0x03}) + require.Error(t, err) +} diff --git a/testutil/random_network/utils.go b/testutil/random_network/utils.go new file mode 100644 index 00000000..a9b70756 --- /dev/null +++ b/testutil/random_network/utils.go @@ -0,0 +1,16 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package random_network + +import ( + "github.com/ava-labs/simplex" +) + +func NewProtocolMetadata(round, seq uint64, prev simplex.Digest) simplex.ProtocolMetadata { + return simplex.ProtocolMetadata{ + Round: round, + Seq: seq, + Prev: prev, + } +}