Skip to content

Feature request: Add a streaming_shard operator for early sample-level sharding when file-level sharding is insufficient #8253

Description

@muyihao

Feature request: Add a streaming_shard operator for early sample-level sharding when file-level sharding is insufficient

Proposed API

It would be useful to provide a first-class streaming sharding operator in datasets, for example:

dataset = datasets.load_dataset(..., streaming=True, split="train")

dataset = dataset.streaming_shard(
    num_shards=world_size,
    index=rank,
)

or:

dataset = datasets.streaming_shard(
    dataset,
    num_shards=world_size,
    index=rank,
)

Then users could write:

dataset = datasets.load_dataset(..., streaming=True, split="train")

dataset = dataset.streaming_shard(
    num_shards=world_size,
    index=rank,
)

dataset = dataset.map(tokenize_fn)
dataset = dataset.map(build_labels_or_loss_mask)
dataset = dataset.filter(...)

This would move rank-aware sample sharding earlier in the streaming pipeline, before expensive transformations.

Why this is needed

Existing file-level sharding works well when the dataset has enough physical shards. However, it is not enough when:

num_physical_files < global_world_size

or when physical files cannot be split.

For example, if we have:

global_world_size = 64
num_physical_files = 4

then file-level assignment alone cannot give each rank a distinct physical shard. Many ranks would either read the same files or receive no file-level shard. In practice, the only reliable way to divide the dataset is sample-level sharding, such as what IterableDatasetShard does.

The problem is that IterableDatasetShard is applied at the training/dataloader layer. This means users cannot easily place the shard boundary before expensive dataset transformations.

Expected behavior

For a streaming dataset yielding:

[0, 1, 2, 3, 4, 5, 6, 7]

with:

num_shards = 2

the result could be:

rank 0 -> [0, 2, 4, 6]
rank 1 -> [1, 3, 5, 7]

This is sample-level logical sharding and does not depend on the number of physical files.

A batch-aware mode may also be useful to align with IterableDatasetShard semantics:

dataset.streaming_shard(
    num_shards=world_size,
    index=rank,
    batch_size=batch_size,
    split_batches=False,
    drop_last=False,
)

Current workaround

The current workaround is to rely on IterableDatasetShard:

dataset = datasets.load_dataset(..., streaming=True, split="train")

dataset = dataset.map(tokenize_fn)
dataset = dataset.map(build_labels_or_loss_mask)

dataset = IterableDatasetShard(
    dataset,
    batch_size=batch_size,
    num_processes=world_size,
    process_index=rank,
)

This is correct from a data partitioning perspective, but inefficient because the sharding happens after the previous dataset transformations.

Desired pattern

The desired pattern is:

dataset = datasets.load_dataset(..., streaming=True, split="train")

dataset = dataset.streaming_shard(
    num_shards=world_size,
    index=rank,
)

dataset = dataset.map(tokenize_fn)
dataset = dataset.map(build_labels_or_loss_mask)

This way, each rank only preprocesses the examples it will actually consume.

Use cases

This would be useful for:

  1. Large-scale LLM pretraining with a small number of very large JSONL / Parquet files.
  2. SFT pipelines where tokenization, chat template application, and loss-mask construction are expensive.
  3. Multimodal training where image/video metadata parsing or remote object loading should not be repeated across ranks.
  4. Distributed jobs where the number of source files is smaller than global_world_size.
  5. Data sources where physical repartitioning is expensive or controlled by an upstream production pipeline.

Relationship to existing APIs

This feature is related to:

  • accelerate.data_loader.IterableDatasetShard, which provides sample-level sharding at the dataloader/training layer.
  • Datasets file-level sharding, which works well when there are enough physical source shards.
  • to_iterable_dataset(num_shards=...), which can create iterable shards when converting a map-style dataset.

The proposed feature is different because it provides logical sample-level sharding directly in the datasets streaming pipeline, before downstream transformations such as map, filter, tokenization, or loss-mask construction.

Open questions

  1. Should the operator exactly match IterableDatasetShard semantics?
  2. Should it support both sample-level and batch-aware sharding?
  3. Should it support both strided and contiguous logical sharding?
  4. How should it interact with shuffle(buffer_size=...) and set_epoch()?
  5. Should this be implemented as a method on IterableDataset, or as a utility function under datasets.distributed?
  6. Should it expose options such as drop_last, even_batches, and split_batches to match distributed dataloader behavior?

Summary

When physical file-level sharding is impossible or when the number of source files is smaller than the global world size, distributed streaming training has to rely on IterableDatasetShard for sample-level data partitioning. However, this sharding happens too late, after expensive dataset transformations may already have been executed.

A streaming_shard operator in datasets would allow users to apply logical rank-aware sharding earlier in the streaming pipeline, reducing duplicated preprocessing work and improving efficiency for large-scale distributed training.

Motivation

When using datasets.load_dataset(..., streaming=True) for distributed training, Hugging Face Datasets can benefit from physical file-level sharding if the dataset is already split into enough source files.

However, in many real-world training pipelines, this assumption does not hold.

For example:

  • the dataset may contain only a few very large JSONL / Parquet files;
  • the number of physical files may be smaller than the global distributed world size;
  • the source files may be generated by an upstream pipeline and cannot be repartitioned easily;
  • the files may be stored on remote storage such as HDFS or object storage, where rewriting or repartitioning is expensive;
  • some file formats or data sources do not provide a convenient way to split one physical file into multiple independent streaming shards.

In these cases, file-level sharding is not sufficient. The training job has to fully rely on accelerate.data_loader.IterableDatasetShard or similar dataloader-level logic to split the data across ranks.

This works functionally, but it happens too late in the pipeline.

A common pipeline looks like this:

dataset = datasets.load_dataset(..., streaming=True, split="train")

dataset = dataset.map(tokenize_fn)
dataset = dataset.map(build_labels_or_loss_mask)
dataset = dataset.filter(...)

Then the distributed dataloader wraps the dataset with IterableDatasetShard.

The issue is that every rank may execute the upstream streaming pipeline before the dataloader-level shard decides which examples should be kept by the current rank. For expensive preprocessing steps such as tokenization, chat template application, multimodal metadata loading, filtering, or loss-mask construction, this may cause duplicated CPU work and unnecessary I/O pressure.

Your contribution

I propose adding a streaming_shard operator that can be applied before expensive streaming dataset transformations.

I’m happy to work on this feature if the maintainers think it makes sense.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions