Skip to content

drkh-n/sat-unet-generalization

Repository files navigation

SAT-UNET: Improving Generalization and cross-dataset inference

This is the official implementation of the SAT-UNET model, as described in our paper:
GhalamClouds: Remote Sensing docs/images for Kazakhstan and Clouds Segmentation Using SAT-UNet


Description

SAT-UNET is a deep learning model for cloud semantic segmentation, designed for remote sensing imagery. This repository provides all necessary code, configuration files, and utilities to train, validate, and infer with SAT-UNET, including multi-GPU support and experiment logging.

🚀 New in this Version

This repository has been updated to include methods for improving cross-dataset generalization, specifically targeting the domain shift between satellites (e.g., Landsat-8 vs. KazSTSat).

Key updates include:

  • Sat-SlideMix Augmentation: Implementation of the batch-level normalization technique adapted from Hopkins et al. (2025).

  • Percentile-Based Normalization: A robust 1-99th percentile linear normalization strategy.

Key Findings: In our recent experiments, we observed that while percentile-based normalization significantly speeds up convergence and stability, Sat-SlideMix drives the major improvements in performance metrics, while combination of these provide the best cross-dataset inference.

New SOTA: We achieved a new F1-score record of 97.87% on the Cloud-38 benchmark (surpassing the previously reported 97.69%).


Project Structure

  • checkpoints/
    Saved model weights and checkpoints.

  • configs/
    YAML configuration files for experiments (e.g., satunet.yml).

  • data/
    Data loading, augmentation, and dataset definitions.

  • docs/
    Documentation and requirements.

  • models/
    Model architectures, attention modules, and custom loss functions.

  • utils/
    Utility scripts for logging, checkpointing, patch merging, and more.

  • train.py
    Main training script.

  • inference.py
    Inference script for running predictions on new data.

  • train_multi_gpu.py
    Multi-GPU training is envoked by choosing multi_gpu parameter in training.type of YAML configuration file.

  • trainer.py
    Single-GPU training is envoked by choosing torch parameter in training.type of YAML configuration file.


Dataset Structure

The dataset should be organized into train, valid, and test (optional) folders.
Each split contains subfolders named after each color band (prefixes train_, valid_, test_). Example structure:

├──Dataset Root Folder
│------------├──train_red
│------------├──train_green
│------------├──train_blue
│------------├──train_nir
│------------├──train_gt

│------------├──valid_red
│------------├──valid_green
│------------├──valid_blue
│------------├──valid_nir
│------------├──valid_gt


Sat-UNet Architecture

Here is the visual diagram of our proposed Sat-UNet model.

SAT-UNET Architecture

Spatial Attention Block

Here is the visual diagram of our proposed NLP-inspired attention block for Sat-UNet model.

SAT-UNET Architecture


Sat-SlideMix and 1-99th Normalization Visualization Results

We conducted extensive experiments to validate the impact of our new augmentation and normalization pipeline.

Models are trained on 38-Cloud and tested on GhalamClouds80.

Note: No Norm means simple normalization, which is division by 8-bit maximum range (255).

NIR GT No Norm, No Aug Norm, No Aug No Norm, Sat-SlideMix Norm, Sat-SlideMix

Configuration

All experiment parameters are set in the YAML config file, e.g., configs/satunet.yml.
To replicate experiments or run your own:

  • Model parameters:

    • xp_config.model_name: Model to use (e.g., SatUNet)
    • model.in_channels, model.out_channels: Input/output channels
  • Training parameters:

    • training.type: Training mode (lightning, torch, or multi_gpu)
    • training.multi_gpu.world_size: Number of GPUs for multi-GPU training
    • trainer.max_epochs: Number of training epochs
    • trainer.precision: Precision (e.g., bf16-mixed)
  • Data parameters:

    • data.data_root: Path to dataset
    • data.train_colors, data.valid_colors: Folder names for each modality
    • data.train_batch_size, data.valid_batch_size: Batch sizes
  • Optimization:

    • model.optim_params: Optimizer settings (type, learning rate, weight decay, etc.)
  • Early Stopping

    • early_stop.patience: Number of epochs to stop training when validation loss doesn't improve
  • Logging:

    • logs.type: Logging backend (comet, tb_logs, etc.)
    • logs.comet.api_key: Set your Comet API key (can be set via environment variable COMET_API_KEY)
    • NOTE: in train.py comment out 35-36 lines to NOT LOG your COMET API KEY.

Change these parameters as needed to match your experimental setup.


Setup

  • Python 3.12.7
  • CUDA 12.1
  • CUDNN 8.9.7

How to Run Training

  1. Install dependencies:
	pip install -r docs/requirements.txt
  1. Run training
  export COMET_API_KEY=YOUR_API_KEY
	python3 train.py -p configs/satunet.yml

How to Run Inference

  1. No thresholding (probability map)
	python3 inference.py -d DataFolder -p configs/satunet.yml -c checkpoints/model.ckpt -s results
  1. Thresholding
	python3 inference.py -d DataFolder -p configs/satunet.yml -c checkpoints/model.ckpt -t 0.5 -s results

Citation

If you use GhalamClouds datasets and/or SAT-UNET in your research, please cite our paper:

	@article{AIMYSHEV2025,
    title = {GhalamClouds: Remote Sensing Images for Kazakhstan and Clouds Segmentation Using SAT-UNet},
    journal = {Advances in Space Research},
    year = {2025},
    issn = {0273-1177},
    doi = {https://doi.org/10.1016/j.asr.2025.09.098},
    url = {https://www.sciencedirect.com/science/article/pii/S0273117725011299},
    author = {Dias Aimyshev and Beket Tulegenov and Berik Smadyarov and Darkhan Nurzhakyp},
    keywords = {Attention, Deep Learning, Image Processing, Remote Sensing, UNet},
    abstract = {Cloud detection and segmentation represent a fundamental step in the analysis of optical remote sensing data. However, the performance of deep learning models heavily depends on the availability of task-specific datasets, which are often lacking. To address these challenges, we introduce two versions of the GhalamClouds dataset, derived from KazSTSat imagery, containing over 40,000 manually annotated cloud patches across six spectral channels and diverse landscapes of Kazakhstan. We also propose Sat-UNet, a U-Net backbone augmented with a spatial attention block designed to refine skip connections. Extensive experiments demonstrate that proposed model confidently competes with state-of-the-art models with significantly less parameters. Ablation studies further show that the proposed attention mechanism consistently outperforms the baseline UNet as well as SE and CBAM modules, with notable gains in F1 score, AUC, and Accuracy. We expect that GhalamClouds and Sat-UNet will provide valuable resources for advancing cloud segmentation research and improving model robustness in remote sensing applications.}
}

About

This repository contains training setup that uses percentile-based normalization and Sat-SlideMix augmentation adapted from "Data Augmentation Approaches for Satellite Imagery" (Hopkins et al., 2025). It aims to improve generalization between two datasets 38-Clouds and GhalamCLouds80

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages