|
| 1 | +""" |
| 2 | +Numerical accuracy comparison helpers. |
| 3 | +
|
| 4 | +Compares the outputs of a PyTorch reference model and a TRT-compiled |
| 5 | +module on the same inputs, reporting cosine similarity, max/mean |
| 6 | +absolute error, and an allclose() pass/fail per output tensor. |
| 7 | +
|
| 8 | +Default tolerances are tuned for FP16 (atol=1e-2, rtol=1e-2, |
| 9 | +cos_sim_min=0.99). Override via --accuracy-atol / --accuracy-rtol / |
| 10 | +--accuracy-cos-sim-min when comparing tighter precisions or models |
| 11 | +known to have larger accumulated error. |
| 12 | +""" |
| 13 | +from __future__ import annotations |
| 14 | + |
| 15 | +from typing import Iterable |
| 16 | + |
| 17 | +import torch |
| 18 | +import torch.utils._pytree as pytree |
| 19 | + |
| 20 | + |
| 21 | +# --------------------------------------------------------------------------- # |
| 22 | +# Output flattening |
| 23 | +# --------------------------------------------------------------------------- # |
| 24 | + |
| 25 | +def _flatten_to_tensors(out) -> list[torch.Tensor]: |
| 26 | + """ |
| 27 | + Flatten an arbitrary HF model output (ModelOutput dataclass, dict, |
| 28 | + tuple, list, or single tensor) into a list of leaf tensors using |
| 29 | + torch's pytree. Non-tensor leaves are dropped. |
| 30 | + """ |
| 31 | + leaves, _ = pytree.tree_flatten(out) |
| 32 | + return [t for t in leaves if isinstance(t, torch.Tensor)] |
| 33 | + |
| 34 | + |
| 35 | +# --------------------------------------------------------------------------- # |
| 36 | +# Per-tensor metrics |
| 37 | +# --------------------------------------------------------------------------- # |
| 38 | + |
| 39 | +def _cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float: |
| 40 | + a = a.detach().to(torch.float32).flatten() |
| 41 | + b = b.detach().to(torch.float32).flatten() |
| 42 | + if a.numel() == 0 or b.numel() == 0: |
| 43 | + return float("nan") |
| 44 | + denom = (a.norm() * b.norm()).item() |
| 45 | + if denom == 0.0: |
| 46 | + # Both zero: define cosine as 1.0; one zero / one not: undefined → 0. |
| 47 | + return 1.0 if (a.norm().item() == 0 and b.norm().item() == 0) else 0.0 |
| 48 | + return (a @ b).item() / denom |
| 49 | + |
| 50 | + |
| 51 | +def per_tensor_metrics( |
| 52 | + pt: torch.Tensor, |
| 53 | + trt: torch.Tensor, |
| 54 | + *, |
| 55 | + atol: float = 1e-2, |
| 56 | + rtol: float = 1e-2, |
| 57 | +) -> dict: |
| 58 | + if pt.shape != trt.shape: |
| 59 | + return { |
| 60 | + "shape_pt": tuple(pt.shape), |
| 61 | + "shape_trt": tuple(trt.shape), |
| 62 | + "shape_match": False, |
| 63 | + "cos_sim": float("nan"), |
| 64 | + "max_abs": float("nan"), |
| 65 | + "mean_abs": float("nan"), |
| 66 | + "allclose": False, |
| 67 | + } |
| 68 | + |
| 69 | + # Cast both to FP32 for fair comparison; use CPU to avoid kernel rounding |
| 70 | + # nondeterminism between repeated GPU runs of the same op. |
| 71 | + a = pt.detach().to(torch.float32) |
| 72 | + b = trt.detach().to(torch.float32) |
| 73 | + |
| 74 | + diff = (a - b).abs() |
| 75 | + return { |
| 76 | + "shape_pt": tuple(pt.shape), |
| 77 | + "shape_trt": tuple(trt.shape), |
| 78 | + "shape_match": True, |
| 79 | + "dtype_pt": str(pt.dtype).replace("torch.", ""), |
| 80 | + "dtype_trt": str(trt.dtype).replace("torch.", ""), |
| 81 | + "cos_sim": _cosine_similarity(a, b), |
| 82 | + "max_abs": diff.max().item() if diff.numel() else 0.0, |
| 83 | + "mean_abs": diff.mean().item() if diff.numel() else 0.0, |
| 84 | + "allclose": torch.allclose(a, b, rtol=rtol, atol=atol), |
| 85 | + } |
| 86 | + |
| 87 | + |
| 88 | +# --------------------------------------------------------------------------- # |
| 89 | +# Compare two outputs (each a tensor / dict / dataclass / tuple) |
| 90 | +# --------------------------------------------------------------------------- # |
| 91 | + |
| 92 | +def compare_outputs( |
| 93 | + pt_out, |
| 94 | + trt_out, |
| 95 | + *, |
| 96 | + atol: float = 1e-2, |
| 97 | + rtol: float = 1e-2, |
| 98 | + output_names: Iterable[str] | None = None, |
| 99 | +) -> list[dict]: |
| 100 | + pt_leaves = _flatten_to_tensors(pt_out) |
| 101 | + trt_leaves = _flatten_to_tensors(trt_out) |
| 102 | + |
| 103 | + if len(pt_leaves) != len(trt_leaves): |
| 104 | + return [{ |
| 105 | + "name": "<output-count-mismatch>", |
| 106 | + "shape_pt": f"{len(pt_leaves)} tensors", |
| 107 | + "shape_trt": f"{len(trt_leaves)} tensors", |
| 108 | + "shape_match": False, |
| 109 | + "cos_sim": float("nan"), |
| 110 | + "max_abs": float("nan"), |
| 111 | + "mean_abs": float("nan"), |
| 112 | + "allclose": False, |
| 113 | + }] |
| 114 | + |
| 115 | + names = list(output_names) if output_names else [f"out[{i}]" for i in range(len(pt_leaves))] |
| 116 | + if len(names) < len(pt_leaves): |
| 117 | + names += [f"out[{i}]" for i in range(len(names), len(pt_leaves))] |
| 118 | + |
| 119 | + rows: list[dict] = [] |
| 120 | + for name, pt, trt in zip(names, pt_leaves, trt_leaves): |
| 121 | + row = {"name": name} |
| 122 | + row.update(per_tensor_metrics(pt, trt, atol=atol, rtol=rtol)) |
| 123 | + rows.append(row) |
| 124 | + return rows |
| 125 | + |
| 126 | + |
| 127 | +# --------------------------------------------------------------------------- # |
| 128 | +# Reporting |
| 129 | +# --------------------------------------------------------------------------- # |
| 130 | + |
| 131 | +def overall_pass( |
| 132 | + rows: list[dict], |
| 133 | + *, |
| 134 | + cos_sim_min: float = 0.99, |
| 135 | +) -> bool: |
| 136 | + """ |
| 137 | + A run passes if every output tensor has matching shape and cosine |
| 138 | + similarity above the threshold. |
| 139 | +
|
| 140 | + Cos-sim is the canonical numerical-equivalence metric; allclose is |
| 141 | + reported but does NOT gate the verdict. In FP16, isolated elements |
| 142 | + can drift past atol (e.g. one logit out of 50k vocab differs by 8.0) |
| 143 | + even when the two tensors are otherwise identical — cos_sim stays |
| 144 | + at 1.0 in those cases and that's the right answer. |
| 145 | + """ |
| 146 | + if not rows: |
| 147 | + return False |
| 148 | + for r in rows: |
| 149 | + if not r.get("shape_match", False): |
| 150 | + return False |
| 151 | + cs = r.get("cos_sim", 0.0) |
| 152 | + if cs != cs: # NaN |
| 153 | + return False |
| 154 | + if cs < cos_sim_min: |
| 155 | + return False |
| 156 | + return True |
| 157 | + |
| 158 | + |
| 159 | +def print_accuracy_table( |
| 160 | + rows: list[dict], |
| 161 | + *, |
| 162 | + title: str = "", |
| 163 | + cos_sim_min: float = 0.99, |
| 164 | +) -> None: |
| 165 | + if not rows: |
| 166 | + print("[accuracy] No outputs to compare.") |
| 167 | + return |
| 168 | + if title: |
| 169 | + print(f"\n{'=' * 70}") |
| 170 | + print(f" {title}") |
| 171 | + print(f"{'=' * 70}") |
| 172 | + |
| 173 | + cols = ("name", "shape_pt", "cos_sim", "max_abs", "mean_abs", "allclose") |
| 174 | + widths = {c: max(len(c), max(len(_fmt(r.get(c, ""), c)) for r in rows)) for c in cols} |
| 175 | + header = " ".join(c.ljust(widths[c]) for c in cols) |
| 176 | + print(header) |
| 177 | + print("-" * len(header)) |
| 178 | + for r in rows: |
| 179 | + print(" ".join(_fmt(r.get(c, ""), c).ljust(widths[c]) for c in cols)) |
| 180 | + |
| 181 | + overall = overall_pass(rows, cos_sim_min=cos_sim_min) |
| 182 | + verdict = "PASS" if overall else "FAIL" |
| 183 | + print(f"\nOverall: {verdict} (cos_sim_min={cos_sim_min})") |
| 184 | + |
| 185 | + |
| 186 | +def _fmt(v, col_name: str) -> str: |
| 187 | + if isinstance(v, float): |
| 188 | + if col_name in ("cos_sim",): |
| 189 | + return f"{v:.6f}" |
| 190 | + return f"{v:.3e}" |
| 191 | + if isinstance(v, bool): |
| 192 | + return "yes" if v else "no" |
| 193 | + return str(v) |
0 commit comments