Skip to content

Basic Training

This example trains any registered model on any registered dataset with sensible defaults.

Source: examples/train_basic.py

Usage

python examples/train_basic.py
python examples/train_basic.py --model resnet --dataset af-classification --epochs 30
python examples/train_basic.py --model tcn --dataset af-classification --batch-size 64 --download

Walkthrough

1. Set up the data

from deepecgkit.datasets import ECGDataModule
from deepecgkit.registry import get_dataset, get_dataset_info, get_model
from deepecgkit.training import ECGTrainer

ECGTrainer.seed_everything(42)

dataset_info = get_dataset_info("af-classification")
input_channels = dataset_info["input_channels"]
num_classes = dataset_info["num_classes"]

dataset_class = get_dataset("af-classification")
data_module = ECGDataModule(
    dataset_class=dataset_class,
    batch_size=32,
    num_workers=4,
    val_split=0.2,
    test_split=0.1,
    seed=42,
    stratify=True,
    download=True,
)
data_module.setup(stage="fit")
data_module.print_metadata()

The registry provides input channel count and number of classes, so the model can be configured automatically.

2. Create the model

model_class = get_model("kanres")
model = model_class(input_channels=input_channels, output_size=num_classes)

Swap "kanres" for any registered model name — the rest of the code stays the same.

3. Configure training

train_config = {
    "learning_rate": 1e-3,
    "scheduler": {"factor": 0.5, "patience": 5},
    "binary_classification": num_classes == 2,
    "task_type": "classification",
}

4. Train and evaluate

trainer = ECGTrainer(model=model, train_config=train_config)
trainer.fit(
    data_module,
    epochs=50,
    early_stopping_patience=10,
    checkpoint_dir="runs/kanres-af/checkpoints",
)
trainer.test(data_module)

if trainer.best_checkpoint_path:
    print(f"Best checkpoint: {trainer.best_checkpoint_path}")
    print(f"Best val_loss: {trainer.best_val_loss:.4f}")

The trainer automatically saves the top 3 checkpoints ranked by validation loss and stops early if no improvement occurs for 10 consecutive epochs.