diff --git a/deepmd/dpmodel/model/ener_model.py b/deepmd/dpmodel/model/ener_model.py index 09b50a6f17..828d307b99 100644 --- a/deepmd/dpmodel/model/ener_model.py +++ b/deepmd/dpmodel/model/ener_model.py @@ -1,11 +1,20 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from collections.abc import ( + Mapping, +) from copy import ( deepcopy, ) +from itertools import ( + pairwise, +) from typing import ( Any, ) +import array_api_compat +import numpy as np + from deepmd.dpmodel.array_api import ( Array, ) @@ -14,6 +23,7 @@ ) from deepmd.dpmodel.common import ( NativeOP, + to_numpy_array, ) from deepmd.dpmodel.model.base_model import ( BaseModel, @@ -92,6 +102,7 @@ def call( fparam: Array | None = None, aparam: Array | None = None, do_atomic_virial: bool = False, + mixed_batch: Mapping[str, Array] | None = None, charge_spin: Array | None = None, neighbor_list: NeighborList | None = None, ) -> dict[str, Array]: @@ -110,6 +121,19 @@ def call( injected to accelerate neighbor-list construction without changing the model outputs. """ + if mixed_batch is not None: + return self.call_flat( + coord=coord, + atype=atype, + box=box, + fparam=fparam, + aparam=aparam, + charge_spin=charge_spin, + do_atomic_virial=do_atomic_virial, + mixed_batch=mixed_batch, + neighbor_list=neighbor_list, + ) + model_ret = self.call_common( coord, atype, @@ -135,6 +159,112 @@ def call( model_predict["hessian"] = model_ret["energy_derv_r_derv_r"].squeeze(-3) return model_predict + def call_flat( + self, + coord: Array, + atype: Array, + mixed_batch: Mapping[str, Array], + box: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + charge_spin: Array | None = None, + do_atomic_virial: bool = False, + neighbor_list: NeighborList | None = None, + ) -> dict[str, Array]: + """Evaluate a flattened mixed-nloc batch with the dpmodel backend. + + The dpmodel backend reuses the regular one-frame call path for each + segment described by ``ptr`` and merges the translated outputs back into + the flat mixed-batch layout. + """ + batch = mixed_batch.get("batch") + ptr = mixed_batch.get("ptr") + if batch is None or ptr is None: + raise ValueError("mixed_batch must contain both batch and ptr.") + if self._enable_hessian: + raise NotImplementedError( + "Hessian is not implemented for dpmodel mixed-batch flat calls." + ) + + xp = array_api_compat.array_namespace(coord, atype) + ptr_np = to_numpy_array(ptr) + if ptr_np is None: + raise ValueError("ptr is required for mixed batches.") + ptr_np = np.asarray(ptr_np, dtype=np.int64) + if ptr_np.ndim != 1 or ptr_np.size < 2: + raise ValueError("ptr must be a 1D array with at least two entries.") + + total_atoms = coord.shape[0] + if ptr_np[0] != 0 or ptr_np[-1] != total_atoms: + raise ValueError("ptr must start at 0 and end at the number of atoms.") + if batch.shape[0] != total_atoms: + raise ValueError("batch length must match the number of atoms.") + + frame_outputs = [] + for frame_idx, (start, end) in enumerate(pairwise(ptr_np)): + nloc = int(end - start) + frame_coord = xp.reshape(coord[start:end], (1, nloc * 3)) + frame_atype = xp.reshape(atype[start:end], (1, nloc)) + frame_box = box[frame_idx : frame_idx + 1] if box is not None else None + frame_fparam = ( + fparam[frame_idx : frame_idx + 1] if fparam is not None else None + ) + frame_aparam = ( + xp.reshape(aparam[start:end], (1, nloc, *aparam.shape[1:])) + if aparam is not None + else None + ) + frame_charge_spin = ( + charge_spin[frame_idx : frame_idx + 1] + if charge_spin is not None + else None + ) + frame_outputs.append( + self.call( + frame_coord, + frame_atype, + box=frame_box, + fparam=frame_fparam, + aparam=frame_aparam, + charge_spin=frame_charge_spin, + do_atomic_virial=do_atomic_virial, + neighbor_list=neighbor_list, + ) + ) + + return self._merge_flat_frame_outputs(frame_outputs) + + @staticmethod + def _merge_flat_frame_outputs( + frame_outputs: list[dict[str, Array]], + ) -> dict[str, Array]: + if not frame_outputs: + raise ValueError("mixed-batch input must contain at least one frame.") + + framewise_keys = {"energy", "virial"} + result: dict[str, Array] = {} + for key in frame_outputs[0]: + values = [frame_output[key] for frame_output in frame_outputs] + xp = array_api_compat.array_namespace(values[0]) + if key in framewise_keys: + result[key] = xp.concat(values, axis=0) + elif key == "mask": + result[key] = xp.concat( + [xp.reshape(value, (-1,)) for value in values], + axis=0, + ) + else: + result[key] = xp.concat( + [ + xp.reshape(value, (-1, *value.shape[2:])) + if value.ndim >= 3 + else xp.reshape(value, (-1,)) + for value in values + ], + axis=0, + ) + return result + def call_lower( self, extended_coord: Array, diff --git a/deepmd/dpmodel/utils/__init__.py b/deepmd/dpmodel/utils/__init__.py index 0179543dd4..50818acf74 100644 --- a/deepmd/dpmodel/utils/__init__.py +++ b/deepmd/dpmodel/utils/__init__.py @@ -10,10 +10,12 @@ PairExcludeMask, ) from .lmdb_data import ( + DistributedMixedBatchSampler, DistributedSameNlocBatchSampler, LmdbDataReader, LmdbTestData, LmdbTestDataNlocView, + MixedBatchSampler, SameNlocBatchSampler, is_lmdb, make_neighbor_stat_data, @@ -71,6 +73,7 @@ __all__ = [ "AtomExcludeMask", "DefaultNeighborList", + "DistributedMixedBatchSampler", "DistributedSameNlocBatchSampler", "EmbeddingNet", "EnvMat", @@ -79,6 +82,7 @@ "LmdbDataReader", "LmdbTestData", "LmdbTestDataNlocView", + "MixedBatchSampler", "NativeLayer", "NativeNet", "NeighborGraph", diff --git a/deepmd/dpmodel/utils/lmdb_data.py b/deepmd/dpmodel/utils/lmdb_data.py index d4e4c65b23..a4b5113d14 100644 --- a/deepmd/dpmodel/utils/lmdb_data.py +++ b/deepmd/dpmodel/utils/lmdb_data.py @@ -8,6 +8,7 @@ import logging import math from collections.abc import ( + Iterable, Iterator, ) from pathlib import ( @@ -248,8 +249,9 @@ class LmdbDataReader: by nloc and yields same-nloc batches. Auto batch_size is computed per-nloc-group. - ``mixed_batch=True`` (new format): frames with different nloc can - coexist in one batch (requires padding + mask in collate_fn). - Currently raises ``NotImplementedError`` at collation time. + coexist in one batch. Atom-wise fields are flattened in the collate + function, and string batch-size rules are applied to the total number + of atoms in the mixed batch. Parameters ---------- @@ -264,13 +266,17 @@ class LmdbDataReader: - ``"auto"`` / ``"auto:N"``: ``ceil(N / nloc)`` per nloc group (``N=32`` for bare ``"auto"``). Acts as a *lower* bound — each batch has at least ``N`` atoms, but may exceed ``N`` - by up to ``nloc - 1``. + by up to ``nloc - 1``. With ``mixed_batch=True``, frames are + accumulated until the total atom count first reaches or exceeds + ``N``. - ``"max:N"``: ``max(1, floor(N / nloc))`` per nloc group. Acts as an *upper* bound for groups with ``nloc <= N`` (batch has at most ``N`` atoms). For groups with ``nloc > N`` the ``max(1, ...)`` floor kicks in: ``bsi=1`` and a single-frame batch still carries ``nloc`` atoms, - which exceeds ``N``. + which exceeds ``N``. With ``mixed_batch=True``, frames are + accumulated while the mixed-batch total stays at or below ``N``; + an oversized single frame is kept as a one-frame batch. - ``"filter:N"``: same per-nloc formula as ``"max:N"`` **and** drops every frame whose ``nloc > N`` from the dataset. By construction every retained batch has at most ``N`` atoms. @@ -322,14 +328,16 @@ def __init__( # Safe because we use num_workers=0 in DataLoader. self._txn = self._env.begin() - # Scan per-frame nloc only when needed for same-nloc batching. - # For mixed_batch=True, skip the scan entirely (future: padding handles it). + # Scan per-frame nloc for same-nloc batching and for mixed-batch + # string rules (auto/max/filter), which use the mixed batch's total + # atom count as the budget. # ``orig_frame_nlocs`` / ``orig_frame_system_ids`` are indexed by the # *original* LMDB frame index. After a potential ``filter:N`` drop we # rebuild ``self._frame_nlocs`` / ``self._frame_system_ids`` so they # are parallel arrays over the *dataset* index space (0..len(self)); # the dataset-to-original mapping lives in ``self._retained_keys``. - if not mixed_batch: + need_frame_nlocs = (not mixed_batch) or isinstance(batch_size, str) + if need_frame_nlocs: # Fast path: use pre-computed frame_nlocs from metadata if available. # Falls back to scanning each frame's atom_types shape (~10 us/frame). meta_nlocs = meta.get("frame_nlocs") @@ -376,17 +384,6 @@ def __init__( "Expected int, 'auto', 'auto:N', 'max:N', or 'filter:N'." ) - # ``filter:N`` needs per-frame nloc to drop oversized frames; the - # ``mixed_batch=True`` fast path skips the nloc scan entirely, so the - # two options are incompatible. Fail fast rather than silently - # retaining every frame and breaking the documented contract. - if self._filter_rule is not None and mixed_batch: - raise ValueError( - "batch_size='filter:N' is incompatible with mixed_batch=True: " - "per-frame nloc is unavailable in the mixed-batch fast path. " - "Use mixed_batch=False, or switch to 'max:N' / a fixed int." - ) - # Determine which original-index frames survive the filter. Without # ``filter:N`` every frame is retained. if self._filter_rule is not None: @@ -412,7 +409,7 @@ def __init__( # space so that every downstream consumer (nloc_groups, system_groups, # SameNlocBatchSampler, _expand_indices_by_blocks) operates in a # single, self-consistent indexing scheme. - if not mixed_batch: + if orig_frame_nlocs: self._frame_nlocs = [orig_frame_nlocs[k] for k in retained_keys] else: self._frame_nlocs = [] @@ -756,6 +753,8 @@ def index(self) -> list[int]: @property def total_batch(self) -> int: if self.mixed_batch: + if self._auto_rule is not None or self._max_rule is not None: + return len(_build_mixed_batches(self, shuffle=False)) return math.ceil(self.nframes / self.batch_size) if self.nframes else 0 total = 0 for nloc, indices in self._nloc_groups.items(): @@ -1332,6 +1331,151 @@ def world_size(self) -> int: return self._world_size +def _build_mixed_batches( + reader: LmdbDataReader, + shuffle: bool, + rng: np.random.Generator | None = None, +) -> list[list[int]]: + """Build mixed-nloc batches using frame order and atom-count budgets. + + Fixed integer ``batch_size`` keeps the historical mixed-batch meaning: + number of frames per batch. String rules use the total number of atoms in + the mixed batch: + + - ``auto:N`` closes a batch when the accumulated atom count first reaches + or exceeds ``N``. + - ``max:N`` closes before adding a frame that would exceed ``N``; a single + frame whose ``nloc > N`` is still emitted as a one-frame batch. + - ``filter:N`` reuses ``max:N`` after oversized frames have been removed by + :class:`LmdbDataReader`. + """ + indices = list(range(len(reader))) + if shuffle: + if rng is None: + rng = np.random.default_rng() + rng.shuffle(indices) + + return list(_iter_mixed_batches(reader, indices)) + + +def _iter_mixed_batches( + reader: LmdbDataReader, + indices: Iterable[int], +) -> Iterator[list[int]]: + """Yield mixed-nloc batches from an ordered frame-index iterable.""" + auto_rule = reader._auto_rule + max_rule = reader._max_rule + if auto_rule is None and max_rule is None: + bs = max(1, reader.batch_size) + current: list[int] = [] + for idx in indices: + current.append(idx) + if len(current) >= bs: + yield current + current = [] + if current: + yield current + return + + if not reader.frame_nlocs: + raise ValueError( + "mixed-batch auto/max/filter batch_size requires per-frame nlocs." + ) + + current: list[int] = [] + current_atoms = 0 + + if auto_rule is not None: + for idx in indices: + current.append(idx) + current_atoms += reader.frame_nlocs[idx] + if current_atoms >= auto_rule: + yield current + current = [] + current_atoms = 0 + else: + assert max_rule is not None + for idx in indices: + nloc = reader.frame_nlocs[idx] + if current and current_atoms + nloc > max_rule: + yield current + current = [] + current_atoms = 0 + current.append(idx) + current_atoms += nloc + if current_atoms >= max_rule: + yield current + current = [] + current_atoms = 0 + + if current: + yield current + + +class MixedBatchSampler: + """Sampler for mixed-nloc LMDB batches. + + It yields lists of frame indices that may have different ``nloc`` values. + Integer ``batch_size`` is interpreted as a frame count; ``auto/max/filter`` + string rules are interpreted as total atom-count budgets. + """ + + def __init__( + self, + reader: LmdbDataReader, + shuffle: bool = True, + seed: int | None = None, + ) -> None: + self._reader = reader + self._shuffle = shuffle + self._seed = seed + + def __iter__(self) -> Iterator[list[int]]: + rng = np.random.default_rng(self._seed) + indices = list(range(len(self._reader))) + if self._shuffle: + rng.shuffle(indices) + yield from _iter_mixed_batches(self._reader, indices) + + def __len__(self) -> int: + return len(_build_mixed_batches(self._reader, shuffle=False)) + + +class DistributedMixedBatchSampler: + """Distributed wrapper for mixed-nloc batch sampling.""" + + def __init__( + self, + reader: LmdbDataReader, + rank: int, + world_size: int, + shuffle: bool = True, + seed: int | None = None, + ) -> None: + self._reader = reader + self._rank = rank + self._world_size = world_size + self._shuffle = shuffle + self._seed = seed if seed is not None else 0 + self._epoch = 0 + + def set_epoch(self, epoch: int) -> None: + self._epoch = epoch + + def __iter__(self) -> Iterator[list[int]]: + rng = np.random.default_rng(self._seed + self._epoch) + indices = list(range(len(self._reader))) + if self._shuffle: + rng.shuffle(indices) + for batch_idx, batch in enumerate(_iter_mixed_batches(self._reader, indices)): + if batch_idx % self._world_size == self._rank: + yield batch + + def __len__(self) -> int: + total = len(_build_mixed_batches(self._reader, shuffle=False)) + return math.ceil(total / self._world_size) + + def make_neighbor_stat_data( lmdb_path: str, type_map: list[str] | None, diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 560ea5a1ba..7dee93e4fe 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -198,10 +198,12 @@ def _make_dp_loader_set( # LMDB path: single string → LmdbDataset if isinstance(training_systems, str) and is_lmdb(training_systems): auto_prob = training_dataset_params.get("auto_prob", None) + mixed_batch = training_dataset_params.get("mixed_batch", False) train_data_single = LmdbDataset( training_systems, model_params_single["type_map"], training_dataset_params["batch_size"], + mixed_batch=mixed_batch, auto_prob_style=auto_prob, ) if ( @@ -209,10 +211,12 @@ def _make_dp_loader_set( and isinstance(validation_systems, str) and is_lmdb(validation_systems) ): + val_mixed_batch = validation_dataset_params.get("mixed_batch", False) validation_data_single = LmdbDataset( validation_systems, model_params_single["type_map"], validation_dataset_params["batch_size"], + mixed_batch=val_mixed_batch, ) elif validation_systems is not None: validation_data_single = _make_dp_loader_set( diff --git a/deepmd/pt/loss/ener.py b/deepmd/pt/loss/ener.py index 50d83a4ac9..7821f48dc0 100644 --- a/deepmd/pt/loss/ener.py +++ b/deepmd/pt/loss/ener.py @@ -238,11 +238,39 @@ def forward( more_loss = {} # more_loss['log_keys'] = [] # showed when validation on the fly # more_loss['test_keys'] = [] # showed when doing dp test - atom_norm = 1.0 / natoms - # Normalization exponent controls loss scaling with system size: - # - norm_exp=2 (intensive_ener_virial=True): loss uses 1/N² scaling, making it independent of system size - # - norm_exp=1 (intensive_ener_virial=False, legacy): loss uses 1/N scaling, which varies with system size + + # Detect mixed batch format + mixed_batch = input_dict.get("mixed_batch") + is_mixed_batch = mixed_batch is not None + + atom_norms = None + if is_mixed_batch: + ptr = mixed_batch["ptr"] + natoms_per_frame = ptr[1:] - ptr[:-1] # [nframes] + atom_norms = 1.0 / natoms_per_frame.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) + atom_norm = None + else: + atom_norm = 1.0 / natoms norm_exp = 2 if self.intensive_ener_virial else 1 + + def get_frame_norm(value: torch.Tensor) -> torch.Tensor: + assert atom_norms is not None + return atom_norms.to(device=value.device, dtype=value.dtype).view( + [-1] + [1] * (value.dim() - 1) + ) + + def weighted_mean(value: torch.Tensor, power: int = 1) -> torch.Tensor: + if atom_norms is None: + assert atom_norm is not None + return value.mean() * (atom_norm**power) + return (value * get_frame_norm(value) ** power).mean() + + def normalized_rmse(diff: torch.Tensor) -> torch.Tensor: + if atom_norms is None: + assert atom_norm is not None + return torch.mean(torch.square(diff)).sqrt() * atom_norm + return torch.mean(torch.square(diff * get_frame_norm(diff))).sqrt() + if self.has_e and "energy" in model_pred and "energy" in label: energy_pred = model_pred["energy"] energy_label = label["energy"] @@ -261,35 +289,37 @@ def forward( energy_pred = torch.sum(atom_ener_coeff * atom_ener_pred, dim=1) find_energy = label.get("find_energy", 0.0) pref_e = pref_e * find_energy + diff_e = energy_pred - energy_label if self.loss_func == "mse": - l2_ener_loss = torch.mean(torch.square(energy_pred - energy_label)) + square_ener_diff = torch.square(diff_e) + l2_ener_loss = torch.mean(square_ener_diff) if not self.inference: more_loss["l2_ener_loss"] = self.display_if_exist( l2_ener_loss.detach(), find_energy ) if not self.use_huber: - loss += atom_norm**norm_exp * (pref_e * l2_ener_loss) + loss += pref_e * weighted_mean(square_ener_diff, norm_exp) else: + energy_norm = ( + atom_norm if atom_norms is None else get_frame_norm(energy_pred) + ) l_huber_loss = custom_huber_loss( - atom_norm * energy_pred, - atom_norm * energy_label, + energy_norm * energy_pred, + energy_norm * energy_label, delta=self._huber_delta_energy, ) loss += pref_e * l_huber_loss - rmse_e = l2_ener_loss.sqrt() * atom_norm + rmse_e = normalized_rmse(diff_e) more_loss["rmse_e"] = self.display_if_exist( rmse_e.detach(), find_energy ) # more_loss['log_keys'].append('rmse_e') elif self.loss_func == "mae": - l1_ener_loss = F.l1_loss( - energy_pred.reshape(-1), - energy_label.reshape(-1), - reduction="mean", - ) - loss += atom_norm * (pref_e * l1_ener_loss) + abs_ener_diff = torch.abs(diff_e) + mae_e = weighted_mean(abs_ener_diff) + loss += pref_e * mae_e more_loss["mae_e"] = self.display_if_exist( - l1_ener_loss.detach() * atom_norm, + mae_e.detach(), find_energy, ) # more_loss['log_keys'].append('rmse_e') @@ -298,9 +328,9 @@ def forward( f"Loss type {self.loss_func} is not implemented for energy loss." ) if mae: - mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm + mae_e = weighted_mean(torch.abs(diff_e)) more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy) - mae_e_all = torch.mean(torch.abs(energy_pred - energy_label)) + mae_e_all = torch.mean(torch.abs(diff_e)) more_loss["mae_e_all"] = self.display_if_exist( mae_e_all.detach(), find_energy ) @@ -417,6 +447,10 @@ def forward( ) if self.has_gf and "drdq" in label: + if is_mixed_batch: + raise NotImplementedError( + "Generalized force loss is not supported with mixed_batch=True yet." + ) drdq = label["drdq"] find_drdq = label.get("find_drdq", 0.0) pref_gf = pref_gf * find_drdq @@ -446,33 +480,36 @@ def forward( pref_v = pref_v * find_virial diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9) if self.loss_func == "mse": - l2_virial_loss = torch.mean(torch.square(diff_v)) + square_virial_diff = torch.square(diff_v) + l2_virial_loss = torch.mean(square_virial_diff) if not self.inference: more_loss["l2_virial_loss"] = self.display_if_exist( l2_virial_loss.detach(), find_virial ) if not self.use_huber: - loss += atom_norm**norm_exp * (pref_v * l2_virial_loss) + loss += pref_v * weighted_mean(square_virial_diff, norm_exp) else: + virial = model_pred["virial"].reshape(-1, 9) + virial_label = label["virial"].reshape(-1, 9) + virial_norm = ( + atom_norm if atom_norms is None else get_frame_norm(virial) + ) l_huber_loss = custom_huber_loss( - atom_norm * model_pred["virial"].reshape(-1), - atom_norm * label["virial"].reshape(-1), + (virial_norm * virial).reshape(-1), + (virial_norm * virial_label).reshape(-1), delta=self._huber_delta_virial, ) loss += pref_v * l_huber_loss - rmse_v = l2_virial_loss.sqrt() * atom_norm + rmse_v = normalized_rmse(diff_v) more_loss["rmse_v"] = self.display_if_exist( rmse_v.detach(), find_virial ) elif self.loss_func == "mae": - l1_virial_loss = F.l1_loss( - label["virial"].reshape(-1), - model_pred["virial"].reshape(-1), - reduction="mean", - ) - loss += atom_norm * (pref_v * l1_virial_loss) + abs_virial_diff = torch.abs(diff_v) + mae_v = weighted_mean(abs_virial_diff) + loss += pref_v * mae_v more_loss["mae_v"] = self.display_if_exist( - l1_virial_loss.detach() * atom_norm, + mae_v.detach(), find_virial, ) else: @@ -480,7 +517,7 @@ def forward( f"Loss type {self.loss_func} is not implemented for virial loss." ) if mae: - mae_v = torch.mean(torch.abs(diff_v)) * atom_norm + mae_v = weighted_mean(torch.abs(diff_v)) more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial) if self.has_ae and "atom_energy" in model_pred and "atom_ener" in label: diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index a5ce444fd3..68d58d5193 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -289,6 +289,103 @@ def forward_atomic( ) return fit_ret + def forward_common_atomic_flat( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_batch: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor, + batch: torch.Tensor, + ptr: torch.Tensor, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + extended_ptr: torch.Tensor | None = None, + central_ext_index: torch.Tensor | None = None, + nlist_ext: torch.Tensor | None = None, + a_nlist: torch.Tensor | None = None, + a_nlist_ext: torch.Tensor | None = None, + nlist_mask: torch.Tensor | None = None, + a_nlist_mask: torch.Tensor | None = None, + edge_index: torch.Tensor | None = None, + angle_index: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Forward pass with flat mixed-nloc batch format.""" + if self.do_grad_r() or self.do_grad_c(): + extended_coord.requires_grad_(True) + + nframes = ptr.numel() - 1 + if self.add_chg_spin_ebd and charge_spin is None: + default_cs_tensor = self.descriptor.get_default_chg_spin() + if default_cs_tensor is not None: + charge_spin = torch.tile( + default_cs_tensor.to(device=extended_coord.device).unsqueeze(0), + [nframes, 1], + ) + + descriptor_out = self.descriptor.forward_flat( + extended_coord, + extended_atype, + extended_batch, + nlist, + mapping, + batch, + ptr, + fparam=fparam, + charge_spin=charge_spin if self.add_chg_spin_ebd else None, + central_ext_index=central_ext_index, + nlist_ext=nlist_ext, + a_nlist=a_nlist, + a_nlist_ext=a_nlist_ext, + nlist_mask=nlist_mask, + a_nlist_mask=a_nlist_mask, + edge_index=edge_index, + angle_index=angle_index, + ) + + descriptor = descriptor_out.get("descriptor") + rot_mat = descriptor_out.get("rot_mat") + g2 = descriptor_out.get("g2") + h2 = descriptor_out.get("h2") + + if central_ext_index is None: + from deepmd.pt.utils.nlist import ( + get_central_ext_index, + ) + + central_ext_index = get_central_ext_index(extended_batch, ptr) + atype = extended_atype[central_ext_index] + + fit_ret = self.fitting_net.forward_flat( + descriptor, + atype, + batch, + ptr, + gr=rot_mat, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) + fit_ret = self.apply_out_stat(fit_ret, atype) + + atom_mask = self.make_atom_mask(atype).to(torch.int32) + if self.atom_excl is not None: + atom_mask *= self.atom_excl(atype.unsqueeze(0)).squeeze(0) + + for kk in fit_ret.keys(): + out_shape = fit_ret[kk].shape + out_shape2 = 1 + for ss in out_shape[1:]: + out_shape2 *= ss + fit_ret[kk] = ( + fit_ret[kk].reshape([out_shape[0], out_shape2]) * atom_mask[:, None] + ).view(out_shape) + fit_ret["mask"] = atom_mask + + return fit_ret + def has_embedding(self) -> bool: """A standard descriptor-fitting atomic model supports embeddings.""" return True diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py index 99f315af17..18d1d8759c 100644 --- a/deepmd/pt/model/descriptor/dpa3.py +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -632,6 +632,148 @@ def forward( sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if sw is not None else None, ) + def forward_flat( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_batch: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor, + batch: torch.Tensor, + ptr: torch.Tensor, + fparam: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + central_ext_index: torch.Tensor | None = None, + nlist_ext: torch.Tensor | None = None, + a_nlist: torch.Tensor | None = None, + a_nlist_ext: torch.Tensor | None = None, + nlist_mask: torch.Tensor | None = None, + a_nlist_mask: torch.Tensor | None = None, + edge_index: torch.Tensor | None = None, + angle_index: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Compute the descriptor with flat batch format. + + Parameters + ---------- + extended_coord : torch.Tensor + Extended coordinates [total_extended_atoms, 3]. + extended_atype : torch.Tensor + Extended atom types [total_extended_atoms]. + extended_batch : torch.Tensor + Frame assignment for extended atoms [total_extended_atoms]. + nlist : torch.Tensor + Neighbor list [total_atoms, nnei]. + mapping : torch.Tensor + Extended atom -> local flat index mapping [total_extended_atoms]. + batch : torch.Tensor + Frame assignment for local atoms [total_atoms]. + ptr : torch.Tensor + Frame boundaries [nframes + 1]. + fparam : torch.Tensor | None + Frame parameters [nframes, ndf]. + charge_spin : torch.Tensor | None + Frame-level charge and spin conditions with shape [nframes, 2]. + central_ext_index : torch.Tensor | None + Extended-atom indices corresponding to local atoms. + nlist_ext, a_nlist_ext : torch.Tensor | None + Edge and angle neighbor lists indexing concatenated extended atoms. + nlist_mask, a_nlist_mask : torch.Tensor | None + Valid-neighbor masks for flat edge and angle neighbor lists. + edge_index, angle_index : torch.Tensor | None + Dynamic graph indices produced by the flat graph preprocessor. + + Returns + ------- + result : dict[str, torch.Tensor] + Dictionary containing: + - 'descriptor': [total_atoms, descriptor_dim] + - 'rot_mat': [total_atoms, e_dim, 3] or None + - 'g2': edge embedding or None + - 'h2': pair representation or None + """ + extended_coord = extended_coord.to(dtype=self.prec) + + # Flat batches embed all extended atoms, then gather central atoms. + node_ebd_ext = self.type_embedding( + extended_atype + ) # [total_extended_atoms, tebd_dim] + + if self.add_chg_spin_ebd: + assert charge_spin is not None + assert self.chg_embedding is not None + assert self.spin_embedding is not None + + # Expand frame-level charge/spin parameters to extended atoms. + frame_charge = charge_spin[:, 0].to(dtype=torch.int64) + frame_spin = charge_spin[:, 1].to(dtype=torch.int64) + if torch.any(frame_charge < -100) or torch.any(frame_charge > 99): + raise ValueError( + "charge must be in range [-100, 99], got " + f"min={frame_charge.min().item()}, " + f"max={frame_charge.max().item()}" + ) + if torch.any(frame_spin < 0) or torch.any(frame_spin >= 100): + raise ValueError( + "spin must be in range [0, 99], got " + f"min={frame_spin.min().item()}, max={frame_spin.max().item()}" + ) + charge = frame_charge[extended_batch] + 100 + spin = frame_spin[extended_batch] + chg_ebd = self.chg_embedding(charge) + spin_ebd = self.spin_embedding(spin) + sys_cs_embd = self.act( + self.mix_cs_mlp(torch.cat((chg_ebd, spin_ebd), dim=-1)) + ) + node_ebd_ext = node_ebd_ext + sys_cs_embd + + if central_ext_index is None: + from deepmd.pt.utils.nlist import ( + get_central_ext_index, + ) + + central_ext_index = get_central_ext_index(extended_batch, ptr) + node_ebd_inp = node_ebd_ext[central_ext_index] + + node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows.forward_flat( + nlist, + extended_coord, + extended_atype, + extended_batch, + node_ebd_ext, + mapping, + batch, + ptr, + central_ext_index=central_ext_index, + nlist_ext=nlist_ext, + a_nlist=a_nlist, + a_nlist_ext=a_nlist_ext, + nlist_mask=nlist_mask, + a_nlist_mask=a_nlist_mask, + edge_index=edge_index, + angle_index=angle_index, + ) + + if self.concat_output_tebd: + node_ebd = torch.cat([node_ebd, node_ebd_inp], dim=-1) + + return { + "descriptor": node_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + "rot_mat": ( + rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) + if rot_mat is not None + else None + ), + "g2": ( + edge_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) + if edge_ebd is not None + else None + ), + "h2": ( + h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if h2 is not None else None + ), + } + @classmethod def update_sel( cls, diff --git a/deepmd/pt/model/descriptor/env_mat.py b/deepmd/pt/model/descriptor/env_mat.py index 0ffdbb7dbb..5d73d784e3 100644 --- a/deepmd/pt/model/descriptor/env_mat.py +++ b/deepmd/pt/model/descriptor/env_mat.py @@ -90,3 +90,98 @@ def prod_env_mat( t_std = stddev[atype] # [n_atom, dim, 4 or 1] env_mat_se_a = (_env_mat_se_a - t_avg) / t_std return env_mat_se_a, diff, switch + + +def prod_env_mat_flat( + extended_coord_flat: torch.Tensor, + nlist_flat: torch.Tensor, + atype_flat: torch.Tensor, + mean: torch.Tensor, + stddev: torch.Tensor, + rcut: float, + rcut_smth: float, + radial_only: bool = False, + protection: float = 0.0, + use_exp_switch: bool = False, + coord_flat: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate smooth environment matrix in flat format. + + Parameters + ---------- + extended_coord_flat + Extended atom coordinates with shape ``[nall, 3]``. + nlist_flat + Neighbor list with shape ``[nloc, nnei]``. ``-1`` marks padding. + atype_flat + Central atom types with shape ``[nloc]``. + mean, stddev + Descriptor statistics with shape ``[ntypes, nnei, 4 or 1]``. + rcut, rcut_smth + Cutoff radius and smooth cutoff radius. + radial_only + Whether to return radial-only descriptors. + protection + Small positive value used in radial divisions. + use_exp_switch + Whether to use the exponential switch function. + coord_flat + Optional central atom coordinates with shape ``[nloc, 3]``. + + Returns + ------- + env_mat + Environment matrix with shape ``[nloc, nnei, 4 or 1]``. + diff + Difference vectors with shape ``[nloc, nnei, 3]``. + switch + Switch function values with shape ``[nloc, nnei, 1]``. + """ + nloc, nnei = nlist_flat.shape + nall = extended_coord_flat.shape[0] + + mask = nlist_flat >= 0 + nlist_safe = torch.where(mask, nlist_flat, nall) + + # coord_l: [nloc, 1, 3] + if coord_flat is not None: + coord_l = coord_flat.view(nloc, 1, 3) + else: + coord_l = extended_coord_flat[:nloc].view(nloc, 1, 3) + + # Gather neighbor coordinates + index = nlist_safe.view(-1).unsqueeze(-1).expand(-1, 3) + coord_pad = torch.cat( + [extended_coord_flat, extended_coord_flat[-1:, :] + rcut], dim=0 + ) + coord_r = torch.gather(coord_pad, 0, index) + coord_r = coord_r.view(nloc, nnei, 3) + + # Compute differences and distances + diff = coord_r - coord_l + length = torch.linalg.norm(diff, dim=-1, keepdim=True) + length = length + ~mask.unsqueeze(-1) + + t0 = 1 / (length + protection) + t1 = diff / (length + protection) ** 2 + + weight = ( + compute_smooth_weight(length, rcut_smth, rcut) + if not use_exp_switch + else compute_exp_sw(length, rcut_smth, rcut) + ) + weight = weight * mask.unsqueeze(-1) + + if radial_only: + env_mat = t0 * weight + else: + env_mat = torch.cat([t0, t1], dim=-1) * weight + + diff = diff * mask.unsqueeze(-1) + + # Normalize by mean and stddev + t_avg = mean[atype_flat] # [nloc, nnei, 4] + t_std = stddev[atype_flat] # [nloc, nnei, 4] + env_mat = (env_mat - t_avg) / t_std + + return env_mat, diff, weight diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index e29fe01ac6..001777b162 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -681,6 +681,242 @@ def forward( return node_ebd, edge_ebd, h2, rot_mat.view(nframes, nloc, self.dim_emb, 3), sw + def forward_flat( + self, + nlist: torch.Tensor, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_batch: torch.Tensor, + extended_atype_embd: torch.Tensor, + mapping: torch.Tensor, + batch: torch.Tensor, + ptr: torch.Tensor, + central_ext_index: torch.Tensor | None = None, + nlist_ext: torch.Tensor | None = None, + a_nlist: torch.Tensor | None = None, + a_nlist_ext: torch.Tensor | None = None, + nlist_mask: torch.Tensor | None = None, + a_nlist_mask: torch.Tensor | None = None, + edge_index: torch.Tensor | None = None, + angle_index: torch.Tensor | None = None, + ) -> tuple[ + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + ]: + """Forward pass for a precomputed flat graph batch. + + Parameters + ---------- + nlist : torch.Tensor + Neighbor list [total_atoms, nnei] with global extended indices. + extended_coord : torch.Tensor + Extended coordinates [total_extended_atoms, 3]. + extended_atype : torch.Tensor + Extended atom types [total_extended_atoms]. + extended_batch : torch.Tensor + Frame assignment for extended atoms [total_extended_atoms]. + extended_atype_embd : torch.Tensor + Type embeddings for extended atoms [total_extended_atoms, tebd_dim]. + mapping : torch.Tensor + Extended atom -> local flat index mapping [total_extended_atoms]. + batch : torch.Tensor + Frame assignment for local atoms [total_atoms]. + ptr : torch.Tensor + Frame boundaries [nframes + 1]. + central_ext_index + Extended-atom indices corresponding to local atoms. + nlist_ext, a_nlist_ext + Edge and angle neighbor lists indexing concatenated extended atoms. + nlist_mask, a_nlist_mask + Valid-neighbor masks for edge and angle neighbor lists. + edge_index, angle_index + Dynamic graph indices generated from the flat neighbor lists. + + Returns + ------- + node_ebd : torch.Tensor + Node embeddings [total_atoms, n_dim]. + edge_ebd : torch.Tensor | None + Edge embeddings. + h2 : torch.Tensor | None + Pair representation. + rot_mat : torch.Tensor | None + Rotation matrix [total_atoms, e_dim, 3]. + sw : torch.Tensor | None + Switch function. + """ + from deepmd.pt.model.descriptor.env_mat import ( + prod_env_mat_flat, + ) + + nloc = batch.shape[0] + if ( + central_ext_index is None + or nlist_ext is None + or a_nlist is None + or a_nlist_ext is None + or nlist_mask is None + or a_nlist_mask is None + ): + raise RuntimeError( + "Repflows flat forward requires precomputed graph fields from collate_fn." + ) + coord_central = extended_coord[central_ext_index] + atype = extended_atype[central_ext_index] + + # Edge environment matrix in extended-atom index space. + dmatrix, diff, sw = prod_env_mat_flat( + extended_coord, + nlist_ext, + atype, + self.mean, + self.stddev, + self.e_rcut, + self.e_rcut_smth, + protection=self.env_protection, + use_exp_switch=self.use_exp_switch, + coord_flat=coord_central, + ) + + sw = torch.squeeze(sw, -1) + sw = sw.masked_fill(~nlist_mask, 0.0) + + # Angle environment matrix uses the angle cutoff and angle neighbor list. + _, a_diff, a_sw = prod_env_mat_flat( + extended_coord, + a_nlist_ext, + atype, + self.mean[:, : self.a_sel], + self.stddev[:, : self.a_sel], + self.a_rcut, + self.a_rcut_smth, + protection=self.env_protection, + use_exp_switch=self.use_exp_switch, + coord_flat=coord_central, + ) + + a_sw = torch.squeeze(a_sw, -1) + a_sw = a_sw.masked_fill(~a_nlist_mask, 0.0) + + # Node embedding for central atoms. + atype_embd = extended_atype_embd[central_ext_index] + assert list(atype_embd.shape) == [nloc, self.n_dim] + node_ebd = self.act(atype_embd) + + # Edge and angle embedding inputs. + edge_input, h2 = torch.split(dmatrix, [1, 3], dim=-1) + if self.edge_init_use_dist: + edge_input = safe_for_norm(diff, dim=-1, keepdim=True) + + # Angle input is the normalized cosine between neighbor directions. + normalized_diff_i = a_diff / ( + safe_for_norm(a_diff, dim=-1, keepdim=True) + 1e-6 + ) + normalized_diff_j = torch.transpose(normalized_diff_i, 1, 2) + cosine_ij = torch.matmul(normalized_diff_i, normalized_diff_j) * (1 - 1e-6) + angle_input = cosine_ij.unsqueeze(-1) / (torch.pi**0.5) + + if self.use_dynamic_sel: + if edge_index is None or angle_index is None: + raise RuntimeError( + "Dynamic flat forward requires precomputed edge_index and angle_index." + ) + # Flatten dynamic-selection tensors to match graph indices. + edge_input = edge_input[nlist_mask] + h2 = h2[nlist_mask] + sw = sw[nlist_mask] + a_nlist_mask_2d = a_nlist_mask[:, :, None] & a_nlist_mask[:, None, :] + angle_input = angle_input[a_nlist_mask_2d] + a_sw = (a_sw[:, :, None] * a_sw[:, None, :])[a_nlist_mask_2d] + else: + edge_index = torch.zeros([2, 1], device=nlist.device, dtype=nlist.dtype) + angle_index = torch.zeros([3, 1], device=nlist.device, dtype=nlist.dtype) + + # Edge and angle embeddings. + if not self.edge_init_use_dist: + edge_ebd = self.act(self.edge_embd(edge_input)) + else: + edge_ebd = self.edge_embd(edge_input) + angle_ebd = self.angle_embd(angle_input) + + # RepFlowLayer expects batched tensors. Use a synthetic one-frame batch + # while preserving flattened dynamic edge and angle tensors. + node_ebd_batched = node_ebd.unsqueeze(0) # [1, nloc, n_dim] + edge_ebd_batched = ( + edge_ebd.unsqueeze(0) if not self.use_dynamic_sel else edge_ebd + ) + h2_batched = h2.unsqueeze(0) if not self.use_dynamic_sel else h2 + angle_ebd_batched = ( + angle_ebd.unsqueeze(0) if not self.use_dynamic_sel else angle_ebd + ) + nlist_safe = torch.where(nlist_mask, nlist, torch.zeros_like(nlist)) + a_nlist_safe = torch.where(a_nlist_mask, a_nlist, torch.zeros_like(a_nlist)) + nlist_batched = nlist_safe.unsqueeze(0) # [1, nloc, nnei] + nlist_mask_batched = nlist_mask.unsqueeze(0) # [1, nloc, nnei] + sw_batched = sw.unsqueeze(0) if not self.use_dynamic_sel else sw + a_nlist_batched = a_nlist_safe.unsqueeze(0) # [1, nloc, a_nnei] + a_nlist_mask_batched = a_nlist_mask.unsqueeze(0) # [1, nloc, a_nnei] + a_sw_batched = a_sw.unsqueeze(0) if not self.use_dynamic_sel else a_sw + + for ll in self.layers: + # Flat precomputed graphs already use local atom indexing here. + node_ebd_ext_batched = node_ebd_batched + + node_ebd_batched, edge_ebd_batched, angle_ebd_batched = ll.forward( + node_ebd_ext_batched, + edge_ebd_batched, + h2_batched, + angle_ebd_batched, + nlist_batched, + nlist_mask_batched, + sw_batched, + a_nlist_batched, + a_nlist_mask_batched, + a_sw_batched, + edge_index=edge_index, + angle_index=angle_index, + ) + + # Rotation matrix from final edge representation. + if self.use_dynamic_sel: + h2g2 = RepFlowLayer._cal_hg_dynamic( + edge_ebd_batched, + h2_batched, + sw_batched, + owner=edge_index[0], + num_owner=nloc, + nb=1, + nloc=nloc, + scale_factor=(self.nnei / self.sel_reduce_factor) ** (-0.5), + ).squeeze(0) + else: + # Use batched versions for _cal_hg, then squeeze + h2g2 = RepFlowLayer._cal_hg( + edge_ebd_batched, + h2_batched, + nlist_mask_batched, + sw_batched, + ) + h2g2 = h2g2.squeeze(0) # Remove batch dimension + + # Remove batch dimension from outputs + node_ebd = node_ebd_batched.squeeze(0) + edge_ebd = ( + edge_ebd_batched.squeeze(0) + if not self.use_dynamic_sel + else edge_ebd_batched + ) + h2 = h2_batched.squeeze(0) if not self.use_dynamic_sel else h2_batched + sw = sw_batched.squeeze(0) if not self.use_dynamic_sel else sw_batched + + # [nloc, e_dim, 3] + rot_mat = torch.permute(h2g2, (0, 2, 1)) + + return node_ebd, edge_ebd, h2, rot_mat, sw + def compute_input_stats( self, merged: Callable[[], list[dict]] | list[dict], diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 1a01b05fe9..72683398eb 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -19,6 +19,9 @@ import numpy as np +from deepmd.dpmodel.utils.seed import ( + child_seed, +) from deepmd.pt.model.atomic_model import ( DPAtomicModel, PairTabAtomicModel, @@ -83,6 +86,27 @@ SpinModel, ) +DEFAULT_DESCRIPTOR_INIT_SEED = 1 +DEFAULT_FITTING_INIT_SEED = 2 + + +def _set_default_init_seed(params: dict[str, Any], seed: int | list[int]) -> None: + if params.get("seed") is None: + params["seed"] = seed + + +def _set_default_descriptor_init_seed( + params: dict[str, Any], seed: int | list[int] +) -> None: + if params.get("type") == "hybrid": + for idx, descriptor_params in enumerate(params.get("list", [])): + if isinstance(descriptor_params, dict): + _set_default_descriptor_init_seed( + descriptor_params, child_seed(seed, idx) + ) + return + _set_default_init_seed(params, seed) + def _get_standard_model_components(model_params: dict, ntypes: int) -> tuple: if "type_embedding" in model_params: @@ -92,9 +116,13 @@ def _get_standard_model_components(model_params: dict, ntypes: int) -> tuple: # descriptor model_params["descriptor"]["ntypes"] = ntypes model_params["descriptor"]["type_map"] = copy.deepcopy(model_params["type_map"]) + _set_default_descriptor_init_seed( + model_params["descriptor"], DEFAULT_DESCRIPTOR_INIT_SEED + ) descriptor = BaseDescriptor(**model_params["descriptor"]) # fitting fitting_net = model_params.get("fitting_net", {}) + _set_default_init_seed(fitting_net, DEFAULT_FITTING_INIT_SEED) fitting_net["type"] = fitting_net.get("type", "ener") fitting_net["ntypes"] = descriptor.get_ntypes() fitting_net["type_map"] = copy.deepcopy(model_params["type_map"]) diff --git a/deepmd/pt/model/model/ener_model.py b/deepmd/pt/model/model/ener_model.py index 28387553fb..4eb32a0807 100644 --- a/deepmd/pt/model/model/ener_model.py +++ b/deepmd/pt/model/model/ener_model.py @@ -73,30 +73,44 @@ def forward( aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, charge_spin: torch.Tensor | None = None, + mixed_batch: dict[str, torch.Tensor] | None = None, ) -> dict[str, torch.Tensor]: - model_ret = self.forward_common( - coord, - atype, - box, - fparam=fparam, - aparam=aparam, - do_atomic_virial=do_atomic_virial, - charge_spin=charge_spin, - ) + if not torch.jit.is_scripting() and mixed_batch is not None: + model_ret = self.forward_common_flat( + coord=coord, + atype=atype, + box=box, + fparam=fparam, + aparam=aparam, + charge_spin=charge_spin, + do_atomic_virial=do_atomic_virial, + mixed_batch=mixed_batch, + ) + else: + model_ret = self.forward_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + charge_spin=charge_spin, + ) if self.get_fitting_net() is not None: model_predict = {} model_predict["atom_energy"] = model_ret["energy"] model_predict["energy"] = model_ret["energy_redu"] if self.do_grad_r("energy"): model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + else: + if "dforce" in model_ret: + model_predict["force"] = model_ret["dforce"] if self.do_grad_c("energy"): model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze( -2 ) - else: - model_predict["force"] = model_ret["dforce"] if "mask" in model_ret: model_predict["mask"] = model_ret["mask"] if self._hessian_enabled: @@ -104,6 +118,7 @@ def forward( else: model_predict = model_ret model_predict["updated_coord"] += coord + return model_predict @torch.jit.export diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 465ff0af19..dd8b7ce066 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -216,6 +216,253 @@ def forward_common( model_predict = self._output_type_cast(model_predict, input_prec) return model_predict + def forward_common_flat_native( + self, + coord: torch.Tensor, + atype: torch.Tensor, + batch: torch.Tensor, + ptr: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + charge_spin: torch.Tensor | None = None, + extended_atype: torch.Tensor | None = None, + extended_batch: torch.Tensor | None = None, + extended_image: torch.Tensor | None = None, + extended_ptr: torch.Tensor | None = None, + mapping: torch.Tensor | None = None, + central_ext_index: torch.Tensor | None = None, + nlist: torch.Tensor | None = None, + nlist_ext: torch.Tensor | None = None, + a_nlist: torch.Tensor | None = None, + a_nlist_ext: torch.Tensor | None = None, + nlist_mask: torch.Tensor | None = None, + a_nlist_mask: torch.Tensor | None = None, + edge_index: torch.Tensor | None = None, + angle_index: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Forward pass for mixed-nloc batches with a precomputed flat graph.""" + if do_atomic_virial: + raise NotImplementedError( + "Atomic virial is not implemented for flat mixed-batch forward." + ) + coord, box, fparam, aparam, input_prec = self._input_type_cast( + coord, box=box, fparam=fparam, aparam=aparam + ) + if self.do_grad_r("energy"): + coord = coord.clone().detach().requires_grad_(True) + if self.do_grad_c("energy") and box is not None: + box = box.clone().detach().requires_grad_(True) + if ( + extended_atype is not None + and extended_batch is not None + and extended_image is not None + and mapping is not None + and nlist is not None + and nlist_ext is not None + and a_nlist is not None + and a_nlist_ext is not None + and nlist_mask is not None + and a_nlist_mask is not None + and central_ext_index is not None + ): + from deepmd.pt.utils.nlist import ( + rebuild_extended_coord_from_flat_graph, + ) + + extended_coord = rebuild_extended_coord_from_flat_graph( + coord, + box, + mapping, + extended_batch, + extended_image, + ) + else: + raise RuntimeError( + "Flat mixed-batch forward requires precomputed graph fields from " + "the LMDB collate_fn." + ) + assert extended_atype is not None + assert extended_batch is not None + assert mapping is not None + assert nlist is not None + model_predict_lower = self.forward_common_lower_flat( + extended_coord, + extended_atype, + extended_batch, + nlist, + mapping, + batch, + ptr, + do_atomic_virial=do_atomic_virial, + fparam=fparam, + aparam=aparam, + charge_spin=charge_spin, + extended_ptr=extended_ptr, + central_ext_index=central_ext_index, + nlist_ext=nlist_ext, + a_nlist=a_nlist, + a_nlist_ext=a_nlist_ext, + nlist_mask=nlist_mask, + a_nlist_mask=a_nlist_mask, + edge_index=edge_index, + angle_index=angle_index, + ) + if self.do_grad_r("energy") or self.do_grad_c("energy"): + model_predict_lower = self._compute_derivatives_flat( + model_predict_lower, + extended_coord, + extended_atype, + extended_batch, + coord, + atype, + batch, + ptr, + box, + do_atomic_virial, + ) + return self._output_type_cast(model_predict_lower, input_prec) + + def forward_common_lower_flat( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_batch: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor, + batch: torch.Tensor, + ptr: torch.Tensor, + do_atomic_virial: bool = False, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + extended_ptr: torch.Tensor | None = None, + central_ext_index: torch.Tensor | None = None, + nlist_ext: torch.Tensor | None = None, + a_nlist: torch.Tensor | None = None, + a_nlist_ext: torch.Tensor | None = None, + nlist_mask: torch.Tensor | None = None, + a_nlist_mask: torch.Tensor | None = None, + edge_index: torch.Tensor | None = None, + angle_index: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + model_ret = self.atomic_model.forward_common_atomic_flat( + extended_coord, + extended_atype, + extended_batch, + nlist, + mapping, + batch, + ptr, + fparam=fparam, + aparam=aparam, + charge_spin=charge_spin, + extended_ptr=extended_ptr, + central_ext_index=central_ext_index, + nlist_ext=nlist_ext, + a_nlist=a_nlist, + a_nlist_ext=a_nlist_ext, + nlist_mask=nlist_mask, + a_nlist_mask=a_nlist_mask, + edge_index=edge_index, + angle_index=angle_index, + ) + nframes = ptr.numel() - 1 + if "energy" in model_ret: + energy_atomic = model_ret["energy"] + energy_redu = energy_atomic.new_zeros( + (nframes, energy_atomic.shape[-1]) + ) + energy_redu.index_add_(0, batch, energy_atomic) + model_ret["energy_redu"] = energy_redu + return model_ret + + def _compute_derivatives_flat( + self, + fit_ret: dict[str, torch.Tensor], + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_batch: torch.Tensor, + coord: torch.Tensor, + atype: torch.Tensor, + batch: torch.Tensor, + ptr: torch.Tensor, + box: torch.Tensor | None, + do_atomic_virial: bool, + ) -> dict[str, torch.Tensor]: + if self.do_grad_r("energy"): + energy_atomic = fit_ret["energy"] + energy_derv_r = torch.autograd.grad( + outputs=energy_atomic.sum(), + inputs=coord, + create_graph=True, + retain_graph=True, + )[0] + fit_ret["energy_derv_r"] = -energy_derv_r.unsqueeze(-2) + fit_ret["dforce"] = -energy_derv_r + + if self.do_grad_c("energy"): + energy_redu = fit_ret["energy_redu"] + if box is not None: + energy_derv_c_redu = torch.autograd.grad( + outputs=energy_redu.sum(), + inputs=box, + create_graph=True, + retain_graph=True, + )[0] + fit_ret["energy_derv_c_redu"] = energy_derv_c_redu.unsqueeze(1) + if do_atomic_virial: + raise NotImplementedError( + "Atomic virial is not implemented for flat mixed-batch " + "forward." + ) + return fit_ret + + def forward_common_flat( + self, + coord: torch.Tensor, + atype: torch.Tensor, + mixed_batch: dict[str, torch.Tensor], + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + """Forward pass for flat mixed-nloc batch.""" + if "batch" not in mixed_batch or "ptr" not in mixed_batch: + raise RuntimeError( + "Flat mixed-batch forward requires batch and ptr fields." + ) + batch = mixed_batch["batch"] + ptr = mixed_batch["ptr"] + return self.forward_common_flat_native( + coord, + atype, + batch, + ptr, + box, + fparam, + aparam, + do_atomic_virial, + charge_spin=charge_spin, + extended_atype=mixed_batch.get("extended_atype"), + extended_batch=mixed_batch.get("extended_batch"), + extended_image=mixed_batch.get("extended_image"), + extended_ptr=mixed_batch.get("extended_ptr"), + mapping=mixed_batch.get("mapping"), + central_ext_index=mixed_batch.get("central_ext_index"), + nlist=mixed_batch.get("nlist"), + nlist_ext=mixed_batch.get("nlist_ext"), + a_nlist=mixed_batch.get("a_nlist"), + a_nlist_ext=mixed_batch.get("a_nlist_ext"), + nlist_mask=mixed_batch.get("nlist_mask"), + a_nlist_mask=mixed_batch.get("a_nlist_mask"), + edge_index=mixed_batch.get("edge_index"), + angle_index=mixed_batch.get("angle_index"), + ) + @torch.jit.export def forward_embedding( self, diff --git a/deepmd/pt/model/network/graph_utils_flat.py b/deepmd/pt/model/network/graph_utils_flat.py new file mode 100644 index 0000000000..5d348c3104 --- /dev/null +++ b/deepmd/pt/model/network/graph_utils_flat.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import torch + + +def get_graph_index_flat( + nlist_flat: torch.Tensor, + a_nlist_mask: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Get edge and angle graph indices for flat neighbor lists. + + Parameters + ---------- + nlist_flat : torch.Tensor + Neighbor list in flat format [total_atoms, nnei]. + Indices refer to positions in extended_coord_flat. + Padded neighbors are marked as -1. + a_nlist_mask : torch.Tensor + Valid angle-neighbor mask with shape [total_atoms, a_sel]. + + Returns + ------- + edge_index : torch.Tensor [2, n_edge] + ``edge_index[0]`` : n_edge + Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). + These are flat indices in range [0, total_atoms). + ``edge_index[1]`` : n_edge + Broadcast indices from extended node(j) to edge(ij). + These are flat indices in range [0, total_extended_atoms). + angle_index : torch.Tensor [3, n_angle] + ``angle_index[0]`` : n_angle + Broadcast indices from node(i) to angle(ijk). + These are flat indices in range [0, total_atoms). + ``angle_index[1]`` : n_angle + Broadcast indices from edge(ij) to angle(ijk), or reduction indices from angle(ijk) to edge(ij). + These are edge indices in range [0, n_edge). + ``angle_index[2]`` : n_angle + Broadcast indices from edge(ik) to angle(ijk). + These are edge indices in range [0, n_edge). + """ + total_atoms = nlist_flat.shape[0] + nnei = nlist_flat.shape[1] + device = nlist_flat.device + dtype = nlist_flat.dtype + a_sel = a_nlist_mask.shape[1] + + # Create mask for valid neighbors (not -1) + nlist_mask = nlist_flat >= 0 # [total_atoms, nnei] + + # Count edges + n_edge = nlist_mask.sum().item() + + # Angle mask: both neighbors must be valid + a_nlist_mask_3d = a_nlist_mask[:, :, None] & a_nlist_mask[:, None, :] + + # 1. Build edge_index + # n2e_index: for each edge, which local atom does it belong to + atom_indices = torch.arange( + total_atoms, dtype=dtype, device=device + ) # [total_atoms] + n2e_index = atom_indices[:, None].expand(-1, nnei)[nlist_mask] # [n_edge] + + # n_ext2e_index: for each edge, which extended atom is the neighbor + n_ext2e_index = nlist_flat[nlist_mask] # [n_edge] + + edge_index = torch.stack([n2e_index, n_ext2e_index], dim=0) # [2, n_edge] + + # 2. Build angle_index + # n2a_index: for each angle, which local atom does it belong to + n2a_index = atom_indices[:, None, None].expand(-1, a_sel, a_sel)[a_nlist_mask_3d] + + # Create edge_id mapping: (atom_idx, neighbor_idx) -> edge_id + edge_id = torch.arange(n_edge, dtype=dtype, device=device) + edge_lookup = torch.full((total_atoms, nnei), -1, dtype=dtype, device=device) + edge_lookup[nlist_mask] = edge_id + # Only consider first a_sel neighbors for angles + edge_lookup_a = edge_lookup[:, :a_sel] # [total_atoms, a_sel] + + # eij2a_index: for each angle (i,j,k), the edge id of (i,j) + edge_lookup_ij = edge_lookup_a[:, :, None].expand(-1, -1, a_sel) + eij2a_index = edge_lookup_ij[a_nlist_mask_3d] # [n_angle] + + # eik2a_index: for each angle (i,j,k), the edge id of (i,k) + edge_lookup_ik = edge_lookup_a[:, None, :].expand(-1, a_sel, -1) + eik2a_index = edge_lookup_ik[a_nlist_mask_3d] # [n_angle] + + angle_index = torch.stack([n2a_index, eij2a_index, eik2a_index], dim=0) + return edge_index, angle_index diff --git a/deepmd/pt/model/task/invar_fitting.py b/deepmd/pt/model/task/invar_fitting.py index 584915321b..ef4e88a1de 100644 --- a/deepmd/pt/model/task/invar_fitting.py +++ b/deepmd/pt/model/task/invar_fitting.py @@ -203,5 +203,121 @@ def forward( result["atomic_feature"] = out["atomic_feature"] return result + def forward_flat( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + batch: torch.Tensor, + ptr: torch.Tensor, + gr: torch.Tensor | None = None, + g2: torch.Tensor | None = None, + h2: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Forward pass with flat batch format. + + Parameters + ---------- + descriptor : torch.Tensor + Descriptor [total_atoms, descriptor_dim]. + atype : torch.Tensor + Atom types [total_atoms]. + batch : torch.Tensor + Frame assignment [total_atoms]. + ptr : torch.Tensor + Frame boundaries [nframes + 1]. + gr : torch.Tensor | None + Rotation matrix [total_atoms, e_dim, 3]. + g2 : torch.Tensor | None + Edge embedding. + h2 : torch.Tensor | None + Pair representation. + fparam : torch.Tensor | None + Frame parameters [nframes, ndf]. + aparam : torch.Tensor | None + Atomic parameters [total_atoms, nda]. + + Returns + ------- + result : dict[str, torch.Tensor] + Model predictions in flat atom format. Atom-wise outputs are + flattened back to ``[total_atoms, ...]`` after the regular fitting + network runs on a padded dense batch. + """ + device = descriptor.device + batch = batch.to(device=device, dtype=torch.long) + ptr = ptr.to(device=device, dtype=torch.long) + atype = atype.to(device=device) + + nframes = ptr.numel() - 1 + total_atoms = descriptor.shape[0] + atom_counts = ptr[1:] - ptr[:-1] + max_nloc = int(atom_counts.max().item()) + flat_index = torch.arange(total_atoms, dtype=torch.long, device=device) + local_index = flat_index - ptr[batch] + + descriptor_batch = torch.zeros( + (nframes, max_nloc, descriptor.shape[1]), + dtype=descriptor.dtype, + device=device, + ) + atype_batch = torch.full( + (nframes, max_nloc), + -1, + dtype=atype.dtype, + device=device, + ) + gr_batch = None + if gr is not None: + gr_batch = torch.zeros( + (nframes, max_nloc, *gr.shape[1:]), + dtype=gr.dtype, + device=device, + ) + aparam_batch = None + if aparam is not None: + aparam_batch = torch.zeros( + (nframes, max_nloc, *aparam.shape[1:]), + dtype=aparam.dtype, + device=device, + ) + + descriptor_batch[batch, local_index] = descriptor + atype_batch[batch, local_index] = atype + if gr is not None: + assert gr_batch is not None + gr_batch[batch, local_index] = gr + if aparam is not None: + assert aparam_batch is not None + aparam_batch[batch, local_index] = aparam + + result_batch = self.forward( + descriptor_batch, + atype_batch, + gr=gr_batch, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam_batch, + ) + + valid_atom_mask = torch.arange( + max_nloc, dtype=torch.long, device=device + ).unsqueeze(0) < atom_counts.unsqueeze(1) + result_flat: dict[str, torch.Tensor] = {} + for key, value in result_batch.items(): + if ( + isinstance(value, torch.Tensor) + and value.dim() >= 2 + and value.shape[0] == nframes + and value.shape[1] == max_nloc + ): + result_flat[key] = value[valid_atom_mask] + else: + result_flat[key] = value + + return result_flat + # make jit happy with torch 2.0.0 exclude_types: list[int] diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 0c860eb708..6e401630e6 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -105,6 +105,7 @@ LmdbDataset, _collate_lmdb_batch, _SameNlocBatchSamplerTorch, + make_lmdb_mixed_batch_collate, ) from deepmd.pt.utils.stat import ( make_stat_input, @@ -129,6 +130,7 @@ get_optimizer_state_dict, set_optimizer_state_dict, ) +from torch.nn.parallel import DistributedDataParallel as DDP try: from torch.distributed.fsdp import ( @@ -139,7 +141,6 @@ from torch.distributed.optim import ( ZeroRedundancyOptimizer, ) -from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import ( DataLoader, ) @@ -150,6 +151,25 @@ log = logging.getLogger(__name__) +_FLAT_GRAPH_INPUT_KEYS = ( + "batch", + "ptr", + "extended_atype", + "extended_batch", + "extended_image", + "extended_ptr", + "mapping", + "central_ext_index", + "nlist", + "nlist_ext", + "a_nlist", + "a_nlist_ext", + "nlist_mask", + "a_nlist_mask", + "edge_index", + "angle_index", +) + class Trainer: def __init__( @@ -273,6 +293,7 @@ def get_data_loader( _training_data: DpLoaderSet | LmdbDataset, _validation_data: DpLoaderSet | LmdbDataset | None, _training_params: dict[str, Any], + _task_key: str = "Default", ) -> tuple[ DataLoader, Generator[Any, None, None], @@ -283,19 +304,62 @@ def get_data_loader( def get_dataloader_and_iter_lmdb( _data: LmdbDataset, ) -> tuple[DataLoader, Generator[Any, None, None]]: + _shuffle = _training_params.get("shuffle", True) + _seed = _training_params.get("seed", training_params.get("seed", 42)) + if _seed is None: + _seed = 42 + if _data.mixed_batch: - # TODO [mixed_batch=True]: Replace SameNlocBatchSampler with - # RandomSampler(replacement=False) + padding collate_fn. - # Changes needed: - # 1. _collate_lmdb_batch: pad coord/force/atype to max_nloc, - # add "atom_mask" bool tensor (nframes, max_nloc) - # 2. Use RandomSampler(_data, replacement=False) as sampler - # 3. Use fixed batch_size in DataLoader (not batch_sampler) - # 4. Model forward: apply atom_mask to descriptor/fitting - # 5. Loss: mask out padded atoms in force loss - raise NotImplementedError( - "mixed_batch=True training is not yet supported." + from deepmd.dpmodel.utils.lmdb_data import ( + MixedBatchSampler, ) + + model_for_graph = ( + self.model[_task_key] if self.multi_task else self.model + ) + descriptor = model_for_graph.atomic_model.descriptor + if not hasattr(descriptor, "repflows"): + raise ValueError( + "mixed_batch=True currently requires a flat-graph " + "capable descriptor, for example DPA3/RepFlow." + ) + graph_config = { + "rcut": descriptor.get_rcut(), + "sel": descriptor.get_sel(), + "a_rcut": descriptor.repflows.a_rcut, + "a_sel": descriptor.repflows.a_sel, + "mixed_types": descriptor.mixed_types(), + } + + if self.world_size > 1: + from deepmd.dpmodel.utils.lmdb_data import ( + DistributedMixedBatchSampler, + ) + + _inner_sampler = DistributedMixedBatchSampler( + _data._reader, + rank=self.rank, + world_size=self.world_size, + shuffle=_shuffle, + seed=_seed, + ) + else: + _inner_sampler = MixedBatchSampler( + _data._reader, + shuffle=_shuffle, + seed=_seed, + ) + _batch_sampler = _SameNlocBatchSamplerTorch(_inner_sampler) + _dataloader = DataLoader( + _data, + batch_sampler=_batch_sampler, + num_workers=0, + collate_fn=make_lmdb_mixed_batch_collate(graph_config), + pin_memory=(DEVICE != "cpu"), + ) + _data_iter = cycle_iterator(_dataloader) + return _dataloader, _data_iter + # mixed_batch=False: group frames by nloc, each batch same nloc. # SameNlocBatchSampler yields list[int] per batch, all same nloc. # Auto batch_size is computed per-nloc-group inside the sampler. @@ -314,14 +378,15 @@ def get_dataloader_and_iter_lmdb( _data._reader, rank=self.rank, world_size=self.world_size, - shuffle=True, - seed=_training_params.get("seed", None), + shuffle=_shuffle, + seed=_seed, block_targets=_block_targets, ) else: _inner_sampler = SameNlocBatchSampler( _data._reader, - shuffle=True, + shuffle=_shuffle, + seed=_seed, block_targets=_block_targets, ) @@ -618,6 +683,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: training_data[model_key], validation_data[model_key], training_params["data_dict"][model_key], + _task_key=model_key, ) training_data[model_key].print_summary( @@ -2228,6 +2294,7 @@ def save_ema_model_merged( def get_data( self, is_train: bool = True, task_key: str = "Default" ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + if is_train: iterator = self.training_data else: @@ -2237,8 +2304,17 @@ def get_data( if iterator is None: return {}, {}, {} batch_data = next(iterator) + + # Detect mixed batch format (has 'batch' and 'ptr' fields) + is_mixed_batch = "batch" in batch_data and "ptr" in batch_data + # === Filter frames with atoms too close (training only) === - if is_train and self.min_pair_dist > 0.0 and "min_pair_dist" in batch_data: + if ( + not is_mixed_batch + and is_train + and self.min_pair_dist > 0.0 + and "min_pair_dist" in batch_data + ): min_dists = batch_data["min_pair_dist"] if isinstance(min_dists, torch.Tensor): valid_mask = min_dists.squeeze(-1) >= self.min_pair_dist @@ -2259,15 +2335,20 @@ def get_data( if isinstance(val, torch.Tensor) and val.shape[0] == n_total: batch_data[key] = val[valid_mask] for key in batch_data.keys(): - if key == "sid" or key == "fid" or key == "box" or "find_" in key: + if key == "sid" or key == "fid" or "find_" in key: + continue + # Skip batch and ptr for now, will handle them separately + elif key == "batch" or key == "ptr": continue elif not isinstance(batch_data[key], list): if batch_data[key] is not None: batch_data[key] = batch_data[key].to(DEVICE, non_blocking=True) else: batch_data[key] = [ - item.to(DEVICE, non_blocking=True) for item in batch_data[key] + item.to(DEVICE, non_blocking=True) if item is not None else None + for item in batch_data[key] ] + # we may need a better way to classify which are inputs and which are labels # now wrapper only supports the following inputs: input_keys = [ @@ -2279,6 +2360,13 @@ def get_data( "aparam", "charge_spin", ] + + # Mixed-nloc LMDB batches include precomputed flat-graph tensors. + if is_mixed_batch: + input_keys = input_keys + list(_FLAT_GRAPH_INPUT_KEYS) + batch_data["batch"] = batch_data["batch"].to(DEVICE, non_blocking=True) + batch_data["ptr"] = batch_data["ptr"].to(DEVICE, non_blocking=True) + input_dict = dict.fromkeys(input_keys) label_dict = {} for item_key in batch_data: diff --git a/deepmd/pt/train/wrapper.py b/deepmd/pt/train/wrapper.py index da710f4fdf..03bde44e74 100644 --- a/deepmd/pt/train/wrapper.py +++ b/deepmd/pt/train/wrapper.py @@ -172,6 +172,22 @@ def forward( fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, charge_spin: torch.Tensor | None = None, + batch: torch.Tensor | None = None, + ptr: torch.Tensor | None = None, + extended_atype: torch.Tensor | None = None, + extended_batch: torch.Tensor | None = None, + extended_image: torch.Tensor | None = None, + extended_ptr: torch.Tensor | None = None, + mapping: torch.Tensor | None = None, + central_ext_index: torch.Tensor | None = None, + nlist: torch.Tensor | None = None, + nlist_ext: torch.Tensor | None = None, + a_nlist: torch.Tensor | None = None, + a_nlist_ext: torch.Tensor | None = None, + nlist_mask: torch.Tensor | None = None, + a_nlist_mask: torch.Tensor | None = None, + edge_index: torch.Tensor | None = None, + angle_index: torch.Tensor | None = None, ) -> tuple[Any, Any, Any]: if not self.multi_task: task_key = "Default" @@ -188,6 +204,25 @@ def forward( "aparam": aparam, "charge_spin": charge_spin, } + if batch is not None and ptr is not None: + input_dict["mixed_batch"] = { + "batch": batch, + "ptr": ptr, + "extended_atype": extended_atype, + "extended_batch": extended_batch, + "extended_image": extended_image, + "extended_ptr": extended_ptr, + "mapping": mapping, + "central_ext_index": central_ext_index, + "nlist": nlist, + "nlist_ext": nlist_ext, + "a_nlist": a_nlist, + "a_nlist_ext": a_nlist_ext, + "nlist_mask": nlist_mask, + "a_nlist_mask": a_nlist_mask, + "edge_index": edge_index, + "angle_index": angle_index, + } has_spin = getattr(self.model[task_key], "has_spin", False) if callable(has_spin): has_spin = has_spin() @@ -207,7 +242,7 @@ def forward( model_pred = self._forward_without_loss(task_key, input_dict) return model_pred, None, None - natoms = atype.shape[-1] + natoms = atype.shape[-1] if atype.dim() > 1 else atype.shape[0] model_pred, loss, more_loss = self.loss[task_key]( input_dict, self.model[task_key], diff --git a/deepmd/pt/utils/lmdb_dataset.py b/deepmd/pt/utils/lmdb_dataset.py index b7f0e17735..21ed66bc7f 100644 --- a/deepmd/pt/utils/lmdb_dataset.py +++ b/deepmd/pt/utils/lmdb_dataset.py @@ -3,6 +3,7 @@ import logging from collections.abc import ( + Callable, Iterator, ) from typing import ( @@ -15,10 +16,14 @@ Dataset, Sampler, ) +from torch.utils.data._utils.collate import ( + collate_tensor_fn, +) from deepmd.dpmodel.utils.lmdb_data import ( LmdbDataReader, LmdbTestData, + MixedBatchSampler, SameNlocBatchSampler, collate_lmdb_frames, compute_block_targets, @@ -30,17 +35,134 @@ log = logging.getLogger(__name__) +FrameDict = dict[str, Any] +BatchDict = dict[str, Any] +GraphConfig = dict[str, Any] +MixedBatchCollate = Callable[[list[FrameDict]], BatchDict] + # Re-export for backward compatibility __all__ = [ "LmdbDataset", "LmdbTestData", "_collate_lmdb_batch", + "_collate_lmdb_mixed_batch", "is_lmdb", + "make_lmdb_mixed_batch_collate", ] +_ATOMWISE_MIXED_BATCH_KEYS = frozenset( + { + "aparam", + "atom_dos", + "atom_ener", + "atom_ener_coeff", + "atom_pref", + "atomic_weight", + "atype", + "coord", + "drdq", + "force", + "force_mag", + "hessian", + "spin", + } +) + + +def _collate_lmdb_mixed_batch(batch: list[FrameDict]) -> BatchDict: + """Collate mixed-nloc frames into flattened atom-wise tensors. -def _collate_lmdb_batch(batch: list[dict[str, Any]]) -> dict[str, Any]: - """Collate a list of frame dicts into a torch batch dict. + Atom-wise fields are concatenated across frames and accompanied by: + + - ``batch``: flattened atom-to-frame assignment with shape ``[total_atoms]``. + - ``ptr``: prefix-sum atom offsets with shape ``[nframes + 1]``. + + Frame-wise fields such as ``energy`` and ``box`` keep the usual batch + dimension via ``torch.stack``. The returned ``sid`` keeps the historical + LMDB collate shape, namely a CPU tensor with shape ``[1]``. + """ + with torch.device("cpu"): + atype_list = [torch.as_tensor(item["atype"]) for item in batch] + counts = torch.tensor( + [int(item.shape[0]) for item in atype_list], + dtype=torch.long, + device=torch.device("cpu"), + ) + ptr = torch.cat( + [ + torch.zeros(1, dtype=torch.long, device=counts.device), + torch.cumsum(counts, dim=0), + ], + dim=0, + ) + atom_batch = torch.repeat_interleave( + torch.arange(len(batch), dtype=torch.long, device=counts.device), + counts, + ) + + example = batch[0] + result: BatchDict = {} + for key in example: + if "find_" in key: + result[key] = batch[0][key] + elif key == "fid": + result[key] = [d[key] for d in batch] + elif key == "type": + continue + elif batch[0][key] is None: + result[key] = None + else: + with torch.device("cpu"): + tensors = [torch.as_tensor(d[key]) for d in batch] + if key in _ATOMWISE_MIXED_BATCH_KEYS: + result[key] = torch.cat(tensors, dim=0) + else: + result[key] = collate_tensor_fn(tensors) + result["batch"] = atom_batch + result["ptr"] = ptr + result["sid"] = torch.tensor([0], dtype=torch.long, device="cpu") + return result + + +def make_lmdb_mixed_batch_collate( + graph_config: GraphConfig | None = None, +) -> MixedBatchCollate: + """Build a collate function for flattened mixed-nloc LMDB batches. + + When ``graph_config`` is provided, the collate function also precomputes the + extended image, neighbor lists, masks, edge index, and angle index consumed + by the flat DPA3 forward path. ``graph_config`` is expected to contain + ``rcut``, ``sel``, ``a_rcut``, ``a_sel``, and ``mixed_types``. + """ + + def collate(batch: list[FrameDict]) -> BatchDict: + result = _collate_lmdb_mixed_batch(batch) + if graph_config is None: + return result + from deepmd.pt.utils.nlist import ( + build_precomputed_flat_graph, + ) + + graph_data = build_precomputed_flat_graph( + result["coord"], + result["atype"], + result["batch"], + result["ptr"], + graph_config["rcut"], + graph_config["sel"], + graph_config["a_rcut"], + graph_config["a_sel"], + mixed_types=graph_config["mixed_types"], + box=result.get("box"), + ) + result.update(graph_data) + return result + + return collate + + +def _collate_lmdb_batch(batch: list[FrameDict]) -> BatchDict: + """Collate a list of frame dicts into a batch dict. Pre-converts per-frame numpy arrays to CPU torch tensors (zero-copy when dtype matches) and delegates stacking to the backend-agnostic @@ -48,18 +170,14 @@ def _collate_lmdb_batch(batch: list[dict[str, Any]]) -> dict[str, Any]: collate yields a torch dict (``sid`` becomes a torch tensor automatically via ``array_api_compat``). - All frames in the batch must have the same nloc (enforced by - SameNlocBatchSampler when mixed_batch=False). For mixed_batch=True, - raises NotImplementedError. + For mixed_batch=True, this function would need padding + mask. + Mixed-nloc batches are flattened atom-wise and augmented with ``batch`` and + ``ptr`` to preserve frame ownership. """ if len(batch) > 1: atypes = [d.get("atype") for d in batch if d.get("atype") is not None] if atypes and any(len(a) != len(atypes[0]) for a in atypes): - raise NotImplementedError( - "mixed_batch collation (frames with different atom counts " - "in the same batch) is not yet supported. " - "Padding + mask in collate_fn needed." - ) + return _collate_lmdb_mixed_batch(batch) with torch.device("cpu"): torch_frames: list[dict[str, Any]] = [] @@ -77,14 +195,14 @@ def _collate_lmdb_batch(batch: list[dict[str, Any]]) -> dict[str, Any]: class _SameNlocBatchSamplerTorch(Sampler): - """Torch Sampler adapter around the framework-agnostic SameNlocBatchSampler. + """Torch Sampler adapter around framework-agnostic LMDB batch samplers. PyTorch DataLoader with batch_sampler expects a Sampler that yields - lists of indices. This wraps SameNlocBatchSampler (or - DistributedSameNlocBatchSampler) to satisfy that. + lists of indices. This wraps SameNlocBatchSampler, MixedBatchSampler, or + their distributed variants to satisfy that. """ - def __init__(self, inner: SameNlocBatchSampler) -> None: + def __init__(self, inner: Any) -> None: self._inner = inner def __iter__(self) -> Iterator[list[int]]: @@ -113,12 +231,16 @@ class LmdbDataset(Dataset): - ``int``: fixed batch size for every nloc group. - ``"auto"`` / ``"auto:N"``: ``ceil(N / nloc)`` per nloc group - (``N=32`` for bare ``"auto"``). - - ``"max:N"``: ``max(1, floor(N / nloc))`` per nloc group. + (``N=32`` for bare ``"auto"``). With ``mixed_batch=True``, mixed + batches accumulate frames until their total atom count reaches + ``N``. + - ``"max:N"``: ``max(1, floor(N / nloc))`` per nloc group. With + ``mixed_batch=True``, mixed batches accumulate frames while their + total atom count stays at or below ``N``. - ``"filter:N"``: same per-nloc formula as ``"max:N"`` and drops every frame whose ``nloc > N`` from the dataset. mixed_batch : bool - If True, allow different nloc in the same batch (future). + If True, allow different nloc in the same batch. If False (default), use SameNlocBatchSampler. """ @@ -133,13 +255,7 @@ def __init__( self._reader = LmdbDataReader( lmdb_path, type_map, batch_size, mixed_batch=mixed_batch ) - - if mixed_batch: - # Future: DataLoader with padding collate_fn - raise NotImplementedError( - "mixed_batch=True is not yet supported. " - "Requires padding + mask in collate_fn." - ) + self._batch_sampler: _SameNlocBatchSamplerTorch | None = None # Compute block_targets from auto_prob_style if provided self._block_targets = None @@ -149,27 +265,43 @@ def __init__( self._reader.nsystems, self._reader.system_nframes, ) - if self._block_targets is not None: + if self._block_targets: + if mixed_batch: + raise NotImplementedError( + "auto_prob_style/block weighting is not supported with " + "mixed_batch=True yet." + ) log.info( f"LMDB auto_prob: {len(self._block_targets)} blocks, " f"nsystems={self._reader.nsystems}" ) - # Same-nloc batching: use SameNlocBatchSampler - sampler = SameNlocBatchSampler( - self._reader, - shuffle=True, - block_targets=self._block_targets, - ) - self._batch_sampler = _SameNlocBatchSamplerTorch(sampler) - - with torch.device("cpu"): - self._inner_dataloader = DataLoader( - self, - batch_sampler=self._batch_sampler, - num_workers=0, - collate_fn=_collate_lmdb_batch, + if mixed_batch: + sampler = MixedBatchSampler(self._reader, shuffle=True) + self._batch_sampler = _SameNlocBatchSamplerTorch(sampler) + with torch.device("cpu"): + self._inner_dataloader = DataLoader( + self, + batch_sampler=self._batch_sampler, + num_workers=0, + collate_fn=_collate_lmdb_mixed_batch, + ) + else: + # Same-nloc batching: use SameNlocBatchSampler + sampler = SameNlocBatchSampler( + self._reader, + shuffle=True, + block_targets=self._block_targets, ) + self._batch_sampler = _SameNlocBatchSamplerTorch(sampler) + + with torch.device("cpu"): + self._inner_dataloader = DataLoader( + self, + batch_sampler=self._batch_sampler, + num_workers=0, + collate_fn=_collate_lmdb_batch, + ) # Per-nloc-group dataloaders for make_stat_input. # Each group gets its own DataLoader so torch.cat in stat collection @@ -329,7 +461,9 @@ def batch_sizes(self) -> list[int]: @property def systems(self) -> list: """One 'system' per nloc group for stat collection compatibility.""" - return [self] * len(self._nloc_dataloaders) + if self._nloc_dataloaders: + return [self] * len(self._nloc_dataloaders) + return [self] @property def dataloaders(self) -> list: @@ -338,8 +472,12 @@ def dataloaders(self) -> list: Each dataloader yields batches with uniform nloc, so torch.cat in stat collection only concatenates same-shape tensors. """ - return self._nloc_dataloaders + if self._nloc_dataloaders: + return self._nloc_dataloaders + return [self._inner_dataloader] @property def sampler_list(self) -> list: + if self._batch_sampler is None: + return [] return [self._batch_sampler] diff --git a/deepmd/pt/utils/nlist.py b/deepmd/pt/utils/nlist.py index ccea0be79c..38d9733796 100644 --- a/deepmd/pt/utils/nlist.py +++ b/deepmd/pt/utils/nlist.py @@ -10,6 +10,8 @@ to_face_distance, ) +FlatGraphData = dict[str, torch.Tensor] + def extend_input_and_build_neighbor_list( coord: torch.Tensor, @@ -46,6 +48,307 @@ def extend_input_and_build_neighbor_list( return extended_coord, extended_atype, mapping, nlist +def build_precomputed_flat_graph( + coord: torch.Tensor, + atype: torch.Tensor, + batch: torch.Tensor, + ptr: torch.Tensor, + rcut: float, + sel: list[int], + a_rcut: float, + a_sel: int, + mixed_types: bool = False, + box: torch.Tensor | None = None, +) -> FlatGraphData: + """Build graph tensors for flattened mixed-nloc LMDB batches. + + Parameters + ---------- + coord + Flattened local coordinates with shape ``[total_atoms, 3]``. + atype + Flattened local atom types with shape ``[total_atoms]``. + batch + Local atom-to-frame assignment with shape ``[total_atoms]``. + ptr + Prefix-sum local atom offsets with shape ``[nframes + 1]``. + rcut, sel + Edge cutoff and neighbor selection used by the descriptor. + a_rcut, a_sel + Angle cutoff and maximum angle-neighbor count. + mixed_types + Whether neighbor selection ignores atom types. + box + Optional flattened cell tensor with shape ``[nframes, 9]``. + + Returns + ------- + FlatGraphData + Dictionary consumed by the flat forward path. ``*_ext`` neighbor lists + index into the concatenated extended atoms, while ``nlist`` and + ``a_nlist`` map neighbors back to flattened local atom indices. + """ + device = coord.device + nframes = ptr.numel() - 1 + extended_coords_list = [] + extended_atypes_list = [] + extended_batches_list = [] + extended_images_list = [] + extended_to_atom_list = [] + nlists_ext_list = [] + central_indices_list = [] + extended_ptr = torch.zeros(nframes + 1, dtype=torch.long, device=device) + extended_offset = 0 + + for frame_idx in range(nframes): + start_idx = int(ptr[frame_idx].item()) + end_idx = int(ptr[frame_idx + 1].item()) + nloc = end_idx - start_idx + frame_coord = coord[start_idx:end_idx].reshape(1, nloc, 3) + frame_atype = atype[start_idx:end_idx].reshape(1, nloc) + frame_box = box[frame_idx : frame_idx + 1] if box is not None else None + + if frame_box is not None: + box_device = frame_box.to(device, non_blocking=True) + coord_normalized = normalize_coord( + frame_coord, + box_device.reshape(1, 3, 3), + ) + else: + box_device = None + coord_normalized = frame_coord.clone() + + ( + frame_extended_coord, + frame_extended_atype, + frame_mapping, + frame_extended_image, + ) = extend_coord_with_ghosts_with_images( + coord_normalized.reshape(1, -1), + frame_atype, + box_device, + rcut, + frame_box, + ) + frame_nlist_ext = build_neighbor_list( + frame_extended_coord, + frame_extended_atype, + nloc, + rcut, + sel, + distinguish_types=(not mixed_types), + ) + + frame_extended_coord = frame_extended_coord.view(-1, 3) + frame_extended_atype = frame_extended_atype.view(-1) + frame_mapping = frame_mapping.view(-1) + frame_extended_image = frame_extended_image.view(-1, 3) + frame_nlist_ext = frame_nlist_ext.view(nloc, -1) + nall_frame = frame_extended_coord.shape[0] + + central_indices_list.append( + torch.arange( + extended_offset, + extended_offset + nloc, + dtype=torch.long, + device=device, + ) + ) + nlists_ext_list.append( + torch.where( + frame_nlist_ext >= 0, + frame_nlist_ext + extended_offset, + frame_nlist_ext, + ) + ) + extended_coords_list.append(frame_extended_coord) + extended_atypes_list.append(frame_extended_atype) + extended_batches_list.append( + torch.full((nall_frame,), frame_idx, dtype=torch.long, device=device) + ) + extended_images_list.append(frame_extended_image) + extended_to_atom_list.append(frame_mapping + start_idx) + extended_offset += nall_frame + extended_ptr[frame_idx + 1] = extended_offset + + extended_coord = torch.cat(extended_coords_list, dim=0) + extended_atype = torch.cat(extended_atypes_list, dim=0) + extended_batch = torch.cat(extended_batches_list, dim=0) + extended_image = torch.cat(extended_images_list, dim=0) + mapping = torch.cat(extended_to_atom_list, dim=0) + central_ext_index = torch.cat(central_indices_list, dim=0) + nlist_ext = torch.cat(nlists_ext_list, dim=0) + nlist_mask = nlist_ext >= 0 + + nall = extended_coord.shape[0] + nlist_ext_clamped = torch.clamp(nlist_ext, min=0, max=nall - 1) + nlist = torch.where( + nlist_mask, + mapping[nlist_ext_clamped], + torch.tensor(-1, dtype=nlist_ext.dtype, device=device), + ) + + coord_central = extended_coord[central_ext_index] + coord_pad = torch.cat([extended_coord, extended_coord[-1:, :] + rcut], dim=0) + nlist_safe = torch.where( + nlist_mask, + nlist_ext, + torch.tensor(nall, dtype=nlist_ext.dtype, device=device), + ) + index = nlist_safe.view(-1).unsqueeze(-1).expand(-1, 3) + coord_nei = torch.gather(coord_pad, 0, index).view(nlist_ext.shape[0], -1, 3) + dist = torch.linalg.norm(coord_nei - coord_central[:, None, :], dim=-1) + a_dist_mask = (dist[:, :a_sel] < a_rcut) & nlist_mask[:, :a_sel] + a_nlist_ext = torch.where( + a_dist_mask, + nlist_ext[:, :a_sel], + torch.tensor(-1, dtype=nlist_ext.dtype, device=device), + ) + a_nlist_mask = a_nlist_ext >= 0 + a_nlist_ext_clamped = torch.clamp(a_nlist_ext, min=0, max=nall - 1) + a_nlist = torch.where( + a_nlist_mask, + mapping[a_nlist_ext_clamped], + torch.tensor(-1, dtype=nlist_ext.dtype, device=device), + ) + + from deepmd.pt.model.network.graph_utils_flat import ( + get_graph_index_flat, + ) + + edge_index, angle_index = get_graph_index_flat( + nlist, + a_nlist_mask, + ) + return { + "extended_atype": extended_atype, + "extended_batch": extended_batch, + "extended_image": extended_image, + "extended_ptr": extended_ptr, + "mapping": mapping, + "central_ext_index": central_ext_index, + "nlist": nlist, + "nlist_ext": nlist_ext, + "a_nlist": a_nlist, + "a_nlist_ext": a_nlist_ext, + "nlist_mask": nlist_mask, + "a_nlist_mask": a_nlist_mask, + "edge_index": edge_index, + "angle_index": angle_index, + } + + +def rebuild_extended_coord_from_flat_graph( + coord: torch.Tensor, + box: torch.Tensor | None, + mapping: torch.Tensor, + extended_batch: torch.Tensor, + extended_image: torch.Tensor, +) -> torch.Tensor: + """Reconstruct extended coordinates from precomputed flat graph metadata. + + ``mapping`` maps each extended atom to its source local atom. When ``box`` + is available, ``extended_image`` is applied after wrapping the source local + coordinate back into the corresponding periodic cell. + """ + if box is None: + return coord[mapping] + cell = box.reshape(-1, 3, 3) + atom_cell = cell[extended_batch] + rec_cell, _ = torch.linalg.inv_ex(atom_cell) + coord_inter = torch.einsum("ni,nij->nj", coord[mapping], rec_cell) + coord_wrapped = torch.einsum( + "ni,nij->nj", + torch.remainder(coord_inter, 1.0), + atom_cell, + ) + image = extended_image.to(dtype=box.dtype, device=box.device) + shift_vec = torch.einsum("ni,nij->nj", image, atom_cell) + return coord_wrapped + shift_vec + + +def get_central_ext_index( + extended_batch: torch.Tensor, + ptr: torch.Tensor, +) -> torch.Tensor: + """Return extended-atom indices corresponding to local atoms.""" + nframes = ptr.numel() - 1 + extended_counts = torch.bincount(extended_batch, minlength=nframes) + extended_ptr = torch.cat( + [ + torch.zeros(1, dtype=extended_counts.dtype, device=extended_counts.device), + torch.cumsum(extended_counts, dim=0), + ] + ) + extended_index = torch.arange( + extended_batch.shape[0], + dtype=extended_batch.dtype, + device=extended_batch.device, + ) + frame_local_index = extended_index - extended_ptr[extended_batch] + nloc_per_frame = (ptr[1:] - ptr[:-1]).to(extended_batch.device) + central_mask = frame_local_index < nloc_per_frame[extended_batch] + return torch.nonzero(central_mask, as_tuple=False).view(-1) + + +def extend_input_and_build_neighbor_list_with_images( + coord: torch.Tensor, + atype: torch.Tensor, + rcut: float, + sel: list[int], + mixed_types: bool = False, + box: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Like ``extend_input_and_build_neighbor_list`` but also returns lattice images. + + This helper is intended for sidecar graph precomputation workflows that need a + stable, replayable description of how extended atoms are generated without + changing the existing training path. + + Returns + ------- + extended_coord + Extended coordinates with shape ``[nf, nall, 3]``. + extended_atype + Extended atom types with shape ``[nf, nall]``. + mapping + Extended atom -> local atom index mapping with shape ``[nf, nall]``. + extended_image + Integer lattice image for each extended atom with shape ``[nf, nall, 3]``. + nlist + Neighbor list with shape ``[nf, nloc, nnei]``. + """ + nframes, nloc = atype.shape[:2] + if box is not None: + box_gpu = box.to(coord.device, non_blocking=True) + coord_normalized = normalize_coord( + coord.view(nframes, nloc, 3), + box_gpu.reshape(nframes, 3, 3), + ) + else: + box_gpu = None + coord_normalized = coord.clone() + extended_coord, extended_atype, mapping, extended_image = ( + extend_coord_with_ghosts_with_images( + coord_normalized, + atype, + box_gpu, + rcut, + box, + ) + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + rcut, + sel, + distinguish_types=(not mixed_types), + ) + extended_coord = extended_coord.view(nframes, -1, 3) + return extended_coord, extended_atype, mapping, extended_image, nlist + + def build_neighbor_list( coord: torch.Tensor, atype: torch.Tensor, @@ -455,9 +758,54 @@ def extend_coord_with_ghosts( mapping extended index to the local index """ + extend_coord, extend_atype, extend_aidx, _ = _extend_coord_with_ghosts_impl( + coord, + atype, + cell, + rcut, + cell_cpu=cell_cpu, + return_image=False, + ) + return extend_coord, extend_atype, extend_aidx + + +def extend_coord_with_ghosts_with_images( + coord: torch.Tensor, + atype: torch.Tensor, + cell: torch.Tensor | None, + rcut: float, + cell_cpu: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Extend coordinates and additionally return the integer lattice image. + + The returned image tensor records which periodic image each extended atom + belongs to. This is useful for sidecar graph serialization where extended + coordinates should be recoverable from the original local coordinates and + the simulation cell. + """ + extend_coord, extend_atype, extend_aidx, extend_image = ( + _extend_coord_with_ghosts_impl( + coord, + atype, + cell, + rcut, + cell_cpu=cell_cpu, + return_image=True, + ) + ) + return extend_coord, extend_atype, extend_aidx, extend_image + + +def _extend_coord_with_ghosts_impl( + coord: torch.Tensor, + atype: torch.Tensor, + cell: torch.Tensor | None, + rcut: float, + cell_cpu: torch.Tensor | None = None, + return_image: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: device = coord.device nf, nloc = atype.shape - # int64 for index aidx = torch.tile( torch.arange(nloc, device=device, dtype=torch.int64).unsqueeze(0), [nf, 1] ) @@ -466,18 +814,17 @@ def extend_coord_with_ghosts( extend_coord = coord.clone() extend_atype = atype.clone() extend_aidx = aidx.clone() + if return_image: + extend_image = torch.zeros((nf, nloc, 3), device=device, dtype=torch.int64) + else: + extend_image = torch.empty((0,), device=device, dtype=torch.int64) else: coord = coord.view([nf, nloc, 3]) cell = cell.view([nf, 3, 3]) cell_cpu = cell_cpu.view([nf, 3, 3]) if cell_cpu is not None else cell - # nf x 3 to_face = to_face_distance(cell_cpu) - # nf x 3 - # *2: ghost copies on + and - directions - # +1: central cell nbuff = torch.ceil(rcut / to_face).to(torch.int64) - # 3 - nbuff = torch.amax(nbuff, dim=0) # faster than torch.max + nbuff = torch.amax(nbuff, dim=0) nbuff_cpu = nbuff.cpu() xi = torch.arange( -nbuff_cpu[0], nbuff_cpu[0] + 1, 1, device="cpu", dtype=torch.int64 @@ -494,20 +841,24 @@ def extend_coord_with_ghosts( xyz = xyz + zi.view(1, 1, -1, 1) * eye_3[2] xyz = xyz.view(-1, 3) xyz = xyz.to(device=device, non_blocking=True) - # ns x 3 shift_idx = xyz[torch.argsort(torch.linalg.norm(xyz, dim=-1))] + # Convert shift_idx to the same dtype as cell to avoid type mismatch + shift_idx = shift_idx.to(dtype=cell.dtype) ns, _ = shift_idx.shape nall = ns * nloc - # nf x ns x 3 shift_vec = torch.einsum("sd,fdk->fsk", shift_idx, cell) - # nf x ns x nloc x 3 extend_coord = coord[:, None, :, :] + shift_vec[:, :, None, :] - # nf x ns x nloc extend_atype = torch.tile(atype.unsqueeze(-2), [1, ns, 1]) - # nf x ns x nloc extend_aidx = torch.tile(aidx.unsqueeze(-2), [1, ns, 1]) - return ( - extend_coord.reshape([nf, nall * 3]).to(device), - extend_atype.view([nf, nall]).to(device), - extend_aidx.view([nf, nall]).to(device), - ) + if return_image: + extend_image = torch.tile(shift_idx.view(1, ns, 1, 3), [nf, 1, nloc, 1]) + else: + extend_image = torch.empty((0,), device=device, dtype=torch.int64) + extend_coord_out = extend_coord.reshape([nf, nall * 3]).to(device) + extend_atype_out = extend_atype.view([nf, nall]).to(device) + extend_aidx_out = extend_aidx.view([nf, nall]).to(device) + if return_image: + extend_image_out = extend_image.view([nf, nall, 3]).to(device) + else: + extend_image_out = extend_image + return extend_coord_out, extend_atype_out, extend_aidx_out, extend_image_out diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index ba1a2cb347..b859d27ea2 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -4770,6 +4770,16 @@ def training_data_args() -> list[ "systems. To avoid the overhead, consider pre-cleaning the dataset instead." ) + doc_mixed_batch = ( + "Whether to enable LMDB mixed-batch training with different numbers of atoms " + "per frame. When set to True, the PyTorch LMDB dataloader flattens atom-wise " + "fields and precomputes graph indices in the collate function. In this mode, " + "integer `batch_size` is the number of frames/systems per batch, while " + "`auto:N`, `max:N`, and `filter:N` use the total atom count of the mixed " + "batch as the budget. " + "The alias `mix_batch` is accepted. Default is False." + ) + args = [ Argument( "systems", [list[str], str], optional=False, default=".", doc=doc_systems @@ -4806,6 +4816,14 @@ def training_data_args() -> list[ doc=doc_sys_probs, alias=["sys_weights"], ), + Argument( + "mixed_batch", + bool, + optional=True, + default=False, + alias=["mix_batch"], + doc=doc_mixed_batch + doc_only_pt_supported, + ), Argument( "min_pair_dist", float, @@ -4853,6 +4871,15 @@ def validation_data_args() -> list[ - "prob_sys_size;stt_idx:end_idx:weight;stt_idx:end_idx:weight;..." : the list of systems is divided into blocks. A block is specified by `stt_idx:end_idx:weight`, where `stt_idx` is the starting index of the system, `end_idx` is then ending (not including) index of the system, the probabilities of the systems in this block sums up to `weight`, and the relatively probabilities within this block is proportional to the number of batches in the system.' doc_sys_probs = "A list of float if specified. Should be of the same length as `systems`, specifying the probability of each system." doc_numb_btch = "An integer that specifies the number of batches to be sampled for each validation period." + doc_mixed_batch = ( + "Whether to enable LMDB mixed-batch validation with different numbers of atoms " + "per frame. When set to True, the PyTorch LMDB dataloader flattens atom-wise " + "fields and precomputes graph indices in the collate function. In this mode, " + "integer `batch_size` is the number of frames/systems per batch, while " + "`auto:N`, `max:N`, and `filter:N` use the total atom count of the mixed " + "batch as the budget. " + "The alias `mix_batch` is accepted. Default is False." + ) args = [ Argument( @@ -4900,6 +4927,14 @@ def validation_data_args() -> list[ "numb_batch", ], ), + Argument( + "mixed_batch", + bool, + optional=True, + default=False, + alias=["mix_batch"], + doc=doc_mixed_batch + doc_only_pt_supported, + ), ] doc_validation_data = "Configurations of validation data. Similar to that of training data, except that a `numb_btch` argument may be configured" diff --git a/examples/water/dpa3/input_torch_lmdb_mixed_auto.json b/examples/water/dpa3/input_torch_lmdb_mixed_auto.json new file mode 100644 index 0000000000..beaed25696 --- /dev/null +++ b/examples/water/dpa3/input_torch_lmdb_mixed_auto.json @@ -0,0 +1,93 @@ +{ + "_comment": "DPA3 PyTorch LMDB mixed-batch example using auto atom-count batching.", + "model": { + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "dpa3", + "repflow": { + "n_dim": 128, + "e_dim": 64, + "a_dim": 32, + "nlayers": 6, + "e_rcut": 6.0, + "e_rcut_smth": 5.3, + "e_sel": 120, + "a_rcut": 4.0, + "a_rcut_smth": 3.5, + "a_sel": 40, + "axis_neuron": 4, + "fix_stat_std": 0.3, + "a_compress_rate": 1, + "a_compress_e_rate": 2, + "a_compress_use_split": true, + "update_angle": true, + "smooth_edge_update": true, + "edge_init_use_dist": true, + "use_exp_switch": true, + "update_style": "res_residual", + "update_residual": 0.1, + "update_residual_init": "const" + }, + "activation_function": "silut:10.0", + "use_tebd_bias": false, + "precision": "float32", + "concat_output_tebd": false, + "seed": 1 + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "precision": "float32", + "activation_function": "silut:10.0", + "seed": 1 + } + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3e-5 + }, + "loss": { + "type": "ener", + "start_pref_e": 0.2, + "limit_pref_e": 20, + "start_pref_f": 100, + "limit_pref_f": 60, + "start_pref_v": 0.02, + "limit_pref_v": 1 + }, + "optimizer": { + "type": "AdamW", + "adam_beta1": 0.9, + "adam_beta2": 0.999, + "weight_decay": 0.001 + }, + "training": { + "stat_file": "./dpa3_lmdb_mixed_auto.hdf5", + "training_data": { + "systems": "../../lmdb_downsample_data/water_training.lmdb", + "batch_size": "auto:128", + "mixed_batch": true, + "_comment": "With mixed_batch=true, auto:128 accumulates frames until the total atom count first reaches or exceeds 128." + }, + "validation_data": { + "systems": "../../lmdb_downsample_data/water_validation.lmdb", + "batch_size": "auto:128", + "mixed_batch": true + }, + "numb_steps": 100, + "gradient_max_norm": 5.0, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 10, + "save_freq": 100 + } +} diff --git a/source/tests/common/dpmodel/test_lmdb_data.py b/source/tests/common/dpmodel/test_lmdb_data.py index 0838cca0ab..718b92cd1c 100644 --- a/source/tests/common/dpmodel/test_lmdb_data.py +++ b/source/tests/common/dpmodel/test_lmdb_data.py @@ -15,6 +15,7 @@ LmdbDataReader, LmdbTestData, LmdbTestDataNlocView, + MixedBatchSampler, SameNlocBatchSampler, _expand_indices_by_blocks, compute_block_targets, @@ -960,21 +961,76 @@ def test_invalid_batch_size_strings_rejected(self): LmdbDataReader(self._uniform_path, self._type_map, batch_size=spec) self.assertIn("positive", str(ctx.exception)) - def test_filter_with_mixed_batch_rejected(self): - """``filter:N`` + ``mixed_batch=True`` must fail loudly. + def test_filter_with_mixed_batch_drops_large_frames(self): + """``filter:N`` + ``mixed_batch=True`` drops oversized frames.""" + reader = LmdbDataReader( + self._mixed_path, + self._type_map, + batch_size="filter:10", + mixed_batch=True, + ) + self.assertEqual(len(reader), 8) + self.assertEqual(reader.frame_nlocs, [6, 6, 6, 6, 9, 9, 9, 9]) + self.assertEqual(reader._retained_keys, [0, 1, 2, 3, 4, 5, 6, 7]) - The mixed-batch fast path skips the per-frame nloc scan, so - filter:N cannot honour its documented ``nloc > N`` drop. - """ - with self.assertRaises(ValueError) as ctx: - LmdbDataReader( - self._mixed_path, - self._type_map, - batch_size="filter:10", - mixed_batch=True, - ) - self.assertIn("filter", str(ctx.exception)) - self.assertIn("mixed_batch", str(ctx.exception)) + def test_mixed_batch_sampler_auto_stops_after_reaching_budget(self): + """``auto:N`` closes the mixed batch once total atoms reaches N.""" + reader = LmdbDataReader( + self._mixed_path, + self._type_map, + batch_size="auto:20", + mixed_batch=True, + ) + batches = list(MixedBatchSampler(reader, shuffle=False)) + self.assertEqual(batches, [[0, 1, 2, 3], [4, 5, 6], [7, 8], [9]]) + self.assertEqual( + [[reader.frame_nlocs[idx] for idx in batch] for batch in batches], + [[6, 6, 6, 6], [9, 9, 9], [9, 12], [12]], + ) + + def test_mixed_batch_sampler_max_respects_budget_when_possible(self): + """``max:N`` closes before the next frame would exceed N.""" + reader = LmdbDataReader( + self._mixed_path, + self._type_map, + batch_size="max:20", + mixed_batch=True, + ) + batches = list(MixedBatchSampler(reader, shuffle=False)) + self.assertEqual(batches, [[0, 1, 2], [3, 4], [5, 6], [7], [8], [9]]) + for batch in batches: + total_atoms = sum(reader.frame_nlocs[idx] for idx in batch) + self.assertLessEqual(total_atoms, 20) + + def test_mixed_batch_sampler_max_keeps_oversized_single_frame(self): + """``max:N`` keeps a single oversized frame instead of dropping it.""" + reader = LmdbDataReader( + self._mixed_path, + self._type_map, + batch_size="max:10", + mixed_batch=True, + ) + batches = list(MixedBatchSampler(reader, shuffle=False)) + self.assertEqual(batches, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]) + oversized = [ + batch + for batch in batches + if sum(reader.frame_nlocs[idx] for idx in batch) > 10 + ] + self.assertEqual(oversized, [[8], [9]]) + + def test_mixed_batch_sampler_filter_respects_budget(self): + """``filter:N`` removes oversized frames then applies max-style packing.""" + reader = LmdbDataReader( + self._mixed_path, + self._type_map, + batch_size="filter:10", + mixed_batch=True, + ) + batches = list(MixedBatchSampler(reader, shuffle=False)) + self.assertEqual(batches, [[0], [1], [2], [3], [4], [5], [6], [7]]) + for batch in batches: + self.assertLessEqual(sum(reader.frame_nlocs[idx] for idx in batch), 10) def test_auto_prob_with_filter_still_works(self): """compute_block_targets + sampler survive a fully-dropped block.""" diff --git a/source/tests/consistent/test_lmdb_data.py b/source/tests/consistent/test_lmdb_data.py index 6e9cecee52..c641e66aec 100644 --- a/source/tests/consistent/test_lmdb_data.py +++ b/source/tests/consistent/test_lmdb_data.py @@ -19,6 +19,7 @@ try: from deepmd.pt.utils.lmdb_dataset import ( LmdbDataset, + _collate_lmdb_mixed_batch, ) INSTALLED_PT = True @@ -222,6 +223,48 @@ def test_mixed_nloc_same_properties(self): self.assertEqual(reader.mixed_batch, ds.mixed_batch) self.assertFalse(reader.mixed_batch) + def test_mixed_batch_same_frame_data(self): + """Reader and dataset produce identical frames when mixed_batch is enabled.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = _create_mixed_nloc_lmdb(f"{tmpdir}/mixed.lmdb") + reader = LmdbDataReader( + path, self._type_map, batch_size=2, mixed_batch=True + ) + ds = LmdbDataset(path, self._type_map, batch_size=2, mixed_batch=True) + self.assertTrue(reader.mixed_batch) + self.assertTrue(ds.mixed_batch) + for i in range(len(reader)): + _assert_frames_equal(self, reader[i], ds[i], i) + + def test_mixed_batch_collate_matches_reader_concatenation(self): + """Mixed-batch PT collation preserves dpmodel reader frame order.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = _create_mixed_nloc_lmdb(f"{tmpdir}/mixed.lmdb") + reader = LmdbDataReader( + path, self._type_map, batch_size=3, mixed_batch=True + ) + ds = LmdbDataset(path, self._type_map, batch_size=3, mixed_batch=True) + frame_ids = [0, 1, 4] + batch = _collate_lmdb_mixed_batch([ds[ii] for ii in frame_ids]) + + np.testing.assert_array_equal( + batch["coord"].numpy(), + np.concatenate([reader[ii]["coord"] for ii in frame_ids], axis=0), + ) + np.testing.assert_array_equal( + batch["atype"].numpy(), + np.concatenate([reader[ii]["atype"] for ii in frame_ids], axis=0), + ) + np.testing.assert_array_equal( + batch["force"].numpy(), + np.concatenate([reader[ii]["force"] for ii in frame_ids], axis=0), + ) + expected_ptr = [0] + for ii in frame_ids: + expected_ptr.append(expected_ptr[-1] + reader[ii]["coord"].shape[0]) + self.assertEqual(batch["ptr"].tolist(), expected_ptr) + self.assertEqual(batch["batch"].tolist(), [0] * 6 + [1] * 6 + [2] * 9) + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pt/test_lmdb_dataloader.py b/source/tests/pt/test_lmdb_dataloader.py index a51b75fff5..5306b34eff 100644 --- a/source/tests/pt/test_lmdb_dataloader.py +++ b/source/tests/pt/test_lmdb_dataloader.py @@ -27,6 +27,10 @@ from deepmd.pt.utils.lmdb_dataset import ( LmdbDataset, _collate_lmdb_batch, + make_lmdb_mixed_batch_collate, +) +from deepmd.utils.argcheck import ( + training_data_args, ) from deepmd.utils.data import ( DataRequirementItem, @@ -211,6 +215,16 @@ def test_mixed_type(self, lmdb_dir): ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=2) assert ds.mixed_type is True + def test_mixed_batch_init(self, multi_nloc_lmdb): + ds = LmdbDataset( + multi_nloc_lmdb, + type_map=["O", "H"], + batch_size=3, + mixed_batch=True, + ) + assert ds.mixed_batch is True + assert len(ds.dataloaders) == 1 + # ============================================================ # Trainer compatibility interface @@ -356,6 +370,165 @@ def test_collate_none_values(self): ] assert _collate_lmdb_batch(frames)["box"] is None + def test_collate_mixed_nloc_flattens_atomwise(self): + rng = np.random.default_rng(7) + frames = [ + { + "coord": rng.standard_normal((2, 3)), + "atype": np.array([0, 1], dtype=np.int64), + "force": rng.standard_normal((2, 3)), + "atom_ener": rng.standard_normal((2, 1)), + "drdq": rng.standard_normal((2, 6)), + "energy": np.array([1.0]), + "box": np.arange(9, dtype=np.float64), + "find_energy": 1.0, + "fid": 3, + }, + { + "coord": rng.standard_normal((3, 3)), + "atype": np.array([1, 0, 1], dtype=np.int64), + "force": rng.standard_normal((3, 3)), + "atom_ener": rng.standard_normal((3, 1)), + "drdq": rng.standard_normal((3, 6)), + "energy": np.array([2.0]), + "box": np.arange(9, dtype=np.float64) + 10.0, + "find_energy": 1.0, + "fid": 9, + }, + ] + batch = _collate_lmdb_batch(frames) + assert batch["coord"].shape == (5, 3) + assert batch["atype"].shape == (5,) + assert batch["force"].shape == (5, 3) + assert batch["atom_ener"].shape == (5, 1) + assert batch["drdq"].shape == (5, 6) + assert batch["energy"].shape == (2, 1) + assert batch["box"].shape == (2, 9) + torch.testing.assert_close( + batch["batch"], torch.tensor([0, 0, 1, 1, 1], device="cpu") + ) + torch.testing.assert_close(batch["ptr"], torch.tensor([0, 2, 5], device="cpu")) + torch.testing.assert_close( + batch["coord"], + torch.as_tensor( + np.concatenate([frame["coord"] for frame in frames]), device="cpu" + ), + ) + torch.testing.assert_close( + batch["atype"], + torch.as_tensor( + np.concatenate([frame["atype"] for frame in frames]), device="cpu" + ), + ) + torch.testing.assert_close( + batch["force"], + torch.as_tensor( + np.concatenate([frame["force"] for frame in frames]), device="cpu" + ), + ) + assert batch["fid"] == [3, 9] + + def test_mixed_batch_collate_precomputes_graph(self): + frames = [ + { + "coord": np.array([[0.0, 0.0, 0.0], [0.2, 0.0, 0.0]]), + "atype": np.array([0, 0], dtype=np.int64), + "force": np.zeros((2, 3)), + "energy": np.array([0.0]), + "box": np.eye(3).reshape(9), + "find_energy": 1.0, + "fid": 0, + }, + { + "coord": np.array([[0.0, 0.0, 0.0], [0.2, 0.0, 0.0], [0.0, 0.2, 0.0]]), + "atype": np.array([0, 0, 0], dtype=np.int64), + "force": np.zeros((3, 3)), + "energy": np.array([1.0]), + "box": np.eye(3).reshape(9), + "find_energy": 1.0, + "fid": 1, + }, + ] + collate = make_lmdb_mixed_batch_collate( + { + "rcut": 0.8, + "sel": [4], + "a_rcut": 0.8, + "a_sel": 4, + "mixed_types": True, + } + ) + batch = collate(frames) + assert batch["coord"].shape == (5, 3) + torch.testing.assert_close(batch["ptr"], torch.tensor([0, 2, 5], device="cpu")) + for key in ( + "extended_atype", + "extended_batch", + "extended_image", + "extended_ptr", + "mapping", + "central_ext_index", + "nlist", + "nlist_ext", + "a_nlist", + "a_nlist_ext", + "nlist_mask", + "a_nlist_mask", + "edge_index", + "angle_index", + ): + assert key in batch + assert batch["nlist"].shape[0] == 5 + assert batch["edge_index"].shape[0] == 2 + assert batch["angle_index"].shape[0] == 3 + assert torch.equal(batch["nlist_mask"], batch["nlist_ext"] >= 0) + assert torch.equal(batch["a_nlist_mask"], batch["a_nlist_ext"] >= 0) + + def test_mixed_batch_graph_keeps_frame_boundaries(self): + frames = [ + { + "coord": np.array([[0.0, 0.0, 0.0], [0.2, 0.0, 0.0]]), + "atype": np.array([0, 0], dtype=np.int64), + "force": np.zeros((2, 3)), + "energy": np.array([0.0]), + "box": np.eye(3).reshape(9) * 10.0, + "find_energy": 1.0, + "fid": 0, + }, + { + "coord": np.array([[0.1, 0.0, 0.0], [0.3, 0.0, 0.0]]), + "atype": np.array([0, 0], dtype=np.int64), + "force": np.zeros((2, 3)), + "energy": np.array([1.0]), + "box": np.eye(3).reshape(9) * 10.0, + "find_energy": 1.0, + "fid": 1, + }, + ] + batch = make_lmdb_mixed_batch_collate( + { + "rcut": 0.5, + "sel": [2], + "a_rcut": 0.5, + "a_sel": 2, + "mixed_types": True, + } + )(frames) + + for atom_idx, frame_idx in enumerate(batch["batch"].tolist()): + local_neighbors = batch["nlist"][atom_idx][batch["nlist_mask"][atom_idx]] + ext_neighbors = batch["nlist_ext"][atom_idx][batch["nlist_mask"][atom_idx]] + assert torch.all(batch["batch"][local_neighbors] == frame_idx) + assert torch.all(batch["extended_batch"][ext_neighbors] == frame_idx) + + def test_mix_batch_arg_alias(self): + arg = training_data_args() + normalized = arg.normalize_value( + {"systems": "train.lmdb", "batch_size": 2, "mix_batch": True}, + trim_pattern="_*", + ) + assert normalized["mixed_batch"] is True + # ============================================================ # Type map remapping (PT-specific: LmdbDataset) @@ -604,6 +777,15 @@ def test_dataset_auto_prob_passthrough(self, auto_prob_lmdb): ) assert ds._block_targets is not None + def test_dataset_auto_prob_default_mixed_batch(self, auto_prob_lmdb): + ds = LmdbDataset( + auto_prob_lmdb, + type_map=["O", "H"], + batch_size=4, + mixed_batch=True, + ) + assert ds._block_targets is None + def test_dataset_auto_prob_none(self, auto_prob_lmdb): ds = LmdbDataset(auto_prob_lmdb, type_map=["O", "H"], batch_size=4) assert ds._block_targets is None @@ -627,6 +809,16 @@ def test_dataset_auto_prob_iteration(self, auto_prob_lmdb): count = sum(len(batch) for batch in ds._batch_sampler) assert count > 300 # expanded + def test_dataset_auto_prob_mixed_batch_raises(self, auto_prob_lmdb): + with pytest.raises(NotImplementedError, match="mixed_batch=True"): + LmdbDataset( + auto_prob_lmdb, + type_map=["O", "H"], + batch_size=4, + mixed_batch=True, + auto_prob_style="prob_sys_size;0:1:0.5;1:3:0.5", + ) + def test_total_batch_matches_auto_prob_sampler(self, auto_prob_lmdb): ds = LmdbDataset( auto_prob_lmdb, diff --git a/source/tests/pt/test_loss.py b/source/tests/pt/test_loss.py index 2519111357..e78ff20389 100644 --- a/source/tests/pt/test_loss.py +++ b/source/tests/pt/test_loss.py @@ -16,6 +16,9 @@ EnergySpinLoss, EnergyStdLoss, ) +from deepmd.pt.utils import ( + env, +) from deepmd.pt.utils.dataset import ( DeepmdDataSetForLoader, ) @@ -370,6 +373,105 @@ def fake_model(): self.assertTrue(np.isnan(pt_more_loss_absent[f"l2_{key}_loss"])) +class TestEnerStdLossMixedBatch(unittest.TestCase): + def test_per_frame_energy_and_virial_normalization(self) -> None: + loss_obj = EnergyStdLoss( + starter_learning_rate=1.0, + start_pref_e=1.0, + limit_pref_e=1.0, + start_pref_v=1.0, + limit_pref_v=1.0, + ) + + energy_pred = torch.tensor( + [[10.0], [200.0]], dtype=torch.float64, device=env.DEVICE + ) + energy_label = torch.zeros_like(energy_pred) + virial_pred = torch.stack( + [ + torch.full((9,), 10.0, dtype=torch.float64, device=env.DEVICE), + torch.full((9,), 200.0, dtype=torch.float64, device=env.DEVICE), + ] + ) + virial_label = torch.zeros_like(virial_pred) + + def fake_model(**kwargs): + return { + "energy": energy_pred, + "virial": virial_pred, + } + + _, loss, _ = loss_obj( + { + "mixed_batch": { + "ptr": torch.tensor( + [0, 10, 110], dtype=torch.long, device=env.DEVICE + ) + }, + }, + fake_model, + { + "energy": energy_label, + "find_energy": 1.0, + "virial": virial_label, + "find_virial": 1.0, + }, + natoms=0, + learning_rate=1.0, + ) + + expected_per_term = ( + torch.tensor( + [10.0**2 / 10, 200.0**2 / 100], + dtype=torch.float64, + device=env.DEVICE, + ) + .mean() + .to(loss.dtype) + ) + torch.testing.assert_close(loss, expected_per_term * 2.0) + + def test_generalized_force_rejected(self) -> None: + loss_obj = EnergyStdLoss( + starter_learning_rate=1.0, + start_pref_f=1.0, + limit_pref_f=1.0, + start_pref_gf=1.0, + limit_pref_gf=1.0, + numb_generalized_coord=2, + ) + + def fake_model(**kwargs): + return { + "force": torch.zeros((2, 3), dtype=torch.float64, device=env.DEVICE), + } + + with self.assertRaisesRegex( + NotImplementedError, + "Generalized force loss is not supported with mixed_batch=True yet.", + ): + loss_obj( + { + "mixed_batch": { + "ptr": torch.tensor([0, 2], dtype=torch.long, device=env.DEVICE) + } + }, + fake_model, + { + "force": torch.zeros( + (2, 3), dtype=torch.float64, device=env.DEVICE + ), + "drdq": torch.zeros( + (1, 12), dtype=torch.float64, device=env.DEVICE + ), + "find_force": 1.0, + "find_drdq": 1.0, + }, + natoms=2, + learning_rate=1.0, + ) + + class TestEnerStdLossAePfGf(LossCommonTest): def setUp(self) -> None: self.start_lr = 1.1 diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index 7851850ba0..77e9afd8a3 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -25,6 +25,8 @@ patch, ) +import lmdb +import msgpack import numpy as np import torch @@ -1038,6 +1040,180 @@ def test_full_validation_rejects_multitask(self) -> None: normalize(config, multi_task=True) +def _encode_lmdb_array(arr: np.ndarray) -> dict[str, Any]: + arr = np.asarray(arr) + return { + "nd": None, + "type": str(arr.dtype), + "kind": "", + "shape": list(arr.shape), + "data": arr.tobytes(), + } + + +def _make_mixed_lmdb_frame(natoms: int, seed: int) -> dict[str, Any]: + rng = np.random.RandomState(seed) + n_type0 = max(1, natoms // 2) + n_type1 = natoms - n_type0 + atype = np.array([0] * n_type0 + [1] * n_type1, dtype=np.int64) + return { + "atom_numbs": [n_type0, n_type1], + "atom_names": ["O", "H"], + "atom_types": _encode_lmdb_array(atype), + "orig": _encode_lmdb_array(np.zeros(3, dtype=np.float64)), + "cells": _encode_lmdb_array(np.eye(3, dtype=np.float64) * 10.0), + "coords": _encode_lmdb_array((rng.rand(natoms, 3) * 2.0).astype(np.float64)), + "energies": _encode_lmdb_array(np.array(rng.randn(), dtype=np.float64)), + "forces": _encode_lmdb_array(rng.randn(natoms, 3).astype(np.float64)), + } + + +def _create_mixed_batch_lmdb(path: str, frame_specs: list[tuple[int, int]]) -> None: + env = lmdb.open(path, map_size=10 * 1024 * 1024) + frame_fmt = "012d" + frame_nlocs = [] + with env.begin(write=True) as txn: + frame_idx = 0 + for natoms, count in frame_specs: + for _ in range(count): + txn.put( + format(frame_idx, frame_fmt).encode(), + msgpack.packb( + _make_mixed_lmdb_frame(natoms, frame_idx), use_bin_type=True + ), + ) + frame_nlocs.append(natoms) + frame_idx += 1 + txn.put( + b"__metadata__", + msgpack.packb( + { + "nframes": len(frame_nlocs), + "frame_idx_fmt": frame_fmt, + "system_info": {"natoms": [2, 2], "formula": "mixed"}, + "frame_nlocs": frame_nlocs, + "type_map": ["O", "H"], + }, + use_bin_type=True, + ), + ) + env.close() + + +def _mixed_batch_dpa3_config(train_path: str, val_path: str) -> dict[str, Any]: + return { + "model": { + "type_map": ["O", "H"], + "descriptor": { + "type": "dpa3", + "repflow": { + "n_dim": 8, + "e_dim": 4, + "a_dim": 4, + "nlayers": 1, + "e_rcut": 1.5, + "e_rcut_smth": 0.2, + "e_sel": 4, + "a_rcut": 1.0, + "a_rcut_smth": 0.1, + "a_sel": 3, + "axis_neuron": 2, + "update_angle": True, + "update_style": "res_residual", + "update_residual_init": "const", + "a_compress_rate": 0, + "n_multi_edge_message": 1, + "smooth_edge_update": True, + }, + "precision": "float64", + "activation_function": "silu", + "concat_output_tebd": False, + }, + "fitting_net": { + "neuron": [8], + "precision": "float64", + "seed": 1, + }, + "data_stat_nbatch": 1, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 10, + "start_lr": 1e-3, + "stop_lr": 1e-8, + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 10, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "optimizer": { + "type": "Adam", + "adam_beta1": 0.9, + "adam_beta2": 0.999, + "weight_decay": 0.0, + }, + "training": { + "training_data": { + "systems": train_path, + "batch_size": "max:5", + "mixed_batch": True, + }, + "validation_data": { + "systems": val_path, + "batch_size": "max:5", + "mixed_batch": True, + "numb_btch": 1, + }, + "numb_steps": 1, + "seed": 10, + "shuffle": False, + "disp_file": "lcurve.out", + "disp_freq": 1, + "save_freq": 1, + "save_ckpt": "model.ckpt", + "disp_training": False, + }, + } + + +class TestMixedBatchLmdbTraining(unittest.TestCase): + def setUp(self) -> None: + self._cwd = os.getcwd() + self._tmpdir = tempfile.TemporaryDirectory() + os.chdir(self._tmpdir.name) + + def tearDown(self) -> None: + os.chdir(self._cwd) + self._tmpdir.cleanup() + + @TRAINING_TEST_TIMEOUT + def test_mixed_batch_dpa3_lmdb_training_smoke(self) -> None: + tmpdir = Path(self._tmpdir.name) + train_path = str(tmpdir / "train.lmdb") + val_path = str(tmpdir / "val.lmdb") + _create_mixed_batch_lmdb(train_path, [(2, 2), (3, 2)]) + _create_mixed_batch_lmdb(val_path, [(2, 1), (3, 1)]) + + trainer = get_trainer(_mixed_batch_dpa3_config(train_path, val_path)) + input_dict, _, _ = trainer.get_data(is_train=True) + + self.assertIn("ptr", input_dict) + self.assertIn("batch", input_dict) + self.assertIn("nlist_ext", input_dict) + self.assertIn("angle_index", input_dict) + self.assertEqual(input_dict["ptr"].shape[0], 3) + + trainer.run() + + self.assertTrue(Path("model.ckpt-1.pt").exists()) + self.assertTrue(Path("model.ckpt.pt").exists()) + + class TestMultiTaskUtils(unittest.TestCase): def test_cascade_top_level_defaults(self) -> None: cfg = {"foo": 1, "model_dict": {"a": {}, "b": {"foo": 2}}} diff --git a/source/tests/universal/dpmodel/model/test_mixed_batch_flat_forward.py b/source/tests/universal/dpmodel/model/test_mixed_batch_flat_forward.py new file mode 100644 index 0000000000..26ef96f202 --- /dev/null +++ b/source/tests/universal/dpmodel/model/test_mixed_batch_flat_forward.py @@ -0,0 +1,202 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for the dpmodel mixed-nloc flat forward interface.""" + +from itertools import ( + pairwise, +) + +import numpy as np +import pytest + +from deepmd.dpmodel.descriptor import ( + DescrptDPA3, +) +from deepmd.dpmodel.descriptor.dpa3 import ( + RepFlowArgs, +) +from deepmd.dpmodel.fitting import ( + EnergyFittingNet, +) +from deepmd.dpmodel.model import ( + EnergyModel, +) + +from ....seed import ( + GLOBAL_SEED, +) + + +def _build_model(with_params: bool = False) -> EnergyModel: + repflow = RepFlowArgs( + n_dim=8, + e_dim=4, + a_dim=4, + nlayers=1, + e_rcut=1.5, + e_rcut_smth=0.2, + e_sel=4, + a_rcut=1.0, + a_rcut_smth=0.1, + a_sel=3, + axis_neuron=2, + update_angle=True, + update_style="res_residual", + update_residual_init="const", + a_compress_rate=0, + n_multi_edge_message=1, + smooth_edge_update=True, + ) + descriptor = DescrptDPA3( + 2, + repflow=repflow, + precision="float64", + seed=GLOBAL_SEED, + type_map=["O", "H"], + ) + fitting_kwargs = { + "ntypes": 2, + "dim_descrpt": descriptor.get_dim_out(), + "neuron": [8], + "mixed_types": descriptor.mixed_types(), + "type_map": ["O", "H"], + "precision": "float64", + "seed": GLOBAL_SEED, + } + if with_params: + fitting_kwargs.update( + { + "numb_fparam": 2, + "numb_aparam": 2, + "dim_case_embd": 2, + "default_fparam": [1.0, 1.0], + } + ) + fitting = EnergyFittingNet(**fitting_kwargs) + return EnergyModel(descriptor, fitting, type_map=["O", "H"]) + + +def _mixed_batch() -> tuple[np.ndarray, ...]: + coord = np.array( + [ + [0.0, 0.0, 0.0], + [0.5, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.4, 0.0, 0.0], + [0.0, 0.4, 0.0], + ], + dtype=np.float64, + ) + atype = np.array([0, 1, 0, 1, 1], dtype=np.int64) + batch = np.array([0, 0, 1, 1, 1], dtype=np.int64) + ptr = np.array([0, 2, 5], dtype=np.int64) + box = np.tile(np.eye(3).reshape(1, 9), (2, 1)).astype(np.float64) * 10.0 + return coord, atype, batch, ptr, box + + +def _regular_outputs( + model: EnergyModel, + coord: np.ndarray, + atype: np.ndarray, + ptr: np.ndarray, + box: np.ndarray, + fparam: np.ndarray | None = None, + aparam: np.ndarray | None = None, +) -> list[dict[str, np.ndarray]]: + regular = [] + for frame_idx, (start, end) in enumerate(pairwise(ptr)): + nloc = end - start + frame_aparam = ( + aparam[start:end].reshape(1, nloc, *aparam.shape[1:]) + if aparam is not None + else None + ) + regular.append( + model.call( + coord[start:end].reshape(1, nloc * 3), + atype[start:end].reshape(1, nloc), + box=box[frame_idx : frame_idx + 1], + fparam=fparam[frame_idx : frame_idx + 1] + if fparam is not None + else None, + aparam=frame_aparam, + ) + ) + return regular + + +def test_dpmodel_dpa3_flat_call_matches_regular_per_frame_outputs() -> None: + model = _build_model(with_params=True) + coord, atype, batch, ptr, box = _mixed_batch() + fparam = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float64) + aparam = np.arange(10, dtype=np.float64).reshape(5, 2) / 10.0 + + flat = model.call( + coord, + atype, + box=box, + fparam=fparam, + aparam=aparam, + mixed_batch={ + "batch": batch, + "ptr": ptr, + "extended_atype": atype, + "extended_batch": batch, + "extended_image": np.zeros_like(coord, dtype=np.int64), + "extended_ptr": ptr, + "mapping": np.arange(coord.shape[0], dtype=np.int64), + "central_ext_index": np.arange(coord.shape[0], dtype=np.int64), + "nlist": np.full((coord.shape[0], 0), -1, dtype=np.int64), + "nlist_ext": np.full((coord.shape[0], 0), -1, dtype=np.int64), + "a_nlist": np.full((coord.shape[0], 0), -1, dtype=np.int64), + "a_nlist_ext": np.full((coord.shape[0], 0), -1, dtype=np.int64), + "nlist_mask": np.zeros((coord.shape[0], 0), dtype=bool), + "a_nlist_mask": np.zeros((coord.shape[0], 0), dtype=bool), + }, + ) + regular = _regular_outputs(model, coord, atype, ptr, box, fparam, aparam) + + np.testing.assert_allclose( + flat["energy"], + np.concatenate([item["energy"] for item in regular], axis=0), + ) + np.testing.assert_allclose( + flat["atom_energy"], + np.concatenate([item["atom_energy"].reshape(-1, 1) for item in regular]), + ) + np.testing.assert_array_equal( + flat["mask"], + np.concatenate([item["mask"].reshape(-1) for item in regular]), + ) + + +def test_dpmodel_flat_call_requires_batch_and_ptr() -> None: + model = _build_model() + coord, atype, batch, _, box = _mixed_batch() + + with pytest.raises(ValueError, match="mixed_batch must contain both batch and ptr"): + model.call(coord, atype, box=box, mixed_batch={"batch": batch}) + + +def test_dpmodel_flat_call_validates_ptr() -> None: + model = _build_model() + coord, atype, batch, _, box = _mixed_batch() + + with pytest.raises(ValueError, match="end at the number of atoms"): + model.call( + coord, + atype, + box=box, + mixed_batch={ + "batch": batch, + "ptr": np.array([0, 2, 4], dtype=np.int64), + }, + ) + + +def test_dpmodel_flat_call_rejects_hessian() -> None: + model = _build_model() + model.enable_hessian() + coord, atype, batch, ptr, box = _mixed_batch() + + with pytest.raises(NotImplementedError, match="Hessian"): + model.call(coord, atype, box=box, mixed_batch={"batch": batch, "ptr": ptr}) diff --git a/source/tests/universal/pt/model/test_mixed_batch_flat_forward.py b/source/tests/universal/pt/model/test_mixed_batch_flat_forward.py new file mode 100644 index 0000000000..ef356a0022 --- /dev/null +++ b/source/tests/universal/pt/model/test_mixed_batch_flat_forward.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for the PyTorch mixed-nloc flat forward path.""" + +import pytest +import torch + +from deepmd.dpmodel.descriptor.dpa3 import ( + RepFlowArgs, +) +from deepmd.pt.model.descriptor import ( + DescrptDPA3, +) +from deepmd.pt.model.model import ( + EnergyModel, +) +from deepmd.pt.model.task import ( + EnergyFittingNet, +) +from deepmd.pt.utils.nlist import ( + build_precomputed_flat_graph, +) + +from ....seed import ( + GLOBAL_SEED, +) + + +def _build_model() -> EnergyModel: + repflow = RepFlowArgs( + n_dim=8, + e_dim=4, + a_dim=4, + nlayers=1, + e_rcut=1.5, + e_rcut_smth=0.2, + e_sel=4, + a_rcut=1.0, + a_rcut_smth=0.1, + a_sel=3, + axis_neuron=2, + update_angle=True, + update_style="res_residual", + update_residual_init="const", + a_compress_rate=0, + n_multi_edge_message=1, + smooth_edge_update=True, + ) + descriptor = DescrptDPA3( + 2, + repflow=repflow, + precision="float64", + seed=GLOBAL_SEED, + type_map=["O", "H"], + ) + fitting = EnergyFittingNet( + ntypes=2, + dim_descrpt=descriptor.get_dim_out(), + neuron=[8], + mixed_types=descriptor.mixed_types(), + type_map=["O", "H"], + precision="float64", + seed=GLOBAL_SEED, + ) + return EnergyModel(descriptor, fitting, type_map=["O", "H"]).to("cpu").eval() + + +def _mixed_batch() -> tuple[torch.Tensor, ...]: + coord = torch.tensor( + [ + [0.0, 0.0, 0.0], + [0.5, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.4, 0.0, 0.0], + [0.0, 0.4, 0.0], + ], + dtype=torch.float64, + device="cpu", + ) + atype = torch.tensor([0, 1, 0, 1, 1], dtype=torch.long, device="cpu") + batch = torch.tensor([0, 0, 1, 1, 1], dtype=torch.long, device="cpu") + ptr = torch.tensor([0, 2, 5], dtype=torch.long, device="cpu") + box = ( + torch.eye(3, dtype=torch.float64, device="cpu").reshape(1, 9).repeat(2, 1) + * 10.0 + ) + return coord, atype, batch, ptr, box + + +def _flat_graph( + coord: torch.Tensor, + atype: torch.Tensor, + batch: torch.Tensor, + ptr: torch.Tensor, + box: torch.Tensor, +) -> dict[str, torch.Tensor]: + return build_precomputed_flat_graph( + coord, + atype, + batch, + ptr, + rcut=1.5, + sel=[4], + a_rcut=1.0, + a_sel=3, + mixed_types=True, + box=box, + ) + + +def test_dpa3_flat_forward_matches_regular_per_frame_energy_and_force() -> None: + model = _build_model() + coord, atype, batch, ptr, box = _mixed_batch() + graph = _flat_graph(coord, atype, batch, ptr, box) + + flat = model( + coord, + atype, + box=box, + mixed_batch={"batch": batch, "ptr": ptr, **graph}, + ) + regular = [] + for frame_idx in range(ptr.numel() - 1): + start = int(ptr[frame_idx].item()) + end = int(ptr[frame_idx + 1].item()) + regular.append( + model( + coord[start:end].reshape(1, -1), + atype[start:end].reshape(1, -1), + box=box[frame_idx : frame_idx + 1], + ) + ) + + expected_energy = torch.cat([item["energy"] for item in regular], dim=0) + expected_atom_energy = torch.cat( + [item["atom_energy"].reshape(-1, 1) for item in regular], dim=0 + ) + expected_force = torch.cat( + [item["force"].reshape(-1, 3) for item in regular], dim=0 + ) + + torch.testing.assert_close(flat["energy"], expected_energy) + torch.testing.assert_close(flat["atom_energy"], expected_atom_energy) + torch.testing.assert_close(flat["force"], expected_force) + torch.testing.assert_close( + flat["energy"], + torch.stack([flat["atom_energy"][:2].sum(0), flat["atom_energy"][2:].sum(0)]), + ) + assert flat["virial"].shape == (2, 9) + assert torch.isfinite(flat["virial"]).all() + assert flat["mask"].tolist() == [1, 1, 1, 1, 1] + + +def test_dpa3_flat_forward_requires_precomputed_graph_fields() -> None: + model = _build_model() + coord, atype, batch, ptr, box = _mixed_batch() + + with pytest.raises(RuntimeError, match="precomputed graph fields"): + model(coord, atype, box=box, mixed_batch={"batch": batch, "ptr": ptr}) + + +def test_dpa3_flat_forward_rejects_atomic_virial() -> None: + model = _build_model() + coord, atype, batch, ptr, box = _mixed_batch() + graph = _flat_graph(coord, atype, batch, ptr, box) + + with pytest.raises(NotImplementedError, match="Atomic virial"): + model( + coord, + atype, + box=box, + do_atomic_virial=True, + mixed_batch={"batch": batch, "ptr": ptr, **graph}, + )