Skip to content

Commit 4e40499

Browse files
committed
fix: handle WindowFunction in ExprFunctionExt filter and distinct on Expr
1 parent 3aefba7 commit 4e40499

2 files changed

Lines changed: 205 additions & 11 deletions

File tree

datafusion/core/tests/expr_api/mod.rs

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use arrow::util::pretty::{pretty_format_batches, pretty_format_columns};
2424
use datafusion::prelude::*;
2525
use datafusion_common::{DFSchema, ScalarValue};
2626
use datafusion_expr::ExprFunctionExt;
27-
use datafusion_expr::expr::NullTreatment;
27+
use datafusion_expr::expr::{NullTreatment, WindowFunction, WindowFunctionDefinition};
2828
use datafusion_expr::simplify::SimplifyContext;
2929
use datafusion_functions::core::expr_ext::FieldAccessor;
3030
use datafusion_functions_aggregate::first_last::first_value_udaf;
@@ -263,6 +263,101 @@ async fn test_aggregate_ext_distinct() {
263263
.await;
264264
}
265265

266+
/// Test that `filter()` works as the first chained call on a window function.
267+
/// This verifies the fix for https://github.com/apache/datafusion/issues/21697
268+
#[tokio::test]
269+
async fn test_window_ext_filter_first() {
270+
// Build a window function with filter as the FIRST chained method.
271+
// Before the fix, this would fail because filter() on Expr only handled
272+
// AggregateFunction, not WindowFunction.
273+
let window_expr = Expr::from(WindowFunction::new(
274+
WindowFunctionDefinition::AggregateUDF(sum_udaf()),
275+
vec![col("i")],
276+
))
277+
.filter(col("i").is_not_null())
278+
.order_by(vec![col("id").sort(true, true)])
279+
.build()
280+
.unwrap()
281+
.alias("sum_filtered");
282+
283+
let ctx = SessionContext::new();
284+
let result = ctx
285+
.read_batch(TEST_BATCH.clone())
286+
.unwrap()
287+
.select(vec![col("id"), col("i"), window_expr])
288+
.unwrap()
289+
.collect()
290+
.await
291+
.unwrap();
292+
293+
let result = pretty_format_batches(&result).unwrap().to_string();
294+
let actual_lines = result.lines().collect::<Vec<_>>();
295+
296+
// TEST_BATCH: id=["1","2","3"], i=[10, NULL, 5]
297+
// Ordered by id ASC: row1(id=1,i=10), row2(id=2,i=NULL), row3(id=3,i=5)
298+
// FILTER(i IS NOT NULL) excludes row2 from the sum
299+
// Running sum: row1=10, row2=10 (NULL filtered out), row3=15
300+
let expected_lines = vec![
301+
"+----+----+--------------+",
302+
"| id | i | sum_filtered |",
303+
"+----+----+--------------+",
304+
"| 1 | 10 | 10 |",
305+
"| 2 | | 10 |",
306+
"| 3 | 5 | 15 |",
307+
"+----+----+--------------+",
308+
];
309+
310+
assert_eq!(
311+
expected_lines, actual_lines,
312+
"\n\nexpected:\n\n{expected_lines:#?}\nactual:\n\n{actual_lines:#?}\n\n"
313+
);
314+
}
315+
316+
/// Test that `distinct()` works as the first chained call on a window function.
317+
#[tokio::test]
318+
async fn test_window_ext_distinct_first() {
319+
let window_expr = Expr::from(WindowFunction::new(
320+
WindowFunctionDefinition::AggregateUDF(sum_udaf()),
321+
vec![col("i")],
322+
))
323+
.distinct()
324+
.order_by(vec![col("id").sort(true, true)])
325+
.build()
326+
.unwrap()
327+
.alias("sum_distinct");
328+
329+
let ctx = SessionContext::new();
330+
let result = ctx
331+
.read_batch(TEST_BATCH.clone())
332+
.unwrap()
333+
.select(vec![col("id"), col("i"), window_expr])
334+
.unwrap()
335+
.collect()
336+
.await
337+
.unwrap();
338+
339+
let result = pretty_format_batches(&result).unwrap().to_string();
340+
let actual_lines = result.lines().collect::<Vec<_>>();
341+
342+
// TEST_BATCH: id=["1","2","3"], i=[10, NULL, 5]
343+
// Ordered by id ASC: row1(id=1,i=10), row2(id=2,i=NULL), row3(id=3,i=5)
344+
// DISTINCT running sum: row1=10, row2=10 (NULL skipped), row3=15
345+
let expected_lines = vec![
346+
"+----+----+--------------+",
347+
"| id | i | sum_distinct |",
348+
"+----+----+--------------+",
349+
"| 1 | 10 | 10 |",
350+
"| 2 | | 10 |",
351+
"| 3 | 5 | 15 |",
352+
"+----+----+--------------+",
353+
];
354+
355+
assert_eq!(
356+
expected_lines, actual_lines,
357+
"\n\nexpected:\n\n{expected_lines:#?}\nactual:\n\n{actual_lines:#?}\n\n"
358+
);
359+
}
360+
266361
#[tokio::test]
267362
async fn test_aggregate_ext_null_treatment() {
268363
let agg = first_value_udaf()

datafusion/expr/src/expr_fn.rs

Lines changed: 109 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -932,26 +932,34 @@ impl ExprFunctionExt for Expr {
932932
builder
933933
}
934934
fn filter(self, filter: Expr) -> ExprFuncBuilder {
935-
match self {
935+
let mut builder = match self {
936936
Expr::AggregateFunction(udaf) => {
937-
let mut builder =
938-
ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
939-
builder.filter = Some(filter);
940-
builder
937+
ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
938+
}
939+
Expr::WindowFunction(udwf) => {
940+
ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)))
941941
}
942942
_ => ExprFuncBuilder::new(None),
943+
};
944+
if builder.fun.is_some() {
945+
builder.filter = Some(filter);
943946
}
947+
builder
944948
}
945949
fn distinct(self) -> ExprFuncBuilder {
946-
match self {
950+
let mut builder = match self {
947951
Expr::AggregateFunction(udaf) => {
948-
let mut builder =
949-
ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
950-
builder.distinct = true;
951-
builder
952+
ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
953+
}
954+
Expr::WindowFunction(udwf) => {
955+
ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)))
952956
}
953957
_ => ExprFuncBuilder::new(None),
958+
};
959+
if builder.fun.is_some() {
960+
builder.distinct = true;
954961
}
962+
builder
955963
}
956964
fn null_treatment(
957965
self,
@@ -998,6 +1006,10 @@ impl ExprFunctionExt for Expr {
9981006
#[cfg(test)]
9991007
mod test {
10001008
use super::*;
1009+
use crate::WindowFunctionDefinition;
1010+
use crate::expr::{AggregateFunction, WindowFunction};
1011+
use crate::lit;
1012+
use crate::test::function_stub::sum_udaf;
10011013

10021014
#[test]
10031015
fn filter_is_null_and_is_not_null() {
@@ -1009,4 +1021,91 @@ mod test {
10091021
"col2 IS NOT NULL"
10101022
);
10111023
}
1024+
1025+
/// Create a window function expression for testing
1026+
fn test_window_expr() -> Expr {
1027+
Expr::WindowFunction(Box::new(WindowFunction::new(
1028+
WindowFunctionDefinition::AggregateUDF(sum_udaf()),
1029+
vec![col("x")],
1030+
)))
1031+
}
1032+
1033+
#[test]
1034+
fn test_window_filter_first() {
1035+
let result = test_window_expr().filter(col("a").gt_eq(lit(5))).build();
1036+
assert!(
1037+
result.is_ok(),
1038+
"filter as first call on WindowFunction should succeed"
1039+
);
1040+
let expr = result.unwrap();
1041+
match expr {
1042+
Expr::WindowFunction(wf) => {
1043+
assert!(wf.params.filter.is_some(), "filter should be set");
1044+
}
1045+
other => panic!("expected WindowFunction, got {other:?}"),
1046+
}
1047+
}
1048+
1049+
#[test]
1050+
fn test_window_distinct_first() {
1051+
let result = test_window_expr().distinct().build();
1052+
assert!(
1053+
result.is_ok(),
1054+
"distinct as first call on WindowFunction should succeed"
1055+
);
1056+
let expr = result.unwrap();
1057+
match expr {
1058+
Expr::WindowFunction(wf) => {
1059+
assert!(wf.params.distinct, "distinct should be true");
1060+
}
1061+
other => panic!("expected WindowFunction, got {other:?}"),
1062+
}
1063+
}
1064+
1065+
#[test]
1066+
fn test_window_filter_then_partition_by() {
1067+
let result = test_window_expr()
1068+
.filter(col("a").gt_eq(lit(5)))
1069+
.partition_by(vec![col("y")])
1070+
.build();
1071+
assert!(
1072+
result.is_ok(),
1073+
"filter then partition_by on WindowFunction should succeed"
1074+
);
1075+
let expr = result.unwrap();
1076+
match expr {
1077+
Expr::WindowFunction(wf) => {
1078+
assert!(wf.params.filter.is_some(), "filter should be set");
1079+
assert_eq!(
1080+
wf.params.partition_by.len(),
1081+
1,
1082+
"partition_by should have one entry"
1083+
);
1084+
}
1085+
other => panic!("expected WindowFunction, got {other:?}"),
1086+
}
1087+
}
1088+
1089+
#[test]
1090+
fn test_aggregate_filter_still_works() {
1091+
let agg = Expr::AggregateFunction(AggregateFunction::new_udf(
1092+
sum_udaf(),
1093+
vec![col("x")],
1094+
false,
1095+
None,
1096+
vec![],
1097+
None,
1098+
));
1099+
let result = agg.filter(col("a").gt_eq(lit(5))).build();
1100+
assert!(
1101+
result.is_ok(),
1102+
"filter on AggregateFunction should still work"
1103+
);
1104+
match result.unwrap() {
1105+
Expr::AggregateFunction(af) => {
1106+
assert!(af.params.filter.is_some(), "filter should be set");
1107+
}
1108+
other => panic!("expected AggregateFunction, got {other:?}"),
1109+
}
1110+
}
10121111
}

0 commit comments

Comments
 (0)