A production-grade debugging and monitoring framework for PyTorch training loops. Automatically tracks loss curves, gradient norms, learning rates, GPU memory, and throughput -- and fires real-time alerts the moment something goes wrong.
Built because deep learning training failures are expensive. Models can run for hours before gradient explosion, NaN loss, or a misconfigured scheduler becomes obvious. This tool catches those issues at the step they happen, with a diagnosis and actionable recommendations attached.
- Real-time anomaly detection -- NaN/Inf loss, gradient explosions, training plateaus, frozen learning rate, and predicted OOM crashes, all with root-cause diagnosis and fix suggestions
- Zero-friction integration -- wrap your existing loop in one context manager, no refactoring needed
- Comprehensive metrics -- loss, accuracy, gradient norms (global and per-layer), learning rate across all param groups, GPU memory (allocated vs reserved), steps/sec, and custom user metrics
- Interactive dashboard -- Plotly HTML report with six panels, saved as a standalone file
- Alert routing -- stdout by default, or plug in Slack webhooks and SMTP email
- Checkpoint analysis -- inspect saved weights, compare two runs layer by layer
- CPU profiling -- thin wrapper around
torch.profilerfor Chrome trace output
git clone https://github.com/Olajide-Badejo/PyTorch-Training-Inspector.git
cd pytorch-training-inspector
pip install -r requirements.txt
pip install -e .from inspector import TrainingInspector
inspector = TrainingInspector(model, optimizer)
for epoch in range(num_epochs):
inspector.current_epoch = epoch
for x, y in dataloader:
with inspector.step_context():
loss = criterion(model(x), y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
inspector.log_metrics({'loss': loss.item()})
inspector.save_metrics('metrics.csv')
inspector.generate_dashboard('dashboard.html')
inspector.remove_hooks()That is the entire integration. Every monitor and detector runs automatically.
| Anomaly | Detector | What triggers it |
|---|---|---|
| NaN / Inf loss | NaNDetector |
Loss becomes nan or inf |
| Gradient explosion | ExplosionDetector |
Norm exceeds hard limit or 5-sigma spike above EMA baseline |
| Training plateau | StallDetector |
No loss improvement for N steps |
| Frozen learning rate | LRMismatchDetector |
LR unchanged for 200+ steps |
| LR spike | LRMismatchDetector |
LR jumps >10x in one step |
| OOM warning | OOMDetector |
GPU memory on track to hit limit in 50 steps |
Each alert prints the type, step, message, and a bulleted list of recommended fixes.
from inspector.utils.config import InspectorConfig
cfg = InspectorConfig(
grad_norm_threshold=50.0, # flag explosions earlier
stall_patience=100,
nan_alert_threshold=1,
)
inspector = TrainingInspector(model, optimizer, config=cfg.to_dict())from inspector.alerts import SlackNotifier
notifier = SlackNotifier(webhook_url="https://hooks.slack.com/services/...")
inspector = TrainingInspector(model, optimizer, alert_callback=notifier)After training (or at any checkpoint):
inspector.generate_dashboard('dashboard.html')Opens in any browser, no server required. Panels:
- Training loss (raw + EMA-smoothed)
- Gradient norm with explosion markers
- Learning rate schedule
- GPU memory (allocated vs reserved)
- Training throughput (steps/sec)
- Step time distribution
pytorch-training-inspector/
|-- inspector/
| |-- core/ -- TrainingInspector, MetricsCollector, hook utilities
| |-- monitors/ -- Loss, gradient, LR, GPU, throughput, activation
| |-- detectors/ -- NaN, explosion, stall, LR mismatch, OOM
| |-- visualization/ -- Plotly dashboard, live terminal monitor, HTML report
| |-- alerts/ -- Slack and email notifiers
| `-- utils/ -- Config, checkpoint analysis, profiler wrapper
|-- examples/ -- Runnable examples (basic, advanced, DDP, checkpoint)
|-- tests/ -- pytest test suite (monitors, detectors, integration)
|-- benchmarks/ -- Overhead benchmark vs bare training loop
`-- docs/ -- User guide, API reference, anomaly patterns, perf tips
# Minimal working example on synthetic data
python examples/basic_usage.py
# Custom alert callback, activation monitoring, live terminal display
python examples/advanced_monitoring.py
# Save and resume with checkpoint analysis
python examples/resume_from_checkpoint.py
# Multi-GPU DDP (requires torchrun)
torchrun --nproc_per_node=2 examples/multi_gpu_training.pypython -m pytest tests/ -vCoverage target is >80% across monitors, detectors, and the integration pipeline.
python benchmarks/overhead_benchmark.pyTypical result on a CPU-only machine with a small MLP:
Label Mean (ms) Std (ms) P95 (ms) Steps/s
-------------------------------------------------------------------------------
Baseline (no inspector) 2.841 0.312 3.201 352.0
With TrainingInspector 2.873 0.318 3.250 348.1
Inspector overhead: +1.13%
Target: < 2% overhead. GPU training with larger models shows similar results because the GPU work dominates the step time.
- Python 3.8+
- PyTorch 2.0+
- NumPy, Pandas, Plotly
Optional:
slack-sdkfor Slack alertsdashfor a live web dashboardpynvmlfor GPU utilization percentagepytestfor running tests
- Contribution guide: CONTRIBUTING.md
- Security policy: SECURITY.md
- Citation metadata: CITATION.cff
MIT