Skip to content

Commit 1e52a33

Browse files
authored
Fix ValueError on read-only arrays in BaseSLearner.predict() (#878)
* #856: Fix ValueError on read-only arrays in BaseSLearner.predict() Learners like CatBoost set flags.writeable=False on arrays passed to predict(), causing the subsequent in-place mutation (X_new[:, 0] = 1) to raise ValueError. Build separate control and treatment arrays instead of mutating in place. * #856: Add regression test for read-only array prediction Adds ReadOnlyLinearRegression wrapper that simulates CatBoost's behavior of setting flags.writeable=False on arrays after fit/predict. Verifies BaseSLearner.predict() succeeds without ValueError.
1 parent b4c76dd commit 1e52a33

2 files changed

Lines changed: 46 additions & 12 deletions

File tree

causalml/inference/meta/slearner.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,14 @@ def predict(
109109
for group in self.t_groups:
110110
model = self.models[group]
111111

112-
# set the treatment column to zero (the control group)
113-
X_new = np.hstack((np.zeros((X.shape[0], 1)), X))
114-
yhat_cs[group] = model.predict(X_new)
112+
# Build separate arrays for control and treatment to avoid in-place
113+
# mutation, which fails when learners like CatBoost set the
114+
# writeable flag to False on arrays passed to predict().
115+
X_new_c = np.hstack((np.zeros((X.shape[0], 1)), X))
116+
yhat_cs[group] = model.predict(X_new_c)
115117

116-
# set the treatment column to one (the treatment group)
117-
X_new[:, 0] = 1
118-
yhat_ts[group] = model.predict(X_new)
118+
X_new_t = np.hstack((np.ones((X.shape[0], 1)), X))
119+
yhat_ts[group] = model.predict(X_new_t)
119120

120121
if (y is not None) and (treatment is not None) and verbose:
121122
mask = (treatment == group) | (treatment == self.control_name)
@@ -346,13 +347,14 @@ def predict(
346347
for group in self.t_groups:
347348
model = self.models[group]
348349

349-
# set the treatment column to zero (the control group)
350-
X_new = np.hstack((np.zeros((X.shape[0], 1)), X))
351-
yhat_cs[group] = model.predict_proba(X_new)[:, 1]
350+
# Build separate arrays for control and treatment to avoid in-place
351+
# mutation, which fails when learners like CatBoost set the
352+
# writeable flag to False on arrays passed to predict().
353+
X_new_c = np.hstack((np.zeros((X.shape[0], 1)), X))
354+
yhat_cs[group] = model.predict_proba(X_new_c)[:, 1]
352355

353-
# set the treatment column to one (the treatment group)
354-
X_new[:, 0] = 1
355-
yhat_ts[group] = model.predict_proba(X_new)[:, 1]
356+
X_new_t = np.hstack((np.ones((X.shape[0], 1)), X))
357+
yhat_ts[group] = model.predict_proba(X_new_t)[:, 1]
356358

357359
if y is not None and (treatment is not None) and verbose:
358360
mask = (treatment == group) | (treatment == self.control_name)

tests/test_meta_learners.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,23 @@
3838
from .const import RANDOM_SEED, N_SAMPLE, ERROR_THRESHOLD, CONTROL_NAME, CONVERSION
3939

4040

41+
class ReadOnlyLinearRegression:
42+
"""Minimal regressor that marks input arrays read-only like CatBoost."""
43+
44+
def __init__(self):
45+
self.model = LinearRegression()
46+
47+
def fit(self, X, y):
48+
self.model.fit(X, y)
49+
X.flags.writeable = False
50+
return self
51+
52+
def predict(self, X):
53+
result = self.model.predict(X)
54+
X.flags.writeable = False
55+
return result
56+
57+
4158
def test_synthetic_data():
4259
y, X, treatment, tau, b, e = synthetic_data(mode=1, n=N_SAMPLE, p=8, sigma=0.1)
4360

@@ -97,6 +114,21 @@ def test_BaseSLearner(generate_regression_data):
97114
assert (ate_p_pt == ate_p) and (lb_pt == lb) and (ub_pt == ub)
98115

99116

117+
def test_BaseSLearner_predict_with_readonly_arrays(generate_regression_data):
118+
y, X, treatment, _, _, _ = generate_regression_data()
119+
X_readonly = np.array(X, copy=True)
120+
X_readonly.flags.writeable = False
121+
122+
learner = BaseSLearner(learner=ReadOnlyLinearRegression())
123+
124+
# Exercise both fit() and predict() with read-only array behavior.
125+
learner.fit(X=X_readonly, treatment=treatment, y=y)
126+
cate = learner.predict(X=X_readonly)
127+
128+
assert cate.shape == (X.shape[0], 1)
129+
assert not X_readonly.flags.writeable
130+
131+
100132
def test_BaseSRegressor(generate_regression_data):
101133
y, X, treatment, tau, b, e = generate_regression_data()
102134

0 commit comments

Comments
 (0)