Skip to content

Commit 2208e9b

Browse files
hwchen2017loadams
authored andcommitted
Cross layer overlapping for domino (deepspeedai#7178)
1. Add implementation for cross layer communication overlapping to achieve communication "free". 2. Optimize the implementation for communication overlapping within transformer layer. Signed-off-by: Hongwei Chen <hongweichen@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Signed-off-by: yisheng <yi.sheng@intel.com>
1 parent 1b0f96f commit 2208e9b

2 files changed

Lines changed: 489 additions & 261 deletions

File tree

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/23.08/megatron/core/tensor_parallel/layers.py
7+
8+
import torch
9+
from torch.nn.parameter import Parameter
10+
import torch.nn.functional as F
11+
from deepspeed.accelerator import get_accelerator
12+
import deepspeed.comm as dist
13+
from typing import Callable
14+
15+
TP_group = None
16+
17+
18+
class DominoAsyncColumnParallelLinearImpl(torch.autograd.Function):
19+
20+
@staticmethod
21+
def forward(ctx, inp, weight, bias, handle_dic, h_id): # inp: (b, s, k), weight: (m, k), bias (m)
22+
ctx.save_for_backward(inp, weight, bias)
23+
ctx.handle_dic = handle_dic
24+
ctx.h_id = h_id
25+
output = torch.matmul(inp, weight.t()) # (b, s, k) @ (k, m) -> (b, s, m)
26+
if bias is not None: # bias (m)
27+
output = output + bias
28+
return output
29+
30+
@staticmethod
31+
def backward(ctx, grad_output):
32+
inp, weight, bias = ctx.saved_tensors
33+
grad_input = grad_weight = grad_bias = None
34+
grad_input = torch.matmul(grad_output, weight) # (b, s, m) @ (m, k) -> (b, s, k)
35+
handle = dist.all_reduce(grad_input, group=TP_group, async_op=True)
36+
ctx.handle_dic[ctx.h_id] = handle
37+
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) # (b*s, m)
38+
39+
inp = inp.view(inp.shape[0] * inp.shape[1], inp.shape[2]) # (b*s, k)
40+
grad_weight = torch.matmul(grad_output.t(), inp) # (m, b*s) @ (b*s, k) -> (m, k)
41+
42+
if bias is not None:
43+
grad_bias = grad_output.sum(dim=0) # (b*s, m) -> (m)
44+
return grad_input, grad_weight, grad_bias, None, None
45+
46+
47+
class DominoAsyncColumnParallelLinear(torch.nn.Module):
48+
49+
def __init__(self,
50+
input_size,
51+
output_size,
52+
_tp_group,
53+
config,
54+
init_method: Callable,
55+
bias=True,
56+
skip_bias_add=False):
57+
super(DominoAsyncColumnParallelLinear, self).__init__()
58+
59+
self.skip_bias_add = skip_bias_add
60+
61+
global TP_group
62+
if TP_group == None:
63+
TP_group = _tp_group
64+
65+
self.weight = Parameter(
66+
torch.empty(
67+
output_size,
68+
input_size,
69+
device=get_accelerator().current_device_name(),
70+
dtype=config.params_dtype,
71+
))
72+
if config.perform_initialization:
73+
init_method(self.weight)
74+
75+
if bias:
76+
self.bias = Parameter(
77+
torch.empty(output_size, device=get_accelerator().current_device_name(), dtype=config.params_dtype))
78+
79+
if config.perform_initialization:
80+
with torch.no_grad():
81+
self.bias.zero_()
82+
else:
83+
self.register_parameter('bias', None)
84+
85+
def forward(self, input_: torch.Tensor, handle_dic, h_id):
86+
87+
bias = self.bias if not self.skip_bias_add else None
88+
89+
output = DominoAsyncColumnParallelLinearImpl.apply(input_, self.weight, bias, handle_dic, h_id)
90+
91+
output_bias = self.bias if self.skip_bias_add else None
92+
return output, output_bias
93+
94+
95+
class RowParallelLinearNoComm(torch.nn.Module):
96+
97+
def __init__(
98+
self,
99+
input_size: int,
100+
output_size: int,
101+
config,
102+
init_method: Callable,
103+
bias: bool = True,
104+
stride: int = 1,
105+
skip_bias_add: bool = False,
106+
):
107+
super(RowParallelLinearNoComm, self).__init__()
108+
109+
self.skip_bias_add = skip_bias_add
110+
111+
self.weight = Parameter(
112+
torch.empty(
113+
output_size,
114+
input_size,
115+
device=get_accelerator().current_device_name(),
116+
dtype=config.params_dtype,
117+
))
118+
if config.perform_initialization:
119+
init_method(self.weight)
120+
if bias:
121+
self.bias = Parameter(
122+
torch.empty(
123+
output_size,
124+
device=get_accelerator().current_device_name(),
125+
dtype=config.params_dtype,
126+
))
127+
128+
if config.perform_initialization:
129+
with torch.no_grad():
130+
self.bias.zero_()
131+
else:
132+
self.register_parameter('bias', None)
133+
134+
def forward(self, input_):
135+
bias = self.bias if not self.skip_bias_add else None
136+
137+
output = F.linear(input_, self.weight, bias)
138+
139+
output_bias = self.bias if self.skip_bias_add else None
140+
return output, output_bias

0 commit comments

Comments
 (0)