Skip to content

Commit 4341579

Browse files
committed
user provided bound for torchtrt compile when export dimension is unbounded
1 parent 85063ab commit 4341579

4 files changed

Lines changed: 533 additions & 14 deletions

File tree

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
import os
66
import platform
77
import warnings
8-
from typing import Any, Collection, List, Optional, Sequence, Union
8+
from typing import Any, Collection, Dict, List, Optional, Sequence, Tuple, Union
99

10+
import sympy
1011
import torch
1112
from torch.export import ExportedProgram
1213
from torch.fx.node import Target
@@ -791,7 +792,7 @@ def _insert_complex_io_adapters(
791792
Outputs: insert view_as_complex before the output node for each originally-complex
792793
output that comes from a TRT block.
793794
794-
Leverages metadata that was captued when the complex rewriter pass was run
795+
Leverages metadata that was captured when the complex rewriter pass was run
795796
"""
796797
complex_input_names = gm.meta.get("complex_input_names", [])
797798
complex_input_dtypes = gm.meta.get("complex_input_dtypes", {})
@@ -875,6 +876,75 @@ def _insert_complex_io_adapters(
875876
partitioned_module.recompile()
876877

877878

879+
def _build_user_symbol_bounds(
880+
gm: torch.fx.GraphModule,
881+
sample_arg_inputs: Sequence[Input],
882+
sample_kwarg_inputs: dict[Any, Any],
883+
) -> Dict[sympy.Symbol, Tuple[int, int]]:
884+
"""Collect ``{sympy.Symbol: (min, max)}`` from user-supplied ``Input``s.
885+
886+
This is a *read-only* bridge between ``torch_tensorrt.Input`` (where the
887+
user declares ``min_shape``/``max_shape``) and the partitioner's shape
888+
reader (``extract_var_range_info``). Each sympy symbol that appears in a
889+
top-level placeholder's ``meta["val"].shape`` is recorded once with the
890+
corresponding ``(min_shape[d], max_shape[d])`` from the user-provided
891+
dynamic ``Input``.
892+
893+
The map is consulted only to *fill* missing upper bounds (``int_oo`` /
894+
unbounded) left by ``Dim.DYNAMIC``; the parent ``ShapeEnv`` is never
895+
mutated. As a result, downstream consumers such as
896+
:func:`torch_tensorrt.save(..., output_format="exported_program")` see
897+
the original ``range_constraints`` from the exporter.
898+
"""
899+
placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"]
900+
901+
sample_by_name: dict[str, Input] = {}
902+
for i, node in enumerate(placeholders):
903+
if i < len(sample_arg_inputs):
904+
inp = sample_arg_inputs[i]
905+
if isinstance(inp, Input) and inp.shape_mode == Input._ShapeMode.DYNAMIC:
906+
sample_by_name[node.target] = inp
907+
for name, inp in sample_kwarg_inputs.items():
908+
if isinstance(inp, Input) and inp.shape_mode == Input._ShapeMode.DYNAMIC:
909+
sample_by_name[name] = inp
910+
911+
user_symbol_bounds: Dict[sympy.Symbol, Tuple[int, int]] = {}
912+
if not sample_by_name:
913+
return user_symbol_bounds
914+
915+
for node in placeholders:
916+
if node.target not in sample_by_name:
917+
continue
918+
sample_input = sample_by_name[node.target]
919+
fake_val = node.meta.get("val")
920+
if not isinstance(fake_val, torch.Tensor):
921+
continue
922+
923+
min_shape = sample_input.shape["min_shape"]
924+
max_shape = sample_input.shape["max_shape"]
925+
926+
for d, dim in enumerate(fake_val.size()):
927+
if not isinstance(dim, torch.SymInt) or d >= len(min_shape):
928+
continue
929+
expr = dim.node.expr
930+
# Only record bounds for plain symbols. Composite expressions
931+
# (e.g. ``2*s0``) are reconstructed by ShapeEnv.bound_sympy and
932+
# would be incorrect to override directly.
933+
if not isinstance(expr, sympy.Symbol):
934+
continue
935+
if expr in user_symbol_bounds:
936+
continue
937+
user_symbol_bounds[expr] = (int(min_shape[d]), int(max_shape[d]))
938+
logger.debug(
939+
"Recorded user-supplied bounds for %s: [%d, %d]",
940+
expr,
941+
int(min_shape[d]),
942+
int(max_shape[d]),
943+
)
944+
945+
return user_symbol_bounds
946+
947+
878948
@fn_supports_debugger # type: ignore[misc]
879949
def compile_module(
880950
gm: torch.fx.GraphModule,
@@ -906,6 +976,16 @@ def compile_module(
906976
if sample_kwarg_inputs is None:
907977
sample_kwarg_inputs = {}
908978

979+
# Build a read-only ``{sympy.Symbol: (min, max)}`` map from the user's
980+
# sample ``Input`` objects. This is forwarded to the partitioner so that
981+
# symbols whose upper bound is left unbounded by ``Dim.DYNAMIC`` get
982+
# filled in with the user's declared bounds, *without* mutating the
983+
# exporter's ``ShapeEnv.var_to_range`` (which preserves the original
984+
# ``range_constraints`` on save / re-export).
985+
user_symbol_bounds = _build_user_symbol_bounds(
986+
gm, sample_arg_inputs, sample_kwarg_inputs
987+
)
988+
909989
# Configure user compilation settings to converters.
910990
CONVERTERS.set_compilation_settings(settings)
911991

@@ -1087,7 +1167,9 @@ def preserve_module_specs(
10871167
)
10881168

10891169
# Get the submodule inputs for min, opt, max shapes of the graph inputs
1090-
submodule_inputs = partitioning.construct_submodule_inputs(submodule)
1170+
submodule_inputs = partitioning.construct_submodule_inputs(
1171+
submodule, user_symbol_bounds=user_symbol_bounds
1172+
)
10911173

10921174
assert submodule_inputs is not None
10931175

py/torch_tensorrt/dynamo/partitioning/common.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import logging
22
from typing import Any, Dict, Optional, Sequence, Set, Tuple
33

4+
import sympy
45
import torch
56
from torch._subclasses.fake_tensor import FakeTensor
67
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
7-
88
from torch_tensorrt._Input import Input
99
from torch_tensorrt.dynamo.utils import (
1010
COMPLEX_TO_REAL_DTYPE,
@@ -20,11 +20,16 @@ def construct_dynamic_input(
2020
input_dtype: torch.dtype,
2121
name: str = "",
2222
is_shape_tensor: bool = False,
23+
user_symbol_bounds: Optional[Dict[sympy.Symbol, Tuple[int, int]]] = None,
2324
) -> Input:
2425
"""
2526
Constructs a torch_tensorrt.Input based on a symbolic input
2627
Args:
2728
input_shape: A symbolic shape / regular shape of a tensor (which can have a mix of SymInt nodes and static values)
29+
user_symbol_bounds: Optional read-only ``{sym: (min, max)}`` map, used
30+
only to fill missing upper bounds when the exporter's ``ShapeEnv``
31+
reports them as unbounded. See
32+
:func:`torch_tensorrt.dynamo.utils.extract_var_range_info`.
2833
Returns:
2934
A dynamic shaped torch_tensorrt.Input which has the properties of the symbolic shaped input.
3035
"""
@@ -33,7 +38,9 @@ def construct_dynamic_input(
3338
max_shape = []
3439
for d, dim in enumerate(input_shape):
3540
if isinstance(dim, torch.SymInt):
36-
min_max_opt = extract_var_range_info(dim)
41+
min_max_opt = extract_var_range_info(
42+
dim, user_symbol_bounds=user_symbol_bounds
43+
)
3744
unwrapped_min_max_opt: Dict[str, int] = {}
3845
if "min" not in min_max_opt or min_max_opt["min"] is None:
3946
logger.warning(
@@ -85,9 +92,14 @@ def get_input(
8592
dtype: torch.dtype,
8693
name: str = "",
8794
is_shape_tensor: bool = False,
95+
user_symbol_bounds: Optional[Dict[sympy.Symbol, Tuple[int, int]]] = None,
8896
) -> Input:
8997
"""
9098
Based on type of dimensions in the input_shape, construct regular or dynamic shaped inputs
99+
100+
``user_symbol_bounds`` is forwarded to :func:`construct_dynamic_input` and
101+
is only consulted when an exporter-supplied symbolic dimension has no
102+
upper bound (e.g. ``Dim.DYNAMIC`` without an explicit ``max``).
91103
"""
92104
if dtype in COMPLEX_TO_REAL_DTYPE:
93105
real_dtype = COMPLEX_TO_REAL_DTYPE[dtype]
@@ -106,19 +118,27 @@ def get_input(
106118
dtype,
107119
name=name,
108120
is_shape_tensor=is_shape_tensor,
121+
user_symbol_bounds=user_symbol_bounds,
109122
)
110123
else:
111124
return Input(
112125
shape=input_shape, dtype=dtype, name=name, is_shape_tensor=is_shape_tensor
113126
)
114127

115128

116-
def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
129+
def construct_submodule_inputs(
130+
module: torch.fx.GraphModule,
131+
user_symbol_bounds: Optional[Dict[sympy.Symbol, Tuple[int, int]]] = None,
132+
) -> Sequence[Input]:
117133
"""
118134
Construct torch_tensorrt Inputs based on the module inputs.
119135
The module inputs will have meta data which has the shape and dtype info
120136
Args:
121137
module: Input FX GraphModule
138+
user_symbol_bounds: Optional read-only ``{sym: (min, max)}`` map built
139+
from user-supplied ``torch_tensorrt.Input`` objects. Used only to
140+
fill missing upper bounds left by ``Dim.DYNAMIC`` in the exporter;
141+
never mutates the parent ``ShapeEnv``.
122142
Returns:
123143
Sequence of torch_tensorrt.Input's representing inputs to given module
124144
"""
@@ -134,7 +154,12 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
134154
if isinstance(input_meta, (FakeTensor, torch.Tensor)):
135155
input_shape = input_meta.size()
136156
torchtrt_inputs.append(
137-
get_input(input_shape, input_meta.dtype, name=input.name)
157+
get_input(
158+
input_shape,
159+
input_meta.dtype,
160+
name=input.name,
161+
user_symbol_bounds=user_symbol_bounds,
162+
)
138163
)
139164
elif isinstance(input_meta, torch.SymInt):
140165
# Assuming sym_integers | shape inputs always have torch.int64 dtype
@@ -144,6 +169,7 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
144169
torch.int64,
145170
name=input.name,
146171
is_shape_tensor=True,
172+
user_symbol_bounds=user_symbol_bounds,
147173
)
148174
)
149175
elif isinstance(input_meta, torch.SymFloat):
@@ -153,6 +179,7 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
153179
torch.float32,
154180
name=input.name,
155181
is_shape_tensor=False, # Only SymInt inputs are treated as shape tensors
182+
user_symbol_bounds=user_symbol_bounds,
156183
)
157184
)
158185
else:
@@ -164,7 +191,12 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
164191
input_meta = input.meta["tensor_meta"]
165192
input_shape = input_meta.shape
166193
torchtrt_inputs.append(
167-
get_input(input_shape, input_meta.dtype, name=input.name)
194+
get_input(
195+
input_shape,
196+
input_meta.dtype,
197+
name=input.name,
198+
user_symbol_bounds=user_symbol_bounds,
199+
)
168200
)
169201
else:
170202
raise AssertionError(

py/torch_tensorrt/dynamo/utils.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,9 @@
2424
import sympy
2525
import tensorrt as trt
2626
import torch
27-
from torch._subclasses.fake_tensor import FakeTensor
28-
from torch._subclasses.fake_tensor import FakeScriptObject
27+
from torch._subclasses.fake_tensor import FakeScriptObject, FakeTensor
2928
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
3029
from torch.utils._sympy.numbers import int_oo
31-
32-
from packaging import version
3330
from torch_tensorrt._Device import Device
3431
from torch_tensorrt._enums import dtype
3532
from torch_tensorrt._features import ENABLED_FEATURES
@@ -40,6 +37,8 @@
4037
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
4138
from torch_tensorrt.dynamo._settings import CompilationSettings
4239

40+
from packaging import version
41+
4342
from .types import TRTDataType
4443

4544
logger = logging.getLogger(__name__)
@@ -105,7 +104,7 @@ class Frameworks(Enum):
105104
torch.complex128: torch.float64,
106105
}
107106

108-
COMPLEX_DTYPES: frozenset = frozenset(COMPLEX_TO_REAL_DTYPE)
107+
COMPLEX_DTYPES: frozenset[torch.dtype] = frozenset(COMPLEX_TO_REAL_DTYPE)
109108

110109

111110
def unified_dtype_converter(
@@ -438,9 +437,24 @@ def contains_sym_int(tensor: torch.Tensor) -> bool:
438437
return any(isinstance(dim, torch.SymInt) for dim in tensor)
439438

440439

441-
def extract_var_range_info(symbolic_integer: torch.SymInt) -> Dict[str, Optional[int]]:
440+
def extract_var_range_info(
441+
symbolic_integer: torch.SymInt,
442+
user_symbol_bounds: Optional[Dict[sympy.Symbol, Tuple[int, int]]] = None,
443+
) -> Dict[str, Optional[int]]:
442444
"""
443445
This function returns the min, max, opt values of a symbolic integer.
446+
447+
Args:
448+
symbolic_integer: The ``torch.SymInt`` whose range is being queried.
449+
user_symbol_bounds: Optional read-only map from a top-level sympy symbol
450+
to ``(min, max)`` bounds supplied by the user via
451+
``torch_tensorrt.Input``. These are used **only** to fill the gap
452+
when the exporter's ``ShapeEnv`` reports an unbounded upper range
453+
(``int_oo``). The exporter's bounds always win when they are
454+
finite; the lower bound is intersected (``max(exporter_lower,
455+
user_lower)``) so we never widen the exporter's 0/1 specialization
456+
(e.g. ``lower == 2 -> 1``) to ``0`` when the user passes
457+
``min_shape=0``. ``ShapeEnv`` itself is never mutated.
444458
"""
445459
node = symbolic_integer.node
446460
expr = node.expr
@@ -474,6 +488,23 @@ def extract_var_range_info(symbolic_integer: torch.SymInt) -> Dict[str, Optional
474488

475489
# Torchdynamo 0/1 specialization outlier
476490
min_val = 1 if min_val == 2 else min_val
491+
492+
# If the exporter left this symbol with an unbounded upper range (i.e.
493+
# the user used ``Dim.DYNAMIC`` without an explicit upper), fall back to
494+
# the bounds the user supplied via ``torch_tensorrt.Input(min_shape=...,
495+
# max_shape=...)``. Only fills the gap; never overrides a finite exporter
496+
# max. The lower bound is intersected so the exporter's specialization
497+
# (e.g. lower == 1) is preserved.
498+
if (
499+
max_val is None
500+
and user_symbol_bounds
501+
and isinstance(expr, sympy.Symbol)
502+
and expr in user_symbol_bounds
503+
):
504+
user_min, user_max = user_symbol_bounds[expr]
505+
min_val = max(min_val, int(user_min))
506+
max_val = int(user_max)
507+
477508
min_max_opt: Dict[str, Optional[int]] = {}
478509
min_max_opt["min"] = min_val
479510
min_max_opt["max"] = max_val

0 commit comments

Comments
 (0)