Skip to content

Commit e322cae

Browse files
committed
fix(dynamo): honor Input(max_shape) for Dim.DYNAMIC dims, crash-fix sympy.oo in extract_var_range_info, validate Input against finite export bounds
1 parent eec2c04 commit e322cae

3 files changed

Lines changed: 371 additions & 21 deletions

File tree

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 95 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
from torch.export import ExportedProgram
1313
from torch.fx.node import Target
14+
from torch.utils._sympy.numbers import int_oo
1415
from torch_tensorrt._Device import Device
1516
from torch_tensorrt._enums import EngineCapability, dtype
1617
from torch_tensorrt._features import needs_cross_compile
@@ -881,20 +882,21 @@ def _build_user_symbol_bounds(
881882
sample_arg_inputs: Sequence[Input],
882883
sample_kwarg_inputs: dict[Any, Any],
883884
) -> Dict[sympy.Symbol, Tuple[int, int]]:
884-
"""Collect ``{sympy.Symbol: (min, max)}`` from user-supplied ``Input``s.
885-
886-
This is a *read-only* bridge between ``torch_tensorrt.Input`` (where the
887-
user declares ``min_shape``/``max_shape``) and the partitioner's shape
888-
reader (``extract_var_range_info``). Each sympy symbol that appears in a
889-
top-level placeholder's ``meta["val"].shape`` is recorded once with the
890-
corresponding ``(min_shape[d], max_shape[d])`` from the user-provided
891-
dynamic ``Input``.
892-
893-
The map is consulted only to *fill* missing upper bounds (``int_oo`` /
894-
unbounded) left by ``Dim.DYNAMIC``; the parent ``ShapeEnv`` is never
895-
mutated. As a result, downstream consumers such as
896-
:func:`torch_tensorrt.save(..., output_format="exported_program")` see
897-
the original ``range_constraints`` from the exporter.
885+
"""Build a read-only ``{sympy.Symbol: (min, max)}`` map from dynamic ``Input``s.
886+
887+
Symbols are taken from top-level placeholder ``meta["val"].shape``; the
888+
map is threaded to ``extract_var_range_info`` to fill upper bounds left
889+
unbounded by ``Dim.DYNAMIC`` (``int_oo``). ``ShapeEnv`` is never mutated.
890+
891+
Validation against the exporter's finite bounds (when present):
892+
893+
- **Outside** (``user_min < exp_min`` or ``user_max > exp_max``): raises
894+
``ValueError`` -- shapes the user listed in ``Input`` would be rejected
895+
by TRT at runtime since the engine profile follows the exporter.
896+
- **Subset** (user fully inside exporter's range): emits a warning that
897+
the engine profile will be widened to the exporter's bounds.
898+
- **``Dim.DYNAMIC``** (unbounded exporter upper): no check; user's
899+
``Input`` fills the gap.
898900
"""
899901
placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"]
900902

@@ -934,12 +936,88 @@ def _build_user_symbol_bounds(
934936
continue
935937
if expr in user_symbol_bounds:
936938
continue
937-
user_symbol_bounds[expr] = (int(min_shape[d]), int(max_shape[d]))
939+
user_min = int(min_shape[d])
940+
user_max = int(max_shape[d])
941+
user_symbol_bounds[expr] = (user_min, user_max)
938942
logger.debug(
939943
"Recorded user-supplied bounds for %s: [%d, %d]",
940944
expr,
941-
int(min_shape[d]),
942-
int(max_shape[d]),
945+
user_min,
946+
user_max,
947+
)
948+
949+
# If the exporter has already declared *finite* bounds for this
950+
# symbol (e.g. ``Dim("batch", min=10, max=20)``) and the user's
951+
# ``Input`` disagrees, ``extract_var_range_info`` will silently
952+
# keep the exporter's bounds (the override path is gated on
953+
# ``max_val is None``). Reject incompatible bounds at compile
954+
# time so the user doesn't hit a confusing TRT runtime error
955+
# ("shape outside profile") on shapes they explicitly declared.
956+
shape_env = getattr(dim.node, "shape_env", None)
957+
if shape_env is None:
958+
continue
959+
exp_range = shape_env.var_to_range.get(expr)
960+
if exp_range is None:
961+
continue
962+
exp_lower = exp_range.lower
963+
exp_upper = exp_range.upper
964+
exp_max_unbounded = exp_upper is int_oo or exp_upper == sympy.oo
965+
if exp_max_unbounded:
966+
# Pure ``Dim.DYNAMIC`` case -- user is filling the gap, which
967+
# is the intended use. No warning, no error.
968+
continue
969+
try:
970+
exp_min = int(exp_lower)
971+
exp_max = int(exp_upper)
972+
except (TypeError, ValueError):
973+
continue
974+
if user_min == exp_min and user_max == exp_max:
975+
continue
976+
977+
# User extends outside the exporter's range -> shapes the user
978+
# declared (e.g. ``Input.min_shape[d] = user_min`` when
979+
# ``user_min < exp_min``) are guaranteed to fail at runtime
980+
# because the TRT engine profile follows the exporter. Hard
981+
# error so the user finds out at compile time.
982+
if user_min < exp_min or user_max > exp_max:
983+
raise ValueError(
984+
f"torch_tensorrt.Input bounds for symbol {expr} "
985+
f"(min={user_min}, max={user_max}) extend outside the "
986+
f"exporter's declared range (min={exp_min}, max={exp_max}). "
987+
f"The TRT engine profile follows the exporter's bounds, "
988+
f"so runtime tensors with shapes outside [{exp_min}, "
989+
f"{exp_max}] would be rejected by TRT with a 'shape "
990+
f"outside profile' error -- including shapes you listed "
991+
f"in Input.min_shape / Input.max_shape. To resolve, "
992+
f"either (a) re-export with "
993+
f"torch.export.Dim('{expr}', min={user_min}, "
994+
f"max={user_max}) so the exporter agrees with the "
995+
f"desired profile, or (b) pass an Input whose "
996+
f"min_shape/max_shape stay within [{exp_min}, "
997+
f"{exp_max}]."
998+
)
999+
1000+
# User is a strict subset of the exporter's range. No shape
1001+
# the user declared will fail at runtime, but the engine
1002+
# profile is silently widened to [exp_min, exp_max] -- so the
1003+
# user's narrower intent is dropped. Warn but do not error.
1004+
logger.warning(
1005+
"torch_tensorrt.Input bounds for symbol %s "
1006+
"(min=%d, max=%d) are narrower than the exporter's "
1007+
"declared range (min=%d, max=%d). The TRT engine profile "
1008+
"will use the exporter's wider [%d, %d] -- the Input "
1009+
"bounds are dropped. To get the narrower profile, "
1010+
"re-export with torch.export.Dim('%s', min=%d, max=%d).",
1011+
expr,
1012+
user_min,
1013+
user_max,
1014+
exp_min,
1015+
exp_max,
1016+
exp_min,
1017+
exp_max,
1018+
expr,
1019+
user_min,
1020+
user_max,
9431021
)
9441022

9451023
return user_symbol_bounds

py/torch_tensorrt/dynamo/utils.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -481,10 +481,32 @@ def extract_var_range_info(
481481
or expr.xreplace(var_to_val_map)
482482
)
483483
assert var_range, var_val
484-
min_val, max_val = (
485-
int(var_range.lower),
486-
int(var_range.upper) if var_range.upper != int_oo else None,
487-
)
484+
485+
# ``var_range`` can come from two paths:
486+
# (1) ``shape_env.var_to_range[expr]`` -- stores PyTorch's integer-typed
487+
# ``int_oo`` sentinel for unbounded sides.
488+
# (2) ``shape_env.bound_sympy(expr)`` (composite exprs like ``s0 + s1``)
489+
# -- runs sympy arithmetic, which collapses unbounded operands to
490+
# ``sympy.oo`` (sympy's float-typed ``S.Infinity``).
491+
# ``int_oo`` and ``sympy.oo`` are different objects, so a single-sentinel
492+
# check (``!= int_oo``) misses path (2) and crashes downstream when sympy
493+
# tries to coerce ``oo`` to int. Treat any non-finite bound as ``None``.
494+
def _bound_to_int_or_none(value: Any) -> Optional[int]:
495+
if value is int_oo or value is -int_oo:
496+
return None
497+
if value == sympy.oo or value == -sympy.oo:
498+
return None
499+
try:
500+
return int(value)
501+
except (TypeError, OverflowError, AttributeError):
502+
return None
503+
504+
min_val_opt = _bound_to_int_or_none(var_range.lower)
505+
max_val = _bound_to_int_or_none(var_range.upper)
506+
# An unbounded lower should be impossible for tensor dims (>=0), but if it
507+
# ever does happen we fall back to 1 (post 0/1-specialization default)
508+
# rather than crashing.
509+
min_val = min_val_opt if min_val_opt is not None else 1
488510

489511
# Torchdynamo 0/1 specialization outlier
490512
min_val = 1 if min_val == 2 else min_val

0 commit comments

Comments
 (0)