From b7d835a04504f323318650a01444952d9b09c77a Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Sat, 28 Feb 2026 12:22:49 -0500 Subject: [PATCH 1/2] Aggregate Fns: Sum Signed-off-by: Nicholas Gates --- vortex-array/src/aggregate_fn/fns/mod.rs | 1 + .../aggregate_fn/fns/sum/bool_accumulator.rs | 124 +++++++++ .../aggregate_fn/fns/sum/float_accumulator.rs | 118 ++++++++ .../aggregate_fn/fns/sum/int_accumulator.rs | 185 +++++++++++++ vortex-array/src/aggregate_fn/fns/sum/mod.rs | 106 +++++++ .../src/aggregate_fn/fns/sum/tests.rs | 261 ++++++++++++++++++ vortex-array/src/aggregate_fn/session.rs | 2 + 7 files changed, 797 insertions(+) create mode 100644 vortex-array/src/aggregate_fn/fns/sum/bool_accumulator.rs create mode 100644 vortex-array/src/aggregate_fn/fns/sum/float_accumulator.rs create mode 100644 vortex-array/src/aggregate_fn/fns/sum/int_accumulator.rs create mode 100644 vortex-array/src/aggregate_fn/fns/sum/mod.rs create mode 100644 vortex-array/src/aggregate_fn/fns/sum/tests.rs diff --git a/vortex-array/src/aggregate_fn/fns/mod.rs b/vortex-array/src/aggregate_fn/fns/mod.rs index ce44ea21406..777480d0e7a 100644 --- a/vortex-array/src/aggregate_fn/fns/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/mod.rs @@ -2,3 +2,4 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors pub mod mean; +pub mod sum; diff --git a/vortex-array/src/aggregate_fn/fns/sum/bool_accumulator.rs b/vortex-array/src/aggregate_fn/fns/sum/bool_accumulator.rs new file mode 100644 index 00000000000..cd15b04cbc1 --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/sum/bool_accumulator.rs @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::ops::BitAnd; + +use vortex_error::VortexResult; +use vortex_mask::Mask; + +use crate::ArrayRef; +use crate::IntoArray; +use crate::aggregate_fn::accumulator::Accumulator; +use crate::arrays::PrimitiveArray; +use crate::canonical::ToCanonical; +use crate::scalar::Scalar; + +/// Accumulator that sums boolean values by counting `true` as 1 and `false` as 0. +/// +/// Output type is `u64` (nullable). Overflow is theoretically possible but extremely +/// unlikely since it would require more than `u64::MAX` true values. +pub(super) struct BoolSumAccumulator { + count: u64, + /// Whether at least one non-null value has been accumulated. + has_values: bool, + /// Whether accumulate() or merge() has been called at all (even with all-null data). + has_input: bool, + checked: bool, + overflowed: bool, + results: Vec>, +} + +impl BoolSumAccumulator { + pub(super) fn new(checked: bool) -> Self { + Self { + count: 0, + has_values: false, + has_input: false, + checked, + overflowed: false, + results: Vec::new(), + } + } +} + +impl Accumulator for BoolSumAccumulator { + fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()> { + self.has_input = true; + if self.overflowed { + return Ok(()); + } + + let bool_array = batch.to_bool(); + let validity = bool_array.validity_mask()?; + + let true_count = match &validity { + Mask::AllTrue(_) => bool_array.to_bit_buffer().true_count() as u64, + Mask::AllFalse(_) => return Ok(()), + Mask::Values(v) => bool_array + .to_bit_buffer() + .bitand(v.bit_buffer()) + .true_count() as u64, + }; + + self.has_values = true; + if self.checked { + if let Some(new_count) = self.count.checked_add(true_count) { + self.count = new_count; + } else { + self.overflowed = true; + } + } else { + self.count = self.count.wrapping_add(true_count); + } + + Ok(()) + } + + fn merge(&mut self, state: &Scalar) -> VortexResult<()> { + if state.is_null() { + return Ok(()); + } + self.has_input = true; + if let Some(v) = state.as_primitive().typed_value::() { + self.has_values = true; + if self.checked { + if let Some(new_count) = self.count.checked_add(v) { + self.count = new_count; + } else { + self.overflowed = true; + } + } else { + self.count = self.count.wrapping_add(v); + } + } + Ok(()) + } + + fn is_saturated(&self) -> bool { + self.checked && self.overflowed + } + + fn flush(&mut self) -> VortexResult<()> { + let result = if self.overflowed { + None + } else if self.has_values { + Some(self.count) + } else if self.has_input { + // All-null group. + None + } else { + // Empty group: identity is zero. + Some(0) + }; + self.results.push(result); + self.count = 0; + self.has_values = false; + self.has_input = false; + self.overflowed = false; + Ok(()) + } + + fn finish(self: Box) -> VortexResult { + Ok(PrimitiveArray::from_option_iter(self.results).into_array()) + } +} diff --git a/vortex-array/src/aggregate_fn/fns/sum/float_accumulator.rs b/vortex-array/src/aggregate_fn/fns/sum/float_accumulator.rs new file mode 100644 index 00000000000..245708e8799 --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/sum/float_accumulator.rs @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; +use vortex_mask::Mask; + +use crate::ArrayRef; +use crate::IntoArray; +use crate::aggregate_fn::accumulator::Accumulator; +use crate::arrays::PrimitiveArray; +use crate::canonical::ToCanonical; +use crate::dtype::NativePType; +use crate::match_each_native_ptype; +use crate::scalar::Scalar; + +pub(super) struct FloatSumAccumulator { + sum: f64, + /// Whether at least one non-null value has been accumulated. + has_values: bool, + /// Whether accumulate() or merge() has been called at all (even with all-null data). + has_input: bool, + results: Vec>, +} + +impl FloatSumAccumulator { + pub(super) fn new() -> Self { + Self { + sum: 0.0, + has_values: false, + has_input: false, + results: Vec::new(), + } + } +} + +fn accumulate_all_valid(values: &[T], sum: &mut f64, has_values: &mut bool) { + for v in values { + *has_values = true; + *sum += v.to_f64().unwrap_or(0.0); + } +} + +fn accumulate_with_mask( + values: &[T], + mask: &vortex_mask::MaskValues, + sum: &mut f64, + has_values: &mut bool, +) { + for (v, valid) in values.iter().zip(mask.bit_buffer().iter()) { + if valid { + *has_values = true; + *sum += v.to_f64().unwrap_or(0.0); + } + } +} + +impl Accumulator for FloatSumAccumulator { + fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()> { + self.has_input = true; + let primitive = batch.to_primitive(); + let validity = primitive.validity_mask()?; + + match_each_native_ptype!(primitive.ptype(), integral: |_T| { + unreachable!("FloatSumAccumulator should not be used with integer types"); + }, floating: |T| { + let values = primitive.as_slice::(); + match &validity { + Mask::AllTrue(_) => accumulate_all_valid( + values, + &mut self.sum, + &mut self.has_values, + ), + Mask::AllFalse(_) => {} + Mask::Values(v) => accumulate_with_mask( + values, + v, + &mut self.sum, + &mut self.has_values, + ), + } + }); + + Ok(()) + } + + fn merge(&mut self, state: &Scalar) -> VortexResult<()> { + if state.is_null() { + return Ok(()); + } + self.has_input = true; + if let Some(v) = state.as_primitive().typed_value::() { + self.has_values = true; + self.sum += v; + } + Ok(()) + } + + fn flush(&mut self) -> VortexResult<()> { + let result = if self.has_values { + Some(self.sum) + } else if self.has_input { + // All-null group. + None + } else { + // Empty group: identity is zero. + Some(0.0) + }; + self.results.push(result); + self.sum = 0.0; + self.has_values = false; + self.has_input = false; + Ok(()) + } + + fn finish(self: Box) -> VortexResult { + Ok(PrimitiveArray::from_option_iter(self.results).into_array()) + } +} diff --git a/vortex-array/src/aggregate_fn/fns/sum/int_accumulator.rs b/vortex-array/src/aggregate_fn/fns/sum/int_accumulator.rs new file mode 100644 index 00000000000..85dfafe38d4 --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/sum/int_accumulator.rs @@ -0,0 +1,185 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use num_traits::CheckedAdd; +use num_traits::ToPrimitive; +use num_traits::WrappingAdd; +use vortex_error::VortexResult; +use vortex_mask::Mask; + +use crate::ArrayRef; +use crate::IntoArray; +use crate::aggregate_fn::accumulator::Accumulator; +use crate::arrays::PrimitiveArray; +use crate::canonical::ToCanonical; +use crate::dtype::NativePType; +use crate::match_each_native_ptype; +use crate::scalar::Scalar; + +pub(super) struct IntSumAccumulator { + sum: R, + overflowed: bool, + /// Whether at least one non-null value has been accumulated. + has_values: bool, + /// Whether accumulate() or merge() has been called at all (even with all-null data). + has_input: bool, + checked: bool, + results: Vec>, +} + +impl IntSumAccumulator { + pub(super) fn new(checked: bool) -> Self { + Self { + sum: R::default(), + overflowed: false, + has_values: false, + has_input: false, + checked, + results: Vec::new(), + } + } +} + +fn accumulate_all_valid( + values: &[T], + sum: &mut R, + overflowed: &mut bool, + has_values: &mut bool, + checked: bool, +) { + for &v in values { + if *overflowed { + return; + } + *has_values = true; + if checked { + let Some(widened) = R::from(v) else { + *overflowed = true; + return; + }; + let Some(new_sum) = sum.checked_add(&widened) else { + *overflowed = true; + return; + }; + *sum = new_sum; + } else { + let widened = R::from(v).unwrap_or_default(); + *sum = sum.wrapping_add(&widened); + } + } +} + +fn accumulate_with_mask( + values: &[T], + mask: &vortex_mask::MaskValues, + sum: &mut R, + overflowed: &mut bool, + has_values: &mut bool, + checked: bool, +) { + for (&v, valid) in values.iter().zip(mask.bit_buffer().iter()) { + if *overflowed { + return; + } + if valid { + *has_values = true; + if checked { + let Some(widened) = R::from(v) else { + *overflowed = true; + return; + }; + let Some(new_sum) = sum.checked_add(&widened) else { + *overflowed = true; + return; + }; + *sum = new_sum; + } else { + let widened = R::from(v).unwrap_or_default(); + *sum = sum.wrapping_add(&widened); + } + } + } +} + +impl Accumulator for IntSumAccumulator { + fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()> { + self.has_input = true; + let primitive = batch.to_primitive(); + let validity = primitive.validity_mask()?; + + match_each_native_ptype!(primitive.ptype(), integral: |T| { + let values = primitive.as_slice::(); + match &validity { + Mask::AllTrue(_) => accumulate_all_valid( + values, + &mut self.sum, + &mut self.overflowed, + &mut self.has_values, + self.checked, + ), + Mask::AllFalse(_) => {} + Mask::Values(v) => accumulate_with_mask( + values, + v, + &mut self.sum, + &mut self.overflowed, + &mut self.has_values, + self.checked, + ), + } + }, floating: |_T| { + unreachable!("IntSumAccumulator should not be used with floating-point types"); + }); + + Ok(()) + } + + fn merge(&mut self, state: &Scalar) -> VortexResult<()> { + if state.is_null() { + return Ok(()); + } + self.has_input = true; + let val = state.as_primitive().typed_value::(); + if let Some(v) = val { + self.has_values = true; + if self.checked { + if let Some(new_sum) = self.sum.checked_add(&v) { + self.sum = new_sum; + } else { + self.overflowed = true; + } + } else { + self.sum = self.sum.wrapping_add(&v); + } + } + Ok(()) + } + + fn is_saturated(&self) -> bool { + self.checked && self.overflowed + } + + fn flush(&mut self) -> VortexResult<()> { + let result = if self.overflowed { + None + } else if self.has_values { + Some(self.sum) + } else if self.has_input { + // All-null group: no non-null values were seen. + None + } else { + // Empty group: no accumulate/merge calls at all. Identity is zero. + Some(R::default()) + }; + self.results.push(result); + self.sum = R::default(); + self.overflowed = false; + self.has_values = false; + self.has_input = false; + Ok(()) + } + + fn finish(self: Box) -> VortexResult { + Ok(PrimitiveArray::from_option_iter(self.results).into_array()) + } +} diff --git a/vortex-array/src/aggregate_fn/fns/sum/mod.rs b/vortex-array/src/aggregate_fn/fns/sum/mod.rs new file mode 100644 index 00000000000..ca0b39c3ece --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/sum/mod.rs @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +mod bool_accumulator; +mod float_accumulator; +mod int_accumulator; + +use std::fmt::Display; +use std::fmt::Formatter; + +use vortex_error::VortexResult; +use vortex_error::vortex_bail; + +use self::bool_accumulator::BoolSumAccumulator; +use self::float_accumulator::FloatSumAccumulator; +use self::int_accumulator::IntSumAccumulator; +use crate::aggregate_fn::AggregateFnId; +use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::accumulator::Accumulator; +use crate::dtype::DType; +use crate::dtype::Nullability; +use crate::dtype::PType; + +/// Options for the Sum aggregate function. +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct SumOptions { + /// Whether to use checked arithmetic (default: `true`). + /// + /// When `true`, integer overflow produces a null result. + /// When `false`, integer overflow wraps around. + /// + /// Note that i64/u64 inputs can still overflow even with type widening, + /// since they are already at the widest integer type. + pub checked: bool, +} + +impl Display for SumOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + if self.checked { + write!(f, "SUM(checked)") + } else { + write!(f, "SUM(wrapping)") + } + } +} + +/// Maps an input PType to the widened output PType for sum. +fn sum_output_ptype(ptype: PType) -> PType { + match ptype { + PType::U8 | PType::U16 | PType::U32 | PType::U64 => PType::U64, + PType::I8 | PType::I16 | PType::I32 | PType::I64 => PType::I64, + PType::F16 | PType::F32 | PType::F64 => PType::F64, + } +} + +/// Computes the sum of numeric or boolean values. +/// +/// For primitive numeric types, the output is widened (unsigned -> u64, signed -> i64, +/// float -> f64). For boolean inputs, `true` counts as 1 and `false` as 0, producing +/// a u64 output. +#[derive(Clone)] +pub struct Sum; + +impl AggregateFnVTable for Sum { + type Options = SumOptions; + + fn id(&self) -> AggregateFnId { + AggregateFnId::new_ref("vortex.sum") + } + + fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> VortexResult { + match input_dtype { + DType::Bool(_) => Ok(DType::Primitive(PType::U64, Nullability::Nullable)), + DType::Primitive(p, _) => Ok(DType::Primitive( + sum_output_ptype(*p), + Nullability::Nullable, + )), + _ => vortex_bail!("Sum requires numeric or boolean input, got {}", input_dtype), + } + } + + fn state_dtype(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult { + self.return_dtype(options, input_dtype) + } + + fn accumulator( + &self, + options: &Self::Options, + input_dtype: &DType, + ) -> VortexResult> { + let checked = options.checked; + match input_dtype { + DType::Bool(_) => Ok(Box::new(BoolSumAccumulator::new(checked))), + DType::Primitive(p, _) => match sum_output_ptype(*p) { + PType::U64 => Ok(Box::new(IntSumAccumulator::::new(checked))), + PType::I64 => Ok(Box::new(IntSumAccumulator::::new(checked))), + PType::F64 => Ok(Box::new(FloatSumAccumulator::new())), + _ => unreachable!(), + }, + _ => vortex_bail!("Sum requires numeric or boolean input, got {}", input_dtype), + } + } +} + +#[cfg(test)] +mod tests; diff --git a/vortex-array/src/aggregate_fn/fns/sum/tests.rs b/vortex-array/src/aggregate_fn/fns/sum/tests.rs new file mode 100644 index 00000000000..afc1bde78c9 --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/sum/tests.rs @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_buffer::buffer; +use vortex_error::VortexResult; + +use crate::ArrayRef; +use crate::IntoArray; +use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::fns::sum::Sum; +use crate::aggregate_fn::fns::sum::SumOptions; +use crate::arrays::BoolArray; +use crate::arrays::PrimitiveArray; +use crate::dtype::DType; +use crate::dtype::Nullability; +use crate::dtype::PType; +use crate::scalar::Scalar; +use crate::validity::Validity; + +fn checked_opts() -> SumOptions { + SumOptions { checked: true } +} + +fn unchecked_opts() -> SumOptions { + SumOptions { checked: false } +} + +fn run_sum(batch: &ArrayRef, options: &SumOptions) -> VortexResult { + let mut acc = Sum.accumulator(options, batch.dtype())?; + acc.accumulate(batch)?; + acc.flush()?; + acc.finish() +} + +fn get_i64_value(array: &ArrayRef, idx: usize) -> VortexResult> { + let scalar = array.scalar_at(idx)?; + Ok(scalar.as_primitive().typed_value::()) +} + +fn get_u64_value(array: &ArrayRef, idx: usize) -> VortexResult> { + let scalar = array.scalar_at(idx)?; + Ok(scalar.as_primitive().typed_value::()) +} + +fn get_f64_value(array: &ArrayRef, idx: usize) -> VortexResult> { + let scalar = array.scalar_at(idx)?; + Ok(scalar.as_primitive().typed_value::()) +} + +#[test] +fn sum_i32() -> VortexResult<()> { + let arr = PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array(); + let result = run_sum(&arr, &checked_opts())?; + assert_eq!(get_i64_value(&result, 0)?, Some(10)); + Ok(()) +} + +#[test] +fn sum_u8() -> VortexResult<()> { + let arr = PrimitiveArray::new(buffer![10u8, 20, 30], Validity::NonNullable).into_array(); + let result = run_sum(&arr, &checked_opts())?; + assert_eq!(get_u64_value(&result, 0)?, Some(60)); + Ok(()) +} + +#[test] +fn sum_f64() -> VortexResult<()> { + let arr = PrimitiveArray::new(buffer![1.5f64, 2.5, 3.0], Validity::NonNullable).into_array(); + let result = run_sum(&arr, &checked_opts())?; + assert_eq!(get_f64_value(&result, 0)?, Some(7.0)); + Ok(()) +} + +#[test] +fn sum_with_nulls() -> VortexResult<()> { + let arr = PrimitiveArray::from_option_iter([Some(2i32), None, Some(4)]).into_array(); + let result = run_sum(&arr, &checked_opts())?; + assert_eq!(get_i64_value(&result, 0)?, Some(6)); + Ok(()) +} + +#[test] +fn sum_all_null() -> VortexResult<()> { + let arr = PrimitiveArray::from_option_iter([None::, None, None]).into_array(); + let result = run_sum(&arr, &checked_opts())?; + assert_eq!(get_i64_value(&result, 0)?, None); + Ok(()) +} + +#[test] +fn sum_empty_flush_produces_zero() -> VortexResult<()> { + let mut acc = Sum.accumulator( + &checked_opts(), + &DType::Primitive(PType::I32, Nullability::NonNullable), + )?; + acc.flush()?; + let result = acc.finish()?; + assert_eq!(get_i64_value(&result, 0)?, Some(0)); + Ok(()) +} + +#[test] +fn sum_empty_flush_f64_produces_zero() -> VortexResult<()> { + let mut acc = Sum.accumulator( + &checked_opts(), + &DType::Primitive(PType::F64, Nullability::NonNullable), + )?; + acc.flush()?; + let result = acc.finish()?; + assert_eq!(get_f64_value(&result, 0)?, Some(0.0)); + Ok(()) +} + +#[test] +fn sum_multi_group() -> VortexResult<()> { + let mut acc = Sum.accumulator( + &checked_opts(), + &DType::Primitive(PType::I32, Nullability::NonNullable), + )?; + + let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); + acc.accumulate(&batch1)?; + acc.flush()?; + + let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array(); + acc.accumulate(&batch2)?; + acc.flush()?; + + let result = acc.finish()?; + assert_eq!(get_i64_value(&result, 0)?, Some(30)); + assert_eq!(get_i64_value(&result, 1)?, Some(18)); + Ok(()) +} + +#[test] +fn sum_merge() -> VortexResult<()> { + let mut acc = Sum.accumulator( + &checked_opts(), + &DType::Primitive(PType::I32, Nullability::NonNullable), + )?; + + let state1 = Scalar::primitive(100i64, Nullability::Nullable); + acc.merge(&state1)?; + + let state2 = Scalar::primitive(50i64, Nullability::Nullable); + acc.merge(&state2)?; + + acc.flush()?; + let result = acc.finish()?; + assert_eq!(get_i64_value(&result, 0)?, Some(150)); + Ok(()) +} + +#[test] +fn sum_checked_overflow() -> VortexResult<()> { + let arr = PrimitiveArray::new(buffer![i64::MAX, 1i64], Validity::NonNullable).into_array(); + let result = run_sum(&arr, &checked_opts())?; + assert_eq!(get_i64_value(&result, 0)?, None); + Ok(()) +} + +#[test] +fn sum_checked_overflow_is_saturated() -> VortexResult<()> { + let mut acc = Sum.accumulator( + &checked_opts(), + &DType::Primitive(PType::I64, Nullability::NonNullable), + )?; + assert!(!acc.is_saturated()); + + let batch = PrimitiveArray::new(buffer![i64::MAX, 1i64], Validity::NonNullable).into_array(); + acc.accumulate(&batch)?; + assert!(acc.is_saturated()); + + acc.flush()?; + assert!(!acc.is_saturated()); + Ok(()) +} + +#[test] +fn sum_unchecked_wrapping() -> VortexResult<()> { + let arr = PrimitiveArray::new(buffer![i64::MAX, 1i64], Validity::NonNullable).into_array(); + let result = run_sum(&arr, &unchecked_opts())?; + assert_eq!(get_i64_value(&result, 0)?, Some(i64::MAX.wrapping_add(1))); + Ok(()) +} + +// Boolean sum tests + +#[test] +fn sum_bool_all_true() -> VortexResult<()> { + let arr: BoolArray = [true, true, true].into_iter().collect(); + let result = run_sum(&arr.into_array(), &checked_opts())?; + assert_eq!(get_u64_value(&result, 0)?, Some(3)); + Ok(()) +} + +#[test] +fn sum_bool_mixed() -> VortexResult<()> { + let arr: BoolArray = [true, false, true, false, true].into_iter().collect(); + let result = run_sum(&arr.into_array(), &checked_opts())?; + assert_eq!(get_u64_value(&result, 0)?, Some(3)); + Ok(()) +} + +#[test] +fn sum_bool_all_false() -> VortexResult<()> { + let arr: BoolArray = [false, false, false].into_iter().collect(); + let result = run_sum(&arr.into_array(), &checked_opts())?; + assert_eq!(get_u64_value(&result, 0)?, Some(0)); + Ok(()) +} + +#[test] +fn sum_bool_with_nulls() -> VortexResult<()> { + let arr = BoolArray::from_iter([Some(true), None, Some(true), Some(false)]); + let result = run_sum(&arr.into_array(), &checked_opts())?; + assert_eq!(get_u64_value(&result, 0)?, Some(2)); + Ok(()) +} + +#[test] +fn sum_bool_all_null() -> VortexResult<()> { + let arr = BoolArray::from_iter([None::, None, None]); + let result = run_sum(&arr.into_array(), &checked_opts())?; + assert_eq!(get_u64_value(&result, 0)?, None); + Ok(()) +} + +#[test] +fn sum_bool_empty_flush_produces_zero() -> VortexResult<()> { + let mut acc = Sum.accumulator(&checked_opts(), &DType::Bool(Nullability::NonNullable))?; + acc.flush()?; + let result = acc.finish()?; + assert_eq!(get_u64_value(&result, 0)?, Some(0)); + Ok(()) +} + +#[test] +fn sum_bool_multi_group() -> VortexResult<()> { + let mut acc = Sum.accumulator(&checked_opts(), &DType::Bool(Nullability::NonNullable))?; + + let batch1: BoolArray = [true, true, false].into_iter().collect(); + acc.accumulate(&batch1.into_array())?; + acc.flush()?; + + let batch2: BoolArray = [false, true].into_iter().collect(); + acc.accumulate(&batch2.into_array())?; + acc.flush()?; + + let result = acc.finish()?; + assert_eq!(get_u64_value(&result, 0)?, Some(2)); + assert_eq!(get_u64_value(&result, 1)?, Some(1)); + Ok(()) +} + +#[test] +fn sum_bool_return_dtype() -> VortexResult<()> { + let dtype = Sum.return_dtype(&checked_opts(), &DType::Bool(Nullability::NonNullable))?; + assert_eq!(dtype, DType::Primitive(PType::U64, Nullability::Nullable)); + Ok(()) +} diff --git a/vortex-array/src/aggregate_fn/session.rs b/vortex-array/src/aggregate_fn/session.rs index 3918b2a9c53..6ec6a10b735 100644 --- a/vortex-array/src/aggregate_fn/session.rs +++ b/vortex-array/src/aggregate_fn/session.rs @@ -10,6 +10,7 @@ use vortex_session::registry::Registry; use crate::aggregate_fn::AggregateFnPluginRef; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::fns::mean::Mean; +use crate::aggregate_fn::fns::sum::Sum; /// Registry of aggregate function vtables. pub type AggregateFnRegistry = Registry; @@ -26,6 +27,7 @@ impl Default for AggregateFnSession { registry: AggregateFnRegistry::default(), }; session.register(Mean); + session.register(Sum); session } } From dc29d55df137f4b7562ace1a447a8ce6ad54c48f Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Sat, 28 Feb 2026 12:24:03 -0500 Subject: [PATCH 2/2] Aggregate Fns: Sum Signed-off-by: Nicholas Gates --- vortex-array/src/aggregate_fn/fns/sum/mod.rs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/vortex-array/src/aggregate_fn/fns/sum/mod.rs b/vortex-array/src/aggregate_fn/fns/sum/mod.rs index ca0b39c3ece..b1b21c6e9a6 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/mod.rs @@ -57,7 +57,20 @@ fn sum_output_ptype(ptype: PType) -> PType { /// /// For primitive numeric types, the output is widened (unsigned -> u64, signed -> i64, /// float -> f64). For boolean inputs, `true` counts as 1 and `false` as 0, producing -/// a u64 output. +/// a u64 output. All output types are nullable. +/// +/// # Flush semantics +/// +/// - **Empty group** (flush with no prior accumulate/merge): produces **zero** (the additive +/// identity). +/// - **All-null group** (accumulate called but every value was null): produces **null**. +/// - **Checked overflow**: produces **null**, and `is_saturated()` returns true so callers can +/// skip further accumulation. +/// +/// Note: this differs from the [`compute::sum`](crate::compute::sum) function, which treats +/// both empty and all-null arrays identically (returning the zero accumulator). The aggregate +/// distinguishes the two cases because all-null is semantically "unknown" while empty is +/// "no data, so the identity applies". #[derive(Clone)] pub struct Sum;