|
100 | 100 | from megatron.bridge.utils.common_utils import get_world_size_safe, print_rank_0 |
101 | 101 |
|
102 | 102 |
|
| 103 | +# For Optimizer CUDA graph support |
| 104 | +try: |
| 105 | + from megatron.core.optimizer.optimizer_cuda_graph import OptimizerCudaGraphWrapper |
| 106 | + |
| 107 | + HAS_OPTIMIZER_CUDA_GRAPH = True |
| 108 | +except ImportError: |
| 109 | + HAS_OPTIMIZER_CUDA_GRAPH = False |
103 | 110 | # For Paged Stashing support |
104 | 111 | try: |
105 | 112 | from megatron.core.transformer.moe.paged_stash import PagedStashRunner |
@@ -303,6 +310,12 @@ def train( |
303 | 310 | forward_backward_func = FullCudaGraphWrapper( |
304 | 311 | forward_backward_func, cuda_graph_warmup_steps=config.model.cuda_graph_warmup_steps |
305 | 312 | ) |
| 313 | + |
| 314 | + if config.optimizer.optimizer_cuda_graph and HAS_OPTIMIZER_CUDA_GRAPH: |
| 315 | + optimizer.step = OptimizerCudaGraphWrapper( |
| 316 | + optimizer.step, cuda_graph_warmup_steps=config.model.cuda_graph_warmup_steps |
| 317 | + ) |
| 318 | + |
306 | 319 | # Wrap model with PagedStashRunner when moe_expert_rank_capacity_factor padding is enabled. |
307 | 320 | # PagedStashRunner is responsible for detecting overflow and re-running iteration in eager-mode without padding. |
308 | 321 | if HAS_PAGED_STASHING and config.model.moe_expert_rank_capacity_factor is not None: |
@@ -1581,6 +1594,11 @@ def _delete_cuda_graphs(cuda_graph_helper: TECudaGraphHelper): |
1581 | 1594 | if "training" in FullCudaGraphWrapper.cuda_graph: |
1582 | 1595 | del FullCudaGraphWrapper.cuda_graph["training"] |
1583 | 1596 |
|
| 1597 | + # Explicitly delete optimizer CUDA graph |
| 1598 | + if HAS_OPTIMIZER_CUDA_GRAPH and OptimizerCudaGraphWrapper.cuda_graph is not None: |
| 1599 | + del OptimizerCudaGraphWrapper.cuda_graph |
| 1600 | + OptimizerCudaGraphWrapper.cuda_graph = None |
| 1601 | + |
1584 | 1602 | # Cleanup CUDA graphs object for partial Cuda-graphs (implemented in TransformerEngine) |
1585 | 1603 | if cuda_graph_helper is not None: |
1586 | 1604 | cuda_graph_helper.delete_cuda_graphs() |
|
0 commit comments