Skip to content

deepecgkit

Top-level exports from the DeepECG-Kit package.

deepecgkit

DeepECG-Kit: Deep learning library for ECG analysis and arrhythmia classification.

ECGDataModule

Data module for ECG datasets.

Handles dataset creation, train/val/test splitting, and DataLoader construction.

Source code in deepecgkit/datasets/modules.py
class ECGDataModule:
    """Data module for ECG datasets.

    Handles dataset creation, train/val/test splitting, and DataLoader construction.
    """

    def __init__(
        self,
        dataset: Optional[Union[BaseECGDataset, Dataset]] = None,
        dataset_class: Optional[Type[BaseECGDataset]] = None,
        data_dir: Optional[str] = None,
        sampling_rate: int = 500,
        leads: Optional[list] = None,
        transform: Optional[callable] = None,
        target_transform: Optional[callable] = None,
        download: bool = False,
        batch_size: int = 32,
        val_split: float = 0.2,
        test_split: float = 0.1,
        num_workers: int = 4,
        seed: int = 42,
        stratify: bool = True,
        pin_memory: bool = False,
        persistent_workers: bool = True,
        prefetch_factor: int = 2,
        verbose: bool = True,
        dataset_kwargs: Optional[Dict] = None,
    ):
        """Initialize the ECG data module.

        Args:
            dataset: PyTorch dataset instance (optional)
            dataset_class: Class of dataset to create (optional)
            data_dir: Directory containing ECG data files (optional, uses dataset's default if None)
            sampling_rate: Sampling rate of the ECG signals (Hz)
            leads: List of leads to use (e.g., ['I', 'II', 'III'] for standard leads)
            transform: Optional transform to be applied to the ECG signals
            target_transform: Optional transform to be applied to the labels
            download: Whether to download the dataset if it doesn't exist
            batch_size: Batch size for data loaders
            val_split: Fraction of data to use for validation
            test_split: Fraction of data to use for testing
            num_workers: Number of worker processes for data loading
            seed: Random seed for reproducibility
            stratify: Whether to use stratified splitting based on labels
            dataset_kwargs: Additional keyword arguments to pass to the dataset class
        """
        self.dataset = dataset
        self.dataset_class = dataset_class
        self.data_dir = data_dir
        self.sampling_rate = sampling_rate
        self.leads = leads
        self.transform = transform
        self.target_transform = target_transform
        self.download = download
        self.batch_size = batch_size
        self.val_split = val_split
        self.test_split = test_split
        self.num_workers = num_workers
        self.seed = seed
        self.stratify = stratify
        self.pin_memory = pin_memory
        self.persistent_workers = persistent_workers
        self.prefetch_factor = prefetch_factor
        self.verbose = verbose
        self.dataset_kwargs = dataset_kwargs or {}

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def setup(self, stage: Optional[str] = None):
        """Set up the datasets.

        Args:
            stage: Current stage ('fit', 'validate', 'test', or 'predict')
        """
        if self.dataset is None:
            if self.dataset_class is None:
                raise ValueError("dataset_class must be provided if dataset is None")

            if self.data_dir is None:
                self.data_dir = self.dataset_class.get_default_data_dir()

            init_kwargs = {
                "data_dir": self.data_dir,
                "sampling_rate": self.sampling_rate,
                "transform": self.transform,
                "target_transform": self.target_transform,
                "download": self.download,
            }

            if self.leads is not None:
                init_kwargs["leads"] = self.leads

            init_kwargs.update(self.dataset_kwargs)

            self.dataset = self.dataset_class(**init_kwargs)

        stratify_labels = None
        if (
            self.stratify and len(self.dataset) >= 8
        ):  # Need at least 2 samples per class for 4 classes
            try:
                all_labels = []
                for i in range(len(self.dataset)):
                    _, label = self.dataset[i]
                    all_labels.append(label)
                stacked = torch.stack(all_labels).numpy()

                # For multi-label (2D), convert rows to string keys for stratification
                if stacked.ndim > 1:
                    stratify_labels = np.array(
                        ["_".join(str(int(v)) for v in row) for row in stacked]
                    )
                else:
                    stratify_labels = stacked

                # Check if we have enough samples per class for stratification
                _, counts = np.unique(stratify_labels, return_counts=True)
                if np.min(counts) < 2:
                    stratify_labels = None  # Disable stratification
            except Exception:
                stratify_labels = None  # Fallback to no stratification

        # Extract groups for patient-level splitting (prevents data leakage)
        groups = None
        if hasattr(self.dataset, "record_names") and self.dataset.record_names:
            groups = np.array(self.dataset.record_names)
            if self.verbose:
                n_groups = len(np.unique(groups))
                print(f"Using patient-level splitting ({n_groups} groups)")

        splitter = DataSplitter(
            dataset=self.dataset,
            val_split=self.val_split,
            test_split=self.test_split,
            seed=self.seed,
            stratify=stratify_labels,
            groups=groups,
        )
        self.train_dataset, self.val_dataset, self.test_dataset = splitter.split()

        if self.verbose:
            print(f"Dataset size: {len(self.train_dataset)}")
            print(f"Validation set size: {len(self.val_dataset)}")
            print(f"Test set size: {len(self.test_dataset)}")

    def train_dataloader(self) -> DataLoader:
        """Get the training data loader."""
        if self.train_dataset is None:
            raise RuntimeError("Call setup() before using the data module")
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            persistent_workers=self.persistent_workers and self.num_workers > 0,
            prefetch_factor=self.prefetch_factor if self.num_workers > 0 else None,
        )

    def val_dataloader(self) -> DataLoader:
        """Get the validation data loader."""
        if self.val_dataset is None:
            raise RuntimeError("Call setup() before using the data module")
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            persistent_workers=self.persistent_workers and self.num_workers > 0,
            prefetch_factor=self.prefetch_factor if self.num_workers > 0 else None,
        )

    def test_dataloader(self) -> DataLoader:
        """Get the test data loader."""
        if self.test_dataset is None:
            raise RuntimeError("Call setup() before using the data module")
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            persistent_workers=self.persistent_workers and self.num_workers > 0,
            prefetch_factor=self.prefetch_factor if self.num_workers > 0 else None,
        )

    def get_metadata(self) -> Dict:
        """Get metadata about the dataset.

        Returns:
            Dictionary containing dataset metadata
        """
        if self.dataset is None:
            raise RuntimeError("Call setup() before getting metadata")
        if isinstance(self.dataset, BaseECGDataset):
            return self.dataset.get_metadata()
        return {
            "dataset_size": len(self.dataset),
            "batch_size": self.batch_size,
            "num_workers": self.num_workers,
        }

    def print_metadata(self):
        """Print metadata about the dataset."""
        print(self.get_metadata())

setup

setup(stage: Optional[str] = None)

Set up the datasets.

Parameters:

Name Type Description Default
stage Optional[str]

Current stage ('fit', 'validate', 'test', or 'predict')

None
Source code in deepecgkit/datasets/modules.py
def setup(self, stage: Optional[str] = None):
    """Set up the datasets.

    Args:
        stage: Current stage ('fit', 'validate', 'test', or 'predict')
    """
    if self.dataset is None:
        if self.dataset_class is None:
            raise ValueError("dataset_class must be provided if dataset is None")

        if self.data_dir is None:
            self.data_dir = self.dataset_class.get_default_data_dir()

        init_kwargs = {
            "data_dir": self.data_dir,
            "sampling_rate": self.sampling_rate,
            "transform": self.transform,
            "target_transform": self.target_transform,
            "download": self.download,
        }

        if self.leads is not None:
            init_kwargs["leads"] = self.leads

        init_kwargs.update(self.dataset_kwargs)

        self.dataset = self.dataset_class(**init_kwargs)

    stratify_labels = None
    if (
        self.stratify and len(self.dataset) >= 8
    ):  # Need at least 2 samples per class for 4 classes
        try:
            all_labels = []
            for i in range(len(self.dataset)):
                _, label = self.dataset[i]
                all_labels.append(label)
            stacked = torch.stack(all_labels).numpy()

            # For multi-label (2D), convert rows to string keys for stratification
            if stacked.ndim > 1:
                stratify_labels = np.array(
                    ["_".join(str(int(v)) for v in row) for row in stacked]
                )
            else:
                stratify_labels = stacked

            # Check if we have enough samples per class for stratification
            _, counts = np.unique(stratify_labels, return_counts=True)
            if np.min(counts) < 2:
                stratify_labels = None  # Disable stratification
        except Exception:
            stratify_labels = None  # Fallback to no stratification

    # Extract groups for patient-level splitting (prevents data leakage)
    groups = None
    if hasattr(self.dataset, "record_names") and self.dataset.record_names:
        groups = np.array(self.dataset.record_names)
        if self.verbose:
            n_groups = len(np.unique(groups))
            print(f"Using patient-level splitting ({n_groups} groups)")

    splitter = DataSplitter(
        dataset=self.dataset,
        val_split=self.val_split,
        test_split=self.test_split,
        seed=self.seed,
        stratify=stratify_labels,
        groups=groups,
    )
    self.train_dataset, self.val_dataset, self.test_dataset = splitter.split()

    if self.verbose:
        print(f"Dataset size: {len(self.train_dataset)}")
        print(f"Validation set size: {len(self.val_dataset)}")
        print(f"Test set size: {len(self.test_dataset)}")

train_dataloader

train_dataloader() -> DataLoader

Get the training data loader.

Source code in deepecgkit/datasets/modules.py
def train_dataloader(self) -> DataLoader:
    """Get the training data loader."""
    if self.train_dataset is None:
        raise RuntimeError("Call setup() before using the data module")
    return DataLoader(
        self.train_dataset,
        batch_size=self.batch_size,
        shuffle=True,
        num_workers=self.num_workers,
        pin_memory=self.pin_memory,
        persistent_workers=self.persistent_workers and self.num_workers > 0,
        prefetch_factor=self.prefetch_factor if self.num_workers > 0 else None,
    )

val_dataloader

val_dataloader() -> DataLoader

Get the validation data loader.

Source code in deepecgkit/datasets/modules.py
def val_dataloader(self) -> DataLoader:
    """Get the validation data loader."""
    if self.val_dataset is None:
        raise RuntimeError("Call setup() before using the data module")
    return DataLoader(
        self.val_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        pin_memory=self.pin_memory,
        persistent_workers=self.persistent_workers and self.num_workers > 0,
        prefetch_factor=self.prefetch_factor if self.num_workers > 0 else None,
    )

test_dataloader

test_dataloader() -> DataLoader

Get the test data loader.

Source code in deepecgkit/datasets/modules.py
def test_dataloader(self) -> DataLoader:
    """Get the test data loader."""
    if self.test_dataset is None:
        raise RuntimeError("Call setup() before using the data module")
    return DataLoader(
        self.test_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        pin_memory=self.pin_memory,
        persistent_workers=self.persistent_workers and self.num_workers > 0,
        prefetch_factor=self.prefetch_factor if self.num_workers > 0 else None,
    )

get_metadata

get_metadata() -> Dict

Get metadata about the dataset.

Returns:

Type Description
Dict

Dictionary containing dataset metadata

Source code in deepecgkit/datasets/modules.py
def get_metadata(self) -> Dict:
    """Get metadata about the dataset.

    Returns:
        Dictionary containing dataset metadata
    """
    if self.dataset is None:
        raise RuntimeError("Call setup() before getting metadata")
    if isinstance(self.dataset, BaseECGDataset):
        return self.dataset.get_metadata()
    return {
        "dataset_size": len(self.dataset),
        "batch_size": self.batch_size,
        "num_workers": self.num_workers,
    }

print_metadata

print_metadata()

Print metadata about the dataset.

Source code in deepecgkit/datasets/modules.py
def print_metadata(self):
    """Print metadata about the dataset."""
    print(self.get_metadata())

ECGTrainer

Trainer for ECG signal classification and regression models.

Wraps a plain nn.Module and provides fit/test methods with built-in early stopping, checkpointing, LR scheduling, and CSV metric logging.

Parameters:

Name Type Description Default
model

The ECG model to train (any nn.Module)

required
train_config

Dictionary containing training configuration: - learning_rate: Learning rate for optimizer - scheduler: Dict with 'factor' and 'patience' for ReduceLROnPlateau - binary_classification: Bool, if True uses BCE loss for binary tasks - multi_label: Bool, if True uses BCE loss for multi-label tasks - task_type: 'classification' or 'regression' - pos_weight: Optional list of positive class weights for BCE loss

required
device

Device to train on ('auto', 'cpu', 'cuda', 'mps')

'auto'
use_plateau_scheduler

If True, uses ReduceLROnPlateau, else StepLR

True
Example

model = KanResWideX(input_channels=1, output_size=4) config = { ... "learning_rate": 0.001, ... "scheduler": {"factor": 0.5, "patience": 10}, ... "binary_classification": False, ... } trainer = ECGTrainer(model=model, train_config=config) trainer.fit(data_module, epochs=50)

Source code in deepecgkit/training/train.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
class ECGTrainer:
    """
    Trainer for ECG signal classification and regression models.

    Wraps a plain nn.Module and provides fit/test methods with built-in
    early stopping, checkpointing, LR scheduling, and CSV metric logging.

    Args:
        model: The ECG model to train (any nn.Module)
        train_config: Dictionary containing training configuration:
            - learning_rate: Learning rate for optimizer
            - scheduler: Dict with 'factor' and 'patience' for ReduceLROnPlateau
            - binary_classification: Bool, if True uses BCE loss for binary tasks
            - multi_label: Bool, if True uses BCE loss for multi-label tasks
            - task_type: 'classification' or 'regression'
            - pos_weight: Optional list of positive class weights for BCE loss
        device: Device to train on ('auto', 'cpu', 'cuda', 'mps')
        use_plateau_scheduler: If True, uses ReduceLROnPlateau, else StepLR

    Example:
        >>> model = KanResWideX(input_channels=1, output_size=4)
        >>> config = {
        ...     "learning_rate": 0.001,
        ...     "scheduler": {"factor": 0.5, "patience": 10},
        ...     "binary_classification": False,
        ... }
        >>> trainer = ECGTrainer(model=model, train_config=config)
        >>> trainer.fit(data_module, epochs=50)
    """

    def __init__(self, model, train_config, device="auto", use_plateau_scheduler=True):
        self.model = model
        self.train_config = train_config

        self.learning_rate = train_config["learning_rate"]
        self.scheduler_factor = train_config["scheduler"]["factor"]
        self.scheduler_patience = train_config["scheduler"]["patience"]
        self.use_plateau_scheduler = use_plateau_scheduler
        self.multi_label = train_config.get("multi_label", False)

        if self.multi_label or train_config.get("binary_classification", False):
            pos_weight = train_config.get("pos_weight")
            if pos_weight is not None:
                pos_weight = torch.tensor(pos_weight, dtype=torch.float32)
            self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        elif train_config.get("task_type", "classification") == "classification":
            self.criterion = torch.nn.CrossEntropyLoss()
        else:
            self.criterion = torch.nn.MSELoss()

        if device == "auto":
            if torch.cuda.is_available():
                self.device = torch.device("cuda")
            elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
                self.device = torch.device("mps")
            else:
                self.device = torch.device("cpu")
        else:
            self.device = torch.device(device)

        self.model.to(self.device)
        self.criterion.to(self.device)

        self.optimizer = None
        self.scheduler = None
        self.test_predictions = []
        self.test_targets = []
        self.test_probabilities = []
        self.log_dir = None
        self.best_checkpoint_path = None
        self.best_val_loss = float("inf")

    @property
    def _is_binary(self):
        return isinstance(self.criterion, torch.nn.BCEWithLogitsLoss) and not self.multi_label

    @property
    def _is_classification(self):
        return isinstance(self.criterion, (torch.nn.CrossEntropyLoss, torch.nn.BCEWithLogitsLoss))

    def _calculate_loss(self, y_hat, y):
        if self.multi_label:
            return self.criterion(y_hat, y.float())
        if self._is_binary:
            return self.criterion(y_hat.squeeze(-1), y.float())
        if isinstance(self.criterion, torch.nn.CrossEntropyLoss):
            return self.criterion(y_hat, y.long())
        return self.criterion(y_hat.float(), y.float())

    def _compute_acc(self, y_hat, y):
        with torch.no_grad():
            if self.multi_label:
                preds = (y_hat > 0).float()
                return (preds == y.float()).float().mean().item()
            if self._is_binary:
                preds = (y_hat.squeeze(-1) > 0).long()
                return (preds == y.long()).float().mean().item()
            return (torch.argmax(y_hat, dim=1) == y.long()).float().mean().item()

    def _setup_optimizer(self):
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
        if self.use_plateau_scheduler:
            self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
                mode="min",
                factor=self.scheduler_factor,
                patience=self.scheduler_patience,
            )
        else:
            self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=0.9)

    def _run_epoch(self, dataloader, train=True):
        if train:
            self.model.train()
        else:
            self.model.eval()

        total_loss = 0.0
        total_acc = 0.0
        n_batches = 0

        ctx = torch.no_grad() if not train else _NullContext()
        with ctx:
            for batch in dataloader:
                x, y = batch
                x, y = x.to(self.device), y.to(self.device)

                y_hat = self.model(x)
                loss = self._calculate_loss(y_hat, y)

                if train:
                    if hasattr(self.model, "l2_regularization_loss"):
                        loss = loss + self.model.l2_regularization_loss()
                    self.optimizer.zero_grad()
                    loss.backward()
                    if self._gradient_clip_val is not None:
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self._gradient_clip_val
                        )
                    self.optimizer.step()

                total_loss += loss.item()
                if self._is_classification:
                    total_acc += self._compute_acc(y_hat, y)
                n_batches += 1

        avg_loss = total_loss / max(n_batches, 1)
        avg_acc = total_acc / max(n_batches, 1) if self._is_classification else None
        return avg_loss, avg_acc

    def fit(
        self,
        data_module,
        epochs=50,
        early_stopping_patience=10,
        checkpoint_dir=None,
        log_dir=None,
        progress_bar=True,
        gradient_clip_val=None,
        save_top_k=3,
    ):
        """Train the model.

        Args:
            data_module: ECGDataModule (or any object with train_dataloader/val_dataloader)
            epochs: Maximum number of training epochs
            early_stopping_patience: Stop after this many epochs without val_loss improvement
            checkpoint_dir: Directory to save checkpoints (None to disable)
            log_dir: Directory to save CSV metrics log (None to disable)
            progress_bar: Whether to show a tqdm progress bar
            gradient_clip_val: Max gradient norm for clipping (None to disable)
            save_top_k: Number of best checkpoints to keep
        """
        self._gradient_clip_val = gradient_clip_val
        self._setup_optimizer()

        if hasattr(data_module, "setup"):
            data_module.setup(stage="fit")

        train_loader = data_module.train_dataloader()
        val_loader = data_module.val_dataloader()

        if checkpoint_dir is not None:
            checkpoint_dir = Path(checkpoint_dir)
            checkpoint_dir.mkdir(parents=True, exist_ok=True)

        csv_writer = None
        csv_file = None
        if log_dir is not None:
            self.log_dir = str(log_dir)
            os.makedirs(log_dir, exist_ok=True)
            metrics_path = Path(log_dir) / "metrics.csv"
            csv_file = open(metrics_path, "w", newline="")
            fieldnames = ["epoch", "train_loss", "val_loss"]
            if self._is_classification:
                fieldnames.extend(["train_acc", "val_acc"])
            csv_writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
            csv_writer.writeheader()

        self.best_val_loss = float("inf")
        patience_counter = 0
        saved_checkpoints = []

        use_tqdm = progress_bar and tqdm is not None
        epoch_iter = tqdm(range(epochs), desc="Training") if use_tqdm else range(epochs)

        try:
            for epoch in epoch_iter:
                train_loss, train_acc = self._run_epoch(train_loader, train=True)
                val_loss, val_acc = self._run_epoch(val_loader, train=False)

                if self.use_plateau_scheduler:
                    self.scheduler.step(val_loss)
                else:
                    self.scheduler.step()

                if use_tqdm:
                    desc = f"Epoch {epoch + 1}/{epochs} | train_loss={train_loss:.4f} val_loss={val_loss:.4f}"
                    if self._is_classification:
                        desc += f" val_acc={val_acc:.4f}"
                    epoch_iter.set_description(desc)

                if csv_writer is not None:
                    row = {
                        "epoch": epoch + 1,
                        "train_loss": f"{train_loss:.6f}",
                        "val_loss": f"{val_loss:.6f}",
                    }
                    if self._is_classification:
                        row["train_acc"] = f"{train_acc:.6f}"
                        row["val_acc"] = f"{val_acc:.6f}"
                    csv_writer.writerow(row)
                    csv_file.flush()

                if val_loss < self.best_val_loss:
                    self.best_val_loss = val_loss
                    patience_counter = 0

                    if checkpoint_dir is not None:
                        ckpt_path = (
                            checkpoint_dir / f"epoch={epoch + 1:02d}-val_loss={val_loss:.4f}.pt"
                        )
                        self.save_checkpoint(str(ckpt_path), epoch=epoch + 1)
                        self.best_checkpoint_path = str(ckpt_path)
                        saved_checkpoints.append((val_loss, str(ckpt_path)))
                        saved_checkpoints.sort(key=lambda x: x[0])
                        while len(saved_checkpoints) > save_top_k:
                            _, old_path = saved_checkpoints.pop()
                            if os.path.exists(old_path):
                                os.remove(old_path)
                else:
                    patience_counter += 1
                    if patience_counter >= early_stopping_patience:
                        if not use_tqdm:
                            print(f"Early stopping at epoch {epoch + 1}")
                        break
        finally:
            if csv_file is not None:
                csv_file.close()

    def test(self, data_module):
        """Evaluate the model on the test set.

        Args:
            data_module: ECGDataModule (or any object with test_dataloader)

        Returns:
            Dict with test_loss and test_acc (if classification)
        """
        if hasattr(data_module, "setup"):
            data_module.setup(stage="test")

        test_loader = data_module.test_dataloader()
        return self._evaluate_loader(test_loader)

    def validate(self, data_module):
        """Evaluate the model on the validation set.

        Args:
            data_module: ECGDataModule (or any object with val_dataloader)

        Returns:
            Dict with val_loss and val_acc (if classification)
        """
        if hasattr(data_module, "setup"):
            data_module.setup(stage="validate")

        val_loader = data_module.val_dataloader()
        return self._evaluate_loader(val_loader)

    def _evaluate_loader(self, dataloader):
        self.model.eval()
        self.test_predictions = []
        self.test_targets = []
        self.test_probabilities = []

        total_loss = 0.0
        total_acc = 0.0
        n_batches = 0

        with torch.no_grad():
            for batch in dataloader:
                x, y = batch
                x, y = x.to(self.device), y.to(self.device)
                y_hat = self.model(x)
                loss = self._calculate_loss(y_hat, y)
                total_loss += loss.item()

                if self._is_classification:
                    if self.multi_label:
                        probs = torch.sigmoid(y_hat)
                        preds = (probs > 0.5).long()
                        acc = (preds == y.long()).float().mean().item()
                        self.test_predictions.append(preds.cpu())
                        self.test_targets.append(y.long().cpu())
                        self.test_probabilities.append(probs.cpu())
                    elif self._is_binary:
                        probs_pos = torch.sigmoid(y_hat.squeeze(-1))
                        probs = torch.stack([1 - probs_pos, probs_pos], dim=1)
                        preds = (probs_pos > 0.5).long()
                        acc = (preds == y.long()).float().mean().item()
                        self.test_predictions.append(preds.cpu())
                        self.test_targets.append(y.long().cpu())
                        self.test_probabilities.append(probs.cpu())
                    else:
                        probs = torch.softmax(y_hat, dim=1)
                        preds = torch.argmax(probs, dim=1)
                        acc = (preds == y.long()).float().mean().item()
                        self.test_predictions.append(preds.cpu())
                        self.test_targets.append(y.long().cpu())
                        self.test_probabilities.append(probs.cpu())
                    total_acc += acc

                n_batches += 1

        avg_loss = total_loss / max(n_batches, 1)
        results = {"test_loss": avg_loss}
        if self._is_classification:
            results["test_acc"] = total_acc / max(n_batches, 1)
        return results

    def get_test_results(self):
        """Get test predictions, targets, and probabilities as numpy arrays.

        Returns:
            Tuple of (predictions, targets, probabilities) as numpy arrays,
            or (None, None, None) if no test results available.
        """
        if not self.test_predictions:
            return None, None, None
        return (
            torch.cat(self.test_predictions).numpy(),
            torch.cat(self.test_targets).numpy(),
            torch.cat(self.test_probabilities).numpy(),
        )

    def save_checkpoint(self, path, epoch=None):
        """Save a checkpoint.

        Args:
            path: File path to save to
            epoch: Current epoch number (optional)
        """
        checkpoint = {
            "model_state_dict": self.model.state_dict(),
            "train_config": self.train_config,
            "epoch": epoch,
            "best_val_loss": self.best_val_loss,
        }
        if self.optimizer is not None:
            checkpoint["optimizer_state_dict"] = self.optimizer.state_dict()
        torch.save(checkpoint, path)

    @classmethod
    def load_checkpoint(cls, path, model=None, device="auto"):
        """Load a trainer from a checkpoint.

        Args:
            path: Path to checkpoint file
            model: Model instance to load weights into. Required.
            device: Device to load onto

        Returns:
            ECGTrainer instance with loaded weights
        """
        checkpoint = torch.load(path, map_location="cpu", weights_only=False)

        if model is None:
            raise ValueError("model argument is required for load_checkpoint")

        trainer = cls(
            model=model,
            train_config=checkpoint["train_config"],
            device=device,
        )
        trainer.model.load_state_dict(checkpoint["model_state_dict"])
        trainer.best_val_loss = checkpoint.get("best_val_loss", float("inf"))
        return trainer

    @staticmethod
    def seed_everything(seed):
        """Set random seeds for reproducibility.

        Args:
            seed: Random seed value
        """
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

fit

fit(
    data_module,
    epochs=50,
    early_stopping_patience=10,
    checkpoint_dir=None,
    log_dir=None,
    progress_bar=True,
    gradient_clip_val=None,
    save_top_k=3,
)

Train the model.

Parameters:

Name Type Description Default
data_module

ECGDataModule (or any object with train_dataloader/val_dataloader)

required
epochs

Maximum number of training epochs

50
early_stopping_patience

Stop after this many epochs without val_loss improvement

10
checkpoint_dir

Directory to save checkpoints (None to disable)

None
log_dir

Directory to save CSV metrics log (None to disable)

None
progress_bar

Whether to show a tqdm progress bar

True
gradient_clip_val

Max gradient norm for clipping (None to disable)

None
save_top_k

Number of best checkpoints to keep

3
Source code in deepecgkit/training/train.py
def fit(
    self,
    data_module,
    epochs=50,
    early_stopping_patience=10,
    checkpoint_dir=None,
    log_dir=None,
    progress_bar=True,
    gradient_clip_val=None,
    save_top_k=3,
):
    """Train the model.

    Args:
        data_module: ECGDataModule (or any object with train_dataloader/val_dataloader)
        epochs: Maximum number of training epochs
        early_stopping_patience: Stop after this many epochs without val_loss improvement
        checkpoint_dir: Directory to save checkpoints (None to disable)
        log_dir: Directory to save CSV metrics log (None to disable)
        progress_bar: Whether to show a tqdm progress bar
        gradient_clip_val: Max gradient norm for clipping (None to disable)
        save_top_k: Number of best checkpoints to keep
    """
    self._gradient_clip_val = gradient_clip_val
    self._setup_optimizer()

    if hasattr(data_module, "setup"):
        data_module.setup(stage="fit")

    train_loader = data_module.train_dataloader()
    val_loader = data_module.val_dataloader()

    if checkpoint_dir is not None:
        checkpoint_dir = Path(checkpoint_dir)
        checkpoint_dir.mkdir(parents=True, exist_ok=True)

    csv_writer = None
    csv_file = None
    if log_dir is not None:
        self.log_dir = str(log_dir)
        os.makedirs(log_dir, exist_ok=True)
        metrics_path = Path(log_dir) / "metrics.csv"
        csv_file = open(metrics_path, "w", newline="")
        fieldnames = ["epoch", "train_loss", "val_loss"]
        if self._is_classification:
            fieldnames.extend(["train_acc", "val_acc"])
        csv_writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
        csv_writer.writeheader()

    self.best_val_loss = float("inf")
    patience_counter = 0
    saved_checkpoints = []

    use_tqdm = progress_bar and tqdm is not None
    epoch_iter = tqdm(range(epochs), desc="Training") if use_tqdm else range(epochs)

    try:
        for epoch in epoch_iter:
            train_loss, train_acc = self._run_epoch(train_loader, train=True)
            val_loss, val_acc = self._run_epoch(val_loader, train=False)

            if self.use_plateau_scheduler:
                self.scheduler.step(val_loss)
            else:
                self.scheduler.step()

            if use_tqdm:
                desc = f"Epoch {epoch + 1}/{epochs} | train_loss={train_loss:.4f} val_loss={val_loss:.4f}"
                if self._is_classification:
                    desc += f" val_acc={val_acc:.4f}"
                epoch_iter.set_description(desc)

            if csv_writer is not None:
                row = {
                    "epoch": epoch + 1,
                    "train_loss": f"{train_loss:.6f}",
                    "val_loss": f"{val_loss:.6f}",
                }
                if self._is_classification:
                    row["train_acc"] = f"{train_acc:.6f}"
                    row["val_acc"] = f"{val_acc:.6f}"
                csv_writer.writerow(row)
                csv_file.flush()

            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                patience_counter = 0

                if checkpoint_dir is not None:
                    ckpt_path = (
                        checkpoint_dir / f"epoch={epoch + 1:02d}-val_loss={val_loss:.4f}.pt"
                    )
                    self.save_checkpoint(str(ckpt_path), epoch=epoch + 1)
                    self.best_checkpoint_path = str(ckpt_path)
                    saved_checkpoints.append((val_loss, str(ckpt_path)))
                    saved_checkpoints.sort(key=lambda x: x[0])
                    while len(saved_checkpoints) > save_top_k:
                        _, old_path = saved_checkpoints.pop()
                        if os.path.exists(old_path):
                            os.remove(old_path)
            else:
                patience_counter += 1
                if patience_counter >= early_stopping_patience:
                    if not use_tqdm:
                        print(f"Early stopping at epoch {epoch + 1}")
                    break
    finally:
        if csv_file is not None:
            csv_file.close()

test

test(data_module)

Evaluate the model on the test set.

Parameters:

Name Type Description Default
data_module

ECGDataModule (or any object with test_dataloader)

required

Returns:

Type Description

Dict with test_loss and test_acc (if classification)

Source code in deepecgkit/training/train.py
def test(self, data_module):
    """Evaluate the model on the test set.

    Args:
        data_module: ECGDataModule (or any object with test_dataloader)

    Returns:
        Dict with test_loss and test_acc (if classification)
    """
    if hasattr(data_module, "setup"):
        data_module.setup(stage="test")

    test_loader = data_module.test_dataloader()
    return self._evaluate_loader(test_loader)

validate

validate(data_module)

Evaluate the model on the validation set.

Parameters:

Name Type Description Default
data_module

ECGDataModule (or any object with val_dataloader)

required

Returns:

Type Description

Dict with val_loss and val_acc (if classification)

Source code in deepecgkit/training/train.py
def validate(self, data_module):
    """Evaluate the model on the validation set.

    Args:
        data_module: ECGDataModule (or any object with val_dataloader)

    Returns:
        Dict with val_loss and val_acc (if classification)
    """
    if hasattr(data_module, "setup"):
        data_module.setup(stage="validate")

    val_loader = data_module.val_dataloader()
    return self._evaluate_loader(val_loader)

get_test_results

get_test_results()

Get test predictions, targets, and probabilities as numpy arrays.

Returns:

Type Description

Tuple of (predictions, targets, probabilities) as numpy arrays,

or (None, None, None) if no test results available.

Source code in deepecgkit/training/train.py
def get_test_results(self):
    """Get test predictions, targets, and probabilities as numpy arrays.

    Returns:
        Tuple of (predictions, targets, probabilities) as numpy arrays,
        or (None, None, None) if no test results available.
    """
    if not self.test_predictions:
        return None, None, None
    return (
        torch.cat(self.test_predictions).numpy(),
        torch.cat(self.test_targets).numpy(),
        torch.cat(self.test_probabilities).numpy(),
    )

save_checkpoint

save_checkpoint(path, epoch=None)

Save a checkpoint.

Parameters:

Name Type Description Default
path

File path to save to

required
epoch

Current epoch number (optional)

None
Source code in deepecgkit/training/train.py
def save_checkpoint(self, path, epoch=None):
    """Save a checkpoint.

    Args:
        path: File path to save to
        epoch: Current epoch number (optional)
    """
    checkpoint = {
        "model_state_dict": self.model.state_dict(),
        "train_config": self.train_config,
        "epoch": epoch,
        "best_val_loss": self.best_val_loss,
    }
    if self.optimizer is not None:
        checkpoint["optimizer_state_dict"] = self.optimizer.state_dict()
    torch.save(checkpoint, path)

load_checkpoint classmethod

load_checkpoint(path, model=None, device='auto')

Load a trainer from a checkpoint.

Parameters:

Name Type Description Default
path

Path to checkpoint file

required
model

Model instance to load weights into. Required.

None
device

Device to load onto

'auto'

Returns:

Type Description

ECGTrainer instance with loaded weights

Source code in deepecgkit/training/train.py
@classmethod
def load_checkpoint(cls, path, model=None, device="auto"):
    """Load a trainer from a checkpoint.

    Args:
        path: Path to checkpoint file
        model: Model instance to load weights into. Required.
        device: Device to load onto

    Returns:
        ECGTrainer instance with loaded weights
    """
    checkpoint = torch.load(path, map_location="cpu", weights_only=False)

    if model is None:
        raise ValueError("model argument is required for load_checkpoint")

    trainer = cls(
        model=model,
        train_config=checkpoint["train_config"],
        device=device,
    )
    trainer.model.load_state_dict(checkpoint["model_state_dict"])
    trainer.best_val_loss = checkpoint.get("best_val_loss", float("inf"))
    return trainer

seed_everything staticmethod

seed_everything(seed)

Set random seeds for reproducibility.

Parameters:

Name Type Description Default
seed

Random seed value

required
Source code in deepecgkit/training/train.py
@staticmethod
def seed_everything(seed):
    """Set random seeds for reproducibility.

    Args:
        seed: Random seed value
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

KanResWideX

Bases: Module

KanRes-Wide-X model for ECG signal classification.

A convolutional neural network architecture designed for ECG signal analysis with residual connections and wide blocks for improved feature extraction.

Parameters:

Name Type Description Default
input_channels int

Number of input channels (default: 1 for single-lead ECG)

1
output_size int

Number of output classes or regression targets

4
base_channels int

Base number of channels for the first layer (default: 64)

64
Example

model = KanResWideX(input_channels=1, output_size=4) x = torch.randn(32, 1, 3000) output = model(x) print(output.shape) # [32, 4]

features = model.extract_features(x) print(features.shape) # (32, 64)

Source code in deepecgkit/models/kanres_x.py
@register_model(
    name="kanres",
    description="KAN-ResNet architecture with wide layers",
)
class KanResWideX(nn.Module):
    """
    KanRes-Wide-X model for ECG signal classification.

    A convolutional neural network architecture designed for ECG signal analysis
    with residual connections and wide blocks for improved feature extraction.

    Args:
        input_channels: Number of input channels (default: 1 for single-lead ECG)
        output_size: Number of output classes or regression targets
        base_channels: Base number of channels for the first layer (default: 64)

    Example:
        >>> model = KanResWideX(input_channels=1, output_size=4)
        >>> x = torch.randn(32, 1, 3000)
        >>> output = model(x)
        >>> print(output.shape)  # [32, 4]

        >>> features = model.extract_features(x)
        >>> print(features.shape)  # (32, 64)
    """

    def __init__(self, input_channels: int = 1, output_size: int = 4, base_channels: int = 64):
        super().__init__()

        self.input_layer = ConvBlock(input_channels, base_channels)
        self.res_modules = nn.Sequential(
            KanResModule(base_channels), KanResModule(base_channels), KanResModule(base_channels)
        )
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self._feature_dim = base_channels
        self.classifier = nn.Linear(base_channels, output_size)

    @property
    def feature_dim(self) -> int:
        return self._feature_dim

    def extract_features(self, x: torch.Tensor) -> torch.Tensor:
        x = self.input_layer(x)
        x = self.res_modules(x)
        x = self.global_pool(x)
        x = x.squeeze(-1)
        return x

    def forward(self, x):
        x = self.extract_features(x)
        return self.classifier(x)

    @classmethod
    def from_pretrained(
        cls,
        weights: str,
        map_location: Optional[Union[str, torch.device]] = None,
        force_download: bool = False,
        **kwargs,
    ) -> "KanResWideX":
        """Load a pretrained KanResWideX model.

        Args:
            weights: Name of pretrained weights (e.g., "kanres-af-30s") or path to weights file
            map_location: Device to map weights to (e.g., "cpu", "cuda")
            force_download: If True, re-download weights even if cached
            **kwargs: Override default model parameters from the weight registry

        Returns:
            Model with pretrained weights loaded

        Example:
            >>> model = KanResWideX.from_pretrained("kanres-af-30s")
            >>> model = KanResWideX.from_pretrained("kanres-af-30s", map_location="cuda")
            >>> model = KanResWideX.from_pretrained("/path/to/weights.pt", output_size=2)
        """
        weight_path = Path(weights)
        if weight_path.exists():
            state_dict = torch.load(weight_path, map_location=map_location, weights_only=True)
            model = cls(**kwargs)
        else:
            info = get_weight_info(weights)
            model_kwargs = {**info["model_kwargs"], **kwargs}
            model = cls(**model_kwargs)
            state_dict = load_pretrained_weights(weights, map_location, force_download)

        model.load_state_dict(state_dict)
        return model

from_pretrained classmethod

from_pretrained(
    weights: str,
    map_location: Optional[Union[str, device]] = None,
    force_download: bool = False,
    **kwargs,
) -> KanResWideX

Load a pretrained KanResWideX model.

Parameters:

Name Type Description Default
weights str

Name of pretrained weights (e.g., "kanres-af-30s") or path to weights file

required
map_location Optional[Union[str, device]]

Device to map weights to (e.g., "cpu", "cuda")

None
force_download bool

If True, re-download weights even if cached

False
**kwargs

Override default model parameters from the weight registry

{}

Returns:

Type Description
KanResWideX

Model with pretrained weights loaded

Example

model = KanResWideX.from_pretrained("kanres-af-30s") model = KanResWideX.from_pretrained("kanres-af-30s", map_location="cuda") model = KanResWideX.from_pretrained("/path/to/weights.pt", output_size=2)

Source code in deepecgkit/models/kanres_x.py
@classmethod
def from_pretrained(
    cls,
    weights: str,
    map_location: Optional[Union[str, torch.device]] = None,
    force_download: bool = False,
    **kwargs,
) -> "KanResWideX":
    """Load a pretrained KanResWideX model.

    Args:
        weights: Name of pretrained weights (e.g., "kanres-af-30s") or path to weights file
        map_location: Device to map weights to (e.g., "cpu", "cuda")
        force_download: If True, re-download weights even if cached
        **kwargs: Override default model parameters from the weight registry

    Returns:
        Model with pretrained weights loaded

    Example:
        >>> model = KanResWideX.from_pretrained("kanres-af-30s")
        >>> model = KanResWideX.from_pretrained("kanres-af-30s", map_location="cuda")
        >>> model = KanResWideX.from_pretrained("/path/to/weights.pt", output_size=2)
    """
    weight_path = Path(weights)
    if weight_path.exists():
        state_dict = torch.load(weight_path, map_location=map_location, weights_only=True)
        model = cls(**kwargs)
    else:
        info = get_weight_info(weights)
        model_kwargs = {**info["model_kwargs"], **kwargs}
        model = cls(**model_kwargs)
        state_dict = load_pretrained_weights(weights, map_location, force_download)

    model.load_state_dict(state_dict)
    return model

read_csv

read_csv(
    csv_file: str,
    delimiter: str = ",",
    transpose: bool = False,
    skip_header: bool = True,
    dtype: Optional[type] = None,
) -> Tuple[np.ndarray, Dict[str, int]]

Read CSV file and return data array and header mapping.

Parameters:

Name Type Description Default
csv_file str

Path to the CSV file

required
delimiter str

Column delimiter (default: ',')

','
transpose bool

Whether to transpose the data array

False
skip_header bool

Whether to skip the first row as header

True
dtype Optional[type]

Data type for the numpy array

None

Returns:

Type Description
Tuple[ndarray, Dict[str, int]]

Tuple of (data_array, header_mapping)

Source code in deepecgkit/utils/__init__.py
def read_csv(
    csv_file: str,
    delimiter: str = ",",
    transpose: bool = False,
    skip_header: bool = True,
    dtype: Optional[type] = None,
) -> Tuple[np.ndarray, Dict[str, int]]:
    """
    Read CSV file and return data array and header mapping.

    Args:
        csv_file: Path to the CSV file
        delimiter: Column delimiter (default: ',')
        transpose: Whether to transpose the data array
        skip_header: Whether to skip the first row as header
        dtype: Data type for the numpy array

    Returns:
        Tuple of (data_array, header_mapping)
    """
    if not os.path.exists(csv_file):
        raise FileNotFoundError(f"CSV file not found: {csv_file}")

    data: List[List[str]] = []
    header: Dict[str, int] = {}

    try:
        with open(csv_file) as f:
            csv_data = csv.reader(f, delimiter=delimiter)

            if skip_header:
                try:
                    temp = next(csv_data)
                    header = {k: v for v, k in enumerate(temp)}
                except StopIteration as err:
                    raise ValueError("File is empty") from err

            for row in csv_data:
                data.append(row)

        if not data:
            raise ValueError("No data found in CSV file")

        data_array = np.array(data, dtype=dtype)

        if transpose:
            data_array = np.transpose(data_array)

        return data_array, header

    except Exception as err:
        raise ValueError(f"Error reading CSV file: {err!s}") from err