|
11 | 11 | import torch |
12 | 12 | from torch.export import ExportedProgram |
13 | 13 | from torch.fx.node import Target |
| 14 | +from torch.utils._sympy.numbers import int_oo |
14 | 15 | from torch_tensorrt._Device import Device |
15 | 16 | from torch_tensorrt._enums import EngineCapability, dtype |
16 | 17 | from torch_tensorrt._features import needs_cross_compile |
@@ -881,20 +882,21 @@ def _build_user_symbol_bounds( |
881 | 882 | sample_arg_inputs: Sequence[Input], |
882 | 883 | sample_kwarg_inputs: dict[Any, Any], |
883 | 884 | ) -> Dict[sympy.Symbol, Tuple[int, int]]: |
884 | | - """Collect ``{sympy.Symbol: (min, max)}`` from user-supplied ``Input``s. |
885 | | -
|
886 | | - This is a *read-only* bridge between ``torch_tensorrt.Input`` (where the |
887 | | - user declares ``min_shape``/``max_shape``) and the partitioner's shape |
888 | | - reader (``extract_var_range_info``). Each sympy symbol that appears in a |
889 | | - top-level placeholder's ``meta["val"].shape`` is recorded once with the |
890 | | - corresponding ``(min_shape[d], max_shape[d])`` from the user-provided |
891 | | - dynamic ``Input``. |
892 | | -
|
893 | | - The map is consulted only to *fill* missing upper bounds (``int_oo`` / |
894 | | - unbounded) left by ``Dim.DYNAMIC``; the parent ``ShapeEnv`` is never |
895 | | - mutated. As a result, downstream consumers such as |
896 | | - :func:`torch_tensorrt.save(..., output_format="exported_program")` see |
897 | | - the original ``range_constraints`` from the exporter. |
| 885 | + """Build a read-only ``{sympy.Symbol: (min, max)}`` map from dynamic ``Input``s. |
| 886 | +
|
| 887 | + Symbols are taken from top-level placeholder ``meta["val"].shape``; the |
| 888 | + map is threaded to ``extract_var_range_info`` to fill upper bounds left |
| 889 | + unbounded by ``Dim.DYNAMIC`` (``int_oo``). ``ShapeEnv`` is never mutated. |
| 890 | +
|
| 891 | + Validation against the exporter's finite bounds (when present): |
| 892 | +
|
| 893 | + - **Outside** (``user_min < exp_min`` or ``user_max > exp_max``): raises |
| 894 | + ``ValueError`` -- shapes the user listed in ``Input`` would be rejected |
| 895 | + by TRT at runtime since the engine profile follows the exporter. |
| 896 | + - **Subset** (user fully inside exporter's range): emits a warning that |
| 897 | + the engine profile will be widened to the exporter's bounds. |
| 898 | + - **``Dim.DYNAMIC``** (unbounded exporter upper): no check; user's |
| 899 | + ``Input`` fills the gap. |
898 | 900 | """ |
899 | 901 | placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"] |
900 | 902 |
|
@@ -934,12 +936,88 @@ def _build_user_symbol_bounds( |
934 | 936 | continue |
935 | 937 | if expr in user_symbol_bounds: |
936 | 938 | continue |
937 | | - user_symbol_bounds[expr] = (int(min_shape[d]), int(max_shape[d])) |
| 939 | + user_min = int(min_shape[d]) |
| 940 | + user_max = int(max_shape[d]) |
| 941 | + user_symbol_bounds[expr] = (user_min, user_max) |
938 | 942 | logger.debug( |
939 | 943 | "Recorded user-supplied bounds for %s: [%d, %d]", |
940 | 944 | expr, |
941 | | - int(min_shape[d]), |
942 | | - int(max_shape[d]), |
| 945 | + user_min, |
| 946 | + user_max, |
| 947 | + ) |
| 948 | + |
| 949 | + # If the exporter has already declared *finite* bounds for this |
| 950 | + # symbol (e.g. ``Dim("batch", min=10, max=20)``) and the user's |
| 951 | + # ``Input`` disagrees, ``extract_var_range_info`` will silently |
| 952 | + # keep the exporter's bounds (the override path is gated on |
| 953 | + # ``max_val is None``). Reject incompatible bounds at compile |
| 954 | + # time so the user doesn't hit a confusing TRT runtime error |
| 955 | + # ("shape outside profile") on shapes they explicitly declared. |
| 956 | + shape_env = getattr(dim.node, "shape_env", None) |
| 957 | + if shape_env is None: |
| 958 | + continue |
| 959 | + exp_range = shape_env.var_to_range.get(expr) |
| 960 | + if exp_range is None: |
| 961 | + continue |
| 962 | + exp_lower = exp_range.lower |
| 963 | + exp_upper = exp_range.upper |
| 964 | + exp_max_unbounded = exp_upper is int_oo or exp_upper == sympy.oo |
| 965 | + if exp_max_unbounded: |
| 966 | + # Pure ``Dim.DYNAMIC`` case -- user is filling the gap, which |
| 967 | + # is the intended use. No warning, no error. |
| 968 | + continue |
| 969 | + try: |
| 970 | + exp_min = int(exp_lower) |
| 971 | + exp_max = int(exp_upper) |
| 972 | + except (TypeError, ValueError): |
| 973 | + continue |
| 974 | + if user_min == exp_min and user_max == exp_max: |
| 975 | + continue |
| 976 | + |
| 977 | + # User extends outside the exporter's range -> shapes the user |
| 978 | + # declared (e.g. ``Input.min_shape[d] = user_min`` when |
| 979 | + # ``user_min < exp_min``) are guaranteed to fail at runtime |
| 980 | + # because the TRT engine profile follows the exporter. Hard |
| 981 | + # error so the user finds out at compile time. |
| 982 | + if user_min < exp_min or user_max > exp_max: |
| 983 | + raise ValueError( |
| 984 | + f"torch_tensorrt.Input bounds for symbol {expr} " |
| 985 | + f"(min={user_min}, max={user_max}) extend outside the " |
| 986 | + f"exporter's declared range (min={exp_min}, max={exp_max}). " |
| 987 | + f"The TRT engine profile follows the exporter's bounds, " |
| 988 | + f"so runtime tensors with shapes outside [{exp_min}, " |
| 989 | + f"{exp_max}] would be rejected by TRT with a 'shape " |
| 990 | + f"outside profile' error -- including shapes you listed " |
| 991 | + f"in Input.min_shape / Input.max_shape. To resolve, " |
| 992 | + f"either (a) re-export with " |
| 993 | + f"torch.export.Dim('{expr}', min={user_min}, " |
| 994 | + f"max={user_max}) so the exporter agrees with the " |
| 995 | + f"desired profile, or (b) pass an Input whose " |
| 996 | + f"min_shape/max_shape stay within [{exp_min}, " |
| 997 | + f"{exp_max}]." |
| 998 | + ) |
| 999 | + |
| 1000 | + # User is a strict subset of the exporter's range. No shape |
| 1001 | + # the user declared will fail at runtime, but the engine |
| 1002 | + # profile is silently widened to [exp_min, exp_max] -- so the |
| 1003 | + # user's narrower intent is dropped. Warn but do not error. |
| 1004 | + logger.warning( |
| 1005 | + "torch_tensorrt.Input bounds for symbol %s " |
| 1006 | + "(min=%d, max=%d) are narrower than the exporter's " |
| 1007 | + "declared range (min=%d, max=%d). The TRT engine profile " |
| 1008 | + "will use the exporter's wider [%d, %d] -- the Input " |
| 1009 | + "bounds are dropped. To get the narrower profile, " |
| 1010 | + "re-export with torch.export.Dim('%s', min=%d, max=%d).", |
| 1011 | + expr, |
| 1012 | + user_min, |
| 1013 | + user_max, |
| 1014 | + exp_min, |
| 1015 | + exp_max, |
| 1016 | + exp_min, |
| 1017 | + exp_max, |
| 1018 | + expr, |
| 1019 | + user_min, |
| 1020 | + user_max, |
943 | 1021 | ) |
944 | 1022 |
|
945 | 1023 | return user_symbol_bounds |
|
0 commit comments