freshcrate
Skin:/
Home > AI Agents > pytorch_template

pytorch_template

AI-agent-friendly PyTorch research pipeline โ€” one YAML config drives preflight, training, Optuna HPO, and real-time TUI monitoring

Why this rank:Release freshnessStrong adoptionHealthy release cadence

Description

AI-agent-friendly PyTorch research pipeline โ€” one YAML config drives preflight, training, Optuna HPO, and real-time TUI monitoring

README

PyTorch Template

English | ํ•œ๊ธ€

License: MIT Python 3.10+ PyTorch 2.7+ Optuna W&B TUI Monitor

One YAML. One command. Full research pipeline.

Config-driven experiment pipeline with dual logging, real-time TUI monitor, and AI agent skills.

Pipeline Overview


Why This Template?

Problem Solution
Config errors discovered after hours of GPU time preflight runs 1-batch forward+backward in seconds
Rewriting training loops for every project Callback-based loop โ€” extend without modifying the core
"Which logging do I use?" Choose wandb or tui per config โ€” CSV always saved
Can't see training progress without W&B Rust TUI monitor renders loss curves in real-time from CSV
Manual hyperparameter tuning Optuna + PFL pruner prunes unpromising trials early
HPO is a black box after it finishes hpo-report shows parameter importance and boundary warnings
Silent overfitting or exploding gradients Auto-detected by callbacks, logged to W&B/TUI/CSV
"It worked on my machine" Full provenance: Python / PyTorch / CUDA / GPU / git hash

Quick Start

git clone https://github.com/Axect/pytorch_template.git && cd pytorch_template

# Install (uv recommended)
uv venv && source .venv/bin/activate
uv pip install -U torch wandb rich beaupy numpy optuna matplotlib \
  scienceplots typer tqdm pyyaml pytorch-optimizer pytorch-scheduler

# Check your environment
python -m cli doctor

# Validate โ†’ Preview โ†’ Train โ†’ Analyze
python -m cli preflight configs/run_template.yaml
python -m cli preview configs/run_template.yaml
python -m cli train configs/run_template.yaml --device cuda:0
python -m cli analyze

How It Works

Everything is YAML

project: MyProject
device: cuda:0
logging: tui                                           # 'wandb' or 'tui'
net: model.MLP                                         # any importlib-resolvable path
optimizer: pytorch_optimizer.SPlus
scheduler: pytorch_scheduler.ExpHyperbolicLRScheduler
criterion: torch.nn.MSELoss
data: recipes.regression.data.load_data                # plug in any load_data() here
seeds: [89, 231, 928, 814, 269]                        # multi-seed reproducibility
epochs: 150
batch_size: 256
net_config:
  nodes: 64
  layers: 4
optimizer_config:
  lr: 1.e-1
scheduler_config:
  total_steps: 150
  upper_bound: 300
  min_lr: 1.e-6

All module paths are resolved via importlib. Three layers of validation run before a single GPU cycle is consumed:

  1. Structural โ€” format checks, non-empty seeds, positive epochs/batch_size
  2. Runtime โ€” CUDA availability, all import paths resolve
  3. Semantic โ€” upper_bound >= total_steps, lr positivity, unique seeds

Dual Logging

The logging field controls where metrics are displayed:

logging: Display W&B CSV latest_model.pt
wandb W&B dashboard + periodic console print Yes Always Always
tui Every-epoch terminal output (agent-friendly) No Always Always

CSV and latest model are always saved, regardless of logging mode. This means:

  • Agents can read metrics.csv and latest_model.pt mid-training
  • The TUI monitor can render loss curves from any run
  • No data is lost even without W&B

Callback Architecture

The training loop emits events; behaviors are independent, priority-ordered callbacks:

Callback Priority Purpose
NaNDetectionCallback 5 Detect NaN loss, signal stop
OptimizerModeCallback 10 SPlus / ScheduleFree train/eval toggle
GradientMonitorCallback 12 Track gradient norms, warn on explosion
LossPredictionCallback 70 Predict final loss via shifted exponential fit
OverfitDetectionCallback 75 Detect train/val divergence
WandbLoggingCallback 80 Log to W&B (when logging: wandb)
TUILoggingCallback 80 Terminal logging (when logging: tui)
CSVLoggingCallback 81 Write metrics.csv every epoch (always active)
PrunerCallback 85 Report to Optuna pruner
EarlyStoppingCallback 90 Patience-based stopping
CheckpointCallback 95 Periodic + best-model checkpoints
LatestModelCallback 96 Save latest_model.pt every epoch (always active)

Add your own by subclassing TrainingCallback โ€” zero changes to the training loop:

class GradientClipCallback(TrainingCallback):
    priority = 15  # runs right after GradientMonitorCallback

    def on_train_step_end(self, trainer, **kwargs):
        torch.nn.utils.clip_grad_norm_(trainer.model.parameters(), 1.0)

Loss Prediction

The LossPredictionCallback fits a shifted exponential decay model to the validation loss history:

L(t) = a ยท exp(-b ยท t) + c

The predicted final loss is c (the asymptotic floor). The fitting uses EMA-smoothed data with a 3-point anchor method, which:

  • Works with positive and negative losses
  • Handles plateau, oscillation, and non-convergent patterns
  • Returns raw loss values (same units as input)

Real-time TUI Monitor

A Rust-based terminal UI reads metrics.csv and renders live charts:

TUI Monitor

Features:

  • Loss curves (train/val) with Braille sub-pixel rendering
  • Learning rate schedule and gradient norm history
  • Predicted final loss overlay
  • Log scale toggle: logโ‚โ‚€ for positive data, symlogโ‚โ‚€ for mixed-sign
  • Status bar with current metrics and time since last update
  • 1.2 MB static binary โ€” no runtime dependencies

Usage:

# Build once (requires Rust toolchain)
cd tools/monitor && cargo build --release && cd ../..

# Run alongside training (in a separate terminal)
python -m cli monitor                                      # auto-detect latest run
python -m cli monitor runs/MyProject/group/42/metrics.csv  # specific file

# Or run the binary directly
./tools/monitor/target/release/training-monitor runs/MyProject/group/42/
Key Action
q / Esc Quit
l Toggle log scale
โ† / โ†’ Switch metric tabs (when custom metrics are logged)

HPO Monitoring

The same binary supports real-time HPO monitoring by reading the Optuna SQLite database:

python -m cli monitor --hpo   # auto-detects .db file

Four tabs: Overview (trial scatter + best convergence), Parameters (per-parameter scatter grid), Best Trial (training curves of current best), Trials (interactive table โ€” select a row and press Enter to view its curves).

Key Action
+ / - Y-axis zoom in/out
โ†‘ / โ†“ Y-axis pan (or row selection in Trials tab)
r Reset Y axis to auto
x Toggle X-axis log (Parameters tab)

See Chapter 4: HPO โ€” Monitoring in Real-time for details.


Pre-flight Check

Catches problems in seconds โ€” not after hours of GPU time:

                         Pre-flight Check
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚ Check                   โ”‚ Status โ”‚ Detail                           โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ Import paths & device   โ”‚  PASS  โ”‚                                  โ”‚
โ”‚ Semantic validation     โ”‚  PASS  โ”‚                                  โ”‚
โ”‚ Object instantiation    โ”‚  PASS  โ”‚                                  โ”‚
โ”‚ Data loading            โ”‚  PASS  โ”‚ train=8000, val=2000             โ”‚
โ”‚ Forward pass            โ”‚  PASS  โ”‚ output=(256, 1), loss=0.512341   โ”‚
โ”‚ Shape check             โ”‚  PASS  โ”‚                                  โ”‚
โ”‚ Gradient check          โ”‚  PASS  โ”‚ grad norm=0.034821               โ”‚
โ”‚ Optimizer step          โ”‚  PASS  โ”‚                                  โ”‚
โ”‚ Scheduler step          โ”‚  PASS  โ”‚                                  โ”‚
โ”‚ GPU memory              โ”‚  PASS  โ”‚ peak=42.3 MB (1 batch)           โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
All pre-flight checks passed.

Use --json for machine-readable output (used by AI agent skills for automated parsing).


HPO with Optuna

# Run HPO
python -m cli train configs/my_run.yaml --optimize-config configs/my_opt.yaml

# Analyze results
python -m cli hpo-report --opt-config configs/my_opt.yaml

The custom PFL (Predicted Final Loss) pruner fits shifted exponential decay to early loss history and prunes trials before they waste GPU time.

Study: my_study (MyProject_Opt.db)
Trials: 50 total, 38 completed, 11 pruned, 1 failed

Best Trial #23
  Value: 0.003241

         Parameter Importance
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚ Parameter               โ”‚ Importance                           โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ optimizer_config_lr     โ”‚ 0.8741 โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ    โ”‚
โ”‚ net_config_layers       โ”‚ 0.1259 โ–ˆโ–ˆโ–ˆโ–ˆ                          โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

Boundary Warnings:
  optimizer_config_lr=0.231 at UPPER boundary [1e-3, 1e+0]

A boundary warning means the optimizer would benefit from a wider search range.


AI-Assisted Training (Agent Skills)

This template ships with built-in agent skills that guide the full experiment lifecycle across Claude Code, Codex, and Forge:

You: "Set up HPO for my FluxNet model, version 0.3"

Agent: Creates configs/SolarFlux_v0.3/fluxnet_run.yaml
       Creates configs/SolarFlux_v0.3/fluxnet_opt.yaml
       Runs preflight to catch any config issues
       Launches HPO with SPlus + ExpHyperbolicLR defaults
       Runs hpo-report to analyze results
       Extracts best params โ†’ fluxnet_best.yaml
       Launches final multi-seed training

The skill encodes domain knowledge: correct lr ranges for SPlus (1e-3 to 1e+0), why total_steps must not be synced to HPO epochs for hyperbolic schedulers, and how to interpret boundary warnings.

See skills/pytorch-train/ for details.

Migrating Existing Projects

If you have a project based on an older version of this template, the pytorch-migrate skill can detect your current version and apply incremental updates automatically.

Install skills globally (once per agent):

python -m cli update-skills --agent claude          # install to ~/.claude/skills
python -m cli update-skills --agent codex           # install to ~/.codex/skills
python -m cli update-skills --agent forge           # install to ~/forge/skills
python -m cli update-skills --agent claude --copy   # copy instead of symlink

Use in any project:

cd ~/my-project  # any pytorch_template-based project
# In your agent:
/pytorch-migrate

The skill detects which features are missing (v1 through v6) and applies only the needed migrations, preserving your custom models, data loaders, and callbacks.


Extend It

Custom Model
# my_model.py
class MyTransformer(nn.Module):
    def __init__(self, hparams: dict, device: str = "cpu"):
        super().__init__()
        self.d_model = hparams["d_model"]
net: my_model.MyTransformer
net_config:
  d_model: 256
  nhead: 8
Custom Loss Function
criterion: my_losses.FocalLoss
criterion_config:
  gamma: 2.0
  alpha: 0.25
Custom Data

Create a module with a load_data() function that returns (train_dataset, val_dataset):

# recipes/myproject/data.py
def load_data():
    return train_dataset, val_dataset
data: recipes.myproject.data.load_data

See recipes/regression/ and recipes/classification/ for complete examples.

Custom Metrics
from metrics import MetricRegistry
registry = MetricRegistry(["mse", "mae", "r2", "my_module.MyMetric"])
results = registry.compute(y_pred, y_true)

Project Structure

pytorch_template/
โ”œโ”€โ”€ cli.py              # CLI: train, preflight, validate, preview, doctor, analyze, hpo-report, monitor
โ”œโ”€โ”€ config.py           # RunConfig (frozen, 3-tier validation) + OptimizeConfig
โ”œโ”€โ”€ util.py             # Trainer, run(), predict_final_loss()
โ”œโ”€โ”€ callbacks.py        # 12 built-in callbacks + CallbackRunner
โ”œโ”€โ”€ checkpoint.py       # CheckpointManager + SeedManifest
โ”œโ”€โ”€ provenance.py       # Environment capture + config hashing
โ”œโ”€โ”€ pruner.py           # PFL pruner for Optuna
โ”œโ”€โ”€ metrics.py          # Metric registry (MSE, MAE, R2)
โ”œโ”€โ”€ model.py            # Built-in MLP
โ”œโ”€โ”€ configs/            # YAML config templates
โ”œโ”€โ”€ recipes/            # Example recipes (regression, classification)
โ”œโ”€โ”€ tools/monitor/      # Rust TUI monitor (ratatui)
โ”œโ”€โ”€ tests/              # Unit tests
โ”œโ”€โ”€ docs/               # Human Skill Guide
โ””โ”€โ”€ skills/             # Agent skills (pytorch-train, pytorch-migrate)

Output per Seed Run

runs/{project}/{group}/{seed}/
โ”œโ”€โ”€ model.pt            # Final model state_dict
โ”œโ”€โ”€ latest_model.pt     # Updated every epoch
โ”œโ”€โ”€ metrics.csv         # Real-time CSV log (all metrics)
โ”œโ”€โ”€ env_snapshot.yaml   # Environment metadata
โ”œโ”€โ”€ run_metadata.yaml   # Training metadata
โ”œโ”€โ”€ best.pt             # Best checkpoint (if enabled)
โ””โ”€โ”€ latest.pt           # Full checkpoint (if enabled)

CLI Reference

Command Description
train <config> [--device DEV] [--optimize-config OPT] Train or run HPO
preflight <config> [--device DEV] [--json] 1-batch forward+backward check
validate <config> Structural + runtime config validation
preview <config> Show model architecture and param count
doctor Check Python, PyTorch, CUDA, wandb, packages
hpo-report [--db DB] [--opt-config OPT] [--top-k K] [--json] HPO analysis: param importance, boundary warnings
analyze [--project P] [--group G] [--seed S] Evaluate a trained model
monitor [PATH] [--interval MS] [--list] Launch real-time TUI monitor (or list available runs)
update-skills [--agent AGENT] [--copy] [--uninstall] Install/update agent skills for Claude, Codex, or Forge

All commands are invoked via python -m cli <command>.


Documentation

AI Agent Skill Human Guide
Location skills/pytorch-train/ docs/
Teaches Config rules, param ranges, CLI commands Design decisions, trade-offs, workflow intuition

Read the Human Skill Guide โ€” 5 chapters covering the full pipeline.

License

MIT

Acknowledgments

Release History

VersionChangesUrgencyDate
v0.3.0## New Features - **Dual logging**: `logging: wandb` (default) or `logging: tui` for agent-friendly terminal output - **CSVLoggingCallback** (always active): writes `metrics.csv` every epoch with dynamic column expansion - **TUILoggingCallback**: structured per-epoch terminal output replacing W&B - **LatestModelCallback** (always active): saves `latest_model.pt` every epoch - **Rust TUI monitor** (`tools/monitor/`): real-time loss curve visualization from `metrics.csv` - **Provenance tracking**High4/8/2026
v0.2.0## What's New ### Pre-flight Check Run 1 batch forward+backward before training to catch config errors in seconds: ```bash python -m cli preflight configs/run_template.yaml --device cuda:0 ``` Detects shape mismatches, NaN/Inf gradients, scheduler param issues, and estimates GPU memory. ### HPO Analysis After HPO, understand what Optuna found: ```bash python -m cli hpo-report --opt-config configs/my_opt.yaml ``` Shows parameter importance (fANOVA), boundary warnings, and top-K trial comparisonMedium3/27/2026

Dependencies & License Audit

Loading dependencies...

Similar Packages

auto-deep-researcher-24x7๐Ÿ”ฅ An autonomous AI agent that runs your deep learning experiments 24/7 while you sleep. Zero-cost monitoring, Leader-Worker architecture, constant-size memory.main@2026-06-01
kagglerun๐Ÿš€ Run Python on Kaggle's free GPUs directly from your terminal without the need for a browser, streamlining your data science workflow.master@2026-06-07
vibe-replayTurn AI coding sessions into animated, interactive web replaysv0.2.3
hermes-agentThe agent that grows with youv2026.6.5
flow-nextPlan-first AI workflow plugin for Claude Code, OpenAI Codex, and Factory Droid. Zero-dep task tracking, worker subagents, Ralph autonomous mode, cross-model reviews.flow-next-v1.6.0

More in AI Agents

@blockrun/franklinFranklin โ€” The AI agent with a wallet. Spends USDC autonomously to get real work done. Pay per action, no subscriptions.
hermes-agentThe agent that grows with you
awesome-copilotCommunity-contributed instructions, agents, skills, and configurations to help you make the most of GitHub Copilot.
e2bE2B SDK that give agents cloud environments