Skip to content

Refactor pipeline parallel helpers for graph PP reuse#2724

Open
sanketpurandare wants to merge 1 commit intomainfrom
sanketpurandare/stack/3
Open

Refactor pipeline parallel helpers for graph PP reuse#2724
sanketpurandare wants to merge 1 commit intomainfrom
sanketpurandare/stack/3

Conversation

@sanketpurandare
Copy link
Copy Markdown
Contributor

@sanketpurandare sanketpurandare commented Mar 27, 2026

Stacked PRs:


Refactor pipeline parallel helpers for graph PP reuse

Extract pipeline metadata, module splitting, and PP rank-to-stage mapping from pipeline_llm so graph PP can reuse the underlying setup logic without duplicating it. Add backward_requires_autograd to the schedule builder for graph PP, which runs explicit backward graphs instead of autograd. Existing eager PP behavior is unchanged.

How does this refactoring help?

This is the call chain:
pipeline_llm
-> _get_pipeline_metadata
-> _generate_llm_fqn_per_model_part
-> _build_get_mesh_callback
-> _pipeline_module_split
-> _get_pp_rank_to_stage_indices_mapping
-> _split_module
-> PipelineStage
-> parallelize_fn
-> _build_pipeline_schedule

GraphPP will define:
graph_pipeline_llm
-> _get_pipeline_metadata
-> _generate_llm_fqn_per_model_part
-> _build_get_mesh_callback
-> _pipeline_module_split
-> _get_pp_rank_to_stage_indices_mapping
-> _split_module
-> GraphPipelineStage
-> parallelize_fn
-> _build_pipeline_schedule
-> GraphPPRunner

So by just overriding _pipeline_module_split but by reusing all the other helpers, we minimize code redundancy and maximize code reuse.

Keep pipeline_llm as the only public entrypoint exported by torchtitan.distributed.pipeline_parallel. Make build_pipeline_schedule, generate_llm_fqn_per_model_part, and pipeline_module_split private because they are implementation details with narrower contracts: schedule construction depends on the current PP config shape, LLM FQN generation encodes TorchTitan-specific module naming heuristics, and module splitting assumes models tolerate deleted or empty layer containers.

Update internal and experiment callsites to use the private helper names directly where reuse is still needed. This keeps the reusable code centralized while avoiding accidentally blessing those helpers as stable public API.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 27, 2026
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/3 branch from 2d2fb54 to bd097f8 Compare March 27, 2026 00:49
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/2 to main March 27, 2026 01:13
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/3 branch from bd097f8 to d3d14c2 Compare March 27, 2026 01:14
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/2 March 27, 2026 01:14
Copy link
Copy Markdown
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add more detailed description about why making these changes.

Comment on lines 40 to 48
__all__ = [
"pipeline_llm",
"get_pipeline_metadata",
"build_pipeline_schedule",
"generate_llm_fqn_per_model_part",
"pipeline_module_split",
"split_module",
"get_pp_rank_to_stage_indices_mapping",
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do all these fields need to be public and exposed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry this should have been marked as draft PR, it is not ready for review yet.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question not answered

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the PR description to include the change description.

@tianyu-l tianyu-l requested a review from acisseJZhong March 27, 2026 02:17
@tianyu-l
Copy link
Copy Markdown
Contributor

@claude review this PR

@claude
Copy link
Copy Markdown

claude Bot commented Mar 27, 2026

Claude finished @tianyu-l's task in 3m 25s —— View job


Review of PR #2724: Refactor pipeline_parallel.py for graph PP reuse

  • Gather context and read the diff
  • Read the full modified file and base version
  • Analyze changes and write review

Summary

This PR refactors pipeline_parallel.py by extracting three nested/inline functions into top-level, reusable public functions:

  1. get_pipeline_metadata — extracted from the top of pipeline_llm, computes num_virtual_stages, num_layers, input_weight, output_weight
  2. split_module — extracted from _build_stage_from_modules (nested inside pipeline_module_split), deep-copies and prunes the model to keep only specified modules
  3. get_pp_rank_to_stage_indices_mapping — extracted from _get_stage_indices (nested inside pipeline_module_split), maps PP rank to stage indices for loop/v-style schedules

Additionally, build_pipeline_schedule gains two new parameters: backward_requires_autograd and scale_grads (previously hardcoded to True and False respectively).

All three new functions and new __all__ exports are presumably intended for reuse by the graph PP experiments (PRs #2725#2727 in this stack).


Issues and Suggestions

1. pp_degree parameter lacks type annotation (get_pp_rank_to_stage_indices_mapping)

pipeline_parallel.py:466

def get_pp_rank_to_stage_indices_mapping(
    pp_rank: int,
    pp_degree,          # <-- missing type annotation
    pp_schedule: str,
    num_stages: int,
) -> tuple[int, ...]:

pp_degree should be annotated as int. Every other parameter in this function is typed.

Fix this →


2. assert used for user-facing validation in get_pp_rank_to_stage_indices_mapping

pipeline_parallel.py:486-488 and pipeline_parallel.py:493-495

Per the repo's code guidelines (CLAUDE.md): "assert only for internal invariants that indicate programmer error" and "ValueError for user-facing errors (bad config, invalid input)."

The two assert statements validate num_stages % pp_degree and stages_per_rank == 2. These were originally inside a private nested function (_get_stage_indices), where assert was arguably acceptable. Now that get_pp_rank_to_stage_indices_mapping is a public API, they should be ValueError raises, consistent with how pipeline_llm validates similar conditions (e.g. pipeline_parallel.py:173-180).

# Current (assert):
assert (
    num_stages % pp_degree == 0
), f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}"

# Should be:
if num_stages % pp_degree != 0:
    raise ValueError(
        f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}"
    )

Same applies to the stages_per_rank == 2 assertion on line 493.

Fix this →


3. split_module broadens isinstance check vs. original without explanation

pipeline_parallel.py:425-427

The original _build_stage_from_modules checked:

isinstance(module_value, (nn.ModuleDict, nn.ModuleList))

The new split_module checks:

isinstance(module_value, (nn.ModuleDict, nn.ModuleList, ModuleDict, ModuleList))

Where ModuleDict and ModuleList are from torchtitan.protocols.module. This is a behavioral change compared to the original code path. If this is intentional (to support torchtitan's custom protocol types), it should be called out in the PR description. If it's accidental, it should be reverted to match the original behavior.


4. build_pipeline_schedule docstring not updated for new parameters

pipeline_parallel.py:216-226

The docstring still only documents the original four parameters. The new backward_requires_autograd and scale_grads parameters are not documented:

    """Builds a pipeline schedule for the given job configuration and stages.

    Args:
        parallelism (ParallelismConfig): The parallelism configuration.
        local_batch_size (int): The local batch size for computing microbatches.
        stages (list[PipelineStage]): The stages to be scheduled.
        loss_fn (Callable): The loss function.
        # Missing: backward_requires_autograd, scale_grads

At minimum, add brief descriptions for both new parameters so downstream consumers understand what they control.

Fix this →


5. Agree with reviewer: consider whether all new exports need to be public

pipeline_parallel.py:40-48

Three new symbols are added to __all__: get_pipeline_metadata, split_module, get_pp_rank_to_stage_indices_mapping. As @tianyu-l noted, it's worth considering whether all of these truly need to be part of the public API.

split_module is the strongest candidate for public exposure — it's a general-purpose utility that could be independently useful. get_pp_rank_to_stage_indices_mapping also has a clear independent use case.

get_pipeline_metadata is more debatable — it returns a fairly specific tuple (num_virtual_stages, num_layers, input_weight, output_weight) that's tightly coupled to the pipeline_llm flow. If the graph PP experiment only needs to call it once in a similar orchestration function, consider whether the caller should just inline the logic or whether this really warrants a public function. Without seeing the consuming PRs (#2725#2727), it's hard to judge definitively.


6. Stale blank line removal in imports

pipeline_parallel.py:24-25

The PR removes the blank line between the torch.distributed.pipelining.schedules imports and the torchtitan imports. This is a minor PEP 8 style regression — standard convention separates third-party/stdlib imports from local imports with a blank line. The original code had this correctly.


7. PR description needs more detail

Agreeing with the existing review comment: the PR description should explain why these functions are being extracted and what the intended reuse pattern is. Linking to or describing the consuming code in the stacked PRs (#2725#2727) would help reviewers understand the motivation.


Correctness Assessment

The refactoring itself appears mechanically correct — the extracted functions preserve the same logic as the original inline/nested versions. The main pipeline_llm function's behavior should be identical before and after this change. The new backward_requires_autograd and scale_grads parameters default to the same values as the original hardcoded ones (True and False), so existing callers are unaffected.

One thing to verify: the split_module isinstance broadening (issue #3 above) is the only behavioral difference that could affect existing code paths.


@sanketpurandare sanketpurandare marked this pull request as draft March 27, 2026 02:21
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/2 to main March 27, 2026 18:57
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/3 branch from d3d14c2 to 2fb3b75 Compare March 27, 2026 18:57
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/2 March 27, 2026 18:57
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/2 to main April 1, 2026 16:28
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/2 April 1, 2026 16:28
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/2 to main April 24, 2026 19:53
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/3 branch from 2fb3b75 to 36c1ede Compare April 24, 2026 19:53
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/2 April 24, 2026 19:53
@sanketpurandare sanketpurandare marked this pull request as ready for review April 24, 2026 19:54
stages: list[PipelineStage],
loss_fn: Callable,
backward_requires_autograd: bool = True,
scale_grads: bool = False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need to expose this? IMO we should deprecate this -- IIRC it's assuming all microbatches have the same number of valid tokens.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scale_grads has been removed

local_batch_size: int,
stages: list[PipelineStage],
loss_fn: Callable,
backward_requires_autograd: bool = True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious what this field is used for

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used when we want to run the backward actions for a microbatch without supplying the loss function to the pp runtime. Currently if we don't supply the loss function the pp_runtime will assume there is no backward. But in graph_pp the forward graph contains the loss function, so explicit loss function is not passed to the pp_runtime.

Comment on lines 40 to 48
__all__ = [
"pipeline_llm",
"get_pipeline_metadata",
"build_pipeline_schedule",
"generate_llm_fqn_per_model_part",
"pipeline_module_split",
"split_module",
"get_pp_rank_to_stage_indices_mapping",
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question not answered

@sanketpurandare sanketpurandare marked this pull request as draft April 30, 2026 18:15
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/2 to main April 30, 2026 18:15
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/3 branch from 36c1ede to 4dd3fe7 Compare April 30, 2026 18:15
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/2 April 30, 2026 18:16
@sanketpurandare sanketpurandare marked this pull request as ready for review April 30, 2026 18:16
@sanketpurandare sanketpurandare marked this pull request as draft April 30, 2026 18:40
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/2 to main April 30, 2026 18:41
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/3 branch from 4dd3fe7 to e3fcf74 Compare April 30, 2026 18:41
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/2 April 30, 2026 18:41
@sanketpurandare sanketpurandare marked this pull request as ready for review April 30, 2026 18:41
@sanketpurandare sanketpurandare marked this pull request as draft May 4, 2026 03:05
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/2 to main May 4, 2026 03:05
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/3 branch from e3fcf74 to ba8eebc Compare May 4, 2026 03:05
@sanketpurandare sanketpurandare marked this pull request as ready for review May 4, 2026 03:06
Extract pipeline metadata, module splitting, and PP rank-to-stage mapping from pipeline_llm so graph PP can reuse the underlying setup logic without duplicating it. Add backward_requires_autograd to the schedule builder for graph PP, which runs explicit backward graphs instead of autograd. Existing eager PP behavior is unchanged.

Keep pipeline_llm as the only public entrypoint exported by torchtitan.distributed.pipeline_parallel. Make build_pipeline_schedule, generate_llm_fqn_per_model_part, and pipeline_module_split private because they are implementation details with narrower contracts: schedule construction depends on the current PP config shape, LLM FQN generation encodes TorchTitan-specific module naming heuristics, and module splitting assumes models tolerate deleted or empty layer containers.

Update internal and experiment callsites to use the private helper names directly where reuse is still needed. This keeps the reusable code centralized while avoiding accidentally blessing those helpers as stable public API.

stack-info: PR: #2724, branch: sanketpurandare/stack/3
@sanketpurandare sanketpurandare marked this pull request as draft May 4, 2026 04:38
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/3 branch from ba8eebc to d036765 Compare May 4, 2026 04:38
@sanketpurandare sanketpurandare changed the title Refactor pipeline_parallel.py for graph PP reuse Refactor pipeline parallel helpers for graph PP reuse May 4, 2026
@sanketpurandare sanketpurandare marked this pull request as ready for review May 4, 2026 04:39
@sanketpurandare sanketpurandare requested a review from tianyu-l May 4, 2026 05:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants