diff --git a/vortex-array/Cargo.toml b/vortex-array/Cargo.toml index 22feac3d943..e915c8e2002 100644 --- a/vortex-array/Cargo.toml +++ b/vortex-array/Cargo.toml @@ -127,6 +127,11 @@ name = "expr_large_struct_pack" path = "benches/expr/large_struct_pack.rs" harness = false +[[bench]] +name = "expr_case_when" +path = "benches/expr/case_when_bench.rs" +harness = false + [[bench]] name = "chunked_dict_builder" harness = false diff --git a/vortex-array/benches/expr/case_when_bench.rs b/vortex-array/benches/expr/case_when_bench.rs new file mode 100644 index 00000000000..e25ad180fa2 --- /dev/null +++ b/vortex-array/benches/expr/case_when_bench.rs @@ -0,0 +1,210 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +#![allow(clippy::unwrap_used)] +#![allow(clippy::cast_possible_truncation)] + +use std::sync::LazyLock; + +use divan::Bencher; +use vortex_array::ArrayRef; +use vortex_array::Canonical; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::StructArray; +use vortex_array::expr::case_when; +use vortex_array::expr::case_when_no_else; +use vortex_array::expr::eq; +use vortex_array::expr::get_item; +use vortex_array::expr::gt; +use vortex_array::expr::lit; +use vortex_array::expr::nested_case_when; +use vortex_array::expr::root; +use vortex_array::session::ArraySession; +use vortex_buffer::Buffer; +use vortex_session::VortexSession; + +static SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); + +fn main() { + divan::main(); +} + +fn make_struct_array(size: usize) -> ArrayRef { + let data: Buffer = (0..size as i32).collect(); + let field = data.into_array(); + StructArray::from_fields(&[("value", field)]) + .unwrap() + .into_array() +} + +/// Benchmark a simple binary CASE WHEN with varying array sizes. +#[divan::bench(args = [1000, 10000, 100000])] +fn case_when_simple(bencher: Bencher, size: usize) { + let array = make_struct_array(size); + + // CASE WHEN value > 500 THEN 100 ELSE 0 END + let expr = case_when( + gt(get_item("value", root()), lit(500i32)), + lit(100i32), + lit(0i32), + ); + + bencher + .with_inputs(|| (&expr, &array)) + .bench_refs(|(expr, array)| { + let mut ctx = SESSION.create_execution_ctx(); + array + .apply(expr) + .unwrap() + .execute::(&mut ctx) + .unwrap() + }); +} + +/// Benchmark n-ary CASE WHEN with 3 conditions. +#[divan::bench(args = [1000, 10000, 100000])] +fn case_when_nary_3_conditions(bencher: Bencher, size: usize) { + let array = make_struct_array(size); + + // CASE WHEN value > 750 THEN 3 WHEN value > 500 THEN 2 WHEN value > 250 THEN 1 ELSE 0 END + let expr = nested_case_when( + vec![ + (gt(get_item("value", root()), lit(750i32)), lit(3i32)), + (gt(get_item("value", root()), lit(500i32)), lit(2i32)), + (gt(get_item("value", root()), lit(250i32)), lit(1i32)), + ], + Some(lit(0i32)), + ); + + bencher + .with_inputs(|| (&expr, &array)) + .bench_refs(|(expr, array)| { + let mut ctx = SESSION.create_execution_ctx(); + array + .apply(expr) + .unwrap() + .execute::(&mut ctx) + .unwrap() + }); +} + +/// Benchmark n-ary CASE WHEN with 10 conditions. +#[divan::bench(args = [1000, 10000, 100000])] +fn case_when_nary_10_conditions(bencher: Bencher, size: usize) { + let array = make_struct_array(size); + + let pairs: Vec<_> = (0..10) + .map(|i| { + let threshold = (i + 1) * (size as i32 / 10); + ( + gt(get_item("value", root()), lit(threshold)), + lit((i + 1) * 100), + ) + }) + .collect(); + let expr = nested_case_when(pairs, Some(lit(0i32))); + + bencher + .with_inputs(|| (&expr, &array)) + .bench_refs(|(expr, array)| { + let mut ctx = SESSION.create_execution_ctx(); + array + .apply(expr) + .unwrap() + .execute::(&mut ctx) + .unwrap() + }); +} + +/// Benchmark n-ary CASE WHEN with equality conditions (lookup-table style). +#[divan::bench(args = [1000, 10000, 100000])] +fn case_when_nary_equality_lookup(bencher: Bencher, size: usize) { + let array = make_struct_array(size); + + // Map specific values: CASE WHEN value = 0 THEN 'a' WHEN value = 1 THEN 'b' ... ELSE 'other' END + let pairs: Vec<_> = (0..5) + .map(|i| (eq(get_item("value", root()), lit(i)), lit(i * 10))) + .collect(); + let expr = nested_case_when(pairs, Some(lit(-1i32))); + + bencher + .with_inputs(|| (&expr, &array)) + .bench_refs(|(expr, array)| { + let mut ctx = SESSION.create_execution_ctx(); + array + .apply(expr) + .unwrap() + .execute::(&mut ctx) + .unwrap() + }); +} + +/// Benchmark CASE WHEN without ELSE clause (result is nullable). +#[divan::bench(args = [1000, 10000, 100000])] +fn case_when_without_else(bencher: Bencher, size: usize) { + let array = make_struct_array(size); + + // CASE WHEN value > 500 THEN 100 END + let expr = case_when_no_else(gt(get_item("value", root()), lit(500i32)), lit(100i32)); + + bencher + .with_inputs(|| (&expr, &array)) + .bench_refs(|(expr, array)| { + let mut ctx = SESSION.create_execution_ctx(); + array + .apply(expr) + .unwrap() + .execute::(&mut ctx) + .unwrap() + }); +} + +/// Benchmark CASE WHEN where all conditions are true. +#[divan::bench(args = [1000, 10000, 100000])] +fn case_when_all_true(bencher: Bencher, size: usize) { + let array = make_struct_array(size); + + // CASE WHEN value >= 0 THEN 100 ELSE 0 END (always true for our data) + let expr = case_when( + gt(get_item("value", root()), lit(-1i32)), + lit(100i32), + lit(0i32), + ); + + bencher + .with_inputs(|| (&expr, &array)) + .bench_refs(|(expr, array)| { + let mut ctx = SESSION.create_execution_ctx(); + array + .apply(expr) + .unwrap() + .execute::(&mut ctx) + .unwrap() + }); +} + +/// Benchmark CASE WHEN where all conditions are false. +#[divan::bench(args = [1000, 10000, 100000])] +fn case_when_all_false(bencher: Bencher, size: usize) { + let array = make_struct_array(size); + + // CASE WHEN value > 1000000 THEN 100 ELSE 0 END (always false for our data) + let expr = case_when( + gt(get_item("value", root()), lit(1_000_000i32)), + lit(100i32), + lit(0i32), + ); + + bencher + .with_inputs(|| (&expr, &array)) + .bench_refs(|(expr, array)| { + let mut ctx = SESSION.create_execution_ctx(); + array + .apply(expr) + .unwrap() + .execute::(&mut ctx) + .unwrap() + }); +} diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index e4a72149f27..4e1817ab468 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -10106,6 +10106,10 @@ pub fn vortex_array::expr::and_collect(iter: I) -> core::option::Option vortex_array::expr::Expression +pub fn vortex_array::expr::case_when(condition: vortex_array::expr::Expression, then_value: vortex_array::expr::Expression, else_value: vortex_array::expr::Expression) -> vortex_array::expr::Expression + +pub fn vortex_array::expr::case_when_no_else(condition: vortex_array::expr::Expression, then_value: vortex_array::expr::Expression) -> vortex_array::expr::Expression + pub fn vortex_array::expr::cast(child: vortex_array::expr::Expression, target: vortex_array::dtype::DType) -> vortex_array::expr::Expression pub fn vortex_array::expr::checked_add(lhs: vortex_array::expr::Expression, rhs: vortex_array::expr::Expression) -> vortex_array::expr::Expression @@ -10160,6 +10164,8 @@ pub fn vortex_array::expr::merge(elements: impl core::iter::traits::collect::Int pub fn vortex_array::expr::merge_opts(elements: impl core::iter::traits::collect::IntoIterator>, duplicate_handling: vortex_array::scalar_fn::fns::merge::DuplicateHandling) -> vortex_array::expr::Expression +pub fn vortex_array::expr::nested_case_when(when_then_pairs: alloc::vec::Vec<(vortex_array::expr::Expression, vortex_array::expr::Expression)>, else_value: core::option::Option) -> vortex_array::expr::Expression + pub fn vortex_array::expr::not(operand: vortex_array::expr::Expression) -> vortex_array::expr::Expression pub fn vortex_array::expr::not_eq(lhs: vortex_array::expr::Expression, rhs: vortex_array::expr::Expression) -> vortex_array::expr::Expression @@ -13210,6 +13216,86 @@ pub fn vortex_array::scalar_fn::fns::binary::or_kleene(lhs: &vortex_array::Array pub fn vortex_array::scalar_fn::fns::binary::scalar_cmp(lhs: &vortex_array::scalar::Scalar, rhs: &vortex_array::scalar::Scalar, operator: vortex_array::scalar_fn::fns::operators::CompareOperator) -> vortex_array::scalar::Scalar +pub mod vortex_array::scalar_fn::fns::case_when + +pub struct vortex_array::scalar_fn::fns::case_when::CaseWhen + +impl core::clone::Clone for vortex_array::scalar_fn::fns::case_when::CaseWhen + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::clone(&self) -> vortex_array::scalar_fn::fns::case_when::CaseWhen + +impl vortex_array::scalar_fn::ScalarFnVTable for vortex_array::scalar_fn::fns::case_when::CaseWhen + +pub type vortex_array::scalar_fn::fns::case_when::CaseWhen::Options = vortex_array::scalar_fn::fns::case_when::CaseWhenOptions + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::arity(&self, options: &Self::Options) -> vortex_array::scalar_fn::Arity + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::child_name(&self, options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::deserialize(&self, metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::fmt_sql(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::id(&self) -> vortex_array::scalar_fn::ScalarFnId + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::is_fallible(&self, _options: &Self::Options) -> bool + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::is_null_sensitive(&self, _options: &Self::Options) -> bool + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::reduce(&self, options: &Self::Options, node: &dyn vortex_array::scalar_fn::ReduceNode, ctx: &dyn vortex_array::scalar_fn::ReduceCtx) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::return_dtype(&self, options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::simplify(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, ctx: &dyn vortex_array::scalar_fn::SimplifyCtx) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::simplify_untyped(&self, options: &Self::Options, expr: &vortex_array::expr::Expression) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::stat_expression(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, stat: vortex_array::expr::stats::Stat, catalog: &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::stat_falsification(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, catalog: &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::validity(&self, options: &Self::Options, expression: &vortex_array::expr::Expression) -> vortex_error::VortexResult> + +pub struct vortex_array::scalar_fn::fns::case_when::CaseWhenOptions + +pub vortex_array::scalar_fn::fns::case_when::CaseWhenOptions::has_else: bool + +pub vortex_array::scalar_fn::fns::case_when::CaseWhenOptions::num_when_then_pairs: u32 + +impl vortex_array::scalar_fn::fns::case_when::CaseWhenOptions + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhenOptions::num_children(&self) -> usize + +impl core::clone::Clone for vortex_array::scalar_fn::fns::case_when::CaseWhenOptions + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhenOptions::clone(&self) -> vortex_array::scalar_fn::fns::case_when::CaseWhenOptions + +impl core::cmp::Eq for vortex_array::scalar_fn::fns::case_when::CaseWhenOptions + +impl core::cmp::PartialEq for vortex_array::scalar_fn::fns::case_when::CaseWhenOptions + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhenOptions::eq(&self, other: &vortex_array::scalar_fn::fns::case_when::CaseWhenOptions) -> bool + +impl core::fmt::Debug for vortex_array::scalar_fn::fns::case_when::CaseWhenOptions + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhenOptions::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::fmt::Display for vortex_array::scalar_fn::fns::case_when::CaseWhenOptions + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhenOptions::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_array::scalar_fn::fns::case_when::CaseWhenOptions + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhenOptions::hash<__H: core::hash::Hasher>(&self, state: &mut __H) + +impl core::marker::Copy for vortex_array::scalar_fn::fns::case_when::CaseWhenOptions + +impl core::marker::StructuralPartialEq for vortex_array::scalar_fn::fns::case_when::CaseWhenOptions + pub mod vortex_array::scalar_fn::fns::cast pub struct vortex_array::scalar_fn::fns::cast::Cast @@ -15042,6 +15128,42 @@ pub fn vortex_array::scalar_fn::fns::binary::Binary::stat_falsification(&self, o pub fn vortex_array::scalar_fn::fns::binary::Binary::validity(&self, operator: &vortex_array::scalar_fn::fns::operators::Operator, expression: &vortex_array::expr::Expression) -> vortex_error::VortexResult> +impl vortex_array::scalar_fn::ScalarFnVTable for vortex_array::scalar_fn::fns::case_when::CaseWhen + +pub type vortex_array::scalar_fn::fns::case_when::CaseWhen::Options = vortex_array::scalar_fn::fns::case_when::CaseWhenOptions + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::arity(&self, options: &Self::Options) -> vortex_array::scalar_fn::Arity + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::child_name(&self, options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::deserialize(&self, metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::fmt_sql(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::id(&self) -> vortex_array::scalar_fn::ScalarFnId + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::is_fallible(&self, _options: &Self::Options) -> bool + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::is_null_sensitive(&self, _options: &Self::Options) -> bool + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::reduce(&self, options: &Self::Options, node: &dyn vortex_array::scalar_fn::ReduceNode, ctx: &dyn vortex_array::scalar_fn::ReduceCtx) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::return_dtype(&self, options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::simplify(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, ctx: &dyn vortex_array::scalar_fn::SimplifyCtx) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::simplify_untyped(&self, options: &Self::Options, expr: &vortex_array::expr::Expression) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::stat_expression(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, stat: vortex_array::expr::stats::Stat, catalog: &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::stat_falsification(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, catalog: &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option + +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::validity(&self, options: &Self::Options, expression: &vortex_array::expr::Expression) -> vortex_error::VortexResult> + impl vortex_array::scalar_fn::ScalarFnVTable for vortex_array::scalar_fn::fns::cast::Cast pub type vortex_array::scalar_fn::fns::cast::Cast::Options = vortex_array::dtype::DType diff --git a/vortex-array/src/expr/exprs.rs b/vortex-array/src/expr/exprs.rs index fb99623ba70..bc30ba86ec4 100644 --- a/vortex-array/src/expr/exprs.rs +++ b/vortex-array/src/expr/exprs.rs @@ -6,6 +6,7 @@ use std::sync::Arc; use vortex_error::VortexExpect; +use vortex_error::vortex_panic; use vortex_utils::iter::ReduceBalancedIterExt; use crate::dtype::DType; @@ -20,6 +21,8 @@ use crate::scalar_fn::ScalarFnVTableExt; use crate::scalar_fn::fns::between::Between; use crate::scalar_fn::fns::between::BetweenOptions; use crate::scalar_fn::fns::binary::Binary; +use crate::scalar_fn::fns::case_when::CaseWhen; +use crate::scalar_fn::fns::case_when::CaseWhenOptions; use crate::scalar_fn::fns::cast::Cast; use crate::scalar_fn::fns::dynamic::DynamicComparison; use crate::scalar_fn::fns::dynamic::DynamicComparisonExpr; @@ -109,6 +112,60 @@ pub fn get_item(field: impl Into, child: Expression) -> Expression { GetItem.new_expr(field.into(), vec![child]) } +// ---- CaseWhen ---- + +/// Creates a CASE WHEN expression with one WHEN/THEN pair and an ELSE value. +pub fn case_when( + condition: Expression, + then_value: Expression, + else_value: Expression, +) -> Expression { + let options = CaseWhenOptions { + num_when_then_pairs: 1, + has_else: true, + }; + CaseWhen.new_expr(options, [condition, then_value, else_value]) +} + +/// Creates a CASE WHEN expression with one WHEN/THEN pair and no ELSE value. +pub fn case_when_no_else(condition: Expression, then_value: Expression) -> Expression { + let options = CaseWhenOptions { + num_when_then_pairs: 1, + has_else: false, + }; + CaseWhen.new_expr(options, [condition, then_value]) +} + +/// Creates an n-ary CASE WHEN expression from WHEN/THEN pairs and an optional ELSE value. +pub fn nested_case_when( + when_then_pairs: Vec<(Expression, Expression)>, + else_value: Option, +) -> Expression { + assert!( + !when_then_pairs.is_empty(), + "nested_case_when requires at least one when/then pair" + ); + + let has_else = else_value.is_some(); + let mut children = Vec::with_capacity(when_then_pairs.len() * 2 + usize::from(has_else)); + for (condition, then_value) in &when_then_pairs { + children.push(condition.clone()); + children.push(then_value.clone()); + } + if let Some(else_expr) = else_value { + children.push(else_expr); + } + + let Ok(num_when_then_pairs) = u32::try_from(when_then_pairs.len()) else { + vortex_panic!("nested_case_when has too many when/then pairs"); + }; + let options = CaseWhenOptions { + num_when_then_pairs, + has_else, + }; + CaseWhen.new_expr(options, children) +} + // ---- Binary operators ---- /// Create a new [`Binary`] using the [`Eq`](Operator::Eq) operator. diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs new file mode 100644 index 00000000000..f701548f145 --- /dev/null +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -0,0 +1,925 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! N-ary CASE WHEN expression for conditional value selection. + +use std::fmt; +use std::fmt::Formatter; +use std::hash::Hash; +use std::sync::Arc; + +use prost::Message; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_proto::expr as pb; +use vortex_session::VortexSession; + +use crate::ArrayRef; +use crate::ExecutionCtx; +use crate::IntoArray; +use crate::arrays::BoolArray; +use crate::arrays::ConstantArray; +use crate::dtype::DType; +use crate::expr::Expression; +use crate::scalar::Scalar; +use crate::scalar_fn::Arity; +use crate::scalar_fn::ChildName; +use crate::scalar_fn::ExecutionArgs; +use crate::scalar_fn::ScalarFnId; +use crate::scalar_fn::ScalarFnVTable; +use crate::scalar_fn::fns::zip::zip_impl; + +/// Options for the n-ary CaseWhen expression. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct CaseWhenOptions { + /// Number of WHEN/THEN pairs. + pub num_when_then_pairs: u32, + /// Whether an ELSE clause is present. + /// If false, unmatched rows return NULL. + pub has_else: bool, +} + +impl CaseWhenOptions { + /// Total number of child expressions: 2 per WHEN/THEN pair, plus 1 if ELSE is present. + pub fn num_children(&self) -> usize { + self.num_when_then_pairs as usize * 2 + usize::from(self.has_else) + } +} + +impl fmt::Display for CaseWhenOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!( + f, + "case_when(pairs={}, else={})", + self.num_when_then_pairs, self.has_else + ) + } +} + +/// An n-ary CASE WHEN expression. +/// +/// Children are in order: `[when_0, then_0, when_1, then_1, ..., else?]`. +#[derive(Clone)] +pub struct CaseWhen; + +impl ScalarFnVTable for CaseWhen { + type Options = CaseWhenOptions; + + fn id(&self) -> ScalarFnId { + ScalarFnId::from("vortex.case_when") + } + + fn serialize(&self, options: &Self::Options) -> VortexResult>> { + let num_children = options.num_when_then_pairs * 2 + u32::from(options.has_else); + Ok(Some(pb::CaseWhenOpts { num_children }.encode_to_vec())) + } + + fn deserialize( + &self, + metadata: &[u8], + _session: &VortexSession, + ) -> VortexResult { + let opts = pb::CaseWhenOpts::decode(metadata)?; + if opts.num_children < 2 { + vortex_bail!( + "CaseWhen expects at least 2 children, got {}", + opts.num_children + ); + } + Ok(CaseWhenOptions { + num_when_then_pairs: opts.num_children / 2, + has_else: opts.num_children % 2 == 1, + }) + } + + fn arity(&self, options: &Self::Options) -> Arity { + Arity::Exact(options.num_children()) + } + + fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName { + let num_pair_children = options.num_when_then_pairs as usize * 2; + if child_idx < num_pair_children { + let pair_idx = child_idx / 2; + if child_idx.is_multiple_of(2) { + ChildName::from(Arc::from(format!("when_{pair_idx}"))) + } else { + ChildName::from(Arc::from(format!("then_{pair_idx}"))) + } + } else if options.has_else && child_idx == num_pair_children { + ChildName::from("else") + } else { + unreachable!("Invalid child index {} for CaseWhen", child_idx) + } + } + + fn fmt_sql( + &self, + options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> fmt::Result { + write!(f, "CASE")?; + for i in 0..options.num_when_then_pairs as usize { + write!( + f, + " WHEN {} THEN {}", + expr.child(i * 2), + expr.child(i * 2 + 1) + )?; + } + if options.has_else { + let else_idx = options.num_when_then_pairs as usize * 2; + write!(f, " ELSE {}", expr.child(else_idx))?; + } + write!(f, " END") + } + + fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + if options.num_when_then_pairs == 0 { + vortex_bail!("CaseWhen must have at least one WHEN/THEN pair"); + } + + let expected_len = options.num_children(); + if arg_dtypes.len() != expected_len { + vortex_bail!( + "CaseWhen expects {expected_len} argument dtypes, got {}", + arg_dtypes.len() + ); + } + + // The return dtype is based on the first THEN expression (index 1). + // Validate all other THEN branches match and union their nullability. + let first_then = &arg_dtypes[1]; + let mut result_dtype = first_then.clone(); + + for i in 1..options.num_when_then_pairs as usize { + let then_i = &arg_dtypes[i * 2 + 1]; + if !first_then.eq_ignore_nullability(then_i) { + vortex_bail!( + "CaseWhen THEN dtypes must match (ignoring nullability), got {} and {}", + first_then, + then_i + ); + } + result_dtype = result_dtype.union_nullability(then_i.nullability()); + } + + if options.has_else { + let else_dtype = &arg_dtypes[options.num_when_then_pairs as usize * 2]; + if !first_then.eq_ignore_nullability(else_dtype) { + vortex_bail!( + "CaseWhen THEN and ELSE dtypes must match (ignoring nullability), got {} and {}", + first_then, + else_dtype + ); + } + result_dtype = result_dtype.union_nullability(else_dtype.nullability()); + } else { + // No ELSE means unmatched rows are NULL + result_dtype = result_dtype.as_nullable(); + } + + Ok(result_dtype) + } + + fn execute( + &self, + options: &Self::Options, + args: &dyn ExecutionArgs, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let row_count = args.row_count(); + let num_pairs = options.num_when_then_pairs as usize; + + let mut result: ArrayRef = if options.has_else { + args.get(num_pairs * 2)? + } else { + let then_dtype = args.get(1)?.dtype().as_nullable(); + ConstantArray::new(Scalar::null(then_dtype), row_count).into_array() + }; + + for i in (0..num_pairs).rev() { + let condition = args.get(i * 2)?; + let then_value = args.get(i * 2 + 1)?; + + let cond_bool = condition.execute::(ctx)?; + let mask = cond_bool.to_mask_fill_null_false(); + + if mask.all_true() { + result = then_value; + continue; + } + + if mask.all_false() { + continue; + } + + result = zip_impl(&then_value, &result, &mask)?; + } + + Ok(result) + } + + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { + // CaseWhen is null-sensitive because NULL conditions are treated as false + true + } + + fn is_fallible(&self, _options: &Self::Options) -> bool { + false + } +} + +#[cfg(test)] +mod tests { + use std::sync::LazyLock; + + use vortex_buffer::buffer; + use vortex_error::VortexExpect as _; + use vortex_session::VortexSession; + + use super::*; + use crate::Canonical; + use crate::IntoArray; + use crate::ToCanonical; + use crate::VortexSessionExecute as _; + use crate::arrays::BoolArray; + use crate::arrays::PrimitiveArray; + use crate::arrays::StructArray; + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::PType; + use crate::expr::case_when; + use crate::expr::case_when_no_else; + use crate::expr::col; + use crate::expr::eq; + use crate::expr::get_item; + use crate::expr::gt; + use crate::expr::lit; + use crate::expr::nested_case_when; + use crate::expr::root; + use crate::expr::test_harness; + use crate::scalar::Scalar; + use crate::session::ArraySession; + + static SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); + + /// Helper to evaluate an expression using the apply+execute pattern + fn evaluate_expr(expr: &Expression, array: &ArrayRef) -> ArrayRef { + let mut ctx = SESSION.create_execution_ctx(); + array + .apply(expr) + .unwrap() + .execute::(&mut ctx) + .unwrap() + .into_array() + } + + // ==================== Serialization Tests ==================== + + #[test] + fn test_serialization_roundtrip() { + let options = CaseWhenOptions { + num_when_then_pairs: 1, + has_else: true, + }; + let serialized = CaseWhen.serialize(&options).unwrap().unwrap(); + let deserialized = CaseWhen + .deserialize(&serialized, &VortexSession::empty()) + .unwrap(); + assert_eq!(options, deserialized); + } + + #[test] + fn test_serialization_no_else() { + let options = CaseWhenOptions { + num_when_then_pairs: 1, + has_else: false, + }; + let serialized = CaseWhen.serialize(&options).unwrap().unwrap(); + let deserialized = CaseWhen + .deserialize(&serialized, &VortexSession::empty()) + .unwrap(); + assert_eq!(options, deserialized); + } + + // ==================== Display Tests ==================== + + #[test] + fn test_display_with_else() { + let expr = case_when(gt(col("value"), lit(0i32)), lit(100i32), lit(0i32)); + let display = format!("{}", expr); + assert!(display.contains("CASE")); + assert!(display.contains("WHEN")); + assert!(display.contains("THEN")); + assert!(display.contains("ELSE")); + assert!(display.contains("END")); + } + + #[test] + fn test_display_no_else() { + let expr = case_when_no_else(gt(col("value"), lit(0i32)), lit(100i32)); + let display = format!("{}", expr); + assert!(display.contains("CASE")); + assert!(display.contains("WHEN")); + assert!(display.contains("THEN")); + assert!(!display.contains("ELSE")); + assert!(display.contains("END")); + } + + #[test] + fn test_display_nested_nary() { + // CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'medium' ELSE 'low' END + let expr = nested_case_when( + vec![ + (gt(col("x"), lit(10i32)), lit("high")), + (gt(col("x"), lit(5i32)), lit("medium")), + ], + Some(lit("low")), + ); + let display = format!("{}", expr); + assert_eq!(display.matches("CASE").count(), 1); + assert_eq!(display.matches("WHEN").count(), 2); + assert_eq!(display.matches("THEN").count(), 2); + } + + // ==================== DType Tests ==================== + + #[test] + fn test_return_dtype_with_else() { + let expr = case_when(lit(true), lit(100i32), lit(0i32)); + let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let result_dtype = expr.return_dtype(&input_dtype).unwrap(); + assert_eq!( + result_dtype, + DType::Primitive(PType::I32, Nullability::NonNullable) + ); + } + + #[test] + fn test_return_dtype_with_nullable_else() { + let expr = case_when( + lit(true), + lit(100i32), + lit(Scalar::null(DType::Primitive( + PType::I32, + Nullability::Nullable, + ))), + ); + let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let result_dtype = expr.return_dtype(&input_dtype).unwrap(); + assert_eq!( + result_dtype, + DType::Primitive(PType::I32, Nullability::Nullable) + ); + } + + #[test] + fn test_return_dtype_without_else_is_nullable() { + let expr = case_when_no_else(lit(true), lit(100i32)); + let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let result_dtype = expr.return_dtype(&input_dtype).unwrap(); + assert_eq!( + result_dtype, + DType::Primitive(PType::I32, Nullability::Nullable) + ); + } + + #[test] + fn test_return_dtype_with_struct_input() { + let dtype = test_harness::struct_dtype(); + let expr = case_when( + gt(get_item("col1", root()), lit(10u16)), + lit(100i32), + lit(0i32), + ); + let result_dtype = expr.return_dtype(&dtype).unwrap(); + assert_eq!( + result_dtype, + DType::Primitive(PType::I32, Nullability::NonNullable) + ); + } + + #[test] + fn test_return_dtype_mismatched_then_else_errors() { + let expr = case_when(lit(true), lit(100i32), lit("zero")); + let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let err = expr.return_dtype(&input_dtype).unwrap_err(); + assert!( + err.to_string() + .contains("THEN and ELSE dtypes must match (ignoring nullability)") + ); + } + + // ==================== Arity Tests ==================== + + #[test] + fn test_arity_with_else() { + let options = CaseWhenOptions { + num_when_then_pairs: 1, + has_else: true, + }; + assert_eq!(CaseWhen.arity(&options), Arity::Exact(3)); + } + + #[test] + fn test_arity_without_else() { + let options = CaseWhenOptions { + num_when_then_pairs: 1, + has_else: false, + }; + assert_eq!(CaseWhen.arity(&options), Arity::Exact(2)); + } + + // ==================== Child Name Tests ==================== + + #[test] + fn test_child_names() { + let options = CaseWhenOptions { + num_when_then_pairs: 1, + has_else: true, + }; + assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0"); + assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0"); + assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "else"); + } + + // ==================== N-ary Serialization Tests ==================== + + #[test] + fn test_serialization_roundtrip_nary() { + let options = CaseWhenOptions { + num_when_then_pairs: 3, + has_else: true, + }; + let serialized = CaseWhen.serialize(&options).unwrap().unwrap(); + let deserialized = CaseWhen + .deserialize(&serialized, &VortexSession::empty()) + .unwrap(); + assert_eq!(options, deserialized); + } + + #[test] + fn test_serialization_roundtrip_nary_no_else() { + let options = CaseWhenOptions { + num_when_then_pairs: 4, + has_else: false, + }; + let serialized = CaseWhen.serialize(&options).unwrap().unwrap(); + let deserialized = CaseWhen + .deserialize(&serialized, &VortexSession::empty()) + .unwrap(); + assert_eq!(options, deserialized); + } + + // ==================== N-ary Arity Tests ==================== + + #[test] + fn test_arity_nary_with_else() { + let options = CaseWhenOptions { + num_when_then_pairs: 3, + has_else: true, + }; + // 3 pairs * 2 children + 1 else = 7 + assert_eq!(CaseWhen.arity(&options), Arity::Exact(7)); + } + + #[test] + fn test_arity_nary_without_else() { + let options = CaseWhenOptions { + num_when_then_pairs: 3, + has_else: false, + }; + // 3 pairs * 2 children = 6 + assert_eq!(CaseWhen.arity(&options), Arity::Exact(6)); + } + + // ==================== N-ary Child Name Tests ==================== + + #[test] + fn test_child_names_nary() { + let options = CaseWhenOptions { + num_when_then_pairs: 3, + has_else: true, + }; + assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0"); + assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0"); + assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "when_1"); + assert_eq!(CaseWhen.child_name(&options, 3).to_string(), "then_1"); + assert_eq!(CaseWhen.child_name(&options, 4).to_string(), "when_2"); + assert_eq!(CaseWhen.child_name(&options, 5).to_string(), "then_2"); + assert_eq!(CaseWhen.child_name(&options, 6).to_string(), "else"); + } + + // ==================== N-ary DType Tests ==================== + + #[test] + fn test_return_dtype_nary_mismatched_then_types_errors() { + let expr = nested_case_when( + vec![(lit(true), lit(100i32)), (lit(false), lit("oops"))], + Some(lit(0i32)), + ); + let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let err = expr.return_dtype(&input_dtype).unwrap_err(); + assert!(err.to_string().contains("THEN dtypes must match")); + } + + #[test] + fn test_return_dtype_nary_mixed_nullability() { + // When some THEN branches are nullable and others are not, + // the result should be nullable (union of nullabilities). + let non_null_then = lit(100i32); + let nullable_then = lit(Scalar::null(DType::Primitive( + PType::I32, + Nullability::Nullable, + ))); + let expr = nested_case_when( + vec![(lit(true), non_null_then), (lit(false), nullable_then)], + Some(lit(0i32)), + ); + let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let result = expr.return_dtype(&input_dtype).unwrap(); + assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable)); + } + + #[test] + fn test_return_dtype_nary_no_else_is_nullable() { + let expr = nested_case_when( + vec![(lit(true), lit(10i32)), (lit(false), lit(20i32))], + None, + ); + let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let result = expr.return_dtype(&input_dtype).unwrap(); + assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable)); + } + + // ==================== Expression Manipulation Tests ==================== + + #[test] + fn test_replace_children() { + let expr = case_when(lit(true), lit(1i32), lit(0i32)); + expr.with_children([lit(false), lit(2i32), lit(3i32)]) + .vortex_expect("operation should succeed in test"); + } + + // ==================== Evaluate Tests ==================== + + #[test] + fn test_evaluate_simple_condition() { + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())]) + .unwrap() + .into_array(); + + let expr = case_when( + gt(get_item("value", root()), lit(2i32)), + lit(100i32), + lit(0i32), + ); + + let result = evaluate_expr(&expr, &test_array).to_primitive(); + assert_eq!(result.as_slice::(), &[0, 0, 100, 100, 100]); + } + + #[test] + fn test_evaluate_nary_multiple_conditions() { + // Test n-ary via nested_case_when + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())]) + .unwrap() + .into_array(); + + let expr = nested_case_when( + vec![ + (eq(get_item("value", root()), lit(1i32)), lit(10i32)), + (eq(get_item("value", root()), lit(3i32)), lit(30i32)), + ], + Some(lit(0i32)), + ); + + let result = evaluate_expr(&expr, &test_array).to_primitive(); + assert_eq!(result.as_slice::(), &[10, 0, 30, 0, 0]); + } + + #[test] + fn test_evaluate_nary_first_match_wins() { + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())]) + .unwrap() + .into_array(); + + // Both conditions match for values > 3, but first one wins + let expr = nested_case_when( + vec![ + (gt(get_item("value", root()), lit(2i32)), lit(100i32)), + (gt(get_item("value", root()), lit(3i32)), lit(200i32)), + ], + Some(lit(0i32)), + ); + + let result = evaluate_expr(&expr, &test_array).to_primitive(); + assert_eq!(result.as_slice::(), &[0, 0, 100, 100, 100]); + } + + #[test] + fn test_evaluate_no_else_returns_null() { + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())]) + .unwrap() + .into_array(); + + let expr = case_when_no_else(gt(get_item("value", root()), lit(3i32)), lit(100i32)); + + let result = evaluate_expr(&expr, &test_array); + assert!(result.dtype().is_nullable()); + + assert_eq!( + result.scalar_at(0).unwrap(), + Scalar::null(result.dtype().clone()) + ); + assert_eq!( + result.scalar_at(1).unwrap(), + Scalar::null(result.dtype().clone()) + ); + assert_eq!( + result.scalar_at(2).unwrap(), + Scalar::null(result.dtype().clone()) + ); + assert_eq!( + result.scalar_at(3).unwrap(), + Scalar::from(100i32).cast(result.dtype()).unwrap() + ); + assert_eq!( + result.scalar_at(4).unwrap(), + Scalar::from(100i32).cast(result.dtype()).unwrap() + ); + } + + #[test] + fn test_evaluate_all_conditions_false() { + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())]) + .unwrap() + .into_array(); + + let expr = case_when( + gt(get_item("value", root()), lit(100i32)), + lit(1i32), + lit(0i32), + ); + + let result = evaluate_expr(&expr, &test_array).to_primitive(); + assert_eq!(result.as_slice::(), &[0, 0, 0, 0, 0]); + } + + #[test] + fn test_evaluate_all_conditions_true() { + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())]) + .unwrap() + .into_array(); + + let expr = case_when( + gt(get_item("value", root()), lit(0i32)), + lit(100i32), + lit(0i32), + ); + + let result = evaluate_expr(&expr, &test_array).to_primitive(); + assert_eq!(result.as_slice::(), &[100, 100, 100, 100, 100]); + } + + #[test] + fn test_evaluate_with_literal_condition() { + let test_array = buffer![1i32, 2, 3].into_array(); + let expr = case_when(lit(true), lit(100i32), lit(0i32)); + let result = evaluate_expr(&expr, &test_array); + + if let Some(constant) = result.as_constant() { + assert_eq!(constant, Scalar::from(100i32)); + } else { + let prim = result.to_primitive(); + assert_eq!(prim.as_slice::(), &[100, 100, 100]); + } + } + + #[test] + fn test_evaluate_with_bool_column_result() { + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())]) + .unwrap() + .into_array(); + + let expr = case_when( + gt(get_item("value", root()), lit(2i32)), + lit(true), + lit(false), + ); + + let result = evaluate_expr(&expr, &test_array).to_bool(); + assert_eq!( + result.to_bit_buffer().iter().collect::>(), + vec![false, false, true, true, true] + ); + } + + #[test] + fn test_evaluate_with_nullable_condition() { + let test_array = StructArray::from_fields(&[( + "cond", + BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)]).into_array(), + )]) + .unwrap() + .into_array(); + + let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32)); + + let result = evaluate_expr(&expr, &test_array).to_primitive(); + assert_eq!(result.as_slice::(), &[100, 0, 0, 0, 100]); + } + + #[test] + fn test_evaluate_with_nullable_result_values() { + let test_array = StructArray::from_fields(&[ + ("value", buffer![1i32, 2, 3, 4, 5].into_array()), + ( + "result", + PrimitiveArray::from_option_iter([Some(10), None, Some(30), Some(40), Some(50)]) + .into_array(), + ), + ]) + .unwrap() + .into_array(); + + let expr = case_when( + gt(get_item("value", root()), lit(2i32)), + get_item("result", root()), + lit(0i32), + ); + + let result = evaluate_expr(&expr, &test_array); + let prim = result.to_primitive(); + assert_eq!(prim.as_slice::(), &[0, 0, 30, 40, 50]); + } + + #[test] + fn test_evaluate_with_all_null_condition() { + let test_array = StructArray::from_fields(&[( + "cond", + BoolArray::from_iter([None, None, None]).into_array(), + )]) + .unwrap() + .into_array(); + + let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32)); + + let result = evaluate_expr(&expr, &test_array).to_primitive(); + assert_eq!(result.as_slice::(), &[0, 0, 0]); + } + + // ==================== N-ary Evaluate Tests ==================== + + #[test] + fn test_evaluate_nary_no_else_returns_null() { + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())]) + .unwrap() + .into_array(); + + // Two conditions, no ELSE — unmatched rows should be NULL + let expr = nested_case_when( + vec![ + (eq(get_item("value", root()), lit(1i32)), lit(10i32)), + (eq(get_item("value", root()), lit(3i32)), lit(30i32)), + ], + None, + ); + + let result = evaluate_expr(&expr, &test_array); + assert!(result.dtype().is_nullable()); + + assert_eq!( + result.scalar_at(0).unwrap(), + Scalar::from(10i32).cast(result.dtype()).unwrap() + ); + assert_eq!( + result.scalar_at(1).unwrap(), + Scalar::null(result.dtype().clone()) + ); + assert_eq!( + result.scalar_at(2).unwrap(), + Scalar::from(30i32).cast(result.dtype()).unwrap() + ); + assert_eq!( + result.scalar_at(3).unwrap(), + Scalar::null(result.dtype().clone()) + ); + assert_eq!( + result.scalar_at(4).unwrap(), + Scalar::null(result.dtype().clone()) + ); + } + + #[test] + fn test_evaluate_nary_many_conditions() { + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())]) + .unwrap() + .into_array(); + + // 5 WHEN/THEN pairs: each value maps to its value * 10 + let expr = nested_case_when( + vec![ + (eq(get_item("value", root()), lit(1i32)), lit(10i32)), + (eq(get_item("value", root()), lit(2i32)), lit(20i32)), + (eq(get_item("value", root()), lit(3i32)), lit(30i32)), + (eq(get_item("value", root()), lit(4i32)), lit(40i32)), + (eq(get_item("value", root()), lit(5i32)), lit(50i32)), + ], + Some(lit(0i32)), + ); + + let result = evaluate_expr(&expr, &test_array).to_primitive(); + assert_eq!(result.as_slice::(), &[10, 20, 30, 40, 50]); + } + + #[test] + fn test_evaluate_nary_all_false_no_else() { + let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())]) + .unwrap() + .into_array(); + + // All conditions are false, no ELSE — everything should be NULL + let expr = nested_case_when( + vec![ + (gt(get_item("value", root()), lit(100i32)), lit(10i32)), + (gt(get_item("value", root()), lit(200i32)), lit(20i32)), + ], + None, + ); + + let result = evaluate_expr(&expr, &test_array); + assert!(result.dtype().is_nullable()); + for i in 0..3 { + assert_eq!( + result.scalar_at(i).unwrap(), + Scalar::null(result.dtype().clone()) + ); + } + } + + #[test] + fn test_evaluate_nary_overlapping_conditions_first_wins() { + let test_array = + StructArray::from_fields(&[("value", buffer![10i32, 20, 30].into_array())]) + .unwrap() + .into_array(); + + // value=10: matches cond1 (>5) and cond2 (>0), first should win + // value=20: matches all three, first should win + // value=30: matches all three, first should win + let expr = nested_case_when( + vec![ + (gt(get_item("value", root()), lit(5i32)), lit(1i32)), + (gt(get_item("value", root()), lit(0i32)), lit(2i32)), + (gt(get_item("value", root()), lit(15i32)), lit(3i32)), + ], + Some(lit(0i32)), + ); + + let result = evaluate_expr(&expr, &test_array).to_primitive(); + // First matching condition always wins + assert_eq!(result.as_slice::(), &[1, 1, 1]); + } + + #[test] + fn test_evaluate_nary_with_nullable_conditions() { + let test_array = StructArray::from_fields(&[ + ( + "cond1", + BoolArray::from_iter([Some(true), None, Some(false)]).into_array(), + ), + ( + "cond2", + BoolArray::from_iter([Some(false), Some(true), None]).into_array(), + ), + ]) + .unwrap() + .into_array(); + + let expr = nested_case_when( + vec![ + (get_item("cond1", root()), lit(10i32)), + (get_item("cond2", root()), lit(20i32)), + ], + Some(lit(0i32)), + ); + + let result = evaluate_expr(&expr, &test_array).to_primitive(); + // row 0: cond1=true → 10 + // row 1: cond1=NULL(→false), cond2=true → 20 + // row 2: cond1=false, cond2=NULL(→false) → else=0 + assert_eq!(result.as_slice::(), &[10, 20, 0]); + } +} diff --git a/vortex-array/src/scalar_fn/fns/mod.rs b/vortex-array/src/scalar_fn/fns/mod.rs index 95c66b09ef6..94fc8fb0384 100644 --- a/vortex-array/src/scalar_fn/fns/mod.rs +++ b/vortex-array/src/scalar_fn/fns/mod.rs @@ -3,6 +3,7 @@ pub mod between; pub mod binary; +pub mod case_when; pub mod cast; pub mod dynamic; pub mod fill_null; diff --git a/vortex-datafusion/src/convert/exprs.rs b/vortex-datafusion/src/convert/exprs.rs index b0380db5cb3..455a5c65a71 100644 --- a/vortex-datafusion/src/convert/exprs.rs +++ b/vortex-datafusion/src/convert/exprs.rs @@ -29,6 +29,7 @@ use vortex::expr::get_item; use vortex::expr::is_null; use vortex::expr::list_contains; use vortex::expr::lit; +use vortex::expr::nested_case_when; use vortex::expr::not; use vortex::expr::pack; use vortex::expr::root; @@ -144,6 +145,45 @@ impl DefaultExpressionConvertor { scalar_fn.name() )) } + + /// Attempts to convert a DataFusion CaseExpr to a Vortex expression. + fn try_convert_case_expr(&self, case_expr: &df_expr::CaseExpr) -> DFResult { + // DataFusion CaseExpr has: + // - expr(): Optional base expression (for "CASE expr WHEN ..." form) + // - when_then_expr(): Vec of (when, then) pairs + // - else_expr(): Optional else expression + + // We don't support the "CASE expr WHEN value1 THEN result1" form yet + if case_expr.expr().is_some() { + return Err(exec_datafusion_err!( + "CASE expr WHEN form is not yet supported, only searched CASE is supported" + )); + } + + let when_then_pairs = case_expr.when_then_expr(); + if when_then_pairs.is_empty() { + return Err(exec_datafusion_err!( + "CASE expression must have at least one WHEN clause" + )); + } + + // Convert all when/then pairs to (condition, value) tuples + let mut pairs = Vec::with_capacity(when_then_pairs.len()); + for (when_expr, then_expr) in when_then_pairs { + let condition = self.convert(when_expr.as_ref())?; + let value = self.convert(then_expr.as_ref())?; + pairs.push((condition, value)); + } + + // Convert optional else expression + let else_value = case_expr + .else_expr() + .map(|e| self.convert(e.as_ref())) + .transpose()?; + + // Build a single n-ary CASE WHEN expression from DataFusion WHEN/THEN pairs + Ok(nested_case_when(pairs, else_value)) + } } impl ExpressionConvertor for DefaultExpressionConvertor { @@ -235,6 +275,10 @@ impl ExpressionConvertor for DefaultExpressionConvertor { return self.try_convert_scalar_function(scalar_fn); } + if let Some(case_expr) = df.as_any().downcast_ref::() { + return self.try_convert_case_expr(case_expr); + } + Err(exec_datafusion_err!( "Couldn't convert DataFusion physical {df} expression to a vortex expression" )) @@ -380,10 +424,12 @@ fn can_be_pushed_down_impl(df_expr: &Arc, schema: &Schema) -> && can_be_pushed_down_impl(like.pattern(), schema) } else if let Some(lit) = expr.downcast_ref::() { supported_data_types(&lit.value().data_type()) - } else if expr.downcast_ref::().is_some() - || expr.downcast_ref::().is_some() - { - true + } else if let Some(cast_expr) = expr.downcast_ref::() { + // CastExpr child must be an expression type that convert() can handle + is_convertible_expr(cast_expr.expr()) + } else if let Some(cast_col_expr) = expr.downcast_ref::() { + // CastColumnExpr child must be an expression type that convert() can handle + is_convertible_expr(cast_col_expr.expr()) } else if let Some(is_null) = expr.downcast_ref::() { can_be_pushed_down_impl(is_null.arg(), schema) } else if let Some(is_not_null) = expr.downcast_ref::() { @@ -396,12 +442,39 @@ fn can_be_pushed_down_impl(df_expr: &Arc, schema: &Schema) -> .all(|e| can_be_pushed_down_impl(e, schema)) } else if let Some(scalar_fn) = expr.downcast_ref::() { can_scalar_fn_be_pushed_down(scalar_fn) + } else if let Some(case_expr) = expr.downcast_ref::() { + can_case_be_pushed_down(case_expr, schema) } else { tracing::debug!(%df_expr, "DataFusion expression can't be pushed down"); false } } +/// Checks if an expression type is one that convert() can handle. +/// This is less restrictive than can_be_pushed_down since it only checks +/// expression types, not data type support. +fn is_convertible_expr(df_expr: &Arc) -> bool { + let expr = df_expr.as_any(); + + // Expression types that convert() handles + expr.downcast_ref::().is_some() + || expr.downcast_ref::().is_some() + || expr.downcast_ref::().is_some() + || expr.downcast_ref::().is_some() + || expr + .downcast_ref::() + .is_some_and(|e| is_convertible_expr(e.expr())) + || expr + .downcast_ref::() + .is_some_and(|e| is_convertible_expr(e.expr())) + || expr.downcast_ref::().is_some() + || expr.downcast_ref::().is_some() + || expr.downcast_ref::().is_some() + || expr + .downcast_ref::() + .is_some_and(|sf| ScalarFunctionExpr::try_downcast_func::(sf).is_some()) +} + fn can_binary_be_pushed_down(binary: &df_expr::BinaryExpr, schema: &Schema) -> bool { let is_op_supported = try_operator_from_df(binary.op()).is_ok(); is_op_supported @@ -409,6 +482,32 @@ fn can_binary_be_pushed_down(binary: &df_expr::BinaryExpr, schema: &Schema) -> b && can_be_pushed_down_impl(binary.right(), schema) } +fn can_case_be_pushed_down(case_expr: &df_expr::CaseExpr, schema: &Schema) -> bool { + // We only support the "searched CASE" form (CASE WHEN cond THEN result ...) + // not the "simple CASE" form (CASE expr WHEN value THEN result ...) + if case_expr.expr().is_some() { + return false; + } + + // Check all when/then pairs + for (when_expr, then_expr) in case_expr.when_then_expr() { + if !can_be_pushed_down_impl(when_expr, schema) + || !can_be_pushed_down_impl(then_expr, schema) + { + return false; + } + } + + // Check the optional else clause + if let Some(else_expr) = case_expr.else_expr() + && !can_be_pushed_down_impl(else_expr, schema) + { + return false; + } + + true +} + fn supported_data_types(dt: &DataType) -> bool { use DataType::*; @@ -442,7 +541,8 @@ fn supported_data_types(dt: &DataType) -> bool { is_supported } -/// Checks if a GetField scalar function can be pushed down. +/// Checks if a scalar function can be pushed down. +/// Currently only GetFieldFunc is supported. fn can_scalar_fn_be_pushed_down(scalar_fn: &ScalarFunctionExpr) -> bool { ScalarFunctionExpr::try_downcast_func::(scalar_fn).is_some() } @@ -811,4 +911,96 @@ mod tests { Ok(()) } + + /// Test that applying a CASE expression to an Arrow RecordBatch using DataFusion + /// matches the result of applying the converted Vortex expression. + #[test] + fn test_case_when_datafusion_vortex_equivalence() { + use datafusion::arrow::array::Int32Array; + use datafusion::arrow::array::RecordBatch; + use datafusion_physical_expr::expressions::CaseExpr; + use vortex::VortexSessionDefault; + use vortex::array::ArrayRef; + use vortex::array::Canonical; + use vortex::array::VortexSessionExecute as _; + use vortex::array::arrow::FromArrowArray; + use vortex::session::VortexSession; + + // Create test data + let values = Arc::new(Int32Array::from(vec![1, 5, 10, 15, 20])); + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int32, + false, + )])); + let batch = RecordBatch::try_new(schema, vec![values]).unwrap(); + + // Build a DataFusion CASE expression: + // CASE WHEN value > 10 THEN 100 WHEN value > 5 THEN 50 ELSE 0 END + let col_value = Arc::new(df_expr::Column::new("value", 0)) as Arc; + let lit_10 = + Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(10)))) as Arc; + let lit_5 = + Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(5)))) as Arc; + let lit_100 = + Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(100)))) as Arc; + let lit_50 = + Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(50)))) as Arc; + let lit_0 = + Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(0)))) as Arc; + + // WHEN value > 10 THEN 100 + let when1 = Arc::new(df_expr::BinaryExpr::new( + col_value.clone(), + DFOperator::Gt, + lit_10, + )) as Arc; + // WHEN value > 5 THEN 50 + let when2 = Arc::new(df_expr::BinaryExpr::new(col_value, DFOperator::Gt, lit_5)) + as Arc; + + let case_expr = + CaseExpr::try_new(None, vec![(when1, lit_100), (when2, lit_50)], Some(lit_0)).unwrap(); + + // Apply DataFusion expression + let df_result = case_expr.evaluate(&batch).unwrap(); + let df_array = df_result.into_array(batch.num_rows()).unwrap(); + + // Convert to Vortex expression + let expr_convertor = DefaultExpressionConvertor::default(); + let vortex_expr = expr_convertor.try_convert_case_expr(&case_expr).unwrap(); + + // Convert batch to Vortex array + let vortex_array: ArrayRef = ArrayRef::from_arrow(&batch, false).unwrap(); + + // Apply Vortex expression + let session = VortexSession::default(); + let mut ctx = session.create_execution_ctx(); + let vortex_result = vortex_array + .apply(&vortex_expr) + .unwrap() + .execute::(&mut ctx) + .unwrap(); + + // Convert back to Arrow for comparison + let vortex_as_arrow = vortex_result.into_primitive().as_slice::().to_vec(); + + // Convert DataFusion result to Vec for comparison + let df_as_arrow: Vec = df_array + .as_any() + .downcast_ref::() + .unwrap() + .values() + .to_vec(); + + // Compare results + // Expected: [0, 0, 50, 100, 100] for values [1, 5, 10, 15, 20] + // value=1: not > 10, not > 5 -> ELSE 0 + // value=5: not > 10, not > 5 -> ELSE 0 + // value=10: not > 10, > 5 -> 50 + // value=15: > 10 -> 100 + // value=20: > 10 -> 100 + assert_eq!(df_as_arrow, vec![0, 0, 50, 100, 100]); + assert_eq!(vortex_as_arrow, df_as_arrow); + } } diff --git a/vortex-proto/proto/expr.proto b/vortex-proto/proto/expr.proto index 4540bce0a63..3b47db2a756 100644 --- a/vortex-proto/proto/expr.proto +++ b/vortex-proto/proto/expr.proto @@ -80,3 +80,12 @@ message SelectOpts { FieldNames exclude = 2; } } + +// Options for `vortex.case_when` +// Encodes num_when_then_pairs and has_else into a single u32 (num_children). +// num_children = num_when_then_pairs * 2 + (has_else ? 1 : 0) +// has_else = num_children % 2 == 1 +// num_when_then_pairs = num_children / 2 +message CaseWhenOpts { + uint32 num_children = 1; +} diff --git a/vortex-proto/public-api.lock b/vortex-proto/public-api.lock index 303fa403401..1f6a2f409e7 100644 --- a/vortex-proto/public-api.lock +++ b/vortex-proto/public-api.lock @@ -810,6 +810,42 @@ pub fn vortex_proto::expr::BinaryOpts::clear(&mut self) pub fn vortex_proto::expr::BinaryOpts::encoded_len(&self) -> usize +pub struct vortex_proto::expr::CaseWhenOpts + +pub vortex_proto::expr::CaseWhenOpts::num_children: u32 + +impl core::clone::Clone for vortex_proto::expr::CaseWhenOpts + +pub fn vortex_proto::expr::CaseWhenOpts::clone(&self) -> vortex_proto::expr::CaseWhenOpts + +impl core::cmp::Eq for vortex_proto::expr::CaseWhenOpts + +impl core::cmp::PartialEq for vortex_proto::expr::CaseWhenOpts + +pub fn vortex_proto::expr::CaseWhenOpts::eq(&self, other: &vortex_proto::expr::CaseWhenOpts) -> bool + +impl core::default::Default for vortex_proto::expr::CaseWhenOpts + +pub fn vortex_proto::expr::CaseWhenOpts::default() -> Self + +impl core::fmt::Debug for vortex_proto::expr::CaseWhenOpts + +pub fn vortex_proto::expr::CaseWhenOpts::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_proto::expr::CaseWhenOpts + +pub fn vortex_proto::expr::CaseWhenOpts::hash<__H: core::hash::Hasher>(&self, state: &mut __H) + +impl core::marker::Copy for vortex_proto::expr::CaseWhenOpts + +impl core::marker::StructuralPartialEq for vortex_proto::expr::CaseWhenOpts + +impl prost::message::Message for vortex_proto::expr::CaseWhenOpts + +pub fn vortex_proto::expr::CaseWhenOpts::clear(&mut self) + +pub fn vortex_proto::expr::CaseWhenOpts::encoded_len(&self) -> usize + pub struct vortex_proto::expr::CastOpts pub vortex_proto::expr::CastOpts::target: core::option::Option diff --git a/vortex-proto/src/generated/vortex.expr.rs b/vortex-proto/src/generated/vortex.expr.rs index f3b6d2cf624..9bc61475e59 100644 --- a/vortex-proto/src/generated/vortex.expr.rs +++ b/vortex-proto/src/generated/vortex.expr.rs @@ -145,3 +145,13 @@ pub mod select_opts { Exclude(super::FieldNames), } } +/// Options for `vortex.case_when` +/// Encodes num_when_then_pairs and has_else into a single u32 (num_children). +/// num_children = num_when_then_pairs * 2 + (has_else ? 1 : 0) +/// has_else = num_children % 2 == 1 +/// num_when_then_pairs = num_children / 2 +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] +pub struct CaseWhenOpts { + #[prost(uint32, tag = "1")] + pub num_children: u32, +}