|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +import dataclasses |
| 7 | +from contextlib import contextmanager |
| 8 | +from copy import deepcopy |
| 9 | +from functools import partial |
| 10 | +from typing import Any |
| 11 | + |
| 12 | +import torch |
| 13 | +import torch.fx.node |
| 14 | +import torch.utils._pytree as pytree |
| 15 | +from torch._functorch._aot_autograd.descriptors import AOTOutput |
| 16 | +from torch._functorch.partitioners import _extract_graph_with_inputs_outputs |
| 17 | + |
| 18 | +from .graph_pp_utils import ( |
| 19 | + find_last_all_gather_in_chain, |
| 20 | + find_last_non_view_node_in_chain, |
| 21 | + find_last_user_in_wait_chain, |
| 22 | + is_reduce_scatter_tensor, |
| 23 | + is_wait_tensor, |
| 24 | +) |
| 25 | + |
| 26 | + |
| 27 | +@contextmanager |
| 28 | +def exclude_from_fx_side_effectful(exclude_vals: set[Any]): |
| 29 | + original_val = torch.fx.node._side_effectful_functions.copy() |
| 30 | + try: |
| 31 | + torch.fx.node._side_effectful_functions -= exclude_vals |
| 32 | + yield |
| 33 | + finally: |
| 34 | + torch.fx.node._side_effectful_functions.clear() |
| 35 | + torch.fx.node._side_effectful_functions.update(original_val) |
| 36 | + |
| 37 | + |
| 38 | +exclude_wait_from_fx_side_effectful = partial( |
| 39 | + exclude_from_fx_side_effectful, |
| 40 | + { |
| 41 | + torch.ops._c10d_functional.wait_tensor, |
| 42 | + torch.ops._c10d_functional.wait_tensor.default, |
| 43 | + }, |
| 44 | +) |
| 45 | + |
| 46 | + |
| 47 | +@dataclasses.dataclass(frozen=True) |
| 48 | +class PrefetchOutput(AOTOutput): |
| 49 | + pass |
| 50 | + |
| 51 | + |
| 52 | +@dataclasses.dataclass(frozen=True) |
| 53 | +class EpilogueInput(AOTOutput): |
| 54 | + pass |
| 55 | + |
| 56 | + |
| 57 | +def split_fsdp_prefetch( |
| 58 | + gm: torch.fx.GraphModule, |
| 59 | + num_params: int, |
| 60 | +) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: |
| 61 | + g = deepcopy(gm.graph) |
| 62 | + all_g_ins = g.find_nodes(op="placeholder") |
| 63 | + param_g_ins = all_g_ins[:num_params] |
| 64 | + rem_g_ins = all_g_ins[num_params:] |
| 65 | + |
| 66 | + prefetch_g_outs_map = [] |
| 67 | + |
| 68 | + for param_g_in in param_g_ins: |
| 69 | + # 1. Find last all_gather from each placeholder |
| 70 | + last_ag_node = find_last_all_gather_in_chain(param_g_in) |
| 71 | + if last_ag_node is None: |
| 72 | + prefetch_g_outs_map.append(param_g_in) |
| 73 | + else: |
| 74 | + # 2. Find last wait_tensor from last all_gather |
| 75 | + last_ag_wait_node = next(iter(last_ag_node.users)) |
| 76 | + assert is_wait_tensor(last_ag_wait_node) |
| 77 | + |
| 78 | + # 3. Continue the linear chain from the last wait_tensor |
| 79 | + last_wait_chain_user = find_last_user_in_wait_chain(last_ag_wait_node) |
| 80 | + |
| 81 | + # 4. Get the last non-view node in the wait chain |
| 82 | + last_non_view_wait_chain_user = find_last_non_view_node_in_chain( |
| 83 | + last_wait_chain_user |
| 84 | + ) |
| 85 | + |
| 86 | + prefetch_g_outs_map.append(last_non_view_wait_chain_user) |
| 87 | + |
| 88 | + prefetch_g_outs = prefetch_g_outs_map |
| 89 | + prefetch_g_outs_descs: list[AOTOutput] = [ |
| 90 | + PrefetchOutput() for _ in range(len(prefetch_g_outs)) |
| 91 | + ] |
| 92 | + g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) |
| 93 | + g_outs_descs = pytree.arg_tree_leaves( |
| 94 | + next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs)) |
| 95 | + ) |
| 96 | + with exclude_wait_from_fx_side_effectful(): |
| 97 | + prefetch_g = _extract_graph_with_inputs_outputs( |
| 98 | + g, |
| 99 | + param_g_ins, |
| 100 | + prefetch_g_outs, |
| 101 | + prefetch_g_outs_descs, |
| 102 | + ignore_must_be_in_fw_bw=True, |
| 103 | + ) |
| 104 | + |
| 105 | + main_g = _extract_graph_with_inputs_outputs( |
| 106 | + g, |
| 107 | + prefetch_g_outs + rem_g_ins, |
| 108 | + g_outs, |
| 109 | + g_outs_descs, |
| 110 | + ignore_must_be_in_fw_bw=True, |
| 111 | + ) |
| 112 | + prefetch_gm = torch.fx._lazy_graph_module._make_graph_module(gm, prefetch_g) |
| 113 | + main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g) |
| 114 | + return prefetch_gm, main_gm |
| 115 | + |
| 116 | + |
| 117 | +def split_fsdp_reduce_scatters_epilogue( |
| 118 | + gm: torch.fx.GraphModule, |
| 119 | + num_grads: int, |
| 120 | +) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: |
| 121 | + g = deepcopy(gm.graph) |
| 122 | + g_ins = g.find_nodes(op="placeholder") |
| 123 | + g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) |
| 124 | + grad_outs = g_outs[:num_grads] |
| 125 | + rem_g_outs = g_outs[num_grads:] |
| 126 | + out_descs = pytree.arg_tree_leaves( |
| 127 | + next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(grad_outs)) |
| 128 | + ) |
| 129 | + grad_outs_descs = out_descs[:num_grads] |
| 130 | + rem_g_outs_descs = out_descs[num_grads:] |
| 131 | + |
| 132 | + grad_outs_map = [] |
| 133 | + for grad_out in grad_outs: |
| 134 | + n = grad_out |
| 135 | + earliest_rs = None |
| 136 | + while n is not None: |
| 137 | + if len(n.all_input_nodes) != 1: |
| 138 | + break |
| 139 | + n_in = n.all_input_nodes[0] |
| 140 | + if len(n_in.users) > 1: |
| 141 | + break |
| 142 | + prev_n = n |
| 143 | + n = n_in |
| 144 | + # Maybe we also need to track all_reduce? |
| 145 | + if is_reduce_scatter_tensor(prev_n): |
| 146 | + # In AP for mesh dim > 1 |
| 147 | + # The reduction of gradients happen in multiple steps |
| 148 | + earliest_rs = n |
| 149 | + if earliest_rs is not None: |
| 150 | + grad_outs_map.append(earliest_rs) |
| 151 | + else: |
| 152 | + grad_outs_map.append(grad_out) |
| 153 | + |
| 154 | + epi_g_ins = grad_outs_map |
| 155 | + epi_g_ins_descs: list[AOTOutput] = [EpilogueInput() for _ in range(len(epi_g_ins))] |
| 156 | + |
| 157 | + with exclude_wait_from_fx_side_effectful(): |
| 158 | + main_g = _extract_graph_with_inputs_outputs( |
| 159 | + g, |
| 160 | + g_ins, |
| 161 | + epi_g_ins + rem_g_outs, |
| 162 | + epi_g_ins_descs + rem_g_outs_descs, |
| 163 | + ignore_must_be_in_fw_bw=True, |
| 164 | + ) |
| 165 | + epi_g = _extract_graph_with_inputs_outputs( |
| 166 | + g, |
| 167 | + epi_g_ins, |
| 168 | + grad_outs, |
| 169 | + grad_outs_descs, |
| 170 | + ignore_must_be_in_fw_bw=True, |
| 171 | + ) |
| 172 | + epi_gm = torch.fx._lazy_graph_module._make_graph_module(gm, epi_g) |
| 173 | + main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g) |
| 174 | + return main_gm, epi_gm |
0 commit comments