Skip to content

Commit e6c5096

Browse files
[graph_trainer][aot_fx_trace][graph_pp] split_fsdp_collectives
stack-info: PR: #3088, branch: sanketpurandare/stack/7
1 parent 8f71d5a commit e6c5096

1 file changed

Lines changed: 174 additions & 0 deletions

File tree

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import dataclasses
7+
from contextlib import contextmanager
8+
from copy import deepcopy
9+
from functools import partial
10+
from typing import Any
11+
12+
import torch
13+
import torch.fx.node
14+
import torch.utils._pytree as pytree
15+
from torch._functorch._aot_autograd.descriptors import AOTOutput
16+
from torch._functorch.partitioners import _extract_graph_with_inputs_outputs
17+
18+
from .graph_pp_utils import (
19+
find_last_all_gather_in_chain,
20+
find_last_non_view_node_in_chain,
21+
find_last_user_in_wait_chain,
22+
is_reduce_scatter_tensor,
23+
is_wait_tensor,
24+
)
25+
26+
27+
@contextmanager
28+
def exclude_from_fx_side_effectful(exclude_vals: set[Any]):
29+
original_val = torch.fx.node._side_effectful_functions.copy()
30+
try:
31+
torch.fx.node._side_effectful_functions -= exclude_vals
32+
yield
33+
finally:
34+
torch.fx.node._side_effectful_functions.clear()
35+
torch.fx.node._side_effectful_functions.update(original_val)
36+
37+
38+
exclude_wait_from_fx_side_effectful = partial(
39+
exclude_from_fx_side_effectful,
40+
{
41+
torch.ops._c10d_functional.wait_tensor,
42+
torch.ops._c10d_functional.wait_tensor.default,
43+
},
44+
)
45+
46+
47+
@dataclasses.dataclass(frozen=True)
48+
class PrefetchOutput(AOTOutput):
49+
pass
50+
51+
52+
@dataclasses.dataclass(frozen=True)
53+
class EpilogueInput(AOTOutput):
54+
pass
55+
56+
57+
def split_fsdp_prefetch(
58+
gm: torch.fx.GraphModule,
59+
num_params: int,
60+
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
61+
g = deepcopy(gm.graph)
62+
all_g_ins = g.find_nodes(op="placeholder")
63+
param_g_ins = all_g_ins[:num_params]
64+
rem_g_ins = all_g_ins[num_params:]
65+
66+
prefetch_g_outs_map = []
67+
68+
for param_g_in in param_g_ins:
69+
# 1. Find last all_gather from each placeholder
70+
last_ag_node = find_last_all_gather_in_chain(param_g_in)
71+
if last_ag_node is None:
72+
prefetch_g_outs_map.append(param_g_in)
73+
else:
74+
# 2. Find last wait_tensor from last all_gather
75+
last_ag_wait_node = next(iter(last_ag_node.users))
76+
assert is_wait_tensor(last_ag_wait_node)
77+
78+
# 3. Continue the linear chain from the last wait_tensor
79+
last_wait_chain_user = find_last_user_in_wait_chain(last_ag_wait_node)
80+
81+
# 4. Get the last non-view node in the wait chain
82+
last_non_view_wait_chain_user = find_last_non_view_node_in_chain(
83+
last_wait_chain_user
84+
)
85+
86+
prefetch_g_outs_map.append(last_non_view_wait_chain_user)
87+
88+
prefetch_g_outs = prefetch_g_outs_map
89+
prefetch_g_outs_descs: list[AOTOutput] = [
90+
PrefetchOutput() for _ in range(len(prefetch_g_outs))
91+
]
92+
g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output")))
93+
g_outs_descs = pytree.arg_tree_leaves(
94+
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs))
95+
)
96+
with exclude_wait_from_fx_side_effectful():
97+
prefetch_g = _extract_graph_with_inputs_outputs(
98+
g,
99+
param_g_ins,
100+
prefetch_g_outs,
101+
prefetch_g_outs_descs,
102+
ignore_must_be_in_fw_bw=True,
103+
)
104+
105+
main_g = _extract_graph_with_inputs_outputs(
106+
g,
107+
prefetch_g_outs + rem_g_ins,
108+
g_outs,
109+
g_outs_descs,
110+
ignore_must_be_in_fw_bw=True,
111+
)
112+
prefetch_gm = torch.fx._lazy_graph_module._make_graph_module(gm, prefetch_g)
113+
main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g)
114+
return prefetch_gm, main_gm
115+
116+
117+
def split_fsdp_reduce_scatters_epilogue(
118+
gm: torch.fx.GraphModule,
119+
num_grads: int,
120+
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
121+
g = deepcopy(gm.graph)
122+
g_ins = g.find_nodes(op="placeholder")
123+
g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output")))
124+
grad_outs = g_outs[:num_grads]
125+
rem_g_outs = g_outs[num_grads:]
126+
out_descs = pytree.arg_tree_leaves(
127+
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(grad_outs))
128+
)
129+
grad_outs_descs = out_descs[:num_grads]
130+
rem_g_outs_descs = out_descs[num_grads:]
131+
132+
grad_outs_map = []
133+
for grad_out in grad_outs:
134+
n = grad_out
135+
earliest_rs = None
136+
while n is not None:
137+
if len(n.all_input_nodes) != 1:
138+
break
139+
n_in = n.all_input_nodes[0]
140+
if len(n_in.users) > 1:
141+
break
142+
prev_n = n
143+
n = n_in
144+
# Maybe we also need to track all_reduce?
145+
if is_reduce_scatter_tensor(prev_n):
146+
# In AP for mesh dim > 1
147+
# The reduction of gradients happen in multiple steps
148+
earliest_rs = n
149+
if earliest_rs is not None:
150+
grad_outs_map.append(earliest_rs)
151+
else:
152+
grad_outs_map.append(grad_out)
153+
154+
epi_g_ins = grad_outs_map
155+
epi_g_ins_descs: list[AOTOutput] = [EpilogueInput() for _ in range(len(epi_g_ins))]
156+
157+
with exclude_wait_from_fx_side_effectful():
158+
main_g = _extract_graph_with_inputs_outputs(
159+
g,
160+
g_ins,
161+
epi_g_ins + rem_g_outs,
162+
epi_g_ins_descs + rem_g_outs_descs,
163+
ignore_must_be_in_fw_bw=True,
164+
)
165+
epi_g = _extract_graph_with_inputs_outputs(
166+
g,
167+
epi_g_ins,
168+
grad_outs,
169+
grad_outs_descs,
170+
ignore_must_be_in_fw_bw=True,
171+
)
172+
epi_gm = torch.fx._lazy_graph_module._make_graph_module(gm, epi_g)
173+
main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g)
174+
return main_gm, epi_gm

0 commit comments

Comments
 (0)