|
5 | 5 | import os |
6 | 6 | import platform |
7 | 7 | import warnings |
8 | | -from typing import Any, Collection, List, Optional, Sequence, Union |
| 8 | +from typing import Any, Collection, Dict, List, Optional, Sequence, Tuple, Union |
9 | 9 |
|
| 10 | +import sympy |
10 | 11 | import torch |
11 | 12 | from torch.export import ExportedProgram |
12 | 13 | from torch.fx.node import Target |
@@ -791,7 +792,7 @@ def _insert_complex_io_adapters( |
791 | 792 | Outputs: insert view_as_complex before the output node for each originally-complex |
792 | 793 | output that comes from a TRT block. |
793 | 794 |
|
794 | | - Leverages metadata that was captued when the complex rewriter pass was run |
| 795 | + Leverages metadata that was captured when the complex rewriter pass was run |
795 | 796 | """ |
796 | 797 | complex_input_names = gm.meta.get("complex_input_names", []) |
797 | 798 | complex_input_dtypes = gm.meta.get("complex_input_dtypes", {}) |
@@ -875,6 +876,75 @@ def _insert_complex_io_adapters( |
875 | 876 | partitioned_module.recompile() |
876 | 877 |
|
877 | 878 |
|
| 879 | +def _build_user_symbol_bounds( |
| 880 | + gm: torch.fx.GraphModule, |
| 881 | + sample_arg_inputs: Sequence[Input], |
| 882 | + sample_kwarg_inputs: dict[Any, Any], |
| 883 | +) -> Dict[sympy.Symbol, Tuple[int, int]]: |
| 884 | + """Collect ``{sympy.Symbol: (min, max)}`` from user-supplied ``Input``s. |
| 885 | +
|
| 886 | + This is a *read-only* bridge between ``torch_tensorrt.Input`` (where the |
| 887 | + user declares ``min_shape``/``max_shape``) and the partitioner's shape |
| 888 | + reader (``extract_var_range_info``). Each sympy symbol that appears in a |
| 889 | + top-level placeholder's ``meta["val"].shape`` is recorded once with the |
| 890 | + corresponding ``(min_shape[d], max_shape[d])`` from the user-provided |
| 891 | + dynamic ``Input``. |
| 892 | +
|
| 893 | + The map is consulted only to *fill* missing upper bounds (``int_oo`` / |
| 894 | + unbounded) left by ``Dim.DYNAMIC``; the parent ``ShapeEnv`` is never |
| 895 | + mutated. As a result, downstream consumers such as |
| 896 | + :func:`torch_tensorrt.save(..., output_format="exported_program")` see |
| 897 | + the original ``range_constraints`` from the exporter. |
| 898 | + """ |
| 899 | + placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"] |
| 900 | + |
| 901 | + sample_by_name: dict[str, Input] = {} |
| 902 | + for i, node in enumerate(placeholders): |
| 903 | + if i < len(sample_arg_inputs): |
| 904 | + inp = sample_arg_inputs[i] |
| 905 | + if isinstance(inp, Input) and inp.shape_mode == Input._ShapeMode.DYNAMIC: |
| 906 | + sample_by_name[node.target] = inp |
| 907 | + for name, inp in sample_kwarg_inputs.items(): |
| 908 | + if isinstance(inp, Input) and inp.shape_mode == Input._ShapeMode.DYNAMIC: |
| 909 | + sample_by_name[name] = inp |
| 910 | + |
| 911 | + user_symbol_bounds: Dict[sympy.Symbol, Tuple[int, int]] = {} |
| 912 | + if not sample_by_name: |
| 913 | + return user_symbol_bounds |
| 914 | + |
| 915 | + for node in placeholders: |
| 916 | + if node.target not in sample_by_name: |
| 917 | + continue |
| 918 | + sample_input = sample_by_name[node.target] |
| 919 | + fake_val = node.meta.get("val") |
| 920 | + if not isinstance(fake_val, torch.Tensor): |
| 921 | + continue |
| 922 | + |
| 923 | + min_shape = sample_input.shape["min_shape"] |
| 924 | + max_shape = sample_input.shape["max_shape"] |
| 925 | + |
| 926 | + for d, dim in enumerate(fake_val.size()): |
| 927 | + if not isinstance(dim, torch.SymInt) or d >= len(min_shape): |
| 928 | + continue |
| 929 | + expr = dim.node.expr |
| 930 | + # Only record bounds for plain symbols. Composite expressions |
| 931 | + # (e.g. ``2*s0``) are reconstructed by ShapeEnv.bound_sympy and |
| 932 | + # would be incorrect to override directly. |
| 933 | + if not isinstance(expr, sympy.Symbol): |
| 934 | + continue |
| 935 | + if expr in user_symbol_bounds: |
| 936 | + continue |
| 937 | + user_symbol_bounds[expr] = (int(min_shape[d]), int(max_shape[d])) |
| 938 | + logger.debug( |
| 939 | + "Recorded user-supplied bounds for %s: [%d, %d]", |
| 940 | + expr, |
| 941 | + int(min_shape[d]), |
| 942 | + int(max_shape[d]), |
| 943 | + ) |
| 944 | + |
| 945 | + return user_symbol_bounds |
| 946 | + |
| 947 | + |
878 | 948 | @fn_supports_debugger # type: ignore[misc] |
879 | 949 | def compile_module( |
880 | 950 | gm: torch.fx.GraphModule, |
@@ -906,6 +976,16 @@ def compile_module( |
906 | 976 | if sample_kwarg_inputs is None: |
907 | 977 | sample_kwarg_inputs = {} |
908 | 978 |
|
| 979 | + # Build a read-only ``{sympy.Symbol: (min, max)}`` map from the user's |
| 980 | + # sample ``Input`` objects. This is forwarded to the partitioner so that |
| 981 | + # symbols whose upper bound is left unbounded by ``Dim.DYNAMIC`` get |
| 982 | + # filled in with the user's declared bounds, *without* mutating the |
| 983 | + # exporter's ``ShapeEnv.var_to_range`` (which preserves the original |
| 984 | + # ``range_constraints`` on save / re-export). |
| 985 | + user_symbol_bounds = _build_user_symbol_bounds( |
| 986 | + gm, sample_arg_inputs, sample_kwarg_inputs |
| 987 | + ) |
| 988 | + |
909 | 989 | # Configure user compilation settings to converters. |
910 | 990 | CONVERTERS.set_compilation_settings(settings) |
911 | 991 |
|
@@ -1087,7 +1167,9 @@ def preserve_module_specs( |
1087 | 1167 | ) |
1088 | 1168 |
|
1089 | 1169 | # Get the submodule inputs for min, opt, max shapes of the graph inputs |
1090 | | - submodule_inputs = partitioning.construct_submodule_inputs(submodule) |
| 1170 | + submodule_inputs = partitioning.construct_submodule_inputs( |
| 1171 | + submodule, user_symbol_bounds=user_symbol_bounds |
| 1172 | + ) |
1091 | 1173 |
|
1092 | 1174 | assert submodule_inputs is not None |
1093 | 1175 |
|
|
0 commit comments