diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 436db6fa524..1d9df743c74 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -30,6 +30,30 @@ pub mod vortex_array::aggregate_fn pub mod vortex_array::aggregate_fn::fns +pub mod vortex_array::aggregate_fn::fns::mean + +pub struct vortex_array::aggregate_fn::fns::mean::Mean + +impl core::clone::Clone for vortex_array::aggregate_fn::fns::mean::Mean + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::clone(&self) -> vortex_array::aggregate_fn::fns::mean::Mean + +impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::mean::Mean + +pub type vortex_array::aggregate_fn::fns::mean::Mean::Options = vortex_array::scalar_fn::EmptyOptions + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::accumulator(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult> + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::state_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + pub mod vortex_array::aggregate_fn::session pub struct vortex_array::aggregate_fn::session::AggregateFnSession @@ -42,7 +66,7 @@ pub fn vortex_array::aggregate_fn::session::AggregateFnSession::registry(&self) impl core::default::Default for vortex_array::aggregate_fn::session::AggregateFnSession -pub fn vortex_array::aggregate_fn::session::AggregateFnSession::default() -> vortex_array::aggregate_fn::session::AggregateFnSession +pub fn vortex_array::aggregate_fn::session::AggregateFnSession::default() -> Self impl core::fmt::Debug for vortex_array::aggregate_fn::session::AggregateFnSession @@ -166,9 +190,9 @@ pub fn vortex_array::aggregate_fn::AggregateFnPlugin::id(&self) -> vortex_array: impl vortex_array::aggregate_fn::AggregateFnPlugin for V -pub fn V::deserialize(&self, metadata: &[u8], session: &vortex_session::VortexSession) -> vortex_error::VortexResult +pub fn V::deserialize(&self, metadata: &[u8], session: &vortex_session::VortexSession) -> core::result::Result -pub fn V::id(&self) -> vortex_array::aggregate_fn::AggregateFnId +pub fn V::id(&self) -> arcref::ArcRef pub trait vortex_array::aggregate_fn::AggregateFnVTable: 'static + core::marker::Sized + core::clone::Clone + core::marker::Send + core::marker::Sync @@ -186,6 +210,22 @@ pub fn vortex_array::aggregate_fn::AggregateFnVTable::serialize(&self, options: pub fn vortex_array::aggregate_fn::AggregateFnVTable::state_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::mean::Mean + +pub type vortex_array::aggregate_fn::fns::mean::Mean::Options = vortex_array::scalar_fn::EmptyOptions + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::accumulator(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult> + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::state_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + pub trait vortex_array::aggregate_fn::AggregateFnVTableExt: vortex_array::aggregate_fn::AggregateFnVTable pub fn vortex_array::aggregate_fn::AggregateFnVTableExt::bind(&self, options: Self::Options) -> vortex_array::aggregate_fn::AggregateFnRef diff --git a/vortex-array/src/aggregate_fn/fns/mean.rs b/vortex-array/src/aggregate_fn/fns/mean.rs new file mode 100644 index 00000000000..00516f716b7 --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/mean.rs @@ -0,0 +1,309 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_mask::Mask; + +use crate::ArrayRef; +use crate::IntoArray; +use crate::aggregate_fn::AggregateFnId; +use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::accumulator::Accumulator; +use crate::arrays::PrimitiveArray; +use crate::canonical::ToCanonical; +use crate::dtype::DType; +use crate::dtype::NativePType; +use crate::dtype::Nullability; +use crate::dtype::PType; +use crate::dtype::StructFields; +use crate::match_each_native_ptype; +use crate::scalar::Scalar; +use crate::scalar_fn::EmptyOptions; + +/// Computes the arithmetic mean of numeric values. +#[derive(Clone)] +pub struct Mean; + +impl AggregateFnVTable for Mean { + type Options = EmptyOptions; + + fn id(&self) -> AggregateFnId { + AggregateFnId::new_ref("vortex.mean") + } + + fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> VortexResult { + if !input_dtype.is_int() && !input_dtype.is_float() { + vortex_bail!("Mean requires numeric input, got {}", input_dtype); + } + Ok(DType::Primitive(PType::F64, Nullability::Nullable)) + } + + fn state_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> VortexResult { + if !input_dtype.is_int() && !input_dtype.is_float() { + vortex_bail!("Mean requires numeric input, got {}", input_dtype); + } + Ok(DType::Struct( + StructFields::from_iter([ + ( + "sum", + DType::Primitive(PType::F64, Nullability::NonNullable), + ), + ( + "count", + DType::Primitive(PType::U64, Nullability::NonNullable), + ), + ]), + Nullability::Nullable, + )) + } + + fn accumulator( + &self, + _options: &Self::Options, + input_dtype: &DType, + ) -> VortexResult> { + if !input_dtype.is_int() && !input_dtype.is_float() { + vortex_bail!("Mean requires numeric input, got {}", input_dtype); + } + Ok(Box::new(MeanAccumulator::new())) + } +} + +struct MeanAccumulator { + sum: f64, + count: u64, + results: Vec>, +} + +impl MeanAccumulator { + fn new() -> Self { + Self { + sum: 0.0, + count: 0, + results: Vec::new(), + } + } +} + +/// Accumulate all-valid values of type `T` into `sum` and `count`. +fn accumulate_all_valid(values: &[T], sum: &mut f64, count: &mut u64) { + for v in values { + *sum += v.to_f64().unwrap_or(0.0); + *count += 1; + } +} + +/// Accumulate partially-valid values of type `T` into `sum` and `count`. +fn accumulate_with_mask( + values: &[T], + mask: &vortex_mask::MaskValues, + sum: &mut f64, + count: &mut u64, +) { + for (val, valid) in values.iter().zip(mask.bit_buffer().iter()) { + if valid { + *sum += val.to_f64().unwrap_or(0.0); + *count += 1; + } + } +} + +impl Accumulator for MeanAccumulator { + fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()> { + let primitive = batch.to_primitive(); + let validity = primitive.validity_mask()?; + + match_each_native_ptype!(primitive.ptype(), |T| { + let values = primitive.as_slice::(); + match &validity { + Mask::AllTrue(_) => accumulate_all_valid(values, &mut self.sum, &mut self.count), + Mask::AllFalse(_) => {} + Mask::Values(v) => accumulate_with_mask(values, v, &mut self.sum, &mut self.count), + } + }); + + Ok(()) + } + + fn merge(&mut self, state: &Scalar) -> VortexResult<()> { + if state.is_null() { + return Ok(()); + } + + let s = state.as_struct(); + let Some(sum_scalar) = s.field_by_idx(0) else { + vortex_bail!("Mean state struct missing sum field at index 0"); + }; + let Some(count_scalar) = s.field_by_idx(1) else { + vortex_bail!("Mean state struct missing count field at index 1"); + }; + + self.sum += sum_scalar + .as_primitive() + .typed_value::() + .unwrap_or(0.0); + self.count += count_scalar + .as_primitive() + .typed_value::() + .unwrap_or(0); + Ok(()) + } + + fn flush(&mut self) -> VortexResult<()> { + if self.count == 0 { + self.results.push(None); + } else { + self.results.push(Some(self.sum / self.count as f64)); + } + self.sum = 0.0; + self.count = 0; + Ok(()) + } + + fn finish(self: Box) -> VortexResult { + Ok(PrimitiveArray::from_option_iter(self.results).into_array()) + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_error::VortexResult; + + use crate::ArrayRef; + use crate::IntoArray; + use crate::aggregate_fn::AggregateFnVTable; + use crate::aggregate_fn::fns::mean::Mean; + use crate::arrays::PrimitiveArray; + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::PType; + use crate::dtype::StructFields; + use crate::scalar::Scalar; + use crate::scalar_fn::EmptyOptions; + use crate::validity::Validity; + + fn run_mean(batch: &ArrayRef) -> VortexResult { + let mut acc = Mean.accumulator(&EmptyOptions, batch.dtype())?; + acc.accumulate(batch)?; + acc.flush()?; + acc.finish() + } + + fn get_f64_value(array: &ArrayRef, idx: usize) -> VortexResult> { + let scalar = array.scalar_at(idx)?; + Ok(scalar.as_primitive().typed_value::()) + } + + #[test] + fn mean_i32() -> VortexResult<()> { + let arr = PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array(); + let result = run_mean(&arr)?; + assert_eq!(get_f64_value(&result, 0)?, Some(2.5)); + Ok(()) + } + + #[test] + fn mean_f64() -> VortexResult<()> { + let arr = + PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0], Validity::NonNullable).into_array(); + let result = run_mean(&arr)?; + assert_eq!(get_f64_value(&result, 0)?, Some(2.0)); + Ok(()) + } + + #[test] + fn mean_with_nulls() -> VortexResult<()> { + let arr = PrimitiveArray::from_option_iter([Some(2i32), None, Some(4)]).into_array(); + let result = run_mean(&arr)?; + assert_eq!(get_f64_value(&result, 0)?, Some(3.0)); + Ok(()) + } + + #[test] + fn mean_all_null() -> VortexResult<()> { + let arr = PrimitiveArray::from_option_iter([None::, None, None]).into_array(); + let result = run_mean(&arr)?; + assert_eq!(get_f64_value(&result, 0)?, None); + Ok(()) + } + + #[test] + fn mean_empty_flush() -> VortexResult<()> { + let mut acc = Mean.accumulator( + &EmptyOptions, + &DType::Primitive(PType::I32, Nullability::NonNullable), + )?; + acc.flush()?; + let result = acc.finish()?; + assert_eq!(get_f64_value(&result, 0)?, None); + Ok(()) + } + + #[test] + fn mean_multi_group() -> VortexResult<()> { + let mut acc = Mean.accumulator( + &EmptyOptions, + &DType::Primitive(PType::I32, Nullability::NonNullable), + )?; + + let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); + acc.accumulate(&batch1)?; + acc.flush()?; + + let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array(); + acc.accumulate(&batch2)?; + acc.flush()?; + + let result = acc.finish()?; + assert_eq!(get_f64_value(&result, 0)?, Some(15.0)); + assert_eq!(get_f64_value(&result, 1)?, Some(6.0)); + Ok(()) + } + + #[test] + fn mean_merge() -> VortexResult<()> { + let mut acc = Mean.accumulator( + &EmptyOptions, + &DType::Primitive(PType::I32, Nullability::NonNullable), + )?; + + let state_dtype = DType::Struct( + StructFields::from_iter([ + ( + "sum", + DType::Primitive(PType::F64, Nullability::NonNullable), + ), + ( + "count", + DType::Primitive(PType::U64, Nullability::NonNullable), + ), + ]), + Nullability::Nullable, + ); + + let state = Scalar::struct_( + state_dtype.clone(), + vec![ + Scalar::primitive(30.0f64, Nullability::NonNullable), + Scalar::primitive(3u64, Nullability::NonNullable), + ], + ); + acc.merge(&state)?; + + let state2 = Scalar::struct_( + state_dtype, + vec![ + Scalar::primitive(20.0f64, Nullability::NonNullable), + Scalar::primitive(2u64, Nullability::NonNullable), + ], + ); + acc.merge(&state2)?; + + acc.flush()?; + let result = acc.finish()?; + assert_eq!(get_f64_value(&result, 0)?, Some(10.0)); + Ok(()) + } +} diff --git a/vortex-array/src/aggregate_fn/fns/mod.rs b/vortex-array/src/aggregate_fn/fns/mod.rs index 0d735177e5d..ce44ea21406 100644 --- a/vortex-array/src/aggregate_fn/fns/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/mod.rs @@ -1,2 +1,4 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors + +pub mod mean; diff --git a/vortex-array/src/aggregate_fn/session.rs b/vortex-array/src/aggregate_fn/session.rs index 4e7bd12937b..3918b2a9c53 100644 --- a/vortex-array/src/aggregate_fn/session.rs +++ b/vortex-array/src/aggregate_fn/session.rs @@ -9,16 +9,27 @@ use vortex_session::registry::Registry; use crate::aggregate_fn::AggregateFnPluginRef; use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::fns::mean::Mean; /// Registry of aggregate function vtables. pub type AggregateFnRegistry = Registry; /// Session state for aggregate function vtables. -#[derive(Debug, Default)] +#[derive(Debug)] pub struct AggregateFnSession { registry: AggregateFnRegistry, } +impl Default for AggregateFnSession { + fn default() -> Self { + let session = Self { + registry: AggregateFnRegistry::default(), + }; + session.register(Mean); + session + } +} + impl AggregateFnSession { /// Returns the aggregate function registry. pub fn registry(&self) -> &AggregateFnRegistry {