From 0808196869591fb679ae8eca2805fdbcec172c48 Mon Sep 17 00:00:00 2001 From: Gabe Rodriguez Date: Fri, 20 Mar 2026 10:33:18 +0100 Subject: [PATCH 1/2] Revert back to published 0.7.2 state --- pod/Cargo.toml | 30 +- pod/src/bytemuck.rs | 2 +- pod/src/error.rs | 54 +++ pod/src/lib.rs | 13 +- pod/src/list/list_trait.rs | 28 ++ pod/src/list/list_view.rs | 645 ++++++++++++++++++++++++++++ pod/src/list/list_view_mut.rs | 426 ++++++++++++++++++ pod/src/list/list_view_read_only.rs | 202 +++++++++ pod/src/list/mod.rs | 9 + pod/src/option.rs | 387 ++--------------- pod/src/optional_keys.rs | 360 ++++++++++++++++ pod/src/pod_length.rs | 41 ++ pod/src/primitives.rs | 150 ++----- pod/src/slice.rs | 221 ++++++++++ 14 files changed, 2095 insertions(+), 473 deletions(-) create mode 100644 pod/src/error.rs create mode 100644 pod/src/list/list_trait.rs create mode 100644 pod/src/list/list_view.rs create mode 100644 pod/src/list/list_view_mut.rs create mode 100644 pod/src/list/list_view_read_only.rs create mode 100644 pod/src/list/mod.rs create mode 100644 pod/src/optional_keys.rs create mode 100644 pod/src/pod_length.rs create mode 100644 pod/src/slice.rs diff --git a/pod/Cargo.toml b/pod/Cargo.toml index fdbdc5d1..cdf0c7ea 100644 --- a/pod/Cargo.toml +++ b/pod/Cargo.toml @@ -8,28 +8,30 @@ license = "Apache-2.0" edition = "2021" [features] -bytemuck = ["dep:bytemuck", "dep:bytemuck_derive", "solana-address/bytemuck"] -serde = ["dep:serde", "dep:serde_derive", "solana-address/serde"] -borsh = ["dep:borsh", "solana-address/borsh"] -wincode = ["dep:wincode", "dep:wincode-derive"] +serde-traits = ["dep:serde"] +borsh = ["dep:borsh", "solana-pubkey/borsh"] +wincode = ["dep:wincode"] [dependencies] -borsh = { version = "1.5.7", default-features = false, features = ["derive", "unstable__schema"], optional = true } -bytemuck = { version = "1.23.2", optional = true } -bytemuck_derive = { version = "1.10.1", optional = true } -serde = { version = "1.0.228", default-features = false, optional = true, features = ["alloc"] } -serde_derive = { version = "1.0.228", optional = true } -solana-address = "2.2.0" +borsh = { version = "1.5.7", features = ["derive", "unstable__schema"], optional = true } +bytemuck = { version = "1.23.2" } +bytemuck_derive = { version = "1.10.1" } +num-derive = "0.4" +num_enum = "0.7" +num-traits = "0.2" +serde = { version = "1.0.228", optional = true } +wincode = { version = "0.4.4", features = ["derive"], optional = true } solana-program-error = "3.0.0" solana-program-option = "3.0.0" -wincode = { version = "0.4.4", default-features = false, optional = true } -wincode-derive = { version = "0.4.2", optional = true } +solana-pubkey = "3.0.0" +solana-zk-sdk = "4.0.0" +thiserror = "2.0" [dev-dependencies] +base64 = { version = "0.22.1" } serde_json = "1.0.145" -spl-pod = { path = ".", features = ["bytemuck", "wincode", "borsh"] } +spl-pod = { path = ".", features = ["wincode"] } test-case = "3.3.1" -wincode = { version = "0.4.4", default-features = false, features = ["alloc"] } [lib] crate-type = ["lib"] diff --git a/pod/src/bytemuck.rs b/pod/src/bytemuck.rs index f50eacc5..744553d2 100644 --- a/pod/src/bytemuck.rs +++ b/pod/src/bytemuck.rs @@ -4,7 +4,7 @@ use {bytemuck::Pod, solana_program_error::ProgramError}; /// On-chain size of a `Pod` type pub const fn pod_get_packed_len() -> usize { - core::mem::size_of::() + std::mem::size_of::() } /// Convert a `Pod` into a slice of bytes (zero copy) diff --git a/pod/src/error.rs b/pod/src/error.rs new file mode 100644 index 00000000..7e1a4547 --- /dev/null +++ b/pod/src/error.rs @@ -0,0 +1,54 @@ +//! Error types +use { + solana_program_error::{ProgramError, ToStr}, + std::num::TryFromIntError, +}; + +/// Errors that may be returned by the spl-pod library. +#[repr(u32)] +#[derive( + Debug, + Clone, + PartialEq, + Eq, + thiserror::Error, + num_enum::TryFromPrimitive, + num_derive::FromPrimitive, +)] +pub enum PodSliceError { + /// Error in checked math operation + #[error("Error in checked math operation")] + CalculationFailure, + /// Provided byte buffer too small for expected type + #[error("Provided byte buffer too small for expected type")] + BufferTooSmall, + /// Provided byte buffer too large for expected type + #[error("Provided byte buffer too large for expected type")] + BufferTooLarge, + /// An integer conversion failed because the value was out of range for the target type + #[error("An integer conversion failed because the value was out of range for the target type")] + ValueOutOfRange, +} + +impl From for ProgramError { + fn from(e: PodSliceError) -> Self { + ProgramError::Custom(e as u32) + } +} + +impl ToStr for PodSliceError { + fn to_str(&self) -> &'static str { + match self { + PodSliceError::CalculationFailure => "Error in checked math operation", + PodSliceError::BufferTooSmall => "Provided byte buffer too small for expected type", + PodSliceError::BufferTooLarge => "Provided byte buffer too large for expected type", + PodSliceError::ValueOutOfRange => "An integer conversion failed because the value was out of range for the target type" + } + } +} + +impl From for PodSliceError { + fn from(_: TryFromIntError) -> Self { + PodSliceError::ValueOutOfRange + } +} diff --git a/pod/src/lib.rs b/pod/src/lib.rs index 2509506f..b9a26da1 100644 --- a/pod/src/lib.rs +++ b/pod/src/lib.rs @@ -1,15 +1,14 @@ -#![no_std] - //! Crate containing `Pod` types and `bytemuck` utilities used in SPL -#[cfg(any(feature = "borsh", feature = "serde", test))] -extern crate alloc; - -#[cfg(feature = "bytemuck")] pub mod bytemuck; +pub mod error; +pub mod list; pub mod option; +pub mod optional_keys; +pub mod pod_length; pub mod primitives; +pub mod slice; // Export current sdk types for downstream users building with a different sdk // version -pub use {solana_address, solana_program_error, solana_program_option}; +pub use {solana_program_error, solana_program_option, solana_pubkey}; diff --git a/pod/src/list/list_trait.rs b/pod/src/list/list_trait.rs new file mode 100644 index 00000000..c7ebf374 --- /dev/null +++ b/pod/src/list/list_trait.rs @@ -0,0 +1,28 @@ +use { + crate::{list::ListView, pod_length::PodLength}, + bytemuck::Pod, + solana_program_error::ProgramError, + std::ops::Deref, +}; + +/// A trait to abstract the shared, read-only behavior +/// between `ListViewReadOnly` and `ListViewMut`. +pub trait List: Deref { + /// The type of the items stored in the list. + type Item: Pod; + /// Length prefix type used (`PodU16`, `PodU32`, …). + type Length: PodLength; + + /// Returns the total number of items that can be stored in the list. + fn capacity(&self) -> usize; + + /// Returns the number of **bytes currently occupied** by the live elements + fn bytes_used(&self) -> Result { + ListView::::size_of(self.len()) + } + + /// Returns the number of **bytes reserved** by the entire backing buffer. + fn bytes_allocated(&self) -> Result { + ListView::::size_of(self.capacity()) + } +} diff --git a/pod/src/list/list_view.rs b/pod/src/list/list_view.rs new file mode 100644 index 00000000..23a64391 --- /dev/null +++ b/pod/src/list/list_view.rs @@ -0,0 +1,645 @@ +//! `ListView`, a compact, zero-copy array wrapper. + +use { + crate::{ + bytemuck::{ + pod_from_bytes, pod_from_bytes_mut, pod_slice_from_bytes, pod_slice_from_bytes_mut, + }, + error::PodSliceError, + list::{list_view_mut::ListViewMut, list_view_read_only::ListViewReadOnly}, + pod_length::PodLength, + primitives::PodU32, + }, + bytemuck::Pod, + solana_program_error::ProgramError, + std::{ + marker::PhantomData, + mem::{align_of, size_of}, + ops::Range, + }, +}; + +/// An API for interpreting a raw buffer (`&[u8]`) as a variable-length collection of Pod elements. +/// +/// `ListView` provides a safe, zero-copy, `Vec`-like interface for a slice of +/// `Pod` data that resides in an external, pre-allocated `&[u8]` buffer. +/// It does not own the buffer itself, but acts as a view over it, which can be +/// read-only (`ListViewReadOnly`) or mutable (`ListViewMut`). +/// +/// This is useful in environments where allocations are restricted or expensive, +/// such as Solana programs, allowing for efficient reads and manipulation of +/// dynamic-length data structures. +/// +/// ## Memory Layout +/// +/// The structure assumes the underlying byte buffer is formatted as follows: +/// 1. **Length**: A length field of type `L` at the beginning of the buffer, +/// indicating the number of currently active elements in the collection. +/// Defaults to `PodU32`. The implementation uses padding to ensure that the +/// data is correctly aligned for any `Pod` type. +/// 2. **Padding**: Optional padding bytes to ensure proper alignment of the data. +/// 3. **Data**: The remaining part of the buffer, which is treated as a slice +/// of `T` elements. The capacity of the collection is the number of `T` +/// elements that can fit into this data portion. +pub struct ListView(PhantomData<(T, L)>); + +struct Layout { + length_range: Range, + data_range: Range, +} + +impl ListView { + /// Calculate the total byte size for a `ListView` holding `num_items`. + /// This includes the length prefix, padding, and data. + pub fn size_of(num_items: usize) -> Result { + let header_padding = Self::header_padding()?; + size_of::() + .checked_mul(num_items) + .and_then(|curr| curr.checked_add(size_of::())) + .and_then(|curr| curr.checked_add(header_padding)) + .ok_or_else(|| PodSliceError::CalculationFailure.into()) + } + + /// Unpack a read-only buffer into a `ListViewReadOnly` + pub fn unpack(buf: &[u8]) -> Result, ProgramError> { + let layout = Self::calculate_layout(buf.len())?; + + // Slice the buffer to get the length prefix and the data. + // The layout calculation provides the correct ranges, accounting for any + // padding between the length and the data. + // + // buf: [ L L L L | P P | D D D D D D D D ...] + // <-----> <------------------> + // len_bytes data_bytes + let len_bytes = &buf[layout.length_range]; + let data_bytes = &buf[layout.data_range]; + + let length = pod_from_bytes::(len_bytes)?; + let data = pod_slice_from_bytes::(data_bytes)?; + let capacity = data.len(); + + if (*length).into() > capacity { + return Err(PodSliceError::BufferTooSmall.into()); + } + + Ok(ListViewReadOnly { + length, + data, + capacity, + }) + } + + /// Unpack the mutable buffer into a mutable `ListViewMut` + pub fn unpack_mut(buf: &mut [u8]) -> Result, ProgramError> { + let view = Self::build_mut_view(buf)?; + if (*view.length).into() > view.capacity { + return Err(PodSliceError::BufferTooSmall.into()); + } + Ok(view) + } + + /// Initialize a buffer: sets `length = 0` and returns a mutable `ListViewMut`. + pub fn init(buf: &mut [u8]) -> Result, ProgramError> { + let view = Self::build_mut_view(buf)?; + *view.length = L::try_from(0)?; + Ok(view) + } + + /// Internal helper to build a mutable view without validation or initialization. + #[inline] + fn build_mut_view(buf: &mut [u8]) -> Result, ProgramError> { + let layout = Self::calculate_layout(buf.len())?; + + // Split the buffer to get the length prefix and the data. + // buf: [ L L L L | P P | D D D D D D D D ...] + // <---- head ---> <--- tail ---------> + let (header_bytes, data_bytes) = buf.split_at_mut(layout.data_range.start); + // header: [ L L L L | P P ] + // <-----> + // len_bytes + let len_bytes = &mut header_bytes[layout.length_range]; + + // Cast the bytes to typed data + let length = pod_from_bytes_mut::(len_bytes)?; + let data = pod_slice_from_bytes_mut::(data_bytes)?; + let capacity = data.len(); + + Ok(ListViewMut { + length, + data, + capacity, + }) + } + + /// Calculate the byte ranges for the length and data sections of the buffer + #[inline] + fn calculate_layout(buf_len: usize) -> Result { + let len_field_end = size_of::(); + let header_padding = Self::header_padding()?; + let data_start = len_field_end.saturating_add(header_padding); + + if buf_len < data_start { + return Err(PodSliceError::BufferTooSmall.into()); + } + + Ok(Layout { + length_range: 0..len_field_end, + data_range: data_start..buf_len, + }) + } + + /// Calculate the padding required to align the data part of the buffer. + /// + /// The goal is to ensure that the data field `T` starts at a memory offset + /// that is a multiple of its alignment requirement. + #[inline] + fn header_padding() -> Result { + // Enforce that the length prefix type `L` itself does not have alignment requirements + if align_of::() != 1 { + return Err(ProgramError::InvalidArgument); + } + + let length_size = size_of::(); + let data_align = align_of::(); + + // No padding is needed for alignments of 0 or 1 + if data_align == 0 || data_align == 1 { + return Ok(0); + } + + // Find how many bytes `length_size` extends past an alignment boundary + #[allow(clippy::arithmetic_side_effects)] + let remainder = length_size.wrapping_rem(data_align); + + // If already aligned (remainder is 0), no padding is needed. + // Otherwise, calculate the distance to the next alignment boundary. + if remainder == 0 { + Ok(0) + } else { + Ok(data_align.wrapping_sub(remainder)) + } + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + crate::{ + list::List, + primitives::{PodU128, PodU16, PodU32, PodU64}, + }, + bytemuck_derive::{Pod as DerivePod, Zeroable}, + }; + + #[test] + fn test_size_of_no_padding() { + // Case 1: T has align 1, so no padding is ever needed. + // 10 items * 1 byte/item + 4 bytes for length = 14 + assert_eq!(ListView::::size_of(10).unwrap(), 14); + + // Case 2: size_of is a multiple of align_of, so no padding needed. + // T = u32 (size 4, align 4), L = PodU32 (size 4). 4 % 4 == 0. + // 10 items * 4 bytes/item + 4 bytes for length = 44 + assert_eq!(ListView::::size_of(10).unwrap(), 44); + + // Case 3: 0 items. Size should just be size_of + padding. + // Padding is 0 here. + // 0 items * 4 bytes/item + 4 bytes for length = 4 + assert_eq!(ListView::::size_of(0).unwrap(), 4); + } + + #[test] + fn test_size_of_with_padding() { + // Case 1: Padding is required. + // T = u64 (size 8, align 8), L = PodU32 (size 4). + // Padding required to align data to 8 bytes is 4. (4 + 4 = 8) + // (10 items * 8 bytes/item) + 4 bytes for length + 4 bytes for padding = 88 + assert_eq!(ListView::::size_of(10).unwrap(), 88); + + #[repr(C, align(16))] + #[derive(DerivePod, Zeroable, Copy, Clone)] + struct Align16(u128); + + // Case 2: Custom struct with high alignment. + // size 16, align 16 + // L = PodU64 (size 8). + // Padding required to align data to 16 bytes is 8. (8 + 8 = 16) + // (10 items * 16 bytes/item) + 8 bytes for length + 8 bytes for padding = 176 + assert_eq!(ListView::::size_of(10).unwrap(), 176); + + // Case 3: 0 items with padding. + // Size should be size_of + padding. + // L = PodU32 (size 4), T = u64 (align 8). Padding is 4. + // Total size = 4 + 4 = 8 + assert_eq!(ListView::::size_of(0).unwrap(), 8); + } + + #[test] + fn test_size_of_overflow() { + // Case 1: Multiplication overflows. + // `size_of::() * usize::MAX` will overflow. + let err = ListView::::size_of(usize::MAX).unwrap_err(); + assert_eq!(err, PodSliceError::CalculationFailure.into()); + + // Case 2: Multiplication does not overflow, but subsequent addition does. + // `size_of::() * usize::MAX` does not overflow, but adding `size_of` will. + let err = ListView::::size_of(usize::MAX).unwrap_err(); + assert_eq!(err, PodSliceError::CalculationFailure.into()); + } + + #[test] + fn test_fails_with_non_aligned_length_type() { + // A custom `PodLength` type with an alignment of 4 + #[repr(C, align(4))] + #[derive(Debug, Copy, Clone, Zeroable, DerivePod)] + struct TestPodU32(u32); + + // Implement the traits for `PodLength` + impl From for usize { + fn from(val: TestPodU32) -> Self { + val.0 as usize + } + } + impl TryFrom for TestPodU32 { + type Error = PodSliceError; + fn try_from(val: usize) -> Result { + Ok(Self(u32::try_from(val)?)) + } + } + + let mut buf = [0u8; 100]; + + let err_size_of = ListView::::size_of(10).unwrap_err(); + assert_eq!(err_size_of, ProgramError::InvalidArgument); + + let err_unpack = ListView::::unpack(&buf).unwrap_err(); + assert_eq!(err_unpack, ProgramError::InvalidArgument); + + let err_init = ListView::::init(&mut buf).unwrap_err(); + assert_eq!(err_init, ProgramError::InvalidArgument); + } + + #[test] + fn test_padding_calculation() { + // `u8` has an alignment of 1, so no padding is ever needed. + assert_eq!(ListView::::header_padding().unwrap(), 0); + + // Zero-Sized Types like `()` have size 0 and align 1, requiring no padding. + assert_eq!(ListView::<(), PodU64>::header_padding().unwrap(), 0); + + // When length and data have the same alignment. + assert_eq!(ListView::::header_padding().unwrap(), 0); + assert_eq!(ListView::::header_padding().unwrap(), 0); + assert_eq!(ListView::::header_padding().unwrap(), 0); + + // When data alignment is smaller than or perfectly divides the length size. + assert_eq!(ListView::::header_padding().unwrap(), 0); // 8 % 2 = 0 + assert_eq!(ListView::::header_padding().unwrap(), 0); // 8 % 4 = 0 + + // When padding IS needed. + assert_eq!(ListView::::header_padding().unwrap(), 2); // size_of is 2. To align to 4, need 2 bytes. + assert_eq!(ListView::::header_padding().unwrap(), 6); // size_of is 2. To align to 8, need 6 bytes. + assert_eq!(ListView::::header_padding().unwrap(), 4); // size_of is 4. To align to 8, need 4 bytes. + + // Test with custom, higher alignments. + #[repr(C, align(8))] + #[derive(DerivePod, Zeroable, Copy, Clone)] + struct Align8(u64); + + // Test against different length types + assert_eq!(ListView::::header_padding().unwrap(), 6); // 2 + 6 = 8 + assert_eq!(ListView::::header_padding().unwrap(), 4); // 4 + 4 = 8 + assert_eq!(ListView::::header_padding().unwrap(), 0); // 8 is already aligned + + #[repr(C, align(16))] + #[derive(DerivePod, Zeroable, Copy, Clone)] + struct Align16(u128); + + assert_eq!(ListView::::header_padding().unwrap(), 14); // 2 + 14 = 16 + assert_eq!(ListView::::header_padding().unwrap(), 12); // 4 + 12 = 16 + assert_eq!(ListView::::header_padding().unwrap(), 8); // 8 + 8 = 16 + } + + #[test] + fn test_unpack_success_no_padding() { + // T = u32 (align 4), L = PodU32 (size 4, align 4). No padding needed. + let length: u32 = 2; + let capacity: usize = 3; + let item_size = size_of::(); + let len_size = size_of::(); + let buf_size = len_size + capacity * item_size; + let mut buf = vec![0u8; buf_size]; + + let pod_len: PodU32 = length.into(); + buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len)); + + let data_start = len_size; + let items = [100u32, 200u32]; + let items_bytes = bytemuck::cast_slice(&items); + buf[data_start..(data_start + items_bytes.len())].copy_from_slice(items_bytes); + + let view_ro = ListView::::unpack(&buf).unwrap(); + assert_eq!(view_ro.len(), length as usize); + assert_eq!(view_ro.capacity(), capacity); + assert_eq!(*view_ro, items[..]); + + let view_mut = ListView::::unpack_mut(&mut buf).unwrap(); + assert_eq!(view_mut.len(), length as usize); + assert_eq!(view_mut.capacity(), capacity); + assert_eq!(*view_mut, items[..]); + } + + #[test] + fn test_unpack_success_with_padding() { + // T = u64 (align 8), L = PodU32 (size 4, align 4). Needs 4 bytes padding. + let padding = ListView::::header_padding().unwrap(); + assert_eq!(padding, 4); + + let length: u32 = 2; + let capacity: usize = 2; + let item_size = size_of::(); + let len_size = size_of::(); + let buf_size = len_size + padding + capacity * item_size; + let mut buf = vec![0u8; buf_size]; + + let pod_len: PodU32 = length.into(); + buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len)); + + // Data starts after length and padding + let data_start = len_size + padding; + let items = [100u64, 200u64]; + let items_bytes = bytemuck::cast_slice(&items); + buf[data_start..(data_start + items_bytes.len())].copy_from_slice(items_bytes); + + let view_ro = ListView::::unpack(&buf).unwrap(); + assert_eq!(view_ro.len(), length as usize); + assert_eq!(view_ro.capacity(), capacity); + assert_eq!(*view_ro, items[..]); + + let view_mut = ListView::::unpack_mut(&mut buf).unwrap(); + assert_eq!(view_mut.len(), length as usize); + assert_eq!(view_mut.capacity(), capacity); + assert_eq!(*view_mut, items[..]); + } + + #[test] + fn test_unpack_success_zero_length() { + let capacity: usize = 5; + let item_size = size_of::(); + let len_size = size_of::(); + let buf_size = len_size + capacity * item_size; + let mut buf = vec![0u8; buf_size]; + + let pod_len: PodU32 = 0u32.into(); + buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len)); + + let view_ro = ListView::::unpack(&buf).unwrap(); + assert_eq!(view_ro.len(), 0); + assert_eq!(view_ro.capacity(), capacity); + assert!(view_ro.is_empty()); + assert_eq!(&*view_ro, &[] as &[u32]); + + let view_mut = ListView::::unpack_mut(&mut buf).unwrap(); + assert_eq!(view_mut.len(), 0); + assert_eq!(view_mut.capacity(), capacity); + assert!(view_mut.is_empty()); + assert_eq!(&*view_mut, &[] as &[u32]); + } + + #[test] + fn test_unpack_success_full_capacity() { + let length: u64 = 3; + let capacity: usize = 3; + let item_size = size_of::(); + let len_size = size_of::(); + let buf_size = len_size + capacity * item_size; + let mut buf = vec![0u8; buf_size]; + + let pod_len: PodU64 = length.into(); + buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len)); + + let data_start = len_size; + let items = [1u64, 2u64, 3u64]; + let items_bytes = bytemuck::cast_slice(&items); + buf[data_start..].copy_from_slice(items_bytes); + + let view_ro = ListView::::unpack(&buf).unwrap(); + assert_eq!(view_ro.len(), length as usize); + assert_eq!(view_ro.capacity(), capacity); + assert_eq!(*view_ro, items[..]); + + let view_mut = ListView::::unpack_mut(&mut buf).unwrap(); + assert_eq!(view_mut.len(), length as usize); + assert_eq!(view_mut.capacity(), capacity); + assert_eq!(*view_mut, items[..]); + } + + #[test] + fn test_unpack_fail_buffer_too_small_for_header() { + // T = u64 (align 8), L = PodU32 (size 4). Header size is 8. + let header_size = ListView::::size_of(0).unwrap(); + assert_eq!(header_size, 8); + + // Provide a buffer smaller than the required header + let mut buf = vec![0u8; header_size - 1]; // 7 bytes + + let err = ListView::::unpack(&buf).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + + let err = ListView::::unpack_mut(&mut buf).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + } + + #[test] + fn test_unpack_fail_declared_length_exceeds_capacity() { + let declared_length: u32 = 4; + let capacity: usize = 3; // buffer can only hold 3 + let item_size = size_of::(); + let len_size = size_of::(); + let buf_size = len_size + capacity * item_size; + let mut buf = vec![0u8; buf_size]; + + // Write a length that is bigger than capacity + let pod_len: PodU32 = declared_length.into(); + buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len)); + + let err = ListView::::unpack(&buf).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + + let err = ListView::::unpack_mut(&mut buf).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + } + + #[test] + fn test_unpack_fail_data_part_not_multiple_of_item_size() { + let len_size = size_of::(); + + // data part is 5 bytes, not a multiple of item_size (4) + let buf_size = len_size + 5; + let mut buf = vec![0u8; buf_size]; + + // bytemuck::try_cast_slice returns an alignment error, which we map to InvalidArgument + + let err = ListView::::unpack(&buf).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + + let err = ListView::::unpack_mut(&mut buf).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + } + + #[test] + fn test_unpack_empty_buffer() { + let mut buf = []; + let err = ListView::::unpack(&buf).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + + let err = ListView::::unpack_mut(&mut buf).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + } + + #[test] + fn test_init_success_no_padding() { + // T = u32 (align 4), L = PodU32 (size 4). No padding needed. + let capacity: usize = 5; + let len_size = size_of::(); + let buf_size = ListView::::size_of(capacity).unwrap(); + let mut buf = vec![0xFFu8; buf_size]; // Pre-fill to ensure init zeroes it + + let view = ListView::::init(&mut buf).unwrap(); + + assert_eq!(view.len(), 0); + assert_eq!(view.capacity(), capacity); + assert!(view.is_empty()); + + // Check that the underlying buffer's length was actually zeroed + let length_bytes = &buf[0..len_size]; + assert_eq!(length_bytes, &[0u8; 4]); + } + + #[test] + fn test_init_success_with_padding() { + // T = u64 (align 8), L = PodU32 (size 4). Needs 4 bytes padding. + let capacity: usize = 3; + let len_size = size_of::(); + let buf_size = ListView::::size_of(capacity).unwrap(); + let mut buf = vec![0xFFu8; buf_size]; // Pre-fill to ensure init zeroes it + + let view = ListView::::init(&mut buf).unwrap(); + + assert_eq!(view.len(), 0); + assert_eq!(view.capacity(), capacity); + assert!(view.is_empty()); + + // Check that the underlying buffer's length was actually zeroed + let length_bytes = &buf[0..len_size]; + assert_eq!(length_bytes, &[0u8; 4]); + // The padding bytes may or may not be zeroed, we don't assert on them. + } + + #[test] + fn test_init_success_zero_capacity() { + // Test initializing a buffer that can only hold the header. + // T = u64 (align 8), L = PodU32 (size 4). Header size is 8. + let buf_size = ListView::::size_of(0).unwrap(); + assert_eq!(buf_size, 8); + let mut buf = vec![0xFFu8; buf_size]; + + let view = ListView::::init(&mut buf).unwrap(); + + assert_eq!(view.len(), 0); + assert_eq!(view.capacity(), 0); + assert!(view.is_empty()); + + // Check that the underlying buffer's length was actually zeroed + let len_size = size_of::(); + let length_bytes = &buf[0..len_size]; + assert_eq!(length_bytes, &[0u8; 4]); + } + + #[test] + fn test_init_fail_buffer_too_small() { + // Header requires 4 bytes (size_of) + let mut buf = vec![0u8; 3]; + let err = ListView::::init(&mut buf).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + + // With padding, header requires 8 bytes (4 for len, 4 for pad) + let mut buf_padded = vec![0u8; 7]; + let err_padded = ListView::::init(&mut buf_padded).unwrap_err(); + assert_eq!(err_padded, PodSliceError::BufferTooSmall.into()); + } + + #[test] + fn test_init_success_default_length_type() { + // This test uses the default L=PodU32 length type by omitting it. + // T = u32 (align 4), L = PodU32 (size 4). No padding needed as 4 % 4 == 0. + let capacity = 5; + let len_size = size_of::(); // Default L is PodU32 + let buf_size = ListView::::size_of(capacity).unwrap(); + let mut buf = vec![0xFFu8; buf_size]; // Pre-fill to ensure init zeroes it + + let view = ListView::::init(&mut buf).unwrap(); + + assert_eq!(view.len(), 0); + assert_eq!(view.capacity(), capacity); + assert!(view.is_empty()); + + // Check that the underlying buffer's length (a u32) was actually zeroed + let length_bytes = &buf[0..len_size]; + assert_eq!(length_bytes, &[0u8; 4]); + } + + macro_rules! test_list_view_for_length_type { + ($test_name:ident, $LengthType:ty) => { + #[test] + fn $test_name() { + type T = u64; + + let padding = ListView::::header_padding().unwrap(); + let length_usize = 2usize; + let capacity = 3; + + let item_size = size_of::(); + let len_size = size_of::<$LengthType>(); + let buf_size = len_size + padding + capacity * item_size; + let mut buf = vec![0u8; buf_size]; + + // Write length + let pod_len = <$LengthType>::try_from(length_usize).unwrap(); + buf[0..len_size].copy_from_slice(bytemuck::bytes_of(&pod_len)); + + // Write data + let data_start = len_size + padding; + let items = [1000 as T, 2000 as T]; + let items_bytes = bytemuck::cast_slice(&items); + buf[data_start..(data_start + items_bytes.len())].copy_from_slice(items_bytes); + + // Test read-only view + let view_ro = ListView::::unpack(&buf).unwrap(); + assert_eq!(view_ro.len(), length_usize); + assert_eq!(view_ro.capacity(), capacity); + assert_eq!(*view_ro, items[..]); + + // Test mutable view + let mut buf_mut = buf.clone(); + let view_mut = ListView::::unpack_mut(&mut buf_mut).unwrap(); + assert_eq!(view_mut.len(), length_usize); + assert_eq!(view_mut.capacity(), capacity); + assert_eq!(*view_mut, items[..]); + + // Test init + let mut init_buf = vec![0xFFu8; buf_size]; + let init_view = ListView::::init(&mut init_buf).unwrap(); + assert_eq!(init_view.len(), 0); + assert_eq!(init_view.capacity(), capacity); + assert_eq!(<$LengthType>::try_from(0usize).unwrap(), *init_view.length); + } + }; + } + + test_list_view_for_length_type!(list_view_with_pod_u16, PodU16); + test_list_view_for_length_type!(list_view_with_pod_u32, PodU32); + test_list_view_for_length_type!(list_view_with_pod_u64, PodU64); + test_list_view_for_length_type!(list_view_with_pod_u128, PodU128); +} diff --git a/pod/src/list/list_view_mut.rs b/pod/src/list/list_view_mut.rs new file mode 100644 index 00000000..4f0ca49f --- /dev/null +++ b/pod/src/list/list_view_mut.rs @@ -0,0 +1,426 @@ +//! `ListViewMut`, a mutable, compact, zero-copy array wrapper. + +use { + crate::{ + error::PodSliceError, list::list_trait::List, pod_length::PodLength, primitives::PodU32, + }, + bytemuck::Pod, + solana_program_error::ProgramError, + std::ops::{Deref, DerefMut}, +}; + +#[derive(Debug)] +pub struct ListViewMut<'data, T: Pod, L: PodLength = PodU32> { + pub(crate) length: &'data mut L, + pub(crate) data: &'data mut [T], + pub(crate) capacity: usize, +} + +impl ListViewMut<'_, T, L> { + /// Add another item to the slice + pub fn push(&mut self, item: T) -> Result<(), ProgramError> { + let length = (*self.length).into(); + if length >= self.capacity { + Err(PodSliceError::BufferTooSmall.into()) + } else { + self.data[length] = item; + *self.length = L::try_from(length.saturating_add(1))?; + Ok(()) + } + } + + /// Remove and return the element at `index`, shifting all later + /// elements one position to the left. + pub fn remove(&mut self, index: usize) -> Result { + let len = (*self.length).into(); + if index >= len { + return Err(ProgramError::InvalidArgument); + } + + let removed_item = self.data[index]; + + // Move the tail left by one + let tail_start = index + .checked_add(1) + .ok_or(ProgramError::ArithmeticOverflow)?; + self.data.copy_within(tail_start..len, index); + + // Store the new length (len - 1) + let new_len = len.checked_sub(1).unwrap(); + *self.length = L::try_from(new_len)?; + + Ok(removed_item) + } +} + +impl Deref for ListViewMut<'_, T, L> { + type Target = [T]; + + fn deref(&self) -> &Self::Target { + let len = (*self.length).into(); + &self.data[..len] + } +} + +impl DerefMut for ListViewMut<'_, T, L> { + fn deref_mut(&mut self) -> &mut Self::Target { + let len = (*self.length).into(); + &mut self.data[..len] + } +} + +impl List for ListViewMut<'_, T, L> { + type Item = T; + type Length = L; + + fn capacity(&self) -> usize { + self.capacity + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + crate::{ + list::{List, ListView}, + primitives::{PodU16, PodU32, PodU64}, + }, + bytemuck_derive::{Pod, Zeroable}, + }; + + #[repr(C)] + #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Pod, Zeroable)] + struct TestStruct { + a: u64, + b: u32, + _padding: [u8; 4], + } + + impl TestStruct { + fn new(a: u64, b: u32) -> Self { + Self { + a, + b, + _padding: [0; 4], + } + } + } + + fn init_view_mut( + buffer: &mut Vec, + capacity: usize, + ) -> ListViewMut { + let size = ListView::::size_of(capacity).unwrap(); + buffer.resize(size, 0); + ListView::::init(buffer).unwrap() + } + + #[test] + fn test_push() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 3); + + assert_eq!(view.len(), 0); + assert!(view.is_empty()); + assert_eq!(view.capacity(), 3); + + // Push first item + let item1 = TestStruct::new(1, 10); + view.push(item1).unwrap(); + assert_eq!(view.len(), 1); + assert!(!view.is_empty()); + assert_eq!(*view, [item1]); + + // Push second item + let item2 = TestStruct::new(2, 20); + view.push(item2).unwrap(); + assert_eq!(view.len(), 2); + assert_eq!(*view, [item1, item2]); + + // Push third item to fill capacity + let item3 = TestStruct::new(3, 30); + view.push(item3).unwrap(); + assert_eq!(view.len(), 3); + assert_eq!(*view, [item1, item2, item3]); + + // Try to push beyond capacity + let item4 = TestStruct::new(4, 40); + let err = view.push(item4).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + + // Ensure state is unchanged + assert_eq!(view.len(), 3); + assert_eq!(*view, [item1, item2, item3]); + } + + #[test] + fn test_remove() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 4); + + let item1 = TestStruct::new(1, 10); + let item2 = TestStruct::new(2, 20); + let item3 = TestStruct::new(3, 30); + let item4 = TestStruct::new(4, 40); + view.push(item1).unwrap(); + view.push(item2).unwrap(); + view.push(item3).unwrap(); + view.push(item4).unwrap(); + + assert_eq!(view.len(), 4); + assert_eq!(*view, [item1, item2, item3, item4]); + + // Remove from the middle + let removed = view.remove(1).unwrap(); + assert_eq!(removed, item2); + assert_eq!(view.len(), 3); + assert_eq!(*view, [item1, item3, item4]); + + // Remove from the end + let removed = view.remove(2).unwrap(); + assert_eq!(removed, item4); + assert_eq!(view.len(), 2); + assert_eq!(*view, [item1, item3]); + + // Remove from the start + let removed = view.remove(0).unwrap(); + assert_eq!(removed, item1); + assert_eq!(view.len(), 1); + assert_eq!(*view, [item3]); + + // Remove the last element + let removed = view.remove(0).unwrap(); + assert_eq!(removed, item3); + assert_eq!(view.len(), 0); + assert!(view.is_empty()); + assert_eq!(*view, []); + } + + #[test] + fn test_remove_out_of_bounds() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 2); + + view.push(TestStruct::new(1, 10)).unwrap(); + view.push(TestStruct::new(2, 20)).unwrap(); + + // Try to remove at index == len + let err = view.remove(2).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + assert_eq!(view.len(), 2); // Unchanged + + // Try to remove at index > len + let err = view.remove(100).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + assert_eq!(view.len(), 2); // Unchanged + + // Empty the view + view.remove(1).unwrap(); + view.remove(0).unwrap(); + assert!(view.is_empty()); + + // Try to remove from empty view + let err = view.remove(0).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + } + + #[test] + fn test_iter_mut() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 4); + + let item1 = TestStruct::new(1, 10); + let item2 = TestStruct::new(2, 20); + let item3 = TestStruct::new(3, 30); + view.push(item1).unwrap(); + view.push(item2).unwrap(); + view.push(item3).unwrap(); + + assert_eq!(view.len(), 3); + assert_eq!(view.capacity(), 4); + + // Modify items using iter_mut + for item in view.iter_mut() { + item.a *= 10; + } + + let expected_item1 = TestStruct::new(10, 10); + let expected_item2 = TestStruct::new(20, 20); + let expected_item3 = TestStruct::new(30, 30); + + // Check that the underlying data is modified + assert_eq!(view.len(), 3); + assert_eq!(*view, [expected_item1, expected_item2, expected_item3]); + + // Check that iter_mut only iterates over `len` items, not `capacity` + assert_eq!(view.iter_mut().count(), 3); + } + + #[test] + fn test_iter_mut_empty() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 5); + + let mut count = 0; + for _ in view.iter_mut() { + count += 1; + } + assert_eq!(count, 0); + assert_eq!(view.iter_mut().next(), None); + } + + #[test] + fn test_zero_capacity() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 0); + + assert_eq!(view.len(), 0); + assert_eq!(view.capacity(), 0); + assert!(view.is_empty()); + + let err = view.push(TestStruct::new(1, 1)).unwrap_err(); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + + let err = view.remove(0).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); + } + + #[test] + fn test_default_length_type() { + let capacity = 2; + let mut buffer = vec![]; + let size = ListView::::size_of(capacity).unwrap(); + buffer.resize(size, 0); + + // Initialize the view *without* specifying L. The compiler uses the default. + let view = ListView::::init(&mut buffer).unwrap(); + + // Check that the capacity is correct for a PodU64 length. + assert_eq!(view.capacity(), capacity); + assert_eq!(view.len(), 0); + + // Verify the size of the length field. + assert_eq!(size_of_val(view.length), size_of::()); + } + + #[test] + fn test_bytes_used_and_allocated_mut() { + // capacity 3, start empty + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 3); + + // Empty view + assert_eq!( + view.bytes_used().unwrap(), + ListView::::size_of(0).unwrap() + ); + assert_eq!( + view.bytes_allocated().unwrap(), + ListView::::size_of(view.capacity()).unwrap() + ); + + // After pushing elements + view.push(TestStruct::new(1, 2)).unwrap(); + view.push(TestStruct::new(3, 4)).unwrap(); + view.push(TestStruct::new(5, 6)).unwrap(); + assert_eq!( + view.bytes_used().unwrap(), + ListView::::size_of(3).unwrap() + ); + assert_eq!( + view.bytes_allocated().unwrap(), + ListView::::size_of(view.capacity()).unwrap() + ); + } + #[test] + fn test_get_and_get_mut() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 3); + + let item0 = TestStruct::new(1, 10); + let item1 = TestStruct::new(2, 20); + view.push(item0).unwrap(); + view.push(item1).unwrap(); + + // Test get() + assert_eq!(view.first(), Some(&item0)); + assert_eq!(view.get(1), Some(&item1)); + assert_eq!(view.get(2), None); // out of bounds + assert_eq!(view.get(100), None); // way out of bounds + + // Test get_mut() to modify an item + let modified_item0 = TestStruct::new(111, 110); + let item_ref = view.get_mut(0).unwrap(); + *item_ref = modified_item0; + + // Verify the modification + assert_eq!(view.first(), Some(&modified_item0)); + assert_eq!(*view, [modified_item0, item1]); + + // Test get_mut() out of bounds + assert_eq!(view.get_mut(2), None); + } + + #[test] + fn test_mutable_access_via_indexing() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 3); + + let item0 = TestStruct::new(1, 10); + let item1 = TestStruct::new(2, 20); + view.push(item0).unwrap(); + view.push(item1).unwrap(); + + assert_eq!(view.len(), 2); + + // Modify via the mutable slice + view[0].a = 99; + + let expected_item0 = TestStruct::new(99, 10); + assert_eq!(view.first(), Some(&expected_item0)); + assert_eq!(*view, [expected_item0, item1]); + } + + #[test] + fn test_sort_by() { + let mut buffer = vec![]; + let mut view = init_view_mut::(&mut buffer, 5); + + let item0 = TestStruct::new(5, 1); + let item1 = TestStruct::new(2, 2); + let item2 = TestStruct::new(5, 3); + let item3 = TestStruct::new(1, 4); + let item4 = TestStruct::new(2, 5); + + view.push(item0).unwrap(); + view.push(item1).unwrap(); + view.push(item2).unwrap(); + view.push(item3).unwrap(); + view.push(item4).unwrap(); + + // Sort by `b` field in descending order. + view.sort_by(|a, b| b.b.cmp(&a.b)); + let expected_order_by_b_desc = [ + item4, // b: 5 + item3, // b: 4 + item2, // b: 3 + item1, // b: 2 + item0, // b: 1 + ]; + assert_eq!(*view, expected_order_by_b_desc); + + // Now, sort by `a` in ascending order. A stable sort preserves the relative + // order of equal elements from the previous state of the list. + view.sort_by(|x, y| x.a.cmp(&y.a)); + + let expected_order_by_a_stable = [ + item3, // a: 1 + item4, // a: 2 (was before item1 in the previous state) + item1, // a: 2 + item2, // a: 5 (was before item0 in the previous state) + item0, // a: 5 + ]; + assert_eq!(*view, expected_order_by_a_stable); + } +} diff --git a/pod/src/list/list_view_read_only.rs b/pod/src/list/list_view_read_only.rs new file mode 100644 index 00000000..6d44379a --- /dev/null +++ b/pod/src/list/list_view_read_only.rs @@ -0,0 +1,202 @@ +//! `ListViewReadOnly`, a read-only, compact, zero-copy array wrapper. + +use { + crate::{list::list_trait::List, pod_length::PodLength, primitives::PodU32}, + bytemuck::Pod, + std::ops::Deref, +}; + +#[derive(Debug)] +pub struct ListViewReadOnly<'data, T: Pod, L: PodLength = PodU32> { + pub(crate) length: &'data L, + pub(crate) data: &'data [T], + pub(crate) capacity: usize, +} + +impl List for ListViewReadOnly<'_, T, L> { + type Item = T; + type Length = L; + + fn capacity(&self) -> usize { + self.capacity + } +} + +impl Deref for ListViewReadOnly<'_, T, L> { + type Target = [T]; + + fn deref(&self) -> &Self::Target { + let len = (*self.length).into(); + &self.data[..len] + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + crate::{ + list::ListView, + pod_length::PodLength, + primitives::{PodU32, PodU64}, + }, + bytemuck_derive::{Pod as DerivePod, Zeroable}, + std::mem::size_of, + }; + + #[repr(C, align(16))] + #[derive(DerivePod, Zeroable, Copy, Clone, Debug, PartialEq)] + struct TestStruct(u128); + + /// Helper to build a byte buffer that conforms to the `ListView` layout. + fn build_test_buffer( + length: usize, + capacity: usize, + items: &[T], + ) -> Vec { + let size = ListView::::size_of(capacity).unwrap(); + let mut buffer = vec![0u8; size]; + + // Write the length prefix + let pod_len = L::try_from(length).unwrap(); + let len_bytes = bytemuck::bytes_of(&pod_len); + buffer[0..size_of::()].copy_from_slice(len_bytes); + + // Write the data items, accounting for padding + if !items.is_empty() { + let data_start = ListView::::size_of(0).unwrap(); + let items_bytes = bytemuck::cast_slice(items); + buffer[data_start..data_start.saturating_add(items_bytes.len())] + .copy_from_slice(items_bytes); + } + + buffer + } + + #[test] + fn test_len_and_capacity() { + let items = [10u32, 20, 30]; + let buffer = build_test_buffer::(items.len(), 5, &items); + let view = ListView::::unpack(&buffer).unwrap(); + + assert_eq!(view.len(), 3); + assert_eq!(view.capacity(), 5); + } + + #[test] + fn test_as_slice() { + let items = [10u32, 20, 30]; + // Buffer has capacity for 5, but we only use 3. + let buffer = build_test_buffer::(items.len(), 5, &items); + let view = ListView::::unpack(&buffer).unwrap(); + + // `as_slice()` should only return the first `len` items. + assert_eq!(*view, items[..]); + } + + #[test] + fn test_is_empty() { + // Not empty + let buffer_full = build_test_buffer::(1, 2, &[10]); + let view_full = ListView::::unpack(&buffer_full).unwrap(); + assert!(!view_full.is_empty()); + + // Empty + let buffer_empty = build_test_buffer::(0, 2, &[]); + let view_empty = ListView::::unpack(&buffer_empty).unwrap(); + assert!(view_empty.is_empty()); + } + + #[test] + fn test_iter() { + let items = [TestStruct(1), TestStruct(2)]; + let buffer = build_test_buffer::(items.len(), 3, &items); + let view = ListView::::unpack(&buffer).unwrap(); + + let mut iter = view.iter(); + assert_eq!(iter.next(), Some(&items[0])); + assert_eq!(iter.next(), Some(&items[1])); + assert_eq!(iter.next(), None); + let collected: Vec<_> = view.iter().collect(); + assert_eq!(collected, vec![&items[0], &items[1]]); + } + + #[test] + fn test_iter_on_empty_list() { + let buffer = build_test_buffer::(0, 5, &[]); + let view = ListView::::unpack(&buffer).unwrap(); + + assert_eq!(view.iter().count(), 0); + assert_eq!(view.iter().next(), None); + } + + #[test] + fn test_zero_capacity() { + // Buffer is just big enough for the header (len + padding), no data. + let buffer = build_test_buffer::(0, 0, &[]); + let view = ListView::::unpack(&buffer).unwrap(); + + assert_eq!(view.len(), 0); + assert_eq!(view.capacity(), 0); + assert!(view.is_empty()); + assert_eq!(*view, []); + } + + #[test] + fn test_with_padding() { + // Test the effect of padding by checking the total header size. + // T=AlignedStruct (align 16), L=PodU32 (size 4). + // The header size should be 16 (4 for len + 12 for padding). + let header_size = ListView::::size_of(0).unwrap(); + assert_eq!(header_size, 16); + + let items = [TestStruct(123), TestStruct(456)]; + let buffer = build_test_buffer::(items.len(), 4, &items); + let view = ListView::::unpack(&buffer).unwrap(); + + // Check if the public API works as expected despite internal padding + assert_eq!(view.len(), 2); + assert_eq!(view.capacity(), 4); + assert_eq!(*view, items[..]); + } + + #[test] + fn test_bytes_used_and_allocated() { + // 3 live elements, capacity 5 + let items = [10u32, 20, 30]; + let capacity = 5; + let buffer = build_test_buffer::(items.len(), capacity, &items); + let view = ListView::::unpack(&buffer).unwrap(); + + let expected_used = ListView::::size_of(view.len()).unwrap(); + let expected_cap = ListView::::size_of(view.capacity()).unwrap(); + + assert_eq!(view.bytes_used().unwrap(), expected_used); + assert_eq!(view.bytes_allocated().unwrap(), expected_cap); + } + + #[test] + fn test_get() { + let items = [10u32, 20, 30]; + let buffer = build_test_buffer::(items.len(), 5, &items); + let view = ListView::::unpack(&buffer).unwrap(); + + // Get in-bounds elements + assert_eq!(view.first(), Some(&10u32)); + assert_eq!(view.get(1), Some(&20u32)); + assert_eq!(view.get(2), Some(&30u32)); + + // Get out-of-bounds element (index == len) + assert_eq!(view.get(3), None); + + // Get way out-of-bounds + assert_eq!(view.get(100), None); + } + + #[test] + fn test_get_on_empty_list() { + let buffer = build_test_buffer::(0, 5, &[]); + let view = ListView::::unpack(&buffer).unwrap(); + assert_eq!(view.first(), None); + } +} diff --git a/pod/src/list/mod.rs b/pod/src/list/mod.rs new file mode 100644 index 00000000..56062237 --- /dev/null +++ b/pod/src/list/mod.rs @@ -0,0 +1,9 @@ +mod list_trait; +mod list_view; +mod list_view_mut; +mod list_view_read_only; + +pub use { + list_trait::List, list_view::ListView, list_view_mut::ListViewMut, + list_view_read_only::ListViewReadOnly, +}; diff --git a/pod/src/option.rs b/pod/src/option.rs index 0a080ecc..02d7edd0 100644 --- a/pod/src/option.rs +++ b/pod/src/option.rs @@ -6,28 +6,18 @@ //! [`Option`](https://doc.rust-lang.org/std/num/type.NonZeroU64.html) //! and provide the same memory layout optimization. -#[cfg(feature = "bytemuck")] -use bytemuck::{Pod, Zeroable}; -#[cfg(feature = "serde")] -use serde::{Deserialize, Deserializer, Serialize, Serializer}; -#[cfg(feature = "wincode")] -use wincode_derive::{SchemaRead, SchemaWrite}; -#[cfg(feature = "borsh")] use { - alloc::format, - borsh::{BorshDeserialize, BorshSchema, BorshSerialize}, -}; -use { - solana_address::{Address, ADDRESS_BYTES}, + bytemuck::{Pod, Zeroable}, solana_program_error::ProgramError, solana_program_option::COption, + solana_pubkey::{Pubkey, PUBKEY_BYTES}, }; /// Trait for types that can be `None`. /// /// This trait is used to indicate that a type can be `None` according to a /// specific value. -pub trait Nullable: PartialEq + Sized { +pub trait Nullable: PartialEq + Pod + Sized { /// Value that represents `None` for the type. const NONE: Self; @@ -48,11 +38,6 @@ pub trait Nullable: PartialEq + Sized { /// This can be used when a specific value of `T` indicates that its /// value is `None`. #[repr(transparent)] -#[cfg_attr( - feature = "borsh", - derive(BorshDeserialize, BorshSerialize, BorshSchema) -)] -#[cfg_attr(feature = "wincode", derive(SchemaRead, SchemaWrite))] #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] pub struct PodOption(T); @@ -92,39 +77,19 @@ impl PodOption { Some(&mut self.0) } } - - /// Maps a `PodOption` to an `Option` by copying the contents of the option. - #[inline] - pub fn copied(&self) -> Option - where - T: Copy, - { - self.as_ref().copied() - } - - /// Maps a `PodOption` to an `Option` by cloning the contents of the option. - #[inline] - pub fn cloned(&self) -> Option - where - T: Clone, - { - self.as_ref().cloned() - } } /// ## Safety /// /// `PodOption` is a transparent wrapper around a `Pod` type `T` with identical /// data representation. -#[cfg(feature = "bytemuck")] -unsafe impl Pod for PodOption {} +unsafe impl Pod for PodOption {} /// ## Safety /// /// `PodOption` is a transparent wrapper around a `Pod` type `T` with identical /// data representation. -#[cfg(feature = "bytemuck")] -unsafe impl Zeroable for PodOption {} +unsafe impl Zeroable for PodOption {} impl From for PodOption { fn from(value: T) -> Self { @@ -132,22 +97,6 @@ impl From for PodOption { } } -impl From> for Option { - fn from(value: PodOption) -> Self { - value.get() - } -} - -impl From> for COption { - fn from(value: PodOption) -> Self { - if value.0.is_none() { - COption::None - } else { - COption::Some(value.0) - } - } -} - impl TryFrom> for PodOption { type Error = ProgramError; @@ -172,310 +121,68 @@ impl TryFrom> for PodOption { } } -/// Implementation of `Nullable` for `Address`. -impl Nullable for Address { - const NONE: Self = Address::new_from_array([0u8; ADDRESS_BYTES]); -} - -#[cfg(feature = "serde")] -impl Serialize for PodOption -where - T: Nullable + Serialize, -{ - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - if self.0.is_none() { - serializer.serialize_none() - } else { - serializer.serialize_some(&self.0) - } - } -} - -#[cfg(feature = "serde")] -impl<'de, T> Deserialize<'de> for PodOption -where - T: Nullable + Deserialize<'de>, -{ - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let option = Option::::deserialize(deserializer)?; - match option { - Some(value) if value.is_none() => Err(serde::de::Error::custom( - "Invalid PodOption encoding: Some(value) cannot equal the none marker.", - )), - Some(value) => Ok(PodOption(value)), - None => Ok(PodOption(T::NONE)), - } - } +/// Implementation of `Nullable` for `Pubkey`. +impl Nullable for Pubkey { + const NONE: Self = Pubkey::new_from_array([0u8; PUBKEY_BYTES]); } #[cfg(test)] mod tests { - use super::*; - - const ID: Address = Address::new_from_array([8; ADDRESS_BYTES]); + use {super::*, crate::bytemuck::pod_slice_from_bytes}; + const ID: Pubkey = Pubkey::from_str_const("TestSysvar111111111111111111111111111111111"); #[test] - fn test_try_from_option() { - let some_address = Some(ID); - assert_eq!(PodOption::try_from(some_address).unwrap(), PodOption(ID)); + fn test_pod_option_pubkey() { + let some_pubkey = PodOption::from(ID); + assert_eq!(some_pubkey.get(), Some(ID)); - let none_address = None; - assert_eq!( - PodOption::try_from(none_address).unwrap(), - PodOption::from(Address::NONE) - ); + let none_pubkey = PodOption::from(Pubkey::default()); + assert_eq!(none_pubkey.get(), None); - let invalid_option = Some(Address::NONE); - let err = PodOption::try_from(invalid_option).unwrap_err(); - assert_eq!(err, ProgramError::InvalidArgument); - } - - #[test] - fn test_try_from_coption_reject_some_zero_address() { - let invalid_option = COption::Some(Address::NONE); - let err = PodOption::try_from(invalid_option).unwrap_err(); - assert_eq!(err, ProgramError::InvalidArgument); - } + let mut data = Vec::with_capacity(64); + data.extend_from_slice(ID.as_ref()); + data.extend_from_slice(&[0u8; 32]); - #[test] - fn test_from_pod_option() { - let some = PodOption::from(ID); - let none = PodOption::from(Address::NONE); + let values = pod_slice_from_bytes::>(&data).unwrap(); + assert_eq!(values[0], PodOption::from(ID)); + assert_eq!(values[1], PodOption::from(Pubkey::default())); - assert_eq!(Option::
::from(some), Some(ID)); - assert_eq!(Option::
::from(none), None); - assert_eq!(COption::
::from(some), COption::Some(ID)); - assert_eq!(COption::
::from(none), COption::None); - } + let option_pubkey = Some(ID); + let pod_option_pubkey: PodOption = option_pubkey.try_into().unwrap(); + assert_eq!(pod_option_pubkey, PodOption::from(ID)); + assert_eq!( + pod_option_pubkey, + PodOption::try_from(option_pubkey).unwrap() + ); - #[test] - fn test_default() { - let def = PodOption::
::default(); - assert_eq!(def, None.try_into().unwrap()); + let coption_pubkey = COption::Some(ID); + let pod_option_pubkey: PodOption = coption_pubkey.try_into().unwrap(); + assert_eq!(pod_option_pubkey, PodOption::from(ID)); + assert_eq!( + pod_option_pubkey, + PodOption::try_from(coption_pubkey).unwrap() + ); } #[test] - fn test_copied() { - let some_address = PodOption::from(ID); - assert_eq!(some_address.copied(), Some(ID)); - - let none_address = PodOption::from(Address::NONE); - assert_eq!(none_address.copied(), None); - } + fn test_try_from_option() { + let some_pubkey = Some(ID); + assert_eq!(PodOption::try_from(some_pubkey).unwrap(), PodOption(ID)); - #[test] - fn test_as_mut() { - let mut some = PodOption::from(Address::new_from_array([3; ADDRESS_BYTES])); - assert!(some.as_mut().is_some()); - *some.as_mut().unwrap() = Address::new_from_array([4; ADDRESS_BYTES]); + let none_pubkey = None; assert_eq!( - some.get(), - Some(Address::new_from_array([4; ADDRESS_BYTES])) + PodOption::try_from(none_pubkey).unwrap(), + PodOption::from(Pubkey::NONE) ); - let mut none = PodOption::from(Address::NONE); - assert!(none.as_mut().is_none()); - } - - #[derive(Clone, Debug, PartialEq)] - struct TestNonCopyNullable([u8; 4]); - - impl Nullable for TestNonCopyNullable { - const NONE: Self = Self([0u8; 4]); - } - - impl Nullable for u64 { - const NONE: Self = 0; + let invalid_option = Some(Pubkey::NONE); + let err = PodOption::try_from(invalid_option).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); } #[test] - fn test_cloned_with_non_copy_nullable() { - let some = PodOption::from(TestNonCopyNullable([1, 2, 3, 4])); - assert_eq!(some.cloned(), Some(TestNonCopyNullable([1, 2, 3, 4]))); - - let none = PodOption::from(TestNonCopyNullable::NONE); - assert_eq!(none.cloned(), None); - } - - #[cfg(feature = "borsh")] - mod borsh_tests { - use {super::*, alloc::vec}; - - #[test] - fn test_borsh_roundtrip_and_encoding() { - let some = PodOption::from(Address::new_from_array([1; ADDRESS_BYTES])); - let none = PodOption::from(Address::NONE); - - let some_bytes = borsh::to_vec(&some).unwrap(); - let none_bytes = borsh::to_vec(&none).unwrap(); - - assert_eq!(some_bytes, vec![1; ADDRESS_BYTES]); - assert_eq!(none_bytes, vec![0; ADDRESS_BYTES]); - assert_eq!( - borsh::from_slice::>(&some_bytes).unwrap(), - some - ); - assert_eq!( - borsh::from_slice::>(&none_bytes).unwrap(), - none - ); - assert!(borsh::from_slice::>(&[]).is_err()); - } - } - - #[cfg(feature = "wincode")] - mod wincode_tests { - use super::*; - - #[test] - fn test_wincode_pod_option_roundtrip_and_size() { - let some = PodOption::from(9u64); - let none = PodOption::from(0u64); - - let some_bytes = wincode::serialize(&some).unwrap(); - let none_bytes = wincode::serialize(&none).unwrap(); - - assert_eq!(some_bytes.len(), core::mem::size_of::()); - assert_eq!(none_bytes.len(), core::mem::size_of::()); - assert_eq!(some_bytes.as_slice(), &9u64.to_le_bytes()); - assert_eq!(none_bytes.as_slice(), &0u64.to_le_bytes()); - - let some_roundtrip: PodOption = wincode::deserialize(&some_bytes).unwrap(); - let none_roundtrip: PodOption = wincode::deserialize(&none_bytes).unwrap(); - assert_eq!(some_roundtrip, some); - assert_eq!(none_roundtrip, none); - } - - #[test] - fn test_wincode_pod_option_rejects_truncated_input() { - assert!(wincode::deserialize::>(&[]).is_err()); - assert!(wincode::deserialize::>(&[0; 7]).is_err()); - } - } - - #[cfg(feature = "serde")] - mod serde_tests { - use {super::*, alloc::string::ToString}; - - #[test] - fn test_serde_some() { - let some = PodOption::from(Address::new_from_array([1; ADDRESS_BYTES])); - let serialized = serde_json::to_string(&some).unwrap(); - assert_eq!( - &serialized, - "[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]" - ); - let deserialized = serde_json::from_str::>(&serialized).unwrap(); - assert_eq!(some, deserialized); - } - - #[test] - fn test_serde_none() { - let none = PodOption::from(Address::new_from_array([0; ADDRESS_BYTES])); - let serialized = serde_json::to_string(&none).unwrap(); - assert_eq!(&serialized, "null"); - let deserialized = serde_json::from_str::>(&serialized).unwrap(); - assert_eq!(none, deserialized); - } - - #[test] - fn test_serde_reject_zero_address_bytes() { - let zero_bytes = "[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]"; - assert!(serde_json::from_str::>(zero_bytes).is_err()); - } - - #[test] - fn test_serde_reject_invalid_address_string() { - assert!(serde_json::from_str::>("\"not_an_address\"").is_err()); - } - - #[test] - fn test_serde_u64_some() { - let some = PodOption::from(7u64); - let serialized = serde_json::to_string(&some).unwrap(); - assert_eq!(serialized, "7"); - let deserialized = serde_json::from_str::>(&serialized).unwrap(); - assert_eq!(deserialized, some); - } - - #[test] - fn test_serde_u64_none() { - let deserialized = serde_json::from_str::>("null").unwrap(); - assert_eq!(deserialized, PodOption::from(0)); - } - - #[test] - fn test_serde_u64_none_marker_error_message() { - let err = serde_json::from_str::>("0").unwrap_err(); - let message = err.to_string(); - assert!(message.contains("PodOption encoding")); - assert!(message.contains("none marker")); - } - - #[test] - fn test_serde_u64_reject_invalid_input() { - assert!(serde_json::from_str::>("\"abc\"").is_err()); - assert!(serde_json::from_str::>("{}").is_err()); - } - } - - #[cfg(feature = "bytemuck")] - mod bytemuck_tests { - use { - super::*, - crate::bytemuck::{pod_from_bytes, pod_slice_from_bytes}, - alloc::vec::Vec, - }; - - #[test] - fn test_pod_option_address() { - let some_address = PodOption::from(ID); - assert_eq!(some_address.get(), Some(ID)); - - let none_address = PodOption::from(Address::default()); - assert_eq!(none_address.get(), None); - - let mut data = Vec::with_capacity(64); - data.extend_from_slice(ID.as_ref()); - data.extend_from_slice(&[0u8; 32]); - - let values = pod_slice_from_bytes::>(&data).unwrap(); - assert_eq!(values[0], PodOption::from(ID)); - assert_eq!(values[1], PodOption::from(Address::default())); - } - - #[test] - fn test_pod_from_bytes() { - assert_eq!( - Option::
::from( - *pod_from_bytes::>(&[1; ADDRESS_BYTES]).unwrap() - ), - Some(Address::new_from_array([1; ADDRESS_BYTES])), - ); - assert_eq!( - Option::
::from( - *pod_from_bytes::>(&[0; ADDRESS_BYTES]).unwrap() - ), - None, - ); - assert_eq!( - pod_from_bytes::>(&[]).unwrap_err(), - ProgramError::InvalidArgument - ); - assert_eq!( - pod_from_bytes::>(&[0; 1]).unwrap_err(), - ProgramError::InvalidArgument - ); - assert_eq!( - pod_from_bytes::>(&[1; 1]).unwrap_err(), - ProgramError::InvalidArgument - ); - } + fn test_default() { + let def = PodOption::::default(); + assert_eq!(def, None.try_into().unwrap()); } } diff --git a/pod/src/optional_keys.rs b/pod/src/optional_keys.rs new file mode 100644 index 00000000..82a0726e --- /dev/null +++ b/pod/src/optional_keys.rs @@ -0,0 +1,360 @@ +//! Optional pubkeys that can be used a `Pod`s +#[cfg(feature = "borsh")] +use borsh::{BorshDeserialize, BorshSchema, BorshSerialize}; +use { + bytemuck_derive::{Pod, Zeroable}, + solana_program_error::ProgramError, + solana_program_option::COption, + solana_pubkey::Pubkey, + solana_zk_sdk::encryption::pod::elgamal::PodElGamalPubkey, +}; +#[cfg(feature = "serde-traits")] +use { + serde::de::{Error, Unexpected, Visitor}, + serde::{Deserialize, Deserializer, Serialize, Serializer}, + std::{convert::TryFrom, fmt, str::FromStr}, +}; + +/// A Pubkey that encodes `None` as all `0`, meant to be usable as a `Pod` type, +/// similar to all `NonZero*` number types from the `bytemuck` library. +#[cfg_attr( + feature = "borsh", + derive(BorshDeserialize, BorshSerialize, BorshSchema) +)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] +#[repr(transparent)] +pub struct OptionalNonZeroPubkey(pub Pubkey); +impl TryFrom> for OptionalNonZeroPubkey { + type Error = ProgramError; + fn try_from(p: Option) -> Result { + match p { + None => Ok(Self(Pubkey::default())), + Some(pubkey) => { + if pubkey == Pubkey::default() { + Err(ProgramError::InvalidArgument) + } else { + Ok(Self(pubkey)) + } + } + } + } +} +impl TryFrom> for OptionalNonZeroPubkey { + type Error = ProgramError; + fn try_from(p: COption) -> Result { + match p { + COption::None => Ok(Self(Pubkey::default())), + COption::Some(pubkey) => { + if pubkey == Pubkey::default() { + Err(ProgramError::InvalidArgument) + } else { + Ok(Self(pubkey)) + } + } + } + } +} +impl From for Option { + fn from(p: OptionalNonZeroPubkey) -> Self { + if p.0 == Pubkey::default() { + None + } else { + Some(p.0) + } + } +} +impl From for COption { + fn from(p: OptionalNonZeroPubkey) -> Self { + if p.0 == Pubkey::default() { + COption::None + } else { + COption::Some(p.0) + } + } +} + +#[cfg(feature = "serde-traits")] +impl Serialize for OptionalNonZeroPubkey { + fn serialize(&self, s: S) -> Result + where + S: Serializer, + { + if self.0 == Pubkey::default() { + s.serialize_none() + } else { + s.serialize_some(&self.0.to_string()) + } + } +} + +#[cfg(feature = "serde-traits")] +/// Visitor for deserializing `OptionalNonZeroPubkey` +struct OptionalNonZeroPubkeyVisitor; + +#[cfg(feature = "serde-traits")] +impl Visitor<'_> for OptionalNonZeroPubkeyVisitor { + type Value = OptionalNonZeroPubkey; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a Pubkey in base58 or `null`") + } + + fn visit_str(self, v: &str) -> Result + where + E: Error, + { + let pkey = Pubkey::from_str(v) + .map_err(|_| Error::invalid_value(Unexpected::Str(v), &"value string"))?; + + OptionalNonZeroPubkey::try_from(Some(pkey)) + .map_err(|_| Error::custom("Failed to convert from pubkey")) + } + + fn visit_unit(self) -> Result + where + E: Error, + { + OptionalNonZeroPubkey::try_from(None).map_err(|e| Error::custom(e.to_string())) + } +} + +#[cfg(feature = "serde-traits")] +impl<'de> Deserialize<'de> for OptionalNonZeroPubkey { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_any(OptionalNonZeroPubkeyVisitor) + } +} + +/// An `ElGamalPubkey` that encodes `None` as all `0`, meant to be usable as a +/// `Pod` type. +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] +#[repr(transparent)] +pub struct OptionalNonZeroElGamalPubkey(PodElGamalPubkey); +impl OptionalNonZeroElGamalPubkey { + /// Checks equality between an `OptionalNonZeroElGamalPubkey` and an + /// `ElGamalPubkey` when interpreted as bytes. + pub fn equals(&self, other: &PodElGamalPubkey) -> bool { + &self.0 == other + } +} +impl TryFrom> for OptionalNonZeroElGamalPubkey { + type Error = ProgramError; + fn try_from(p: Option) -> Result { + match p { + None => Ok(Self(PodElGamalPubkey::default())), + Some(elgamal_pubkey) => { + if elgamal_pubkey == PodElGamalPubkey::default() { + Err(ProgramError::InvalidArgument) + } else { + Ok(Self(elgamal_pubkey)) + } + } + } + } +} +impl From for Option { + fn from(p: OptionalNonZeroElGamalPubkey) -> Self { + if p.0 == PodElGamalPubkey::default() { + None + } else { + Some(p.0) + } + } +} + +#[cfg(feature = "serde-traits")] +impl Serialize for OptionalNonZeroElGamalPubkey { + fn serialize(&self, s: S) -> Result + where + S: Serializer, + { + if self.0 == PodElGamalPubkey::default() { + s.serialize_none() + } else { + s.serialize_some(&self.0.to_string()) + } + } +} + +#[cfg(feature = "serde-traits")] +struct OptionalNonZeroElGamalPubkeyVisitor; + +#[cfg(feature = "serde-traits")] +impl Visitor<'_> for OptionalNonZeroElGamalPubkeyVisitor { + type Value = OptionalNonZeroElGamalPubkey; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("an ElGamal public key as base64 or `null`") + } + + fn visit_str(self, v: &str) -> Result + where + E: Error, + { + let elgamal_pubkey: PodElGamalPubkey = FromStr::from_str(v).map_err(Error::custom)?; + OptionalNonZeroElGamalPubkey::try_from(Some(elgamal_pubkey)).map_err(Error::custom) + } + + fn visit_unit(self) -> Result + where + E: Error, + { + Ok(OptionalNonZeroElGamalPubkey::default()) + } +} + +#[cfg(feature = "serde-traits")] +impl<'de> Deserialize<'de> for OptionalNonZeroElGamalPubkey { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_any(OptionalNonZeroElGamalPubkeyVisitor) + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + crate::bytemuck::pod_from_bytes, + base64::{prelude::BASE64_STANDARD, Engine}, + solana_pubkey::PUBKEY_BYTES, + }; + + #[test] + fn test_pod_non_zero_option() { + assert_eq!( + Some(Pubkey::new_from_array([1; PUBKEY_BYTES])), + Option::::from( + *pod_from_bytes::(&[1; PUBKEY_BYTES]).unwrap() + ) + ); + assert_eq!( + None, + Option::::from( + *pod_from_bytes::(&[0; PUBKEY_BYTES]).unwrap() + ) + ); + assert_eq!( + pod_from_bytes::(&[]).unwrap_err(), + ProgramError::InvalidArgument + ); + assert_eq!( + pod_from_bytes::(&[0; 1]).unwrap_err(), + ProgramError::InvalidArgument + ); + assert_eq!( + pod_from_bytes::(&[1; 1]).unwrap_err(), + ProgramError::InvalidArgument + ); + } + + #[cfg(feature = "serde-traits")] + #[test] + fn test_pod_non_zero_option_serde_some() { + let optional_non_zero_pubkey_some = + OptionalNonZeroPubkey(Pubkey::new_from_array([1; PUBKEY_BYTES])); + let serialized_some = serde_json::to_string(&optional_non_zero_pubkey_some).unwrap(); + assert_eq!( + &serialized_some, + "\"4vJ9JU1bJJE96FWSJKvHsmmFADCg4gpZQff4P3bkLKi\"" + ); + + let deserialized_some = + serde_json::from_str::(&serialized_some).unwrap(); + assert_eq!(optional_non_zero_pubkey_some, deserialized_some); + } + + #[cfg(feature = "serde-traits")] + #[test] + fn test_pod_non_zero_option_serde_none() { + let optional_non_zero_pubkey_none = + OptionalNonZeroPubkey(Pubkey::new_from_array([0; PUBKEY_BYTES])); + let serialized_none = serde_json::to_string(&optional_non_zero_pubkey_none).unwrap(); + assert_eq!(&serialized_none, "null"); + + let deserialized_none = + serde_json::from_str::(&serialized_none).unwrap(); + assert_eq!(optional_non_zero_pubkey_none, deserialized_none); + } + + const OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN: usize = 32; + + // Unfortunately, the `solana-zk-sdk` does not expose a constructor interface + // to construct `PodRistrettoPoint` from bytes. As a work-around, encode the + // bytes as base64 string and then convert the string to a + // `PodElGamalCiphertext`. + // + // The constructor will be added (and this function removed) with + // `solana-zk-sdk` 2.1. + fn elgamal_pubkey_from_bytes(bytes: &[u8]) -> PodElGamalPubkey { + let string = BASE64_STANDARD.encode(bytes); + std::str::FromStr::from_str(&string).unwrap() + } + + #[test] + fn test_pod_non_zero_elgamal_option() { + assert_eq!( + Some(elgamal_pubkey_from_bytes( + &[1; OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN] + )), + Option::::from(OptionalNonZeroElGamalPubkey( + elgamal_pubkey_from_bytes(&[1; OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN]) + )) + ); + assert_eq!( + None, + Option::::from(OptionalNonZeroElGamalPubkey( + elgamal_pubkey_from_bytes(&[0; OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN]) + )) + ); + + assert_eq!( + OptionalNonZeroElGamalPubkey(elgamal_pubkey_from_bytes( + &[1; OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN] + )), + *pod_from_bytes::( + &[1; OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN] + ) + .unwrap() + ); + assert!(pod_from_bytes::(&[]).is_err()); + } + + #[cfg(feature = "serde-traits")] + #[test] + fn test_pod_non_zero_elgamal_option_serde_some() { + let optional_non_zero_elgamal_pubkey_some = OptionalNonZeroElGamalPubkey( + elgamal_pubkey_from_bytes(&[1; OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN]), + ); + let serialized_some = + serde_json::to_string(&optional_non_zero_elgamal_pubkey_some).unwrap(); + assert_eq!( + &serialized_some, + "\"AQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQE=\"" + ); + + let deserialized_some = + serde_json::from_str::(&serialized_some).unwrap(); + assert_eq!(optional_non_zero_elgamal_pubkey_some, deserialized_some); + } + + #[cfg(feature = "serde-traits")] + #[test] + fn test_pod_non_zero_elgamal_option_serde_none() { + let optional_non_zero_elgamal_pubkey_none = OptionalNonZeroElGamalPubkey( + elgamal_pubkey_from_bytes(&[0; OPTIONAL_NONZERO_ELGAMAL_PUBKEY_LEN]), + ); + let serialized_none = + serde_json::to_string(&optional_non_zero_elgamal_pubkey_none).unwrap(); + assert_eq!(&serialized_none, "null"); + + let deserialized_none = + serde_json::from_str::(&serialized_none).unwrap(); + assert_eq!(optional_non_zero_elgamal_pubkey_none, deserialized_none); + } +} diff --git a/pod/src/pod_length.rs b/pod/src/pod_length.rs new file mode 100644 index 00000000..0f46b574 --- /dev/null +++ b/pod/src/pod_length.rs @@ -0,0 +1,41 @@ +use { + crate::{ + error::PodSliceError, + primitives::{PodU128, PodU16, PodU32, PodU64}, + }, + bytemuck::Pod, +}; + +/// Marker trait for converting to/from Pod `uint`'s and `usize` +pub trait PodLength: Pod + Into + TryFrom {} + +/// Blanket implementation to automatically implement `PodLength` for any type +/// that satisfies the required bounds. +impl PodLength for T where T: Pod + Into + TryFrom {} + +/// Implements the `TryFrom` and `From for usize` conversions for a Pod integer type +macro_rules! impl_pod_length_for { + ($PodType:ty, $PrimitiveType:ty) => { + impl TryFrom for $PodType { + type Error = PodSliceError; + + fn try_from(val: usize) -> Result { + let primitive_val = <$PrimitiveType>::try_from(val)?; + Ok(primitive_val.into()) + } + } + + impl From<$PodType> for usize { + fn from(pod_val: $PodType) -> Self { + let primitive_val = <$PrimitiveType>::from(pod_val); + Self::try_from(primitive_val) + .expect("value out of range for usize on this platform") + } + } + }; +} + +impl_pod_length_for!(PodU16, u16); +impl_pod_length_for!(PodU32, u32); +impl_pod_length_for!(PodU64, u64); +impl_pod_length_for!(PodU128, u128); diff --git a/pod/src/primitives.rs b/pod/src/primitives.rs index 6f841b37..23953b2f 100644 --- a/pod/src/primitives.rs +++ b/pod/src/primitives.rs @@ -1,23 +1,18 @@ //! primitive types that can be used in `Pod`s -#[cfg(feature = "bytemuck")] +#[cfg(feature = "borsh")] +use borsh::{BorshDeserialize, BorshSchema, BorshSerialize}; use bytemuck_derive::{Pod, Zeroable}; -#[cfg(feature = "serde")] -use serde_derive::{Deserialize, Serialize}; +#[cfg(feature = "serde-traits")] +use serde::{Deserialize, Serialize}; #[cfg(feature = "wincode")] -use wincode_derive::{SchemaRead, SchemaWrite}; -#[cfg(feature = "borsh")] -use { - alloc::string::ToString, - borsh::{BorshDeserialize, BorshSchema, BorshSerialize}, -}; +use wincode::{SchemaRead, SchemaWrite}; /// The standard `bool` is not a `Pod`, define a replacement that is #[cfg_attr(feature = "wincode", derive(SchemaRead, SchemaWrite))] #[cfg_attr(feature = "wincode", wincode(assert_zero_copy))] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(from = "bool", into = "bool"))] -#[cfg_attr(feature = "bytemuck", derive(Pod, Zeroable))] -#[derive(Clone, Copy, Debug, Default, PartialEq)] +#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde-traits", serde(from = "bool", into = "bool"))] +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] #[repr(transparent)] pub struct PodBool(pub u8); impl PodBool { @@ -79,10 +74,9 @@ macro_rules! impl_int_conversion { /// `u16` type that can be used in `Pod`s #[cfg_attr(feature = "wincode", derive(SchemaRead, SchemaWrite))] #[cfg_attr(feature = "wincode", wincode(assert_zero_copy))] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(from = "u16", into = "u16"))] -#[cfg_attr(feature = "bytemuck", derive(Pod, Zeroable))] -#[derive(Clone, Copy, Debug, Default, PartialEq)] +#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde-traits", serde(from = "u16", into = "u16"))] +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] #[repr(transparent)] pub struct PodU16(pub [u8; 2]); impl_int_conversion!(PodU16, u16); @@ -90,10 +84,9 @@ impl_int_conversion!(PodU16, u16); /// `i16` type that can be used in Pods #[cfg_attr(feature = "wincode", derive(SchemaRead, SchemaWrite))] #[cfg_attr(feature = "wincode", wincode(assert_zero_copy))] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(from = "i16", into = "i16"))] -#[cfg_attr(feature = "bytemuck", derive(Pod, Zeroable))] -#[derive(Clone, Copy, Debug, Default, PartialEq)] +#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde-traits", serde(from = "i16", into = "i16"))] +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] #[repr(transparent)] pub struct PodI16(pub [u8; 2]); impl_int_conversion!(PodI16, i16); @@ -105,10 +98,9 @@ impl_int_conversion!(PodI16, i16); feature = "borsh", derive(BorshDeserialize, BorshSerialize, BorshSchema) )] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(from = "u32", into = "u32"))] -#[cfg_attr(feature = "bytemuck", derive(Pod, Zeroable))] -#[derive(Clone, Copy, Debug, Default, PartialEq)] +#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde-traits", serde(from = "u32", into = "u32"))] +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] #[repr(transparent)] pub struct PodU32(pub [u8; 4]); impl_int_conversion!(PodU32, u32); @@ -120,10 +112,9 @@ impl_int_conversion!(PodU32, u32); feature = "borsh", derive(BorshDeserialize, BorshSerialize, BorshSchema) )] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(from = "u64", into = "u64"))] -#[cfg_attr(feature = "bytemuck", derive(Pod, Zeroable))] -#[derive(Clone, Copy, Debug, Default, PartialEq)] +#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde-traits", serde(from = "u64", into = "u64"))] +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] #[repr(transparent)] pub struct PodU64(pub [u8; 8]); impl_int_conversion!(PodU64, u64); @@ -131,10 +122,9 @@ impl_int_conversion!(PodU64, u64); /// `i64` type that can be used in Pods #[cfg_attr(feature = "wincode", derive(SchemaRead, SchemaWrite))] #[cfg_attr(feature = "wincode", wincode(assert_zero_copy))] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(from = "i64", into = "i64"))] -#[cfg_attr(feature = "bytemuck", derive(Pod, Zeroable))] -#[derive(Clone, Copy, Debug, Default, PartialEq)] +#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde-traits", serde(from = "i64", into = "i64"))] +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] #[repr(transparent)] pub struct PodI64([u8; 8]); impl_int_conversion!(PodI64, i64); @@ -146,48 +136,17 @@ impl_int_conversion!(PodI64, i64); feature = "borsh", derive(BorshDeserialize, BorshSerialize, BorshSchema) )] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(from = "u128", into = "u128"))] -#[cfg_attr(feature = "bytemuck", derive(Pod, Zeroable))] -#[derive(Clone, Copy, Debug, Default, PartialEq)] +#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde-traits", serde(from = "u128", into = "u128"))] +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] #[repr(transparent)] pub struct PodU128(pub [u8; 16]); impl_int_conversion!(PodU128, u128); -/// Implements the `TryFrom` and `From for usize` conversions for a Pod integer type -macro_rules! impl_usize_conversion { - ($PodType:ty, $PrimitiveType:ty) => { - impl TryFrom for $PodType { - type Error = core::num::TryFromIntError; - - fn try_from(val: usize) -> Result { - let primitive_val = <$PrimitiveType>::try_from(val)?; - Ok(primitive_val.into()) - } - } - - impl From<$PodType> for usize { - fn from(pod_val: $PodType) -> Self { - let primitive_val = <$PrimitiveType>::from(pod_val); - Self::try_from(primitive_val) - .expect("value out of range for usize on this platform") - } - } - }; -} - -impl_usize_conversion!(PodU16, u16); -impl_usize_conversion!(PodU32, u32); -impl_usize_conversion!(PodU64, u64); -impl_usize_conversion!(PodU128, u128); - #[cfg(test)] mod tests { - use super::*; - #[cfg(feature = "bytemuck")] - use crate::bytemuck::pod_from_bytes; + use {super::*, crate::bytemuck::pod_from_bytes}; - #[cfg(feature = "bytemuck")] #[test] fn test_pod_bool() { assert!(pod_from_bytes::(&[]).is_err()); @@ -198,7 +157,7 @@ mod tests { } } - #[cfg(feature = "serde")] + #[cfg(feature = "serde-traits")] #[test] fn test_pod_bool_serde() { let pod_false: PodBool = false.into(); @@ -215,14 +174,13 @@ mod tests { assert_eq!(pod_true, deserialized_true); } - #[cfg(feature = "bytemuck")] #[test] fn test_pod_u16() { assert!(pod_from_bytes::(&[]).is_err()); assert_eq!(1u16, u16::from(*pod_from_bytes::(&[1, 0]).unwrap())); } - #[cfg(feature = "serde")] + #[cfg(feature = "serde-traits")] #[test] fn test_pod_u16_serde() { let pod_u16: PodU16 = u16::MAX.into(); @@ -234,7 +192,6 @@ mod tests { assert_eq!(pod_u16, deserialized); } - #[cfg(feature = "bytemuck")] #[test] fn test_pod_i16() { assert!(pod_from_bytes::(&[]).is_err()); @@ -244,10 +201,13 @@ mod tests { ); } - #[cfg(feature = "serde")] + #[cfg(feature = "serde-traits")] #[test] fn test_pod_i16_serde() { let pod_i16: PodI16 = i16::MAX.into(); + + println!("pod_i16 {:?}", pod_i16); + let serialized = serde_json::to_string(&pod_i16).unwrap(); assert_eq!(&serialized, "32767"); @@ -255,7 +215,6 @@ mod tests { assert_eq!(pod_i16, deserialized); } - #[cfg(feature = "bytemuck")] #[test] fn test_pod_u64() { assert!(pod_from_bytes::(&[]).is_err()); @@ -265,7 +224,7 @@ mod tests { ); } - #[cfg(feature = "serde")] + #[cfg(feature = "serde-traits")] #[test] fn test_pod_u64_serde() { let pod_u64: PodU64 = u64::MAX.into(); @@ -277,7 +236,6 @@ mod tests { assert_eq!(pod_u64, deserialized); } - #[cfg(feature = "bytemuck")] #[test] fn test_pod_i64() { assert!(pod_from_bytes::(&[]).is_err()); @@ -289,7 +247,7 @@ mod tests { ); } - #[cfg(feature = "serde")] + #[cfg(feature = "serde-traits")] #[test] fn test_pod_i64_serde() { let pod_i64: PodI64 = i64::MAX.into(); @@ -301,7 +259,6 @@ mod tests { assert_eq!(pod_i64, deserialized); } - #[cfg(feature = "bytemuck")] #[test] fn test_pod_u128() { assert!(pod_from_bytes::(&[]).is_err()); @@ -314,7 +271,7 @@ mod tests { ); } - #[cfg(feature = "serde")] + #[cfg(feature = "serde-traits")] #[test] fn test_pod_u128_serde() { let pod_u128: PodU128 = u128::MAX.into(); @@ -326,31 +283,6 @@ mod tests { assert_eq!(pod_u128, deserialized); } - macro_rules! test_usize_roundtrip { - ($test_name:ident, $PodType:ty, $max:expr) => { - #[test] - fn $test_name() { - // zero - let pod = <$PodType>::try_from(0usize).unwrap(); - assert_eq!(usize::from(pod), 0); - - // mid-range - let pod = <$PodType>::try_from(42usize).unwrap(); - assert_eq!(usize::from(pod), 42); - - // max - let max = $max as usize; - let pod = <$PodType>::try_from(max).unwrap(); - assert_eq!(usize::from(pod), max); - } - }; - } - - test_usize_roundtrip!(test_usize_roundtrip_u16, PodU16, u16::MAX); - test_usize_roundtrip!(test_usize_roundtrip_u32, PodU32, u32::MAX); - test_usize_roundtrip!(test_usize_roundtrip_u64, PodU64, u64::MAX); - test_usize_roundtrip!(test_usize_roundtrip_u128, PodU128, u128::MAX); - #[cfg(feature = "wincode")] mod wincode_tests { use {super::*, test_case::test_case}; @@ -365,18 +297,14 @@ mod tests { #[test_case(PodU128::from_primitive(u128::MAX))] fn wincode_roundtrip< T: PartialEq - + core::fmt::Debug + + std::fmt::Debug + for<'de> wincode::SchemaRead<'de, wincode::config::DefaultConfig, Dst = T> + wincode::SchemaWrite, >( pod: T, ) { - let size = wincode::serialized_size(&pod).unwrap() as usize; - let mut bytes = [0u8; 32]; - assert!(size <= bytes.len()); - wincode::serialize_into(&mut bytes[..size], &pod).unwrap(); - - let deserialized: T = wincode::deserialize(&bytes[..size]).unwrap(); + let bytes = wincode::serialize(&pod).unwrap(); + let deserialized: T = wincode::deserialize(&bytes).unwrap(); assert_eq!(pod, deserialized); } } diff --git a/pod/src/slice.rs b/pod/src/slice.rs new file mode 100644 index 00000000..a5b01e77 --- /dev/null +++ b/pod/src/slice.rs @@ -0,0 +1,221 @@ +//! Special types for working with slices of `Pod`s + +use { + crate::{ + list::{ListView, ListViewMut, ListViewReadOnly}, + primitives::PodU32, + }, + bytemuck::Pod, + solana_program_error::ProgramError, +}; + +#[deprecated( + since = "0.6.0", + note = "This struct will be removed in the next major release (1.0.0). Please use `ListView` instead." +)] +/// Special type for using a slice of `Pod`s in a zero-copy way +#[allow(deprecated)] +pub struct PodSlice<'data, T: Pod> { + inner: ListViewReadOnly<'data, T, PodU32>, +} + +#[allow(deprecated)] +impl<'data, T: Pod> PodSlice<'data, T> { + /// Unpack the buffer into a slice + pub fn unpack<'a>(data: &'a [u8]) -> Result + where + 'a: 'data, + { + let inner = ListView::::unpack(data)?; + Ok(Self { inner }) + } + + /// Get the slice data + pub fn data(&self) -> &[T] { + let len = self.inner.len(); + &self.inner.data[..len] + } + + /// Get the amount of bytes used by `num_items` + pub fn size_of(num_items: usize) -> Result { + ListView::::size_of(num_items) + } +} + +#[deprecated( + since = "0.6.0", + note = "This struct will be removed in the next major release (1.0.0). Please use `ListView` instead." +)] +/// Special type for using a slice of mutable `Pod`s in a zero-copy way. +/// Uses `ListView` under the hood. +pub struct PodSliceMut<'data, T: Pod> { + inner: ListViewMut<'data, T, PodU32>, +} + +#[allow(deprecated)] +impl<'data, T: Pod> PodSliceMut<'data, T> { + /// Unpack the mutable buffer into a mutable slice + pub fn unpack<'a>(data: &'a mut [u8]) -> Result + where + 'a: 'data, + { + let inner = ListView::::unpack_mut(data)?; + Ok(Self { inner }) + } + + /// Unpack the mutable buffer into a mutable slice, and initialize the + /// slice to 0-length + pub fn init<'a>(data: &'a mut [u8]) -> Result + where + 'a: 'data, + { + let inner = ListView::::init(data)?; + Ok(Self { inner }) + } + + /// Add another item to the slice + pub fn push(&mut self, t: T) -> Result<(), ProgramError> { + self.inner.push(t) + } +} + +#[cfg(test)] +#[allow(deprecated)] +mod tests { + use { + super::*, + crate::{bytemuck::pod_slice_to_bytes, error::PodSliceError}, + bytemuck_derive::{Pod, Zeroable}, + }; + + #[repr(C)] + #[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] + struct TestStruct { + test_field: u8, + test_pubkey: [u8; 32], + } + + const LENGTH_SIZE: usize = std::mem::size_of::(); + + #[test] + fn test_pod_slice() { + let test_field_bytes = [0]; + let test_pubkey_bytes = [1; 32]; + let len_bytes = [2, 0, 0, 0]; + + // Slice will contain 2 `TestStruct` + let mut data_bytes = [0; 66]; + data_bytes[0..1].copy_from_slice(&test_field_bytes); + data_bytes[1..33].copy_from_slice(&test_pubkey_bytes); + data_bytes[33..34].copy_from_slice(&test_field_bytes); + data_bytes[34..66].copy_from_slice(&test_pubkey_bytes); + + let mut pod_slice_bytes = [0; 70]; + pod_slice_bytes[0..4].copy_from_slice(&len_bytes); + pod_slice_bytes[4..70].copy_from_slice(&data_bytes); + + let pod_slice = PodSlice::::unpack(&pod_slice_bytes).unwrap(); + let pod_slice_data = pod_slice.data(); + + assert_eq!(pod_slice.inner.len(), 2); + assert_eq!(pod_slice_to_bytes(pod_slice.data()), data_bytes); + assert_eq!(pod_slice_data[0].test_field, test_field_bytes[0]); + assert_eq!(pod_slice_data[0].test_pubkey, test_pubkey_bytes); + assert_eq!(PodSlice::::size_of(1).unwrap(), 37); + } + + #[test] + fn test_pod_slice_buffer_too_large() { + // Length is 1. We pass one test struct with 6 trailing bytes to + // trigger BufferTooLarge. + let data_len = LENGTH_SIZE + std::mem::size_of::() + 6; + let mut pod_slice_bytes = vec![1; data_len]; + pod_slice_bytes[0..4].copy_from_slice(&[1, 0, 0, 0]); + let err = PodSlice::::unpack(&pod_slice_bytes) + .err() + .unwrap(); + assert!(matches!(err, ProgramError::InvalidArgument)); + } + + #[test] + fn test_pod_slice_buffer_larger_than_length_value() { + // If the buffer is longer than the u32 length value declares, it + // should still unpack successfully, as long as the length of the rest + // of the buffer can be divided by `size_of::`. + let length: u32 = 12; + let length_le = length.to_le_bytes(); + + // First set up the data to have room for extra items. + let data_len = PodSlice::::size_of(length as usize + 2).unwrap(); + let mut data = vec![0; data_len]; + + // Now write the bogus length - which is smaller - into the first 4 + // bytes. + data[..LENGTH_SIZE].copy_from_slice(&length_le); + + let pod_slice = PodSlice::::unpack(&data).unwrap(); + let pod_slice_len = pod_slice.inner.len() as u32; + let data = pod_slice.data(); + let data_vec = data.to_vec(); + + assert_eq!(pod_slice_len, length); + assert_eq!(data.len(), length as usize); + assert_eq!(data_vec.len(), length as usize); + } + + #[test] + fn test_pod_slice_buffer_too_small() { + // 1 `TestStruct` + length = 37 bytes + // we pass 36 to trigger BufferTooSmall + let pod_slice_bytes = [1; 36]; + let err = PodSlice::::unpack(&pod_slice_bytes) + .err() + .unwrap(); + assert!(matches!(err, ProgramError::InvalidArgument)); + } + + #[test] + fn test_pod_slice_buffer_shorter_than_length_value() { + // If the buffer is shorter than the u32 length value declares, we + // should get a BufferTooSmall error. + let length: u32 = 12; + let length_le = length.to_le_bytes(); + for num_items in 0..length { + // First set up the data to have `num_elements` items. + let data_len = PodSlice::::size_of(num_items as usize).unwrap(); + let mut data = vec![0; data_len]; + + // Now write the bogus length - which is larger - into the first 4 + // bytes. + data[..LENGTH_SIZE].copy_from_slice(&length_le); + + // Expect an error on unpacking. + let err = PodSlice::::unpack(&data).err().unwrap(); + assert_eq!( + err, + PodSliceError::BufferTooSmall.into(), + "Expected an `PodSliceError::BufferTooSmall` error" + ); + } + } + + #[test] + fn test_pod_slice_mut() { + // slice can fit 2 `TestStruct` + let mut pod_slice_bytes = [0; 70]; + // set length to 1, so we have room to push 1 more item + let len_bytes = [1, 0, 0, 0]; + pod_slice_bytes[0..4].copy_from_slice(&len_bytes); + + let mut pod_slice = PodSliceMut::::unpack(&mut pod_slice_bytes).unwrap(); + + assert_eq!(pod_slice.inner.len(), 1); + pod_slice.push(TestStruct::default()).unwrap(); + assert_eq!(pod_slice.inner.len(), 2); + + let err = pod_slice + .push(TestStruct::default()) + .expect_err("Expected an `PodSliceError::BufferTooSmall` error"); + assert_eq!(err, PodSliceError::BufferTooSmall.into()); + } +} From 34114b3f3347368367b0aa584fe1374831c731c1 Mon Sep 17 00:00:00 2001 From: Gabe Rodriguez Date: Fri, 20 Mar 2026 10:47:20 +0100 Subject: [PATCH 2/2] Depend on remote deps --- Cargo.lock | 61 ++++++++++++++++++++++++++----- tlv-account-resolution/Cargo.toml | 10 ++--- type-length-value/Cargo.toml | 2 +- 3 files changed, 57 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f99d91ac..7094b111 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1101,7 +1101,6 @@ dependencies = [ "five8_const", "rand 0.9.2", "serde", - "serde_derive", "sha2-const-stable", "solana-atomic-u64", "solana-define-syscall 5.0.0", @@ -1686,23 +1685,38 @@ dependencies = [ "thiserror 2.0.18", ] +[[package]] +name = "spl-list-view" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ff25a803cee203606fc29a5cac537b9b78b5c4bee107579efc0a678a53c4e9f" +dependencies = [ + "bytemuck", + "solana-program-error", + "spl-pod 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "spl-pod" version = "0.7.2" dependencies = [ + "base64", "borsh", "bytemuck", "bytemuck_derive", + "num-derive", + "num-traits", + "num_enum", "serde", - "serde_derive", "serde_json", - "solana-address 2.2.0", "solana-program-error", "solana-program-option", + "solana-pubkey", + "solana-zk-sdk", "spl-pod 0.7.2", "test-case", + "thiserror 2.0.18", "wincode", - "wincode-derive", ] [[package]] @@ -1737,7 +1751,22 @@ dependencies = [ "solana-program-error", "solana-sha256-hasher", "solana-sysvar", - "spl-program-error-derive", + "spl-program-error-derive 0.6.0", + "thiserror 2.0.18", +] + +[[package]] +name = "spl-program-error" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c4f6cf26cb6768110bf024bc7224326c720d711f7ad25d16f40f6cee40edb2d" +dependencies = [ + "num-derive", + "num-traits", + "num_enum", + "solana-msg", + "solana-program-error", + "spl-program-error-derive 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)", "thiserror 2.0.18", ] @@ -1751,6 +1780,18 @@ dependencies = [ "syn", ] +[[package]] +name = "spl-program-error-derive" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ec8965aa4dc6c74701cbb48b9cad5af35b9a394514934949edbb357b78f840d" +dependencies = [ + "proc-macro2", + "quote", + "sha2", + "syn", +] + [[package]] name = "spl-tlv-account-resolution" version = "0.11.1" @@ -1766,11 +1807,11 @@ dependencies = [ "solana-instruction", "solana-program-error", "solana-pubkey", - "spl-discriminator 0.5.1", - "spl-list-view", - "spl-pod 0.7.2", - "spl-program-error", - "spl-type-length-value 0.9.0", + "spl-discriminator 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", + "spl-list-view 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", + "spl-pod 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)", + "spl-program-error 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)", + "spl-type-length-value 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)", "thiserror 2.0.18", "tokio", ] diff --git a/tlv-account-resolution/Cargo.toml b/tlv-account-resolution/Cargo.toml index c8388dbc..1f3699b9 100644 --- a/tlv-account-resolution/Cargo.toml +++ b/tlv-account-resolution/Cargo.toml @@ -20,11 +20,11 @@ solana-account-info = "3.0.0" solana-instruction = { version = "3.0.0", features = ["std"] } solana-program-error = "3.0.0" solana-pubkey = { version = "3.0.0", features = ["curve25519"] } -spl-discriminator = { version = "0.5.1", path = "../discriminator" } -spl-list-view = { version = "0.1.0", path = "../list-view" } -spl-program-error = { version = "0.8.0", path = "../program-error" } -spl-pod = { version = "0.7.1", path = "../pod", features = ["bytemuck"] } -spl-type-length-value = { version = "0.9.0", path = "../type-length-value" } +spl-discriminator = "0.5.1" +spl-list-view = "0.1.0" +spl-program-error = "0.8.0" +spl-pod = "0.7.1" +spl-type-length-value = "0.9.0" thiserror = "2.0" [dev-dependencies] diff --git a/type-length-value/Cargo.toml b/type-length-value/Cargo.toml index 5e000486..7be69d4b 100644 --- a/type-length-value/Cargo.toml +++ b/type-length-value/Cargo.toml @@ -21,7 +21,7 @@ solana-msg = "3.0.0" solana-program-error = "3.0.0" spl-discriminator = { version = "0.5.1", path = "../discriminator" } spl-type-length-value-derive = { version = "0.2", path = "../type-length-value-derive", optional = true } -spl-pod = { version = "0.7.1", path = "../pod", features = ["bytemuck"] } +spl-pod = { version = "0.7.1", path = "../pod" } thiserror = "2.0" [lib]