Skip to content

Commit 55e5086

Browse files
committed
feat: a generic run hf tool that works for a number of model classes
1 parent f59c9b1 commit 55e5086

21 files changed

Lines changed: 3786 additions & 3 deletions

tools/hf/__init__.py

Whitespace-only changes.

tools/hf/common/__init__.py

Whitespace-only changes.

tools/hf/common/accuracy.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
"""
2+
Numerical accuracy comparison helpers.
3+
4+
Compares the outputs of a PyTorch reference model and a TRT-compiled
5+
module on the same inputs, reporting cosine similarity, max/mean
6+
absolute error, and an allclose() pass/fail per output tensor.
7+
8+
Default tolerances are tuned for FP16 (atol=1e-2, rtol=1e-2,
9+
cos_sim_min=0.99). Override via --accuracy-atol / --accuracy-rtol /
10+
--accuracy-cos-sim-min when comparing tighter precisions or models
11+
known to have larger accumulated error.
12+
"""
13+
from __future__ import annotations
14+
15+
from typing import Iterable
16+
17+
import torch
18+
import torch.utils._pytree as pytree
19+
20+
21+
# --------------------------------------------------------------------------- #
22+
# Output flattening
23+
# --------------------------------------------------------------------------- #
24+
25+
def _flatten_to_tensors(out) -> list[torch.Tensor]:
26+
"""
27+
Flatten an arbitrary HF model output (ModelOutput dataclass, dict,
28+
tuple, list, or single tensor) into a list of leaf tensors using
29+
torch's pytree. Non-tensor leaves are dropped.
30+
"""
31+
leaves, _ = pytree.tree_flatten(out)
32+
return [t for t in leaves if isinstance(t, torch.Tensor)]
33+
34+
35+
# --------------------------------------------------------------------------- #
36+
# Per-tensor metrics
37+
# --------------------------------------------------------------------------- #
38+
39+
def _cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float:
40+
a = a.detach().to(torch.float32).flatten()
41+
b = b.detach().to(torch.float32).flatten()
42+
if a.numel() == 0 or b.numel() == 0:
43+
return float("nan")
44+
denom = (a.norm() * b.norm()).item()
45+
if denom == 0.0:
46+
# Both zero: define cosine as 1.0; one zero / one not: undefined → 0.
47+
return 1.0 if (a.norm().item() == 0 and b.norm().item() == 0) else 0.0
48+
return (a @ b).item() / denom
49+
50+
51+
def per_tensor_metrics(
52+
pt: torch.Tensor,
53+
trt: torch.Tensor,
54+
*,
55+
atol: float = 1e-2,
56+
rtol: float = 1e-2,
57+
) -> dict:
58+
if pt.shape != trt.shape:
59+
return {
60+
"shape_pt": tuple(pt.shape),
61+
"shape_trt": tuple(trt.shape),
62+
"shape_match": False,
63+
"cos_sim": float("nan"),
64+
"max_abs": float("nan"),
65+
"mean_abs": float("nan"),
66+
"allclose": False,
67+
}
68+
69+
# Cast both to FP32 for fair comparison; use CPU to avoid kernel rounding
70+
# nondeterminism between repeated GPU runs of the same op.
71+
a = pt.detach().to(torch.float32)
72+
b = trt.detach().to(torch.float32)
73+
74+
diff = (a - b).abs()
75+
return {
76+
"shape_pt": tuple(pt.shape),
77+
"shape_trt": tuple(trt.shape),
78+
"shape_match": True,
79+
"dtype_pt": str(pt.dtype).replace("torch.", ""),
80+
"dtype_trt": str(trt.dtype).replace("torch.", ""),
81+
"cos_sim": _cosine_similarity(a, b),
82+
"max_abs": diff.max().item() if diff.numel() else 0.0,
83+
"mean_abs": diff.mean().item() if diff.numel() else 0.0,
84+
"allclose": torch.allclose(a, b, rtol=rtol, atol=atol),
85+
}
86+
87+
88+
# --------------------------------------------------------------------------- #
89+
# Compare two outputs (each a tensor / dict / dataclass / tuple)
90+
# --------------------------------------------------------------------------- #
91+
92+
def compare_outputs(
93+
pt_out,
94+
trt_out,
95+
*,
96+
atol: float = 1e-2,
97+
rtol: float = 1e-2,
98+
output_names: Iterable[str] | None = None,
99+
) -> list[dict]:
100+
pt_leaves = _flatten_to_tensors(pt_out)
101+
trt_leaves = _flatten_to_tensors(trt_out)
102+
103+
if len(pt_leaves) != len(trt_leaves):
104+
return [{
105+
"name": "<output-count-mismatch>",
106+
"shape_pt": f"{len(pt_leaves)} tensors",
107+
"shape_trt": f"{len(trt_leaves)} tensors",
108+
"shape_match": False,
109+
"cos_sim": float("nan"),
110+
"max_abs": float("nan"),
111+
"mean_abs": float("nan"),
112+
"allclose": False,
113+
}]
114+
115+
names = list(output_names) if output_names else [f"out[{i}]" for i in range(len(pt_leaves))]
116+
if len(names) < len(pt_leaves):
117+
names += [f"out[{i}]" for i in range(len(names), len(pt_leaves))]
118+
119+
rows: list[dict] = []
120+
for name, pt, trt in zip(names, pt_leaves, trt_leaves):
121+
row = {"name": name}
122+
row.update(per_tensor_metrics(pt, trt, atol=atol, rtol=rtol))
123+
rows.append(row)
124+
return rows
125+
126+
127+
# --------------------------------------------------------------------------- #
128+
# Reporting
129+
# --------------------------------------------------------------------------- #
130+
131+
def overall_pass(
132+
rows: list[dict],
133+
*,
134+
cos_sim_min: float = 0.99,
135+
) -> bool:
136+
"""
137+
A run passes if every output tensor has matching shape and cosine
138+
similarity above the threshold.
139+
140+
Cos-sim is the canonical numerical-equivalence metric; allclose is
141+
reported but does NOT gate the verdict. In FP16, isolated elements
142+
can drift past atol (e.g. one logit out of 50k vocab differs by 8.0)
143+
even when the two tensors are otherwise identical — cos_sim stays
144+
at 1.0 in those cases and that's the right answer.
145+
"""
146+
if not rows:
147+
return False
148+
for r in rows:
149+
if not r.get("shape_match", False):
150+
return False
151+
cs = r.get("cos_sim", 0.0)
152+
if cs != cs: # NaN
153+
return False
154+
if cs < cos_sim_min:
155+
return False
156+
return True
157+
158+
159+
def print_accuracy_table(
160+
rows: list[dict],
161+
*,
162+
title: str = "",
163+
cos_sim_min: float = 0.99,
164+
) -> None:
165+
if not rows:
166+
print("[accuracy] No outputs to compare.")
167+
return
168+
if title:
169+
print(f"\n{'=' * 70}")
170+
print(f" {title}")
171+
print(f"{'=' * 70}")
172+
173+
cols = ("name", "shape_pt", "cos_sim", "max_abs", "mean_abs", "allclose")
174+
widths = {c: max(len(c), max(len(_fmt(r.get(c, ""), c)) for r in rows)) for c in cols}
175+
header = " ".join(c.ljust(widths[c]) for c in cols)
176+
print(header)
177+
print("-" * len(header))
178+
for r in rows:
179+
print(" ".join(_fmt(r.get(c, ""), c).ljust(widths[c]) for c in cols))
180+
181+
overall = overall_pass(rows, cos_sim_min=cos_sim_min)
182+
verdict = "PASS" if overall else "FAIL"
183+
print(f"\nOverall: {verdict} (cos_sim_min={cos_sim_min})")
184+
185+
186+
def _fmt(v, col_name: str) -> str:
187+
if isinstance(v, float):
188+
if col_name in ("cos_sim",):
189+
return f"{v:.6f}"
190+
return f"{v:.3e}"
191+
if isinstance(v, bool):
192+
return "yes" if v else "no"
193+
return str(v)

tools/hf/common/bench.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""
2+
Shared benchmarking harness for all HF model strategies.
3+
"""
4+
from __future__ import annotations
5+
6+
import timeit
7+
from typing import Callable, Sequence
8+
9+
import numpy as np
10+
import torch
11+
12+
13+
WARMUP_ITERS = 5
14+
15+
16+
def warmup_and_time(
17+
fn: Callable,
18+
args: Sequence,
19+
iterations: int = 10,
20+
warmup: int = WARMUP_ITERS,
21+
) -> list[float]:
22+
"""Run fn(*args) with warmup and return per-iteration wall-clock times (seconds)."""
23+
for _ in range(warmup):
24+
fn(*args)
25+
torch.cuda.synchronize()
26+
27+
timings: list[float] = []
28+
for _ in range(iterations):
29+
start = timeit.default_timer()
30+
fn(*args)
31+
torch.cuda.synchronize()
32+
timings.append(timeit.default_timer() - start)
33+
return timings
34+
35+
36+
def compute_stats(timings: list[float], batch_size: int = 1) -> dict:
37+
t = np.array(timings)
38+
fps = batch_size / t
39+
return {
40+
"mean_latency_ms": float(np.mean(t) * 1000),
41+
"median_latency_ms": float(np.median(t) * 1000),
42+
"p99_latency_ms": float(np.percentile(t, 99) * 1000),
43+
"std_latency_ms": float(np.std(t) * 1000),
44+
"mean_throughput": float(np.mean(fps)),
45+
"median_throughput": float(np.median(fps)),
46+
}

0 commit comments

Comments
 (0)