Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/mlx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ jobs:
ASTYPE_COUNT=$(${CONDA_RUN} python -m executorch.backends.mlx.pte_inspector \
/tmp/qwen35_moe_mlx_tiny/model.pte --mlx-instructions 2>&1 | grep -c "AsTypeNode" || true)
echo "AsType nodes: ${ASTYPE_COUNT}"
if [ "$ASTYPE_COUNT" -gt 23 ]; then
echo "Failed: expected no more than 23 AsType nodes, got ${ASTYPE_COUNT}"
if [ "$ASTYPE_COUNT" -gt 24 ]; then
echo "Failed: expected no more than 24 AsType nodes, got ${ASTYPE_COUNT}"
exit 1
fi
echo "::endgroup::"
Expand Down
12 changes: 11 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
#
# ==============================================================================

.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu qwen3_5_moe-cuda qwen3_5_moe-metal clean help
.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda voxtral_tts-mlx whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu qwen3_5_moe-cuda qwen3_5_moe-metal clean help

help:
@echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make <target>\`. Available targets:"
Expand All @@ -105,6 +105,7 @@ help:
@echo " voxtral_realtime-mlx - Build Voxtral Realtime runner with MLX backend"
@echo " voxtral_tts-cpu - Build Voxtral TTS runner (CPU)"
@echo " voxtral_tts-cuda - Build Voxtral TTS runner with CUDA backend"
@echo " voxtral_tts-mlx - Build Voxtral TTS runner with MLX backend (macOS only)"
@echo " whisper-cuda - Build Whisper runner with CUDA backend"
@echo " whisper-cuda-debug - Build Whisper runner with CUDA backend (debug mode)"
@echo " whisper-cpu - Build Whisper runner with CPU backend"
Expand Down Expand Up @@ -416,6 +417,15 @@ voxtral_tts-cuda:
@echo "✓ Build complete!"
@echo " Binary: cmake-out/examples/models/voxtral_tts/voxtral_tts_runner"

voxtral_tts-mlx:
@echo "==> Building and installing ExecuTorch with MLX..."
cmake --workflow --preset mlx-release
@echo "==> Building Voxtral TTS runner with MLX..."
cd examples/models/voxtral_tts && cmake --workflow --preset voxtral-tts-mlx
@echo ""
@echo "✓ Build complete!"
@echo " Binary: cmake-out/examples/models/voxtral_tts/voxtral_tts_runner"

qwen3_5_moe-cuda:
@echo "==> Building and installing ExecuTorch with CUDA..."
cmake --workflow --preset llm-release-cuda
Expand Down
205 changes: 144 additions & 61 deletions backends/mlx/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,82 +1678,112 @@ def _repeat_handler(P: MLXProgramBuilder, n: Node) -> Slot:
return out


@REGISTRY.register(target=[torch.ops.aten.index.Tensor])
def _index_handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
require_args(args, 2, 2, "aten.index.Tensor")
require_kwargs(P.kwargs(n), set(), "aten.index.Tensor")
x, idx_list = args
def _index_gather_permutation(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What op was not lowering correctly? Why is this change needed?

indexed_axes: Set[int],
x_ndim: int,
broadcast_ndim: int,
) -> List[int]:
indexed_axes_sorted = sorted(indexed_axes)
expected_rank = broadcast_ndim + x_ndim
is_contiguous = indexed_axes_sorted == list(
range(indexed_axes_sorted[0], indexed_axes_sorted[-1] + 1)
)
if not is_contiguous:
return list(range(expected_rank))

non_indexed_axes = [i for i in range(x_ndim) if i not in indexed_axes]
before_axes = [i for i in non_indexed_axes if i < indexed_axes_sorted[0]]
after_axes = [i for i in non_indexed_axes if i > indexed_axes_sorted[-1]]
return (
[broadcast_ndim + i for i in before_axes]
+ list(range(broadcast_ndim))
+ [broadcast_ndim + i for i in after_axes]
+ [broadcast_ndim + i for i in indexed_axes_sorted]
)


def _non_none_index_tensors(idx_list: Any) -> List[Tuple[int, Slot]]:
if not isinstance(idx_list, list) or len(idx_list) == 0:
raise ValueError(
f"aten.index.Tensor requires a list of index tensors, "
f"got {type(idx_list)}"
f"aten.index.Tensor requires a list of index tensors, got {type(idx_list)}"
)

x_meta = n.args[0].meta.get("val")
x_ndim = len(x_meta.shape) if x_meta is not None else None

# Filter out None indices and track which axes they correspond to
non_none = [(i, idx) for i, idx in enumerate(idx_list) if idx is not None]

if len(non_none) == 0:
raise ValueError("aten.index.Tensor: all indices are None")
return non_none

if len(non_none) == 1:
axis, idx = non_none[0]
idx_meta = n.args[1][axis].meta.get("val")
ndim_match = (
x_meta is not None
and idx_meta is not None
and len(x_meta.shape) == len(idx_meta.shape)

def _emit_single_index_handler(
P: MLXProgramBuilder,
n: Node,
x: Slot,
axis: int,
idx: Slot,
x_meta: Any,
) -> Slot:
idx_meta = n.args[1][axis].meta.get("val")
ndim_match = (
x_meta is not None
and idx_meta is not None
and len(x_meta.shape) == len(idx_meta.shape)
)
out = P.make_or_get_slot(n)
if ndim_match:
# Same ndim: use TakeAlongAxisNode (element-wise gather)
P.emit(
TakeAlongAxisNode(
x=P.slot_to_tid(x),
indices=P.slot_to_tid(idx),
out=P.slot_to_tid(out),
axis=axis,
)
)
out = P.make_or_get_slot(n)
if ndim_match:
# Same ndim: use TakeAlongAxisNode (element-wise gather)
P.emit(
TakeAlongAxisNode(
x=P.slot_to_tid(x),
indices=P.slot_to_tid(idx),
out=P.slot_to_tid(out),
axis=axis,
)
else:
# Different ndim (e.g. 1D indices into 3D tensor): use TakeNode
P.emit(
TakeNode(
x=P.slot_to_tid(x),
index=IntOrVidOrTid.from_tid(P.slot_to_tid(idx)),
out=P.slot_to_tid(out),
axis=axis,
)
else:
# Different ndim (e.g. 1D indices into 3D tensor): use TakeNode
P.emit(
TakeNode(
x=P.slot_to_tid(x),
index=IntOrVidOrTid.from_tid(P.slot_to_tid(idx)),
out=P.slot_to_tid(out),
axis=axis,
)
)
return out


def _index_slice_sizes(x_meta: Any, x_ndim: int, indexed_axes: Set[int]) -> List[int]:
slice_sizes = []
for dim in range(x_ndim):
if dim in indexed_axes:
slice_sizes.append(1)
continue

dim_size = x_meta.shape[dim]
if not isinstance(dim_size, int):
raise ValueError(
f"aten.index.Tensor: non-indexed dimension {dim} has dynamic size "
f"{dim_size}, which is not supported with multi-index gather"
)
return out
slice_sizes.append(dim_size)
return slice_sizes

# Multi-index: use GatherNode (maps to mlx::gather)
if x_meta is None or x_ndim is None:
raise ValueError(
"aten.index.Tensor with multiple indices requires input shape metadata"
)

def _emit_multi_index_handler(
P: MLXProgramBuilder,
n: Node,
x: Slot,
x_meta: Any,
x_ndim: int,
non_none: List[Tuple[int, Slot]],
) -> Slot:
indices = [P.slot_to_tid(idx) for _, idx in non_none]
axes = [i for i, _ in non_none]
indexed_axes = set(axes)

# slice_sizes: 1 for indexed axes, full dim size for non-indexed axes
# Use int() to handle SymInt values from dynamic shapes
indexed_axes = set(axes)
slice_sizes = []
for dim in range(x_ndim):
if dim in indexed_axes:
slice_sizes.append(1)
else:
dim_size = x_meta.shape[dim]
if not isinstance(dim_size, int):
raise ValueError(
f"aten.index.Tensor: non-indexed dimension {dim} has dynamic size "
f"{dim_size}, which is not supported with multi-index gather"
)
slice_sizes.append(dim_size)
slice_sizes = _index_slice_sizes(x_meta, x_ndim, indexed_axes)

# Emit gather — output shape is broadcast(indices).shape + slice_sizes
_, gather_slot = P.make_tmp_slot()
Expand All @@ -1767,26 +1797,79 @@ def _index_handler(P: MLXProgramBuilder, n: Node) -> Slot:
)
)

# Reshape to match aten.index.Tensor output shape, which strips the
# trailing dimensions introduced by gather's slice_sizes
out_meta = n.meta.get("val")
if out_meta is None:
raise ValueError(
"aten.index.Tensor: output shape metadata required for reshape after gather"
)
out_shape = [P.to_int_or_vid(int(d)) for d in out_meta.shape]

# MLX gather returns broadcast(indices).shape followed by one slice
# dimension per input dimension. For contiguous advanced-index groups,
# PyTorch keeps the broadcast dims at the indexed position, so reorder
# before stripping the singleton indexed slice dimensions via reshape.
non_indexed_axes = [i for i in range(x_ndim) if i not in indexed_axes]
broadcast_ndim = len(out_meta.shape) - len(non_indexed_axes)
if broadcast_ndim < 0:
raise ValueError(
"aten.index.Tensor: could not infer broadcast rank for multi-index gather"
)

reshape_input = gather_slot
expected_rank = broadcast_ndim + x_ndim
perm = _index_gather_permutation(indexed_axes, x_ndim, broadcast_ndim)
if len(perm) != expected_rank:
raise ValueError(
f"aten.index.Tensor: internal gather permutation has rank {len(perm)}, "
f"expected {expected_rank}"
)
if perm != list(range(expected_rank)):
_, ordered_slot = P.make_tmp_slot()
P.emit(
TransposeNode(
x=P.slot_to_tid(gather_slot),
out=P.slot_to_tid(ordered_slot),
perm=perm,
)
)
reshape_input = ordered_slot

# Reshape to match aten.index.Tensor output shape, stripping the singleton
# dimensions introduced by gather's slice_sizes for indexed axes.
out = P.make_or_get_slot(n)
P.emit(
ReshapeNode(
x=P.slot_to_tid(gather_slot),
x=P.slot_to_tid(reshape_input),
out=P.slot_to_tid(out),
shape=out_shape,
)
)
return out


@REGISTRY.register(target=[torch.ops.aten.index.Tensor])
def _index_handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
require_args(args, 2, 2, "aten.index.Tensor")
require_kwargs(P.kwargs(n), set(), "aten.index.Tensor")
x, idx_list = args

x_meta = n.args[0].meta.get("val")
x_ndim = len(x_meta.shape) if x_meta is not None else None
non_none = _non_none_index_tensors(idx_list)

if len(non_none) == 1:
axis, idx = non_none[0]
return _emit_single_index_handler(P, n, x, axis, idx, x_meta)

if x_meta is None or x_ndim is None:
raise ValueError(
"aten.index.Tensor with multiple indices requires input shape metadata"
)

return _emit_multi_index_handler(P, n, x, x_meta, x_ndim, non_none)


@REGISTRY.register(target=[torch.ops.aten.index_select.default])
def _index_select_handler(P: MLXProgramBuilder, n: Node) -> Slot:
"""Handle aten.index_select: select elements along an axis using a 1D index tensor.
Expand Down
80 changes: 80 additions & 0 deletions backends/mlx/test/test_runtime_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import sys
import unittest

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.export import export

try:
import executorch.exir as exir
from executorch.backends.mlx.partitioner import MLXPartitioner
from executorch.extension.pybindings.portable_lib import (
_load_for_executorch_from_buffer,
)

_MLX_RUNTIME_OK = sys.platform == "darwin"
except (AttributeError, ImportError, OSError):
_MLX_RUNTIME_OK = False


class _Unfold1DModel(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.unfold(x.unsqueeze(-1), kernel_size=(3, 1), stride=(1, 1))


class _SeparatedAdvancedIndexModel(nn.Module):
def forward(
self, x: torch.Tensor, idx0: torch.Tensor, idx2: torch.Tensor
) -> torch.Tensor:
return x[idx0, :, idx2]


def _run_with_mlx(module: nn.Module, inputs: tuple[torch.Tensor, ...]) -> torch.Tensor:
exported = export(module.eval(), inputs, strict=True)
edge = exir.to_edge_transform_and_lower(
exported,
partitioner=[MLXPartitioner()],
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
)
lowered = edge.to_executorch(
config=exir.ExecutorchBackendConfig(extract_delegate_segments=True)
)
runtime_module = _load_for_executorch_from_buffer(lowered.buffer)
return runtime_module.forward(list(inputs))[0]


@unittest.skipUnless(_MLX_RUNTIME_OK, "MLX runtime tests require macOS + pybindings")
class TestMLXRuntimeOps(unittest.TestCase):
def test_unfold_1d_preserves_channel_major_patch_order(self):
x = torch.arange(10, dtype=torch.float32).reshape(1, 2, 5)
module = _Unfold1DModel()

with torch.no_grad():
expected = module(x)
actual = _run_with_mlx(module, (x,))

torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5)

def test_separated_advanced_indices_keep_broadcast_dims_front(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use the existing test_ops.py file?

x = torch.arange(2 * 3 * 5, dtype=torch.float32).reshape(2, 3, 5)
idx0 = torch.tensor([[0], [1]], dtype=torch.long)
idx2 = torch.tensor([[0, 2, 4]], dtype=torch.long)
module = _SeparatedAdvancedIndexModel()

with torch.no_grad():
expected = module(x, idx0, idx2)
actual = _run_with_mlx(module, (x, idx0, idx2))

torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5)


if __name__ == "__main__":
unittest.main()
Loading
Loading