Skip to content

Commit eec2c04

Browse files
committed
user passing input, adding test case
1 parent 4341579 commit eec2c04

1 file changed

Lines changed: 18 additions & 3 deletions

File tree

tests/py/dynamo/models/test_dynamic_shape_user_bounds.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,24 @@ def test_dim_dynamic_save_preserves_range_constraints(tmpdir):
365365
}
366366
assertions.assertEqual(expected_constraints, reloaded_constraints)
367367

368-
# Sanity: the compiled engine must accept up to the user's max_shape.
369-
big_input = torch.randn(8, 8, device="cuda")
370-
trt_module(big_input)
368+
# Sanity: the compiled engine must accept inputs across the user's
369+
# declared profile - both at the lower edge (min_shape=1) and at the
370+
# upper edge (max_shape=8).
371+
trt_module(torch.randn(1, 8, device="cuda"))
372+
trt_module(torch.randn(4, 8, device="cuda"))
373+
trt_module(torch.randn(8, 8, device="cuda"))
374+
375+
# And it must REJECT inputs beyond max_shape, even though the model
376+
# graph (exported with ``Dim.DYNAMIC``) is itself unbounded and could
377+
# theoretically handle batch=16 in eager. The TRT engine's profile is
378+
# the binding runtime envelope: ``Input(max_shape=8)`` is the user
379+
# opting into a strict cap. If a user wants batch=16 they must either
380+
# re-compile with ``max_shape>=16`` or omit ``max_shape`` (heuristic
381+
# fallback). This pins down the contract to prevent regressions where
382+
# ``Input.max_shape`` would silently widen back to the heuristic.
383+
too_big = torch.randn(16, 8, device="cuda")
384+
with assertions.assertRaises(Exception):
385+
trt_module(too_big)
371386

372387

373388
if __name__ == "__main__":

0 commit comments

Comments
 (0)