Skip to content

Commit 44c51e3

Browse files
authored
Fix Adam subgroup inconsistency (#7982)
Fix CPUAdam same-step subgroup drift in ZeRO-3 (#7819) This PR ports the fix from #7820 to the latest DeepSpeed version. It makes `Adam_Optimizer::IncrementStep` idempotent for repeated calls at the same logical step and avoids unnecessary recomputation when the step has not changed. ZeRO-3/SuperOffload can invoke multiple subgroup updates within a single logical step on a shared native optimizer object. The previous logic mixed multiply and recompute paths, producing non-bit-identical bias-correction metadata across subgroup calls. This change aligns the step-transition logic in both the CPU and XPU headers, clarifies first-step and non-sequential-step behavior, and prevents unnecessary work on repeated same-step updates. It also adds CPUAdam regression tests covering subgroup-style repeated same-step updates through both `step_subgroup()` and `step()` with parameter swapping. Signed-off-by: st_bang <st.bang@dgist.ac.kr>
1 parent 077bff5 commit 44c51e3

3 files changed

Lines changed: 72 additions & 10 deletions

File tree

csrc/includes/cpu_adam.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,17 @@ class Adam_Optimizer {
6363
_betta1_t = std::pow(_betta1, step);
6464
_betta2_t = std::pow(_betta2, step);
6565
} else {
66-
_step++;
67-
if (_step != step) {
66+
if (step == _step + 1) { // first optimizer step increase
67+
_step++;
68+
_betta1_t *= _betta1;
69+
_betta2_t *= _betta2;
70+
} else if (step ==
71+
_step) { // no need to update step; beta1_t and beta2_t already updated
72+
return;
73+
} else { // support step increase not equal to 1
6874
_betta1_t = std::pow(_betta1, step);
6975
_betta2_t = std::pow(_betta2, step);
7076
_step = step;
71-
} else {
72-
_betta1_t *= _betta1;
73-
_betta2_t *= _betta2;
7477
}
7578
}
7679
}

csrc/xpu/includes/cpu_adam.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,17 @@ class Adam_Optimizer {
6969
_betta1_t = std::pow(_betta1, step);
7070
_betta2_t = std::pow(_betta2, step);
7171
} else {
72-
_step++;
73-
if (_step != step) {
72+
if (step == _step + 1) { // first optimizer step increase
73+
_step++;
74+
_betta1_t *= _betta1;
75+
_betta2_t *= _betta2;
76+
} else if (step ==
77+
_step) { // no need to update step; beta1_t and beta2_t already updated
78+
return;
79+
} else { // support step increase not equal to 1
7480
_betta1_t = std::pow(_betta1, step);
7581
_betta2_t = std::pow(_betta2, step);
7682
_step = step;
77-
} else {
78-
_betta1_t *= _betta1;
79-
_betta2_t *= _betta2;
8083
}
8184
}
8285
}

tests/unit/ops/adam/test_cpu_adam.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,59 @@ def test_multiple_subgroups(self):
312312
optimizer.rollback_subgroup(0)
313313
assert optimizer.state[0]['step'] == 1, "Subgroup 0 step count should be decremented"
314314
assert optimizer.state[1]['step'] == 1, "Subgroup 1 step count should be unchanged"
315+
316+
def test_step_subgroup_same_step_idempotent_across_subgroups(self):
317+
"""Repeated same-step subgroup updates should remain bit-identical."""
318+
from deepspeed.ops.adam import DeepSpeedCPUAdam
319+
320+
model_size = 128
321+
steps = 4
322+
base = torch.randn(model_size, device='cpu', dtype=torch.float32)
323+
param_a = torch.nn.Parameter(base.clone())
324+
param_b = torch.nn.Parameter(base.clone())
325+
326+
optimizer = DeepSpeedCPUAdam([param_a])
327+
for logical_step in range(1, steps + 1):
328+
grad = torch.randn(model_size, device='cpu', dtype=torch.float32)
329+
330+
optimizer.param_groups[0]['params'] = [param_a]
331+
param_a.grad = grad.clone()
332+
optimizer.step_subgroup(0)
333+
334+
optimizer.param_groups[0]['params'] = [param_b]
335+
param_b.grad = grad.clone()
336+
optimizer.step_subgroup(1)
337+
338+
assert optimizer.state[0]['step'] == logical_step
339+
assert optimizer.state[1]['step'] == logical_step
340+
assert torch.equal(param_a.data, param_b.data)
341+
assert torch.equal(optimizer.state[0]['exp_avg'], optimizer.state[1]['exp_avg'])
342+
assert torch.equal(optimizer.state[0]['exp_avg_sq'], optimizer.state[1]['exp_avg_sq'])
343+
344+
def test_step_same_step_idempotent_across_param_keys(self):
345+
"""Repeated optimizer.step() with swapped param keys should be deterministic."""
346+
from deepspeed.ops.adam import DeepSpeedCPUAdam
347+
348+
model_size = 128
349+
steps = 4
350+
base = torch.randn(model_size, device='cpu', dtype=torch.float32)
351+
param_a = torch.nn.Parameter(base.clone())
352+
param_b = torch.nn.Parameter(base.clone())
353+
354+
optimizer = DeepSpeedCPUAdam([param_a])
355+
for logical_step in range(1, steps + 1):
356+
grad = torch.randn(model_size, device='cpu', dtype=torch.float32)
357+
358+
optimizer.param_groups[0]['params'] = [param_a]
359+
param_a.grad = grad.clone()
360+
optimizer.step()
361+
362+
optimizer.param_groups[0]['params'] = [param_b]
363+
param_b.grad = grad.clone()
364+
optimizer.step()
365+
366+
assert optimizer.state[param_a]['step'] == logical_step
367+
assert optimizer.state[param_b]['step'] == logical_step
368+
assert torch.equal(param_a.data, param_b.data)
369+
assert torch.equal(optimizer.state[param_a]['exp_avg'], optimizer.state[param_b]['exp_avg'])
370+
assert torch.equal(optimizer.state[param_a]['exp_avg_sq'], optimizer.state[param_b]['exp_avg_sq'])

0 commit comments

Comments
 (0)