|
| 1 | +from typing import Optional |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import pandas as pd |
| 5 | +from matplotlib import pyplot as plt |
| 6 | +import logging |
| 7 | + |
| 8 | +plt.style.use("fivethirtyeight") |
| 9 | +RANDOM_COL = "Random" |
| 10 | + |
| 11 | +logger = logging.getLogger("causalml") |
| 12 | + |
| 13 | + |
| 14 | +def get_toc( |
| 15 | + df, |
| 16 | + outcome_col="y", |
| 17 | + treatment_col="w", |
| 18 | + treatment_effect_col="tau", |
| 19 | + normalize=False, |
| 20 | +): |
| 21 | + """Get the Targeting Operator Characteristic (TOC) of model estimates in population. |
| 22 | +
|
| 23 | + TOC(q) is the difference between the ATE among the top-q fraction of units ranked |
| 24 | + by the prioritization score and the overall ATE. A positive TOC at low q indicates |
| 25 | + the model successfully identifies units with above-average treatment benefit. |
| 26 | +
|
| 27 | + By definition, TOC(0) = 0 and TOC(1) = 0 (the subset ATE equals the overall ATE |
| 28 | + when the entire population is selected). |
| 29 | +
|
| 30 | + If the true treatment effect is provided (e.g. in synthetic data), it's used directly |
| 31 | + to calculate TOC. Otherwise, it's estimated as the difference between the mean outcomes |
| 32 | + of the treatment and control groups in each quantile band. |
| 33 | +
|
| 34 | + Note: when using observed outcomes, if a quantile band contains only treated or only |
| 35 | + control units, the code falls back to TOC(q) = 0 for that band (i.e., subset ATE is |
| 36 | + set to the overall ATE). This is a conservative approximation and is logged as a warning. |
| 37 | +
|
| 38 | + For details, see Yadlowsky et al. (2021), `Evaluating Treatment Prioritization Rules |
| 39 | + via Rank-Weighted Average Treatment Effects`. https://arxiv.org/abs/2111.07966 |
| 40 | +
|
| 41 | + For the former, `treatment_effect_col` should be provided. For the latter, both |
| 42 | + `outcome_col` and `treatment_col` should be provided. |
| 43 | +
|
| 44 | + Args: |
| 45 | + df (pandas.DataFrame): a data frame with model estimates and actual data as columns |
| 46 | + outcome_col (str, optional): the column name for the actual outcome |
| 47 | + treatment_col (str, optional): the column name for the treatment indicator (0 or 1) |
| 48 | + treatment_effect_col (str, optional): the column name for the true treatment effect |
| 49 | + normalize (bool, optional): whether to normalize the TOC curve by its maximum |
| 50 | + absolute value. Uses max(|TOC|) as the reference to avoid division by zero |
| 51 | + at q=1 where TOC is always zero by definition. |
| 52 | +
|
| 53 | + Returns: |
| 54 | + (pandas.DataFrame): TOC values of model estimates in population, indexed by quantile q |
| 55 | + """ |
| 56 | + assert ( |
| 57 | + (outcome_col in df.columns and df[outcome_col].notnull().all()) |
| 58 | + and (treatment_col in df.columns and df[treatment_col].notnull().all()) |
| 59 | + or ( |
| 60 | + treatment_effect_col in df.columns |
| 61 | + and df[treatment_effect_col].notnull().all() |
| 62 | + ) |
| 63 | + ), "{outcome_col} and {treatment_col}, or {treatment_effect_col} should be present without null.".format( |
| 64 | + outcome_col=outcome_col, |
| 65 | + treatment_col=treatment_col, |
| 66 | + treatment_effect_col=treatment_effect_col, |
| 67 | + ) |
| 68 | + |
| 69 | + df = df.copy() |
| 70 | + |
| 71 | + model_names = [ |
| 72 | + x |
| 73 | + for x in df.columns |
| 74 | + if x not in [outcome_col, treatment_col, treatment_effect_col] |
| 75 | + ] |
| 76 | + |
| 77 | + use_oracle = ( |
| 78 | + treatment_effect_col in df.columns and df[treatment_effect_col].notnull().all() |
| 79 | + ) |
| 80 | + |
| 81 | + if use_oracle: |
| 82 | + overall_ate = df[treatment_effect_col].mean() |
| 83 | + else: |
| 84 | + treated = df[treatment_col] == 1 |
| 85 | + overall_ate = ( |
| 86 | + df.loc[treated, outcome_col].mean() - df.loc[~treated, outcome_col].mean() |
| 87 | + ) |
| 88 | + |
| 89 | + n_total = len(df) |
| 90 | + |
| 91 | + toc = [] |
| 92 | + for col in model_names: |
| 93 | + sorted_df = df.sort_values(col, ascending=False).reset_index(drop=True) |
| 94 | + |
| 95 | + if use_oracle: |
| 96 | + # O(n) via cumulative sum |
| 97 | + cumsum_tau = sorted_df[treatment_effect_col].cumsum().values |
| 98 | + counts = np.arange(1, n_total + 1) |
| 99 | + subset_ates = cumsum_tau / counts |
| 100 | + else: |
| 101 | + cumsum_tr = sorted_df[treatment_col].cumsum().values |
| 102 | + cumsum_ct = np.arange(1, n_total + 1) - cumsum_tr |
| 103 | + cumsum_y_tr = ( |
| 104 | + (sorted_df[outcome_col] * sorted_df[treatment_col]).cumsum().values |
| 105 | + ) |
| 106 | + cumsum_y_ct = ( |
| 107 | + (sorted_df[outcome_col] * (1 - sorted_df[treatment_col])) |
| 108 | + .cumsum() |
| 109 | + .values |
| 110 | + ) |
| 111 | + |
| 112 | + # Guard against division by zero when a band is all-treated or all-control; |
| 113 | + # fall back to overall_ate (TOC = 0) for those positions. |
| 114 | + with np.errstate(invalid="ignore", divide="ignore"): |
| 115 | + subset_ates = np.where( |
| 116 | + (cumsum_tr == 0) | (cumsum_ct == 0), |
| 117 | + overall_ate, |
| 118 | + cumsum_y_tr / cumsum_tr - cumsum_y_ct / cumsum_ct, |
| 119 | + ) |
| 120 | + |
| 121 | + if np.any((cumsum_tr == 0) | (cumsum_ct == 0)): |
| 122 | + logger.warning( |
| 123 | + "Some quantile bands contain only treated or only control units " |
| 124 | + "for column '%s'. TOC is set to 0 for those positions.", |
| 125 | + col, |
| 126 | + ) |
| 127 | + |
| 128 | + toc_values = subset_ates - overall_ate |
| 129 | + toc.append(pd.Series(toc_values, index=np.linspace(0, 1, n_total + 1)[1:])) |
| 130 | + |
| 131 | + toc = pd.concat(toc, join="inner", axis=1) |
| 132 | + toc.loc[0] = np.zeros((toc.shape[1],)) |
| 133 | + toc = toc.sort_index().interpolate() |
| 134 | + toc.columns = model_names |
| 135 | + toc.index.name = "q" |
| 136 | + |
| 137 | + if normalize: |
| 138 | + # Normalize by max absolute value rather than the value at q=1, which is |
| 139 | + # always zero by definition and would cause division by zero. |
| 140 | + max_abs = toc.abs().max() |
| 141 | + max_abs = max_abs.replace(0, 1) # guard for flat TOC curves |
| 142 | + toc = toc.div(max_abs, axis=1) |
| 143 | + |
| 144 | + return toc |
| 145 | + |
| 146 | + |
| 147 | +def rate_score( |
| 148 | + df, |
| 149 | + outcome_col="y", |
| 150 | + treatment_col="w", |
| 151 | + treatment_effect_col="tau", |
| 152 | + weighting="autoc", |
| 153 | + normalize=False, |
| 154 | +): |
| 155 | + """Calculate the Rank-weighted Average Treatment Effect (RATE) score. |
| 156 | +
|
| 157 | + RATE is the weighted area under the Targeting Operator Characteristic (TOC) curve: |
| 158 | +
|
| 159 | + RATE = integral_0^1 alpha(q) * TOC(q) dq |
| 160 | +
|
| 161 | + Two standard weighting schemes are supported (Yadlowsky et al., 2021): |
| 162 | +
|
| 163 | + - ``"autoc"``: alpha(q) = 1/q. Places more weight on the highest-priority units. |
| 164 | + Most powerful when treatment effects are concentrated in a small subgroup. |
| 165 | +
|
| 166 | + - ``"qini"``: alpha(q) = q. Uniform weighting across units; reduces to the Qini |
| 167 | + coefficient. More powerful when treatment effects are diffuse across the population. |
| 168 | +
|
| 169 | + A positive RATE indicates the prioritization rule effectively identifies units with |
| 170 | + above-average treatment benefit. A RATE near zero suggests little heterogeneity or |
| 171 | + a poor prioritization rule. |
| 172 | +
|
| 173 | + Note: the integral is approximated via a weighted mean over the discrete quantile grid |
| 174 | + using midpoint values. Weights are normalized to sum to 1 (i.e. ``weights / weights.sum()``), |
| 175 | + so the absolute scale matches the TOC values but may differ slightly from the paper's |
| 176 | + continuous integral definition. Model rankings are preserved. |
| 177 | +
|
| 178 | + For details, see Yadlowsky et al. (2021), `Evaluating Treatment Prioritization Rules |
| 179 | + via Rank-Weighted Average Treatment Effects`. https://arxiv.org/abs/2111.07966 |
| 180 | +
|
| 181 | + For the former, `treatment_effect_col` should be provided. For the latter, both |
| 182 | + `outcome_col` and `treatment_col` should be provided. |
| 183 | +
|
| 184 | + Args: |
| 185 | + df (pandas.DataFrame): a data frame with model estimates and actual data as columns |
| 186 | + outcome_col (str, optional): the column name for the actual outcome |
| 187 | + treatment_col (str, optional): the column name for the treatment indicator (0 or 1) |
| 188 | + treatment_effect_col (str, optional): the column name for the true treatment effect |
| 189 | + weighting (str, optional): the weighting scheme for the RATE integral. |
| 190 | + One of ``"autoc"`` (default) or ``"qini"``. |
| 191 | + normalize (bool, optional): whether to normalize the TOC curve before scoring |
| 192 | +
|
| 193 | + Returns: |
| 194 | + (pandas.Series): RATE scores of model estimates |
| 195 | + """ |
| 196 | + assert weighting in ( |
| 197 | + "autoc", |
| 198 | + "qini", |
| 199 | + ), "{} weighting is not implemented. Select one of {}".format( |
| 200 | + weighting, ("autoc", "qini") |
| 201 | + ) |
| 202 | + |
| 203 | + toc = get_toc( |
| 204 | + df, |
| 205 | + outcome_col=outcome_col, |
| 206 | + treatment_col=treatment_col, |
| 207 | + treatment_effect_col=treatment_effect_col, |
| 208 | + normalize=normalize, |
| 209 | + ) |
| 210 | + |
| 211 | + quantiles = toc.index.values # includes 0 and 1 |
| 212 | + |
| 213 | + # Use midpoints to avoid division by zero for autoc at q=0 |
| 214 | + q_mid = (quantiles[:-1] + quantiles[1:]) / 2 |
| 215 | + toc_mid = (toc.iloc[:-1].values + toc.iloc[1:].values) / 2 |
| 216 | + |
| 217 | + if weighting == "autoc": |
| 218 | + weights = 1.0 / q_mid |
| 219 | + else: |
| 220 | + weights = q_mid |
| 221 | + |
| 222 | + # Normalize weights so they sum to 1 over the integration domain |
| 223 | + weights = weights / weights.sum() |
| 224 | + |
| 225 | + rate = pd.Series( |
| 226 | + np.average(toc_mid, axis=0, weights=weights), |
| 227 | + index=toc.columns, |
| 228 | + ) |
| 229 | + rate.name = "RATE ({})".format(weighting) |
| 230 | + return rate |
| 231 | + |
| 232 | + |
| 233 | +def plot_toc( |
| 234 | + df, |
| 235 | + outcome_col="y", |
| 236 | + treatment_col="w", |
| 237 | + treatment_effect_col="tau", |
| 238 | + normalize=False, |
| 239 | + n=100, |
| 240 | + figsize=(8, 8), |
| 241 | + ax: Optional[plt.Axes] = None, |
| 242 | +) -> plt.Axes: |
| 243 | + """Plot the Targeting Operator Characteristic (TOC) curve of model estimates. |
| 244 | +
|
| 245 | + The TOC(q) shows the excess ATE when treating only the top-q fraction of units |
| 246 | + prioritized by a model score, relative to the overall ATE. A positive and steeply |
| 247 | + decreasing curve indicates the model effectively ranks high-benefit units first. |
| 248 | +
|
| 249 | + If the true treatment effect is provided (e.g. in synthetic data), it's used directly. |
| 250 | + Otherwise, it's estimated from observed outcomes and treatment assignments. |
| 251 | +
|
| 252 | + For details, see Yadlowsky et al. (2021), `Evaluating Treatment Prioritization Rules |
| 253 | + via Rank-Weighted Average Treatment Effects`. https://arxiv.org/abs/2111.07966 |
| 254 | +
|
| 255 | + For the former, `treatment_effect_col` should be provided. For the latter, both |
| 256 | + `outcome_col` and `treatment_col` should be provided. |
| 257 | +
|
| 258 | + Args: |
| 259 | + df (pandas.DataFrame): a data frame with model estimates and actual data as columns |
| 260 | + outcome_col (str, optional): the column name for the actual outcome |
| 261 | + treatment_col (str, optional): the column name for the treatment indicator (0 or 1) |
| 262 | + treatment_effect_col (str, optional): the column name for the true treatment effect |
| 263 | + normalize (bool, optional): whether to normalize the TOC curve by its maximum |
| 264 | + absolute value before plotting |
| 265 | + n (int, optional): the number of samples to be used for plotting |
| 266 | + figsize (tuple, optional): the size of the figure to plot |
| 267 | + ax (plt.Axes, optional): an existing axes object to draw on |
| 268 | +
|
| 269 | + Returns: |
| 270 | + (plt.Axes): the matplotlib Axes with the TOC plot |
| 271 | + """ |
| 272 | + toc = get_toc( |
| 273 | + df, |
| 274 | + outcome_col=outcome_col, |
| 275 | + treatment_col=treatment_col, |
| 276 | + treatment_effect_col=treatment_effect_col, |
| 277 | + normalize=normalize, |
| 278 | + ) |
| 279 | + |
| 280 | + if (n is not None) and (n < toc.shape[0]): |
| 281 | + toc = toc.iloc[np.linspace(0, len(toc) - 1, n, endpoint=True).astype(int)] |
| 282 | + |
| 283 | + if ax is None: |
| 284 | + _, ax = plt.subplots(figsize=figsize) |
| 285 | + |
| 286 | + ax = toc.plot(ax=ax) |
| 287 | + |
| 288 | + # Random baseline (TOC = 0 everywhere) |
| 289 | + ax.plot( |
| 290 | + [toc.index[0], toc.index[-1]], |
| 291 | + [0, 0], |
| 292 | + label=RANDOM_COL, |
| 293 | + color="k", |
| 294 | + linestyle="--", |
| 295 | + ) |
| 296 | + ax.legend() |
| 297 | + ax.set_xlabel("Fraction treated (q)") |
| 298 | + ax.set_ylabel("TOC(q)") |
| 299 | + ax.set_title("Targeting Operator Characteristic (TOC)") |
| 300 | + |
| 301 | + return ax |
0 commit comments