@@ -24,7 +24,7 @@ use arrow::util::pretty::{pretty_format_batches, pretty_format_columns};
2424use datafusion:: prelude:: * ;
2525use datafusion_common:: { DFSchema , ScalarValue } ;
2626use datafusion_expr:: ExprFunctionExt ;
27- use datafusion_expr:: expr:: NullTreatment ;
27+ use datafusion_expr:: expr:: { NullTreatment , WindowFunction , WindowFunctionDefinition } ;
2828use datafusion_expr:: simplify:: SimplifyContext ;
2929use datafusion_functions:: core:: expr_ext:: FieldAccessor ;
3030use datafusion_functions_aggregate:: first_last:: first_value_udaf;
@@ -263,6 +263,101 @@ async fn test_aggregate_ext_distinct() {
263263 . await ;
264264}
265265
266+ /// Test that `filter()` works as the first chained call on a window function.
267+ /// This verifies the fix for https://github.com/apache/datafusion/issues/21697
268+ #[ tokio:: test]
269+ async fn test_window_ext_filter_first ( ) {
270+ // Build a window function with filter as the FIRST chained method.
271+ // Before the fix, this would fail because filter() on Expr only handled
272+ // AggregateFunction, not WindowFunction.
273+ let window_expr = Expr :: from ( WindowFunction :: new (
274+ WindowFunctionDefinition :: AggregateUDF ( sum_udaf ( ) ) ,
275+ vec ! [ col( "i" ) ] ,
276+ ) )
277+ . filter ( col ( "i" ) . is_not_null ( ) )
278+ . order_by ( vec ! [ col( "id" ) . sort( true , true ) ] )
279+ . build ( )
280+ . unwrap ( )
281+ . alias ( "sum_filtered" ) ;
282+
283+ let ctx = SessionContext :: new ( ) ;
284+ let result = ctx
285+ . read_batch ( TEST_BATCH . clone ( ) )
286+ . unwrap ( )
287+ . select ( vec ! [ col( "id" ) , col( "i" ) , window_expr] )
288+ . unwrap ( )
289+ . collect ( )
290+ . await
291+ . unwrap ( ) ;
292+
293+ let result = pretty_format_batches ( & result) . unwrap ( ) . to_string ( ) ;
294+ let actual_lines = result. lines ( ) . collect :: < Vec < _ > > ( ) ;
295+
296+ // TEST_BATCH: id=["1","2","3"], i=[10, NULL, 5]
297+ // Ordered by id ASC: row1(id=1,i=10), row2(id=2,i=NULL), row3(id=3,i=5)
298+ // FILTER(i IS NOT NULL) excludes row2 from the sum
299+ // Running sum: row1=10, row2=10 (NULL filtered out), row3=15
300+ let expected_lines = vec ! [
301+ "+----+----+--------------+" ,
302+ "| id | i | sum_filtered |" ,
303+ "+----+----+--------------+" ,
304+ "| 1 | 10 | 10 |" ,
305+ "| 2 | | 10 |" ,
306+ "| 3 | 5 | 15 |" ,
307+ "+----+----+--------------+" ,
308+ ] ;
309+
310+ assert_eq ! (
311+ expected_lines, actual_lines,
312+ "\n \n expected:\n \n {expected_lines:#?}\n actual:\n \n {actual_lines:#?}\n \n "
313+ ) ;
314+ }
315+
316+ /// Test that `distinct()` works as the first chained call on a window function.
317+ #[ tokio:: test]
318+ async fn test_window_ext_distinct_first ( ) {
319+ let window_expr = Expr :: from ( WindowFunction :: new (
320+ WindowFunctionDefinition :: AggregateUDF ( sum_udaf ( ) ) ,
321+ vec ! [ col( "i" ) ] ,
322+ ) )
323+ . distinct ( )
324+ . order_by ( vec ! [ col( "id" ) . sort( true , true ) ] )
325+ . build ( )
326+ . unwrap ( )
327+ . alias ( "sum_distinct" ) ;
328+
329+ let ctx = SessionContext :: new ( ) ;
330+ let result = ctx
331+ . read_batch ( TEST_BATCH . clone ( ) )
332+ . unwrap ( )
333+ . select ( vec ! [ col( "id" ) , col( "i" ) , window_expr] )
334+ . unwrap ( )
335+ . collect ( )
336+ . await
337+ . unwrap ( ) ;
338+
339+ let result = pretty_format_batches ( & result) . unwrap ( ) . to_string ( ) ;
340+ let actual_lines = result. lines ( ) . collect :: < Vec < _ > > ( ) ;
341+
342+ // TEST_BATCH: id=["1","2","3"], i=[10, NULL, 5]
343+ // Ordered by id ASC: row1(id=1,i=10), row2(id=2,i=NULL), row3(id=3,i=5)
344+ // DISTINCT running sum: row1=10, row2=10 (NULL skipped), row3=15
345+ let expected_lines = vec ! [
346+ "+----+----+--------------+" ,
347+ "| id | i | sum_distinct |" ,
348+ "+----+----+--------------+" ,
349+ "| 1 | 10 | 10 |" ,
350+ "| 2 | | 10 |" ,
351+ "| 3 | 5 | 15 |" ,
352+ "+----+----+--------------+" ,
353+ ] ;
354+
355+ assert_eq ! (
356+ expected_lines, actual_lines,
357+ "\n \n expected:\n \n {expected_lines:#?}\n actual:\n \n {actual_lines:#?}\n \n "
358+ ) ;
359+ }
360+
266361#[ tokio:: test]
267362async fn test_aggregate_ext_null_treatment ( ) {
268363 let agg = first_value_udaf ( )
0 commit comments