File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -299,3 +299,26 @@ def test_CausalRandomForestRegressor_no_inf_predictions():
299299 preds = model .predict (X = X )
300300
301301 assert np .all (np .isfinite (preds )), "Predictions contain inf or NaN values"
302+
303+
304+ def test_CausalRandomForestRegressor_no_inf_predictions_ttest ():
305+ """Test that CausalRandomForestRegressor with criterion='ttest' does not
306+ predict inf values when some tree splits have zero-count
307+ treatment/control groups (#589)."""
308+ np .random .seed (RANDOM_SEED )
309+ n = 100
310+ X = np .random .randn (n , 5 )
311+ treatment = np .array ([0 ] * 90 + [1 ] * 10 )
312+ y = np .random .randn (n )
313+
314+ model = CausalRandomForestRegressor (
315+ criterion = "ttest" ,
316+ control_name = 0 ,
317+ n_estimators = 10 ,
318+ min_samples_leaf = 1 ,
319+ random_state = RANDOM_SEED ,
320+ )
321+ model .fit (X = X , treatment = treatment , y = y )
322+ preds = model .predict (X = X )
323+
324+ assert np .all (np .isfinite (preds )), "Predictions contain inf or NaN values"
You can’t perform that action at this time.
0 commit comments