Skip to content

[voxtral_tts] enable MLX backend#19177

Open
seyeong-han wants to merge 7 commits intopytorch:mainfrom
seyeong-han:voxtral-tts-mlx
Open

[voxtral_tts] enable MLX backend#19177
seyeong-han wants to merge 7 commits intopytorch:mainfrom
seyeong-han:voxtral-tts-mlx

Conversation

@seyeong-han
Copy link
Copy Markdown
Contributor

@seyeong-han seyeong-han commented Apr 28, 2026

Summary

Enable Voxtral TTS on the ExecuTorch MLX backend for Apple Silicon.

This PR adds MLX export support for the LM/flow methods and the codec decoder, make voxtral_tts-mlx and CMake preset wiring, README instructions, a one-shot MLX E2E script, and MLX parity/regression tests. The native MLX codec fix is in the backend advanced-indexing lowering used by the codec's F.unfold / im2col conv rewrite.

The shared runner exposes --streaming and --speaker for MLX builds. Offline MLX synthesis is validated here with the native MLX codec artifact; streaming uses the same codec artifact and avoids the old portable CPU codec fallback. --backend mlx --qlinear-codec is still rejected because MLX codec quantization is not yet validated.

Benchmark

Apple Silicon MLX benchmark using bf16 + 4w linear + 8w embedding export, with native MLX LM/flow and native MLX codec:

Run Audio Generate time RTF Process wall
1 3.44s 3132ms 0.910465 5.19s
2 3.44s 2634ms 0.765698 3.15s
3 3.44s 2607ms 0.757849 3.13s

Average generation RTF: 0.811337 (0.761774 warm-run average). Average process wall: 3.82s (3.14s warm-run average). WAV quality check: peak 0.42575, clipped samples 0. Apple Speech transcribed the benchmark WAV as Hello how are you today.

Test plan

  • conda run -n et-mlx python -m pytest -q examples/models/voxtral_tts/test_mlx_parity.py
    • 8 passed in 395.04s
  • PYTHONPATH=/Users/younghan conda run -n et-mlx python -m pytest -q backends/mlx/test/test_runtime_ops.py
    • 2 passed in 4.82s
  • Native MLX codec export with --backend mlx --dtype bf16 --export-target codec
    • wrote /tmp/voxtral_tts_mlx_native_codec_256/codec_decoder.pte (289.2 MB)
  • Three runner benchmark invocations using /tmp/voxtral_tts_mlx_final/model.pte and /tmp/voxtral_tts_mlx_native_codec_256/codec_decoder.pte
    • average generation RTF 0.811337, average process wall 3.82s
  • Apple Speech verification on generated WAV
    • FINAL Hello how are you today

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 28, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19177

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 3 Unrelated Failures

As of commit 0c733aa with merge base f3e49ff (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 28, 2026
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@seyeong-han seyeong-han changed the title examples/voxtral_tts: enable MLX backend feat: [voxtral_tts] enable MLX backend Apr 28, 2026
@seyeong-han seyeong-han changed the title feat: [voxtral_tts] enable MLX backend [voxtral_tts] enable MLX backend Apr 29, 2026
Adds Apple Silicon MLX export and runner wiring for Voxtral TTS while keeping codec lowering portable for waveform correctness.

Made-with: Cursor
Clarify that the runner exposes streaming flags for MLX builds, while this branch only reports offline MLX performance because codec decoding still falls back to portable CPU.

Made-with: Cursor
Keep the advanced-indexing fix reviewable by isolating gather permutation logic, tightening MLX test availability checks, and updating the Qwen MoE CI cast budget to match the corrected graph.

Made-with: Cursor
Excluded per PR feedback; parity testing will be handled separately.
The original intro said weights are loaded directly from safetensors,
which was ambiguous — at inference the C++ runner loads .pte files,
not safetensors. Localize the claim to export.
Comment thread backends/mlx/ops.py
require_args(args, 2, 2, "aten.index.Tensor")
require_kwargs(P.kwargs(n), set(), "aten.index.Tensor")
x, idx_list = args
def _index_gather_permutation(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What op was not lowering correctly? Why is this change needed?


torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5)

def test_separated_advanced_indices_keep_broadcast_dims_front(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use the existing test_ops.py file?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants