-
Notifications
You must be signed in to change notification settings - Fork 971
[voxtral_tts] enable MLX backend #19177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
seyeong-han
wants to merge
7
commits into
pytorch:main
Choose a base branch
from
seyeong-han:voxtral-tts-mlx
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
43f9261
examples/voxtral_tts: enable MLX backend
seyeong-han bbe9fe8
examples/voxtral_tts: document MLX streaming scope
seyeong-han 9783048
backends/mlx: fix advanced indexing gather order
seyeong-han 6d027dd
backends/mlx: fix CI lint fallout
seyeong-han fca9ecf
backends/mlx: reduce index handler lint complexity
seyeong-han f894a11
examples/voxtral_tts: drop test_mlx_parity.py from PR
seyeong-han 0c733aa
examples/voxtral_tts: clarify safetensors are loaded at export time
seyeong-han File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| #!/usr/bin/env python3 | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import sys | ||
| import unittest | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| from torch.export import export | ||
|
|
||
| try: | ||
| import executorch.exir as exir | ||
| from executorch.backends.mlx.partitioner import MLXPartitioner | ||
| from executorch.extension.pybindings.portable_lib import ( | ||
| _load_for_executorch_from_buffer, | ||
| ) | ||
|
|
||
| _MLX_RUNTIME_OK = sys.platform == "darwin" | ||
| except (AttributeError, ImportError, OSError): | ||
| _MLX_RUNTIME_OK = False | ||
|
|
||
|
|
||
| class _Unfold1DModel(nn.Module): | ||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| return F.unfold(x.unsqueeze(-1), kernel_size=(3, 1), stride=(1, 1)) | ||
|
|
||
|
|
||
| class _SeparatedAdvancedIndexModel(nn.Module): | ||
| def forward( | ||
| self, x: torch.Tensor, idx0: torch.Tensor, idx2: torch.Tensor | ||
| ) -> torch.Tensor: | ||
| return x[idx0, :, idx2] | ||
|
|
||
|
|
||
| def _run_with_mlx(module: nn.Module, inputs: tuple[torch.Tensor, ...]) -> torch.Tensor: | ||
| exported = export(module.eval(), inputs, strict=True) | ||
| edge = exir.to_edge_transform_and_lower( | ||
| exported, | ||
| partitioner=[MLXPartitioner()], | ||
| compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), | ||
| ) | ||
| lowered = edge.to_executorch( | ||
| config=exir.ExecutorchBackendConfig(extract_delegate_segments=True) | ||
| ) | ||
| runtime_module = _load_for_executorch_from_buffer(lowered.buffer) | ||
| return runtime_module.forward(list(inputs))[0] | ||
|
|
||
|
|
||
| @unittest.skipUnless(_MLX_RUNTIME_OK, "MLX runtime tests require macOS + pybindings") | ||
| class TestMLXRuntimeOps(unittest.TestCase): | ||
| def test_unfold_1d_preserves_channel_major_patch_order(self): | ||
| x = torch.arange(10, dtype=torch.float32).reshape(1, 2, 5) | ||
| module = _Unfold1DModel() | ||
|
|
||
| with torch.no_grad(): | ||
| expected = module(x) | ||
| actual = _run_with_mlx(module, (x,)) | ||
|
|
||
| torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5) | ||
|
|
||
| def test_separated_advanced_indices_keep_broadcast_dims_front(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you use the existing test_ops.py file? |
||
| x = torch.arange(2 * 3 * 5, dtype=torch.float32).reshape(2, 3, 5) | ||
| idx0 = torch.tensor([[0], [1]], dtype=torch.long) | ||
| idx2 = torch.tensor([[0, 2, 4]], dtype=torch.long) | ||
| module = _SeparatedAdvancedIndexModel() | ||
|
|
||
| with torch.no_grad(): | ||
| expected = module(x, idx0, idx2) | ||
| actual = _run_with_mlx(module, (x, idx0, idx2)) | ||
|
|
||
| torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What op was not lowering correctly? Why is this change needed?