[voxtral_tts] enable MLX backend#19177
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19177
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 3 Unrelated FailuresAs of commit 0c733aa with merge base f3e49ff ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
Adds Apple Silicon MLX export and runner wiring for Voxtral TTS while keeping codec lowering portable for waveform correctness. Made-with: Cursor
Clarify that the runner exposes streaming flags for MLX builds, while this branch only reports offline MLX performance because codec decoding still falls back to portable CPU. Made-with: Cursor
Made-with: Cursor
Keep the advanced-indexing fix reviewable by isolating gather permutation logic, tightening MLX test availability checks, and updating the Qwen MoE CI cast budget to match the corrected graph. Made-with: Cursor
3f98344 to
fca9ecf
Compare
Excluded per PR feedback; parity testing will be handled separately.
The original intro said weights are loaded directly from safetensors, which was ambiguous — at inference the C++ runner loads .pte files, not safetensors. Localize the claim to export.
| require_args(args, 2, 2, "aten.index.Tensor") | ||
| require_kwargs(P.kwargs(n), set(), "aten.index.Tensor") | ||
| x, idx_list = args | ||
| def _index_gather_permutation( |
There was a problem hiding this comment.
What op was not lowering correctly? Why is this change needed?
|
|
||
| torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5) | ||
|
|
||
| def test_separated_advanced_indices_keep_broadcast_dims_front(self): |
There was a problem hiding this comment.
Can you use the existing test_ops.py file?
Summary
Enable Voxtral TTS on the ExecuTorch MLX backend for Apple Silicon.
This PR adds MLX export support for the LM/flow methods and the codec decoder,
make voxtral_tts-mlxand CMake preset wiring, README instructions, a one-shot MLX E2E script, and MLX parity/regression tests. The native MLX codec fix is in the backend advanced-indexing lowering used by the codec'sF.unfold/im2colconv rewrite.The shared runner exposes
--streamingand--speakerfor MLX builds. Offline MLX synthesis is validated here with the native MLX codec artifact; streaming uses the same codec artifact and avoids the old portable CPU codec fallback.--backend mlx --qlinear-codecis still rejected because MLX codec quantization is not yet validated.Benchmark
Apple Silicon MLX benchmark using bf16 + 4w linear + 8w embedding export, with native MLX LM/flow and native MLX codec:
Average generation RTF:
0.811337(0.761774warm-run average). Average process wall:3.82s(3.14swarm-run average). WAV quality check: peak0.42575, clipped samples0. Apple Speech transcribed the benchmark WAV asHello how are you today.Test plan
conda run -n et-mlx python -m pytest -q examples/models/voxtral_tts/test_mlx_parity.py8 passed in 395.04sPYTHONPATH=/Users/younghan conda run -n et-mlx python -m pytest -q backends/mlx/test/test_runtime_ops.py2 passed in 4.82s--backend mlx --dtype bf16 --export-target codec/tmp/voxtral_tts_mlx_native_codec_256/codec_decoder.pte(289.2 MB)/tmp/voxtral_tts_mlx_final/model.pteand/tmp/voxtral_tts_mlx_native_codec_256/codec_decoder.pte0.811337, average process wall3.82sFINAL Hello how are you today