diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index c0eabcd59f3..436db6fa524 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -26,6 +26,178 @@ impl vortex_array::accessor::ArrayAccessor< pub fn vortex_array::arrays::PrimitiveArray::with_iterator(&self, f: F) -> R where F: for<'a> core::ops::function::FnOnce(&mut dyn core::iter::traits::iterator::Iterator>) -> R +pub mod vortex_array::aggregate_fn + +pub mod vortex_array::aggregate_fn::fns + +pub mod vortex_array::aggregate_fn::session + +pub struct vortex_array::aggregate_fn::session::AggregateFnSession + +impl vortex_array::aggregate_fn::session::AggregateFnSession + +pub fn vortex_array::aggregate_fn::session::AggregateFnSession::register(&self, vtable: V) + +pub fn vortex_array::aggregate_fn::session::AggregateFnSession::registry(&self) -> &vortex_array::aggregate_fn::session::AggregateFnRegistry + +impl core::default::Default for vortex_array::aggregate_fn::session::AggregateFnSession + +pub fn vortex_array::aggregate_fn::session::AggregateFnSession::default() -> vortex_array::aggregate_fn::session::AggregateFnSession + +impl core::fmt::Debug for vortex_array::aggregate_fn::session::AggregateFnSession + +pub fn vortex_array::aggregate_fn::session::AggregateFnSession::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub trait vortex_array::aggregate_fn::session::AggregateFnSessionExt: vortex_session::SessionExt + +pub fn vortex_array::aggregate_fn::session::AggregateFnSessionExt::aggregate_fns(&self) -> vortex_session::Ref<'_, vortex_array::aggregate_fn::session::AggregateFnSession> + +impl vortex_array::aggregate_fn::session::AggregateFnSessionExt for S + +pub fn S::aggregate_fns(&self) -> vortex_session::Ref<'_, vortex_array::aggregate_fn::session::AggregateFnSession> + +pub type vortex_array::aggregate_fn::session::AggregateFnRegistry = vortex_session::registry::Registry + +pub struct vortex_array::aggregate_fn::AggregateFn(_) + +impl vortex_array::aggregate_fn::AggregateFn + +pub fn vortex_array::aggregate_fn::AggregateFn::erased(self) -> vortex_array::aggregate_fn::AggregateFnRef + +pub fn vortex_array::aggregate_fn::AggregateFn::new(vtable: V, options: ::Options) -> Self + +pub fn vortex_array::aggregate_fn::AggregateFn::options(&self) -> &::Options + +pub fn vortex_array::aggregate_fn::AggregateFn::vtable(&self) -> &V + +pub struct vortex_array::aggregate_fn::AggregateFnOptions<'a> + +impl vortex_array::aggregate_fn::AggregateFnOptions<'_> + +pub fn vortex_array::aggregate_fn::AggregateFnOptions<'_>::serialize(&self) -> vortex_error::VortexResult>> + +impl<'a> vortex_array::aggregate_fn::AggregateFnOptions<'a> + +pub fn vortex_array::aggregate_fn::AggregateFnOptions<'a>::as_any(&self) -> &'a dyn core::any::Any + +impl core::cmp::Eq for vortex_array::aggregate_fn::AggregateFnOptions<'_> + +impl core::cmp::PartialEq for vortex_array::aggregate_fn::AggregateFnOptions<'_> + +pub fn vortex_array::aggregate_fn::AggregateFnOptions<'_>::eq(&self, other: &Self) -> bool + +impl core::fmt::Debug for vortex_array::aggregate_fn::AggregateFnOptions<'_> + +pub fn vortex_array::aggregate_fn::AggregateFnOptions<'_>::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::fmt::Display for vortex_array::aggregate_fn::AggregateFnOptions<'_> + +pub fn vortex_array::aggregate_fn::AggregateFnOptions<'_>::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_array::aggregate_fn::AggregateFnOptions<'_> + +pub fn vortex_array::aggregate_fn::AggregateFnOptions<'_>::hash(&self, state: &mut H) + +pub struct vortex_array::aggregate_fn::AggregateFnRef(_) + +impl vortex_array::aggregate_fn::AggregateFnRef + +pub fn vortex_array::aggregate_fn::AggregateFnRef::accumulator(&self, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult> + +pub fn vortex_array::aggregate_fn::AggregateFnRef::as_(&self) -> &::Options + +pub fn vortex_array::aggregate_fn::AggregateFnRef::as_opt(&self) -> core::option::Option<&::Options> + +pub fn vortex_array::aggregate_fn::AggregateFnRef::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::AggregateFnRef::is(&self) -> bool + +pub fn vortex_array::aggregate_fn::AggregateFnRef::options(&self) -> vortex_array::aggregate_fn::AggregateFnOptions<'_> + +pub fn vortex_array::aggregate_fn::AggregateFnRef::return_dtype(&self, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::AggregateFnRef::state_dtype(&self, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::AggregateFnRef::vtable_ref(&self) -> core::option::Option<&V> + +impl core::clone::Clone for vortex_array::aggregate_fn::AggregateFnRef + +pub fn vortex_array::aggregate_fn::AggregateFnRef::clone(&self) -> vortex_array::aggregate_fn::AggregateFnRef + +impl core::cmp::Eq for vortex_array::aggregate_fn::AggregateFnRef + +impl core::cmp::PartialEq for vortex_array::aggregate_fn::AggregateFnRef + +pub fn vortex_array::aggregate_fn::AggregateFnRef::eq(&self, other: &Self) -> bool + +impl core::fmt::Debug for vortex_array::aggregate_fn::AggregateFnRef + +pub fn vortex_array::aggregate_fn::AggregateFnRef::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::fmt::Display for vortex_array::aggregate_fn::AggregateFnRef + +pub fn vortex_array::aggregate_fn::AggregateFnRef::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_array::aggregate_fn::AggregateFnRef + +pub fn vortex_array::aggregate_fn::AggregateFnRef::hash(&self, state: &mut H) + +pub trait vortex_array::aggregate_fn::Accumulator: core::marker::Send + core::marker::Sync + +pub fn vortex_array::aggregate_fn::Accumulator::accumulate(&mut self, batch: &vortex_array::ArrayRef) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::Accumulator::accumulate_list(&mut self, list: &vortex_array::arrays::ListViewArray) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::Accumulator::finish(self: alloc::boxed::Box) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::Accumulator::flush(&mut self) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::Accumulator::is_saturated(&self) -> bool + +pub fn vortex_array::aggregate_fn::Accumulator::merge(&mut self, state: &vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::Accumulator::merge_list(&mut self, states: &vortex_array::ArrayRef) -> vortex_error::VortexResult<()> + +pub trait vortex_array::aggregate_fn::AggregateFnPlugin: 'static + core::marker::Send + core::marker::Sync + +pub fn vortex_array::aggregate_fn::AggregateFnPlugin::deserialize(&self, metadata: &[u8], session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::AggregateFnPlugin::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +impl vortex_array::aggregate_fn::AggregateFnPlugin for V + +pub fn V::deserialize(&self, metadata: &[u8], session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn V::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub trait vortex_array::aggregate_fn::AggregateFnVTable: 'static + core::marker::Sized + core::clone::Clone + core::marker::Send + core::marker::Sync + +pub type vortex_array::aggregate_fn::AggregateFnVTable::Options: 'static + core::marker::Send + core::marker::Sync + core::clone::Clone + core::fmt::Debug + core::fmt::Display + core::cmp::PartialEq + core::cmp::Eq + core::hash::Hash + +pub fn vortex_array::aggregate_fn::AggregateFnVTable::accumulator(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult> + +pub fn vortex_array::aggregate_fn::AggregateFnVTable::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::AggregateFnVTable::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::AggregateFnVTable::return_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::AggregateFnVTable::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::aggregate_fn::AggregateFnVTable::state_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub trait vortex_array::aggregate_fn::AggregateFnVTableExt: vortex_array::aggregate_fn::AggregateFnVTable + +pub fn vortex_array::aggregate_fn::AggregateFnVTableExt::bind(&self, options: Self::Options) -> vortex_array::aggregate_fn::AggregateFnRef + +impl vortex_array::aggregate_fn::AggregateFnVTableExt for V + +pub fn V::bind(&self, options: Self::Options) -> vortex_array::aggregate_fn::AggregateFnRef + +pub type vortex_array::aggregate_fn::AggregateFnId = arcref::ArcRef + +pub type vortex_array::aggregate_fn::AggregateFnPluginRef = alloc::sync::Arc + pub mod vortex_array::arrays pub mod vortex_array::arrays::build_views diff --git a/vortex-array/src/aggregate_fn/accumulator.rs b/vortex-array/src/aggregate_fn/accumulator.rs new file mode 100644 index 00000000000..ecf55d5945b --- /dev/null +++ b/vortex-array/src/aggregate_fn/accumulator.rs @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; + +use crate::ArrayRef; +use crate::arrays::ListViewArray; +use crate::scalar::Scalar; + +/// The execution interface for all aggregation. +/// +/// An accumulator processes one group at a time: the caller feeds element batches via +/// [`accumulate`](Accumulator::accumulate), then calls [`flush`](Accumulator::flush) to finalize +/// the group and begin the next. The accumulator owns an output buffer and returns all results +/// via [`finish`](Accumulator::finish). +pub trait Accumulator: Send + Sync { + /// Feed a batch of elements for the currently open group. + /// + /// May be called multiple times per group (e.g., chunked elements). + fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()>; + + /// Accumulate all groups defined by a [`ListViewArray`] in one call. + /// + /// Default: for each group, accumulate its elements then flush. + /// Override for vectorized fast paths (e.g., segmented sum over the flat + /// elements + offsets without per-group slicing). + fn accumulate_list(&mut self, list: &ListViewArray) -> VortexResult<()> { + for i in 0..list.len() { + self.accumulate(&list.list_elements_at(i)?)?; + self.flush()?; + } + Ok(()) + } + + /// Merge pre-computed partial state into the currently open group. + /// + /// The scalar's dtype must match the aggregate's `state_dtype`. + /// This is equivalent to having processed raw elements that would produce + /// this state — used by encoding-specific optimizations. + fn merge(&mut self, state: &Scalar) -> VortexResult<()>; + + /// Merge an array of pre-computed states, one per group, flushing each. + /// + /// The array's dtype must match the aggregate's `state_dtype`. + /// Default: merge + flush for each element. + fn merge_list(&mut self, states: &ArrayRef) -> VortexResult<()> { + for i in 0..states.len() { + self.merge(&states.scalar_at(i)?)?; + self.flush()?; + } + Ok(()) + } + + /// Whether the currently open group's result is fully determined. + /// + /// When true, callers may skip further accumulate/merge calls and proceed + /// directly to [`flush`](Accumulator::flush). Resets to false after flush. + fn is_saturated(&self) -> bool { + false + } + + /// Finalize the currently open group: push its result to the output buffer + /// and reset internal state for the next group. + /// + /// Flushing a group with zero accumulated elements produces the aggregate's + /// identity value (e.g., 0 for Sum, u64::MAX for Min) or null if no identity + /// exists. + fn flush(&mut self) -> VortexResult<()>; + + /// Return all flushed results as a single array. + /// + /// Length equals the number of [`flush`](Accumulator::flush) calls made over the + /// accumulator's lifetime. + fn finish(self: Box) -> VortexResult; +} diff --git a/vortex-array/src/aggregate_fn/erased.rs b/vortex-array/src/aggregate_fn/erased.rs new file mode 100644 index 00000000000..9b2953fe2f2 --- /dev/null +++ b/vortex-array/src/aggregate_fn/erased.rs @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Type-erased aggregate function ([`AggregateFnRef`]). + +use std::fmt::Debug; +use std::fmt::Display; +use std::fmt::Formatter; +use std::hash::Hash; +use std::hash::Hasher; +use std::sync::Arc; + +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_utils::debug_with::DebugWith; + +use crate::aggregate_fn::AggregateFnId; +use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::accumulator::Accumulator; +use crate::aggregate_fn::options::AggregateFnOptions; +use crate::aggregate_fn::typed::AggregateFnInner; +use crate::aggregate_fn::typed::DynAggregateFn; +use crate::dtype::DType; + +/// A type-erased aggregate function, pairing a vtable with bound options behind a trait object. +/// +/// This stores an [`AggregateFnVTable`] and its options behind an `Arc`, +/// allowing heterogeneous storage and dispatch. +/// +/// Use [`super::AggregateFn::new()`] to construct, and [`super::AggregateFn::erased()`] to +/// obtain an [`AggregateFnRef`]. +#[derive(Clone)] +pub struct AggregateFnRef(pub(super) Arc); + +impl AggregateFnRef { + /// Returns the ID of this aggregate function. + pub fn id(&self) -> AggregateFnId { + self.0.id() + } + + /// Returns whether the aggregate function is of the given vtable type. + pub fn is(&self) -> bool { + self.0.as_any().is::>() + } + + /// Returns the typed options for this aggregate function if it matches the given vtable type. + pub fn as_opt(&self) -> Option<&V::Options> { + self.downcast_inner::().map(|inner| &inner.options) + } + + /// Returns a reference to the typed vtable if it matches the given vtable type. + pub fn vtable_ref(&self) -> Option<&V> { + self.downcast_inner::().map(|inner| &inner.vtable) + } + + /// Downcast the inner to the concrete `AggregateFnInner`. + fn downcast_inner(&self) -> Option<&AggregateFnInner> { + self.0.as_any().downcast_ref::>() + } + + /// Returns the typed options for this aggregate function if it matches the given vtable type. + /// + /// # Panics + /// + /// Panics if the vtable type does not match. + pub fn as_(&self) -> &V::Options { + self.as_opt::() + .vortex_expect("Aggregate function options type mismatch") + } + + /// The type-erased options for this aggregate function. + pub fn options(&self) -> AggregateFnOptions<'_> { + AggregateFnOptions { inner: &*self.0 } + } + + /// Compute the return [`DType`] per group given the input element type. + pub fn return_dtype(&self, input_dtype: &DType) -> VortexResult { + self.0.return_dtype(input_dtype) + } + + /// DType of the intermediate accumulator state. + pub fn state_dtype(&self, input_dtype: &DType) -> VortexResult { + self.0.state_dtype(input_dtype) + } + + /// Create an accumulator for streaming aggregation. + pub fn accumulator(&self, input_dtype: &DType) -> VortexResult> { + self.0.accumulator(input_dtype) + } +} + +impl Debug for AggregateFnRef { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AggregateFnRef") + .field("vtable", &self.0.id()) + .field("options", &DebugWith(|fmt| self.0.options_debug(fmt))) + .finish() + } +} + +impl Display for AggregateFnRef { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}(", self.0.id())?; + self.0.options_display(f)?; + write!(f, ")") + } +} + +impl PartialEq for AggregateFnRef { + fn eq(&self, other: &Self) -> bool { + self.0.id() == other.0.id() && self.0.options_eq(other.0.options_any()) + } +} +impl Eq for AggregateFnRef {} + +impl Hash for AggregateFnRef { + fn hash(&self, state: &mut H) { + self.0.id().hash(state); + self.0.options_hash(state); + } +} diff --git a/vortex-array/src/aggregate_fn/fns/mod.rs b/vortex-array/src/aggregate_fn/fns/mod.rs new file mode 100644 index 00000000000..0d735177e5d --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/mod.rs @@ -0,0 +1,2 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors diff --git a/vortex-array/src/aggregate_fn/mod.rs b/vortex-array/src/aggregate_fn/mod.rs new file mode 100644 index 00000000000..6ce112e20ee --- /dev/null +++ b/vortex-array/src/aggregate_fn/mod.rs @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Aggregate function vtable machinery. +//! +//! This module contains the [`AggregateFnVTable`] trait, the [`Accumulator`] trait, and the +//! type-erasure infrastructure for aggregate functions. + +use arcref::ArcRef; + +mod accumulator; +pub use accumulator::*; + +mod vtable; +pub use vtable::*; + +mod plugin; +pub use plugin::*; + +mod typed; +pub use typed::*; + +mod erased; +pub use erased::*; + +mod options; +pub use options::*; + +pub mod fns; +pub mod session; + +/// A unique identifier for an aggregate function. +pub type AggregateFnId = ArcRef; + +/// Private module to seal [`typed::DynAggregateFn`]. +mod sealed { + use crate::aggregate_fn::AggregateFnVTable; + use crate::aggregate_fn::typed::AggregateFnInner; + + /// Marker trait to prevent external implementations of [`super::typed::DynAggregateFn`]. + pub(crate) trait Sealed {} + + /// This can be the **only** implementor for [`super::typed::DynAggregateFn`]. + impl Sealed for AggregateFnInner {} +} diff --git a/vortex-array/src/aggregate_fn/options.rs b/vortex-array/src/aggregate_fn/options.rs new file mode 100644 index 00000000000..792e2924ea9 --- /dev/null +++ b/vortex-array/src/aggregate_fn/options.rs @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::any::Any; +use std::fmt::Debug; +use std::fmt::Display; +use std::hash::Hash; +use std::hash::Hasher; + +use vortex_error::VortexResult; + +use crate::aggregate_fn::typed::DynAggregateFn; + +/// An opaque handle to aggregate function options. +pub struct AggregateFnOptions<'a> { + pub(super) inner: &'a dyn DynAggregateFn, +} + +impl Display for AggregateFnOptions<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.inner.options_display(f) + } +} + +impl Debug for AggregateFnOptions<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.inner.options_debug(f) + } +} + +impl PartialEq for AggregateFnOptions<'_> { + fn eq(&self, other: &Self) -> bool { + self.inner.id() == other.inner.id() && self.inner.options_eq(other.inner.options_any()) + } +} +impl Eq for AggregateFnOptions<'_> {} + +impl Hash for AggregateFnOptions<'_> { + fn hash(&self, state: &mut H) { + self.inner.id().hash(state); + self.inner.options_hash(state); + } +} + +impl AggregateFnOptions<'_> { + /// Serialize the options to a byte vector. + pub fn serialize(&self) -> VortexResult>> { + self.inner.options_serialize() + } +} + +impl<'a> AggregateFnOptions<'a> { + /// Return the underlying `Any` reference. + pub fn as_any(&self) -> &'a dyn Any { + self.inner.options_any() + } +} diff --git a/vortex-array/src/aggregate_fn/plugin.rs b/vortex-array/src/aggregate_fn/plugin.rs new file mode 100644 index 00000000000..b7ff8b893ac --- /dev/null +++ b/vortex-array/src/aggregate_fn/plugin.rs @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::sync::Arc; + +use vortex_error::VortexResult; +use vortex_session::VortexSession; + +use crate::aggregate_fn::AggregateFn; +use crate::aggregate_fn::AggregateFnId; +use crate::aggregate_fn::AggregateFnRef; +use crate::aggregate_fn::AggregateFnVTable; + +/// Reference-counted pointer to an aggregate function plugin. +pub type AggregateFnPluginRef = Arc; + +/// Registry trait for ID-based deserialization of aggregate functions. +/// +/// Plugins are registered in the session by their [`AggregateFnId`]. When a serialized aggregate +/// function is encountered, the session resolves the ID to the plugin and calls [`deserialize`] +/// to reconstruct the value as an [`AggregateFnRef`]. +/// +/// [`deserialize`]: AggregateFnPlugin::deserialize +pub trait AggregateFnPlugin: 'static + Send + Sync { + /// Returns the ID for this aggregate function. + fn id(&self) -> AggregateFnId; + + /// Deserialize an aggregate function from serialized metadata. + fn deserialize(&self, metadata: &[u8], session: &VortexSession) + -> VortexResult; +} + +impl std::fmt::Debug for dyn AggregateFnPlugin { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("AggregateFnPlugin") + .field(&self.id()) + .finish() + } +} + +impl AggregateFnPlugin for V { + fn id(&self) -> AggregateFnId { + AggregateFnVTable::id(self) + } + + fn deserialize( + &self, + metadata: &[u8], + session: &VortexSession, + ) -> VortexResult { + let options = AggregateFnVTable::deserialize(self, metadata, session)?; + Ok(AggregateFn::new(self.clone(), options).erased()) + } +} diff --git a/vortex-array/src/aggregate_fn/session.rs b/vortex-array/src/aggregate_fn/session.rs new file mode 100644 index 00000000000..4e7bd12937b --- /dev/null +++ b/vortex-array/src/aggregate_fn/session.rs @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::sync::Arc; + +use vortex_session::Ref; +use vortex_session::SessionExt; +use vortex_session::registry::Registry; + +use crate::aggregate_fn::AggregateFnPluginRef; +use crate::aggregate_fn::AggregateFnVTable; + +/// Registry of aggregate function vtables. +pub type AggregateFnRegistry = Registry; + +/// Session state for aggregate function vtables. +#[derive(Debug, Default)] +pub struct AggregateFnSession { + registry: AggregateFnRegistry, +} + +impl AggregateFnSession { + /// Returns the aggregate function registry. + pub fn registry(&self) -> &AggregateFnRegistry { + &self.registry + } + + /// Register an aggregate function vtable in the session, replacing any existing vtable with + /// the same ID. + pub fn register(&self, vtable: V) { + self.registry + .register(vtable.id(), Arc::new(vtable) as AggregateFnPluginRef); + } +} + +/// Extension trait for accessing aggregate function session data. +pub trait AggregateFnSessionExt: SessionExt { + /// Returns the aggregate function vtable registry. + fn aggregate_fns(&self) -> Ref<'_, AggregateFnSession> { + self.get::() + } +} +impl AggregateFnSessionExt for S {} diff --git a/vortex-array/src/aggregate_fn/typed.rs b/vortex-array/src/aggregate_fn/typed.rs new file mode 100644 index 00000000000..25e05424ea1 --- /dev/null +++ b/vortex-array/src/aggregate_fn/typed.rs @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Typed and inner representations of aggregate functions. +//! +//! - [`AggregateFn`]: The public typed wrapper, parameterized by a concrete +//! [`AggregateFnVTable`]. +//! - [`AggregateFnInner`]: The private inner struct that holds the vtable + options. +//! - [`DynAggregateFn`]: The private sealed trait for type-erased dispatch (bound, options in +//! self). + +use std::any::Any; +use std::fmt; +use std::fmt::Debug; +use std::fmt::Display; +use std::fmt::Formatter; +use std::hash::Hash; +use std::hash::Hasher; +use std::sync::Arc; + +use vortex_error::VortexResult; + +use crate::aggregate_fn::AggregateFnId; +use crate::aggregate_fn::AggregateFnRef; +use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::accumulator::Accumulator; +use crate::dtype::DType; + +/// An object-safe, sealed trait for bound aggregate function dispatch. +/// +/// Options are stored inside the implementing [`AggregateFnInner`], not passed externally. +/// This is the sole trait behind [`AggregateFnRef`]'s `Arc`. +pub(super) trait DynAggregateFn: 'static + Send + Sync + super::sealed::Sealed { + fn as_any(&self) -> &dyn Any; + fn id(&self) -> AggregateFnId; + fn options_any(&self) -> &dyn Any; + + fn return_dtype(&self, input_dtype: &DType) -> VortexResult; + fn state_dtype(&self, input_dtype: &DType) -> VortexResult; + fn accumulator(&self, input_dtype: &DType) -> VortexResult>; + + fn options_serialize(&self) -> VortexResult>>; + fn options_eq(&self, other_options: &dyn Any) -> bool; + fn options_hash(&self, hasher: &mut dyn Hasher); + fn options_display(&self, f: &mut Formatter<'_>) -> fmt::Result; + fn options_debug(&self, f: &mut Formatter<'_>) -> fmt::Result; +} + +/// The private inner representation of a bound aggregate function, pairing a vtable with its +/// options. +/// +/// This is the sole implementor of [`DynAggregateFn`], enabling [`AggregateFnRef`] to safely +/// downcast back to the concrete vtable type via [`Any`]. +pub(super) struct AggregateFnInner { + pub(super) vtable: V, + pub(super) options: V::Options, +} + +impl DynAggregateFn for AggregateFnInner { + #[inline(always)] + fn as_any(&self) -> &dyn Any { + self + } + + #[inline(always)] + fn id(&self) -> AggregateFnId { + V::id(&self.vtable) + } + + fn options_any(&self) -> &dyn Any { + &self.options + } + + fn return_dtype(&self, input_dtype: &DType) -> VortexResult { + V::return_dtype(&self.vtable, &self.options, input_dtype) + } + + fn state_dtype(&self, input_dtype: &DType) -> VortexResult { + V::state_dtype(&self.vtable, &self.options, input_dtype) + } + + fn accumulator(&self, input_dtype: &DType) -> VortexResult> { + V::accumulator(&self.vtable, &self.options, input_dtype) + } + + fn options_serialize(&self) -> VortexResult>> { + V::serialize(&self.vtable, &self.options) + } + + fn options_eq(&self, other_options: &dyn Any) -> bool { + other_options + .downcast_ref::() + .is_some_and(|o| self.options == *o) + } + + fn options_hash(&self, mut hasher: &mut dyn Hasher) { + self.options.hash(&mut hasher); + } + + fn options_display(&self, f: &mut Formatter<'_>) -> fmt::Result { + Display::fmt(&self.options, f) + } + + fn options_debug(&self, f: &mut Formatter<'_>) -> fmt::Result { + Debug::fmt(&self.options, f) + } +} + +/// A typed aggregate function instance, parameterized by a concrete [`AggregateFnVTable`]. +/// +/// You can construct one via [`new()`], and erase the type with [`erased()`] to obtain an +/// [`AggregateFnRef`]. +/// +/// [`new()`]: AggregateFn::new +/// [`erased()`]: AggregateFn::erased +pub struct AggregateFn(pub(super) Arc>); + +impl AggregateFn { + /// Create a new typed aggregate function instance. + pub fn new(vtable: V, options: V::Options) -> Self { + Self(Arc::new(AggregateFnInner { vtable, options })) + } + + /// Returns a reference to the vtable. + pub fn vtable(&self) -> &V { + &self.0.vtable + } + + /// Returns a reference to the options. + pub fn options(&self) -> &V::Options { + &self.0.options + } + + /// Erase the concrete type information, returning a type-erased [`AggregateFnRef`]. + pub fn erased(self) -> AggregateFnRef { + AggregateFnRef(self.0) + } +} diff --git a/vortex-array/src/aggregate_fn/vtable.rs b/vortex-array/src/aggregate_fn/vtable.rs new file mode 100644 index 00000000000..04f0a0030c7 --- /dev/null +++ b/vortex-array/src/aggregate_fn/vtable.rs @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::fmt::Debug; +use std::fmt::Display; +use std::hash::Hash; + +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_session::VortexSession; + +use crate::aggregate_fn::AggregateFn; +use crate::aggregate_fn::AggregateFnId; +use crate::aggregate_fn::AggregateFnRef; +use crate::aggregate_fn::accumulator::Accumulator; +use crate::dtype::DType; + +/// Defines the interface for aggregate function vtables. +/// +/// This trait is non-object-safe and allows the implementer to make use of associated types +/// for improved type safety, while allowing Vortex to enforce runtime checks on the inputs and +/// outputs of each function. +/// +/// The [`AggregateFnVTable`] trait should be implemented for a struct that holds global data across +/// all instances of the aggregate. In almost all cases, this struct will be an empty unit +/// struct, since most aggregates do not require any global state. +pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync { + /// Options for this aggregate function. + type Options: 'static + Send + Sync + Clone + Debug + Display + PartialEq + Eq + Hash; + + /// Returns the ID of the aggregate function vtable. + fn id(&self) -> AggregateFnId; + + /// Serialize the options for this aggregate function. + /// + /// Should return `Ok(None)` if the function is not serializable, and `Ok(vec![])` if it is + /// serializable but has no metadata. + fn serialize(&self, options: &Self::Options) -> VortexResult>> { + _ = options; + Ok(None) + } + + /// Deserialize the options of this aggregate function. + fn deserialize( + &self, + _metadata: &[u8], + _session: &VortexSession, + ) -> VortexResult { + vortex_bail!("Aggregate function {} is not deserializable", self.id()); + } + + /// Compute the return [`DType`] per group given the input element type. + fn return_dtype(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult; + + /// DType of the intermediate accumulator state. + /// + /// Use a struct dtype when multiple fields are needed + /// (e.g., Mean: `Struct { sum: f64, count: u64 }`). + fn state_dtype(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult; + + /// Create an accumulator for streaming aggregation. + fn accumulator( + &self, + options: &Self::Options, + input_dtype: &DType, + ) -> VortexResult>; +} + +/// Factory functions for aggregate vtables. +pub trait AggregateFnVTableExt: AggregateFnVTable { + /// Bind this vtable with the given options into an [`AggregateFnRef`]. + fn bind(&self, options: Self::Options) -> AggregateFnRef { + AggregateFn::new(self.clone(), options).erased() + } +} +impl AggregateFnVTableExt for V {} diff --git a/vortex-array/src/lib.rs b/vortex-array/src/lib.rs index d85569c046f..64affff1f04 100644 --- a/vortex-array/src/lib.rs +++ b/vortex-array/src/lib.rs @@ -28,6 +28,7 @@ use vortex_session::VortexSession; use crate::session::ArraySession; pub mod accessor; +pub mod aggregate_fn; #[doc(hidden)] pub mod aliases; mod array; diff --git a/vortex/public-api.lock b/vortex/public-api.lock index a823449de25..0c8ce9d0cd9 100644 --- a/vortex/public-api.lock +++ b/vortex/public-api.lock @@ -1,5 +1,7 @@ pub mod vortex +pub use vortex::aggregate_fn + pub use vortex::compute pub use vortex::expr diff --git a/vortex/src/lib.rs b/vortex/src/lib.rs index b5b374d9861..ae3eae817bf 100644 --- a/vortex/src/lib.rs +++ b/vortex/src/lib.rs @@ -5,6 +5,8 @@ #![doc = include_str!(concat!("../", env!("CARGO_PKG_README")))] // vortex::compute is deprecated and will be ported over to expressions. +pub use vortex_array::aggregate_fn; +use vortex_array::aggregate_fn::session::AggregateFnSession; pub use vortex_array::compute; use vortex_array::dtype::session::DTypeSession; // vortex::expr is in the process of having its dependencies inverted, and will eventually be @@ -165,6 +167,7 @@ impl VortexSessionDefault for VortexSession { .with::() .with::() .with::() + .with::() .with::(); #[cfg(feature = "files")]