Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vortex-array/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ name = "expr_large_struct_pack"
path = "benches/expr/large_struct_pack.rs"
harness = false

[[bench]]
name = "expr_case_when"
path = "benches/expr/case_when_bench.rs"
harness = false

[[bench]]
name = "chunked_dict_builder"
harness = false
Expand Down
210 changes: 210 additions & 0 deletions vortex-array/benches/expr/case_when_bench.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

#![allow(clippy::unwrap_used)]
#![allow(clippy::cast_possible_truncation)]

use std::sync::LazyLock;

use divan::Bencher;
use vortex_array::ArrayRef;
use vortex_array::Canonical;
use vortex_array::IntoArray;
use vortex_array::VortexSessionExecute;
use vortex_array::arrays::StructArray;
use vortex_array::expr::case_when;
use vortex_array::expr::case_when_no_else;
use vortex_array::expr::eq;
use vortex_array::expr::get_item;
use vortex_array::expr::gt;
use vortex_array::expr::lit;
use vortex_array::expr::nested_case_when;
use vortex_array::expr::root;
use vortex_array::session::ArraySession;
use vortex_buffer::Buffer;
use vortex_session::VortexSession;

static SESSION: LazyLock<VortexSession> =
LazyLock::new(|| VortexSession::empty().with::<ArraySession>());

fn main() {
divan::main();
}

fn make_struct_array(size: usize) -> ArrayRef {
let data: Buffer<i32> = (0..size as i32).collect();
let field = data.into_array();
StructArray::from_fields(&[("value", field)])
.unwrap()
.into_array()
}

/// Benchmark a simple binary CASE WHEN with varying array sizes.
#[divan::bench(args = [1000, 10000, 100000])]
fn case_when_simple(bencher: Bencher, size: usize) {
let array = make_struct_array(size);

// CASE WHEN value > 500 THEN 100 ELSE 0 END
let expr = case_when(
gt(get_item("value", root()), lit(500i32)),
lit(100i32),
lit(0i32),
);

bencher
.with_inputs(|| (&expr, &array))
.bench_refs(|(expr, array)| {
let mut ctx = SESSION.create_execution_ctx();
array
.apply(expr)
.unwrap()
.execute::<Canonical>(&mut ctx)
.unwrap()
});
}

/// Benchmark n-ary CASE WHEN with 3 conditions.
#[divan::bench(args = [1000, 10000, 100000])]
fn case_when_nary_3_conditions(bencher: Bencher, size: usize) {
let array = make_struct_array(size);

// CASE WHEN value > 750 THEN 3 WHEN value > 500 THEN 2 WHEN value > 250 THEN 1 ELSE 0 END
let expr = nested_case_when(
vec![
(gt(get_item("value", root()), lit(750i32)), lit(3i32)),
(gt(get_item("value", root()), lit(500i32)), lit(2i32)),
(gt(get_item("value", root()), lit(250i32)), lit(1i32)),
],
Some(lit(0i32)),
);

bencher
.with_inputs(|| (&expr, &array))
.bench_refs(|(expr, array)| {
let mut ctx = SESSION.create_execution_ctx();
array
.apply(expr)
.unwrap()
.execute::<Canonical>(&mut ctx)
.unwrap()
});
}

/// Benchmark n-ary CASE WHEN with 10 conditions.
#[divan::bench(args = [1000, 10000, 100000])]
fn case_when_nary_10_conditions(bencher: Bencher, size: usize) {
let array = make_struct_array(size);

let pairs: Vec<_> = (0..10)
.map(|i| {
let threshold = (i + 1) * (size as i32 / 10);
(
gt(get_item("value", root()), lit(threshold)),
lit((i + 1) * 100),
)
})
.collect();
let expr = nested_case_when(pairs, Some(lit(0i32)));

bencher
.with_inputs(|| (&expr, &array))
.bench_refs(|(expr, array)| {
let mut ctx = SESSION.create_execution_ctx();
array
.apply(expr)
.unwrap()
.execute::<Canonical>(&mut ctx)
.unwrap()
});
}

/// Benchmark n-ary CASE WHEN with equality conditions (lookup-table style).
#[divan::bench(args = [1000, 10000, 100000])]
fn case_when_nary_equality_lookup(bencher: Bencher, size: usize) {
let array = make_struct_array(size);

// Map specific values: CASE WHEN value = 0 THEN 'a' WHEN value = 1 THEN 'b' ... ELSE 'other' END
let pairs: Vec<_> = (0..5)
.map(|i| (eq(get_item("value", root()), lit(i)), lit(i * 10)))
.collect();
let expr = nested_case_when(pairs, Some(lit(-1i32)));

bencher
.with_inputs(|| (&expr, &array))
.bench_refs(|(expr, array)| {
let mut ctx = SESSION.create_execution_ctx();
array
.apply(expr)
.unwrap()
.execute::<Canonical>(&mut ctx)
.unwrap()
});
}

/// Benchmark CASE WHEN without ELSE clause (result is nullable).
#[divan::bench(args = [1000, 10000, 100000])]
fn case_when_without_else(bencher: Bencher, size: usize) {
let array = make_struct_array(size);

// CASE WHEN value > 500 THEN 100 END
let expr = case_when_no_else(gt(get_item("value", root()), lit(500i32)), lit(100i32));

bencher
.with_inputs(|| (&expr, &array))
.bench_refs(|(expr, array)| {
let mut ctx = SESSION.create_execution_ctx();
array
.apply(expr)
.unwrap()
.execute::<Canonical>(&mut ctx)
.unwrap()
});
}

/// Benchmark CASE WHEN where all conditions are true.
#[divan::bench(args = [1000, 10000, 100000])]
fn case_when_all_true(bencher: Bencher, size: usize) {
let array = make_struct_array(size);

// CASE WHEN value >= 0 THEN 100 ELSE 0 END (always true for our data)
let expr = case_when(
gt(get_item("value", root()), lit(-1i32)),
lit(100i32),
lit(0i32),
);

bencher
.with_inputs(|| (&expr, &array))
.bench_refs(|(expr, array)| {
let mut ctx = SESSION.create_execution_ctx();
array
.apply(expr)
.unwrap()
.execute::<Canonical>(&mut ctx)
.unwrap()
});
}

/// Benchmark CASE WHEN where all conditions are false.
#[divan::bench(args = [1000, 10000, 100000])]
fn case_when_all_false(bencher: Bencher, size: usize) {
let array = make_struct_array(size);

// CASE WHEN value > 1000000 THEN 100 ELSE 0 END (always false for our data)
let expr = case_when(
gt(get_item("value", root()), lit(1_000_000i32)),
lit(100i32),
lit(0i32),
);

bencher
.with_inputs(|| (&expr, &array))
.bench_refs(|(expr, array)| {
let mut ctx = SESSION.create_execution_ctx();
array
.apply(expr)
.unwrap()
.execute::<Canonical>(&mut ctx)
.unwrap()
});
}
122 changes: 122 additions & 0 deletions vortex-array/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -10106,6 +10106,10 @@ pub fn vortex_array::expr::and_collect<I>(iter: I) -> core::option::Option<vorte

pub fn vortex_array::expr::between(arr: vortex_array::expr::Expression, lower: vortex_array::expr::Expression, upper: vortex_array::expr::Expression, options: vortex_array::scalar_fn::fns::between::BetweenOptions) -> vortex_array::expr::Expression

pub fn vortex_array::expr::case_when(condition: vortex_array::expr::Expression, then_value: vortex_array::expr::Expression, else_value: vortex_array::expr::Expression) -> vortex_array::expr::Expression

pub fn vortex_array::expr::case_when_no_else(condition: vortex_array::expr::Expression, then_value: vortex_array::expr::Expression) -> vortex_array::expr::Expression

pub fn vortex_array::expr::cast(child: vortex_array::expr::Expression, target: vortex_array::dtype::DType) -> vortex_array::expr::Expression

pub fn vortex_array::expr::checked_add(lhs: vortex_array::expr::Expression, rhs: vortex_array::expr::Expression) -> vortex_array::expr::Expression
Expand Down Expand Up @@ -10160,6 +10164,8 @@ pub fn vortex_array::expr::merge(elements: impl core::iter::traits::collect::Int

pub fn vortex_array::expr::merge_opts(elements: impl core::iter::traits::collect::IntoIterator<Item = impl core::convert::Into<vortex_array::expr::Expression>>, duplicate_handling: vortex_array::scalar_fn::fns::merge::DuplicateHandling) -> vortex_array::expr::Expression

pub fn vortex_array::expr::nested_case_when(when_then_pairs: alloc::vec::Vec<(vortex_array::expr::Expression, vortex_array::expr::Expression)>, else_value: core::option::Option<vortex_array::expr::Expression>) -> vortex_array::expr::Expression

pub fn vortex_array::expr::not(operand: vortex_array::expr::Expression) -> vortex_array::expr::Expression

pub fn vortex_array::expr::not_eq(lhs: vortex_array::expr::Expression, rhs: vortex_array::expr::Expression) -> vortex_array::expr::Expression
Expand Down Expand Up @@ -13210,6 +13216,86 @@ pub fn vortex_array::scalar_fn::fns::binary::or_kleene(lhs: &vortex_array::Array

pub fn vortex_array::scalar_fn::fns::binary::scalar_cmp(lhs: &vortex_array::scalar::Scalar, rhs: &vortex_array::scalar::Scalar, operator: vortex_array::scalar_fn::fns::operators::CompareOperator) -> vortex_array::scalar::Scalar

pub mod vortex_array::scalar_fn::fns::case_when

pub struct vortex_array::scalar_fn::fns::case_when::CaseWhen

impl core::clone::Clone for vortex_array::scalar_fn::fns::case_when::CaseWhen

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::clone(&self) -> vortex_array::scalar_fn::fns::case_when::CaseWhen

impl vortex_array::scalar_fn::ScalarFnVTable for vortex_array::scalar_fn::fns::case_when::CaseWhen

pub type vortex_array::scalar_fn::fns::case_when::CaseWhen::Options = vortex_array::scalar_fn::fns::case_when::CaseWhenOptions

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::arity(&self, options: &Self::Options) -> vortex_array::scalar_fn::Arity

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::child_name(&self, options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::deserialize(&self, metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::ArrayRef>

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::fmt_sql(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::id(&self) -> vortex_array::scalar_fn::ScalarFnId

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::is_fallible(&self, _options: &Self::Options) -> bool

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::is_null_sensitive(&self, _options: &Self::Options) -> bool

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::reduce(&self, options: &Self::Options, node: &dyn vortex_array::scalar_fn::ReduceNode, ctx: &dyn vortex_array::scalar_fn::ReduceCtx) -> vortex_error::VortexResult<core::option::Option<vortex_array::scalar_fn::ReduceNodeRef>>

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::return_dtype(&self, options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult<vortex_array::dtype::DType>

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::simplify(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, ctx: &dyn vortex_array::scalar_fn::SimplifyCtx) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::Expression>>

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::simplify_untyped(&self, options: &Self::Options, expr: &vortex_array::expr::Expression) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::Expression>>

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::stat_expression(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, stat: vortex_array::expr::stats::Stat, catalog: &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option<vortex_array::expr::Expression>

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::stat_falsification(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, catalog: &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option<vortex_array::expr::Expression>

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::validity(&self, options: &Self::Options, expression: &vortex_array::expr::Expression) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::Expression>>

pub struct vortex_array::scalar_fn::fns::case_when::CaseWhenOptions

pub vortex_array::scalar_fn::fns::case_when::CaseWhenOptions::has_else: bool

pub vortex_array::scalar_fn::fns::case_when::CaseWhenOptions::num_when_then_pairs: u32

impl vortex_array::scalar_fn::fns::case_when::CaseWhenOptions

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhenOptions::num_children(&self) -> usize

impl core::clone::Clone for vortex_array::scalar_fn::fns::case_when::CaseWhenOptions

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhenOptions::clone(&self) -> vortex_array::scalar_fn::fns::case_when::CaseWhenOptions

impl core::cmp::Eq for vortex_array::scalar_fn::fns::case_when::CaseWhenOptions

impl core::cmp::PartialEq for vortex_array::scalar_fn::fns::case_when::CaseWhenOptions

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhenOptions::eq(&self, other: &vortex_array::scalar_fn::fns::case_when::CaseWhenOptions) -> bool

impl core::fmt::Debug for vortex_array::scalar_fn::fns::case_when::CaseWhenOptions

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhenOptions::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result

impl core::fmt::Display for vortex_array::scalar_fn::fns::case_when::CaseWhenOptions

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhenOptions::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result

impl core::hash::Hash for vortex_array::scalar_fn::fns::case_when::CaseWhenOptions

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhenOptions::hash<__H: core::hash::Hasher>(&self, state: &mut __H)

impl core::marker::Copy for vortex_array::scalar_fn::fns::case_when::CaseWhenOptions

impl core::marker::StructuralPartialEq for vortex_array::scalar_fn::fns::case_when::CaseWhenOptions

pub mod vortex_array::scalar_fn::fns::cast

pub struct vortex_array::scalar_fn::fns::cast::Cast
Expand Down Expand Up @@ -15042,6 +15128,42 @@ pub fn vortex_array::scalar_fn::fns::binary::Binary::stat_falsification(&self, o

pub fn vortex_array::scalar_fn::fns::binary::Binary::validity(&self, operator: &vortex_array::scalar_fn::fns::operators::Operator, expression: &vortex_array::expr::Expression) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::Expression>>

impl vortex_array::scalar_fn::ScalarFnVTable for vortex_array::scalar_fn::fns::case_when::CaseWhen

pub type vortex_array::scalar_fn::fns::case_when::CaseWhen::Options = vortex_array::scalar_fn::fns::case_when::CaseWhenOptions

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::arity(&self, options: &Self::Options) -> vortex_array::scalar_fn::Arity

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::child_name(&self, options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::deserialize(&self, metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::ArrayRef>

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::fmt_sql(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::id(&self) -> vortex_array::scalar_fn::ScalarFnId

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::is_fallible(&self, _options: &Self::Options) -> bool

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::is_null_sensitive(&self, _options: &Self::Options) -> bool

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::reduce(&self, options: &Self::Options, node: &dyn vortex_array::scalar_fn::ReduceNode, ctx: &dyn vortex_array::scalar_fn::ReduceCtx) -> vortex_error::VortexResult<core::option::Option<vortex_array::scalar_fn::ReduceNodeRef>>

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::return_dtype(&self, options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult<vortex_array::dtype::DType>

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::simplify(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, ctx: &dyn vortex_array::scalar_fn::SimplifyCtx) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::Expression>>

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::simplify_untyped(&self, options: &Self::Options, expr: &vortex_array::expr::Expression) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::Expression>>

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::stat_expression(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, stat: vortex_array::expr::stats::Stat, catalog: &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option<vortex_array::expr::Expression>

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::stat_falsification(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, catalog: &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option<vortex_array::expr::Expression>

pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::validity(&self, options: &Self::Options, expression: &vortex_array::expr::Expression) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::Expression>>

impl vortex_array::scalar_fn::ScalarFnVTable for vortex_array::scalar_fn::fns::cast::Cast

pub type vortex_array::scalar_fn::fns::cast::Cast::Options = vortex_array::dtype::DType
Expand Down
Loading
Loading