Skip to content

deepecgkit.datasets

Dataset classes, data modules, and preprocessing utilities for ECG analysis.

Base Classes

BaseECGDataset

Bases: Dataset, ABC

Base class for all ECG datasets in deepecg-kit.

This class defines the common interface and functionality for all ECG datasets. Each specific dataset implementation should inherit from this class and implement the required methods.

Source code in deepecgkit/datasets/base.py
class BaseECGDataset(Dataset, ABC):
    """Base class for all ECG datasets in deepecg-kit.

    This class defines the common interface and functionality for all ECG datasets.
    Each specific dataset implementation should inherit from this class and implement
    the required methods.
    """

    @classmethod
    def get_default_data_dir(cls) -> Path:
        """Get the default data directory for this dataset.

        Returns:
            Path to the default data directory
        """
        return Path.home() / ".deepecgkit" / "datasets" / cls.__name__.lower()

    def __init__(
        self,
        data_dir: Optional[Union[str, Path]] = None,
        sampling_rate: int = 500,
        leads: Optional[List[str]] = None,
        transform: Optional[callable] = None,
        target_transform: Optional[callable] = None,
        download: bool = False,
        force_download: bool = False,
    ):
        """Initialize the base ECG dataset.

        Args:
            data_dir: Directory where the dataset is stored or will be downloaded.
                     If None, uses the default data directory.
            sampling_rate: Sampling rate of the ECG signals (Hz)
            leads: List of leads to use (e.g., ['I', 'II', 'III'] for standard leads)
            transform: Optional transform to be applied to the ECG signals
            target_transform: Optional transform to be applied to the labels
            download: Whether to download the dataset if it doesn't exist locally.
            force_download: Whether to force re-download even if the dataset exists
        """
        self.data_dir = Path(data_dir) if data_dir is not None else self.get_default_data_dir()
        self.sampling_rate = sampling_rate
        self.leads = leads
        self.transform = transform
        self.target_transform = target_transform
        self.home_dir = Path.home()

        if force_download:
            self._clear_dataset()
            self.download()
        elif not self.data_dir.exists():
            if download:
                self.download()

        self._load_data()

    def _clear_dataset(self):
        """Clear the dataset directory for re-download."""
        if self.data_dir.exists():
            print(f"Clearing existing dataset at {self.data_dir}...")
            shutil.rmtree(self.data_dir)
        self.data_dir.mkdir(parents=True, exist_ok=True)

    @abstractmethod
    def download(self):
        """Download the dataset if it doesn't exist."""
        pass

    @abstractmethod
    def _load_data(self):
        """Load the dataset into memory."""
        pass

    @abstractmethod
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get a sample from the dataset.

        Args:
            idx: Index of the sample to get

        Returns:
            Tuple of (ecg_signal, label) where:
                - ecg_signal: Tensor of shape (num_leads, signal_length)
                - label: Tensor of shape (num_classes,) for classification or (num_beats,) for segmentation
        """
        pass

    @abstractmethod
    def __len__(self) -> int:
        """Get the number of samples in the dataset."""
        pass

    @property
    @abstractmethod
    def num_classes(self) -> int:
        """Get the number of classes in the dataset."""
        pass

    @property
    @abstractmethod
    def class_names(self) -> List[str]:
        """Get the names of the classes in the dataset."""
        pass

    def get_class_distribution(self) -> Dict[str, int]:
        """Get the distribution of classes in the dataset."""
        return {}

    def _print_class_distribution(self):
        """Print class distribution statistics."""
        distribution = self.get_class_distribution()
        if not distribution:
            return
        print("\nClass distribution:")
        for class_name, count in distribution.items():
            print(f"  {class_name}: {count}")

    def get_metadata(self) -> Dict:
        """Get metadata about the dataset.

        Returns:
            Dictionary containing metadata such as:
            - sampling_rate: Sampling rate of the signals
            - num_leads: Number of leads
            - lead_names: Names of the leads
            - num_classes: Number of classes
            - class_names: Names of the classes
            - dataset_size: Number of samples
            - signal_length: Length of each signal
        """
        return {
            "sampling_rate": self.sampling_rate,
            "num_leads": len(self.leads) if self.leads else None,
            "lead_names": self.leads,
            "num_classes": self.num_classes,
            "class_names": self.class_names,
            "dataset_size": len(self),
        }

num_classes abstractmethod property

num_classes: int

Get the number of classes in the dataset.

class_names abstractmethod property

class_names: List[str]

Get the names of the classes in the dataset.

get_default_data_dir classmethod

get_default_data_dir() -> Path

Get the default data directory for this dataset.

Returns:

Type Description
Path

Path to the default data directory

Source code in deepecgkit/datasets/base.py
@classmethod
def get_default_data_dir(cls) -> Path:
    """Get the default data directory for this dataset.

    Returns:
        Path to the default data directory
    """
    return Path.home() / ".deepecgkit" / "datasets" / cls.__name__.lower()

download abstractmethod

download()

Download the dataset if it doesn't exist.

Source code in deepecgkit/datasets/base.py
@abstractmethod
def download(self):
    """Download the dataset if it doesn't exist."""
    pass

get_class_distribution

get_class_distribution() -> Dict[str, int]

Get the distribution of classes in the dataset.

Source code in deepecgkit/datasets/base.py
def get_class_distribution(self) -> Dict[str, int]:
    """Get the distribution of classes in the dataset."""
    return {}

get_metadata

get_metadata() -> Dict

Get metadata about the dataset.

Returns:

Type Description
Dict

Dictionary containing metadata such as:

Dict
  • sampling_rate: Sampling rate of the signals
Dict
  • num_leads: Number of leads
Dict
  • lead_names: Names of the leads
Dict
  • num_classes: Number of classes
Dict
  • class_names: Names of the classes
Dict
  • dataset_size: Number of samples
Dict
  • signal_length: Length of each signal
Source code in deepecgkit/datasets/base.py
def get_metadata(self) -> Dict:
    """Get metadata about the dataset.

    Returns:
        Dictionary containing metadata such as:
        - sampling_rate: Sampling rate of the signals
        - num_leads: Number of leads
        - lead_names: Names of the leads
        - num_classes: Number of classes
        - class_names: Names of the classes
        - dataset_size: Number of samples
        - signal_length: Length of each signal
    """
    return {
        "sampling_rate": self.sampling_rate,
        "num_leads": len(self.leads) if self.leads else None,
        "lead_names": self.leads,
        "num_classes": self.num_classes,
        "class_names": self.class_names,
        "dataset_size": len(self),
    }

Data Module

ECGDataModule

Data module for ECG datasets.

Handles dataset creation, train/val/test splitting, and DataLoader construction.

Source code in deepecgkit/datasets/modules.py
class ECGDataModule:
    """Data module for ECG datasets.

    Handles dataset creation, train/val/test splitting, and DataLoader construction.
    """

    def __init__(
        self,
        dataset: Optional[Union[BaseECGDataset, Dataset]] = None,
        dataset_class: Optional[Type[BaseECGDataset]] = None,
        data_dir: Optional[str] = None,
        sampling_rate: int = 500,
        leads: Optional[list] = None,
        transform: Optional[callable] = None,
        target_transform: Optional[callable] = None,
        download: bool = False,
        batch_size: int = 32,
        val_split: float = 0.2,
        test_split: float = 0.1,
        num_workers: int = 4,
        seed: int = 42,
        stratify: bool = True,
        pin_memory: bool = False,
        persistent_workers: bool = True,
        prefetch_factor: int = 2,
        verbose: bool = True,
        dataset_kwargs: Optional[Dict] = None,
    ):
        """Initialize the ECG data module.

        Args:
            dataset: PyTorch dataset instance (optional)
            dataset_class: Class of dataset to create (optional)
            data_dir: Directory containing ECG data files (optional, uses dataset's default if None)
            sampling_rate: Sampling rate of the ECG signals (Hz)
            leads: List of leads to use (e.g., ['I', 'II', 'III'] for standard leads)
            transform: Optional transform to be applied to the ECG signals
            target_transform: Optional transform to be applied to the labels
            download: Whether to download the dataset if it doesn't exist
            batch_size: Batch size for data loaders
            val_split: Fraction of data to use for validation
            test_split: Fraction of data to use for testing
            num_workers: Number of worker processes for data loading
            seed: Random seed for reproducibility
            stratify: Whether to use stratified splitting based on labels
            dataset_kwargs: Additional keyword arguments to pass to the dataset class
        """
        self.dataset = dataset
        self.dataset_class = dataset_class
        self.data_dir = data_dir
        self.sampling_rate = sampling_rate
        self.leads = leads
        self.transform = transform
        self.target_transform = target_transform
        self.download = download
        self.batch_size = batch_size
        self.val_split = val_split
        self.test_split = test_split
        self.num_workers = num_workers
        self.seed = seed
        self.stratify = stratify
        self.pin_memory = pin_memory
        self.persistent_workers = persistent_workers
        self.prefetch_factor = prefetch_factor
        self.verbose = verbose
        self.dataset_kwargs = dataset_kwargs or {}

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def setup(self, stage: Optional[str] = None):
        """Set up the datasets.

        Args:
            stage: Current stage ('fit', 'validate', 'test', or 'predict')
        """
        if self.dataset is None:
            if self.dataset_class is None:
                raise ValueError("dataset_class must be provided if dataset is None")

            if self.data_dir is None:
                self.data_dir = self.dataset_class.get_default_data_dir()

            init_kwargs = {
                "data_dir": self.data_dir,
                "sampling_rate": self.sampling_rate,
                "transform": self.transform,
                "target_transform": self.target_transform,
                "download": self.download,
            }

            if self.leads is not None:
                init_kwargs["leads"] = self.leads

            init_kwargs.update(self.dataset_kwargs)

            self.dataset = self.dataset_class(**init_kwargs)

        stratify_labels = None
        if (
            self.stratify and len(self.dataset) >= 8
        ):  # Need at least 2 samples per class for 4 classes
            try:
                all_labels = []
                for i in range(len(self.dataset)):
                    _, label = self.dataset[i]
                    all_labels.append(label)
                stacked = torch.stack(all_labels).numpy()

                # For multi-label (2D), convert rows to string keys for stratification
                if stacked.ndim > 1:
                    stratify_labels = np.array(
                        ["_".join(str(int(v)) for v in row) for row in stacked]
                    )
                else:
                    stratify_labels = stacked

                # Check if we have enough samples per class for stratification
                _, counts = np.unique(stratify_labels, return_counts=True)
                if np.min(counts) < 2:
                    stratify_labels = None  # Disable stratification
            except Exception:
                stratify_labels = None  # Fallback to no stratification

        # Extract groups for patient-level splitting (prevents data leakage)
        groups = None
        if hasattr(self.dataset, "record_names") and self.dataset.record_names:
            groups = np.array(self.dataset.record_names)
            if self.verbose:
                n_groups = len(np.unique(groups))
                print(f"Using patient-level splitting ({n_groups} groups)")

        splitter = DataSplitter(
            dataset=self.dataset,
            val_split=self.val_split,
            test_split=self.test_split,
            seed=self.seed,
            stratify=stratify_labels,
            groups=groups,
        )
        self.train_dataset, self.val_dataset, self.test_dataset = splitter.split()

        if self.verbose:
            print(f"Dataset size: {len(self.train_dataset)}")
            print(f"Validation set size: {len(self.val_dataset)}")
            print(f"Test set size: {len(self.test_dataset)}")

    def train_dataloader(self) -> DataLoader:
        """Get the training data loader."""
        if self.train_dataset is None:
            raise RuntimeError("Call setup() before using the data module")
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            persistent_workers=self.persistent_workers and self.num_workers > 0,
            prefetch_factor=self.prefetch_factor if self.num_workers > 0 else None,
        )

    def val_dataloader(self) -> DataLoader:
        """Get the validation data loader."""
        if self.val_dataset is None:
            raise RuntimeError("Call setup() before using the data module")
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            persistent_workers=self.persistent_workers and self.num_workers > 0,
            prefetch_factor=self.prefetch_factor if self.num_workers > 0 else None,
        )

    def test_dataloader(self) -> DataLoader:
        """Get the test data loader."""
        if self.test_dataset is None:
            raise RuntimeError("Call setup() before using the data module")
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            persistent_workers=self.persistent_workers and self.num_workers > 0,
            prefetch_factor=self.prefetch_factor if self.num_workers > 0 else None,
        )

    def get_metadata(self) -> Dict:
        """Get metadata about the dataset.

        Returns:
            Dictionary containing dataset metadata
        """
        if self.dataset is None:
            raise RuntimeError("Call setup() before getting metadata")
        if isinstance(self.dataset, BaseECGDataset):
            return self.dataset.get_metadata()
        return {
            "dataset_size": len(self.dataset),
            "batch_size": self.batch_size,
            "num_workers": self.num_workers,
        }

    def print_metadata(self):
        """Print metadata about the dataset."""
        print(self.get_metadata())

setup

setup(stage: Optional[str] = None)

Set up the datasets.

Parameters:

Name Type Description Default
stage Optional[str]

Current stage ('fit', 'validate', 'test', or 'predict')

None
Source code in deepecgkit/datasets/modules.py
def setup(self, stage: Optional[str] = None):
    """Set up the datasets.

    Args:
        stage: Current stage ('fit', 'validate', 'test', or 'predict')
    """
    if self.dataset is None:
        if self.dataset_class is None:
            raise ValueError("dataset_class must be provided if dataset is None")

        if self.data_dir is None:
            self.data_dir = self.dataset_class.get_default_data_dir()

        init_kwargs = {
            "data_dir": self.data_dir,
            "sampling_rate": self.sampling_rate,
            "transform": self.transform,
            "target_transform": self.target_transform,
            "download": self.download,
        }

        if self.leads is not None:
            init_kwargs["leads"] = self.leads

        init_kwargs.update(self.dataset_kwargs)

        self.dataset = self.dataset_class(**init_kwargs)

    stratify_labels = None
    if (
        self.stratify and len(self.dataset) >= 8
    ):  # Need at least 2 samples per class for 4 classes
        try:
            all_labels = []
            for i in range(len(self.dataset)):
                _, label = self.dataset[i]
                all_labels.append(label)
            stacked = torch.stack(all_labels).numpy()

            # For multi-label (2D), convert rows to string keys for stratification
            if stacked.ndim > 1:
                stratify_labels = np.array(
                    ["_".join(str(int(v)) for v in row) for row in stacked]
                )
            else:
                stratify_labels = stacked

            # Check if we have enough samples per class for stratification
            _, counts = np.unique(stratify_labels, return_counts=True)
            if np.min(counts) < 2:
                stratify_labels = None  # Disable stratification
        except Exception:
            stratify_labels = None  # Fallback to no stratification

    # Extract groups for patient-level splitting (prevents data leakage)
    groups = None
    if hasattr(self.dataset, "record_names") and self.dataset.record_names:
        groups = np.array(self.dataset.record_names)
        if self.verbose:
            n_groups = len(np.unique(groups))
            print(f"Using patient-level splitting ({n_groups} groups)")

    splitter = DataSplitter(
        dataset=self.dataset,
        val_split=self.val_split,
        test_split=self.test_split,
        seed=self.seed,
        stratify=stratify_labels,
        groups=groups,
    )
    self.train_dataset, self.val_dataset, self.test_dataset = splitter.split()

    if self.verbose:
        print(f"Dataset size: {len(self.train_dataset)}")
        print(f"Validation set size: {len(self.val_dataset)}")
        print(f"Test set size: {len(self.test_dataset)}")

train_dataloader

train_dataloader() -> DataLoader

Get the training data loader.

Source code in deepecgkit/datasets/modules.py
def train_dataloader(self) -> DataLoader:
    """Get the training data loader."""
    if self.train_dataset is None:
        raise RuntimeError("Call setup() before using the data module")
    return DataLoader(
        self.train_dataset,
        batch_size=self.batch_size,
        shuffle=True,
        num_workers=self.num_workers,
        pin_memory=self.pin_memory,
        persistent_workers=self.persistent_workers and self.num_workers > 0,
        prefetch_factor=self.prefetch_factor if self.num_workers > 0 else None,
    )

val_dataloader

val_dataloader() -> DataLoader

Get the validation data loader.

Source code in deepecgkit/datasets/modules.py
def val_dataloader(self) -> DataLoader:
    """Get the validation data loader."""
    if self.val_dataset is None:
        raise RuntimeError("Call setup() before using the data module")
    return DataLoader(
        self.val_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        pin_memory=self.pin_memory,
        persistent_workers=self.persistent_workers and self.num_workers > 0,
        prefetch_factor=self.prefetch_factor if self.num_workers > 0 else None,
    )

test_dataloader

test_dataloader() -> DataLoader

Get the test data loader.

Source code in deepecgkit/datasets/modules.py
def test_dataloader(self) -> DataLoader:
    """Get the test data loader."""
    if self.test_dataset is None:
        raise RuntimeError("Call setup() before using the data module")
    return DataLoader(
        self.test_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        pin_memory=self.pin_memory,
        persistent_workers=self.persistent_workers and self.num_workers > 0,
        prefetch_factor=self.prefetch_factor if self.num_workers > 0 else None,
    )

get_metadata

get_metadata() -> Dict

Get metadata about the dataset.

Returns:

Type Description
Dict

Dictionary containing dataset metadata

Source code in deepecgkit/datasets/modules.py
def get_metadata(self) -> Dict:
    """Get metadata about the dataset.

    Returns:
        Dictionary containing dataset metadata
    """
    if self.dataset is None:
        raise RuntimeError("Call setup() before getting metadata")
    if isinstance(self.dataset, BaseECGDataset):
        return self.dataset.get_metadata()
    return {
        "dataset_size": len(self.dataset),
        "batch_size": self.batch_size,
        "num_workers": self.num_workers,
    }

print_metadata

print_metadata()

Print metadata about the dataset.

Source code in deepecgkit/datasets/modules.py
def print_metadata(self):
    """Print metadata about the dataset."""
    print(self.get_metadata())

Data Splitting

DataSplitter

Handles dataset splitting into train, validation, and test sets.

Source code in deepecgkit/datasets/splitting.py
class DataSplitter:
    """Handles dataset splitting into train, validation, and test sets."""

    def __init__(
        self,
        dataset: Dataset,
        val_split: float = 0.2,
        test_split: float = 0.1,
        seed: Optional[int] = 42,
        stratify: Optional[np.ndarray] = None,
        groups: Optional[np.ndarray] = None,
    ):
        """Initialize the data splitter.

        Args:
            dataset: PyTorch dataset to split
            val_split: Fraction of data to use for validation
            test_split: Fraction of data to use for testing
            seed: Random seed for reproducibility
            stratify: Array of labels for stratified splitting (optional)
            groups: Array of group identifiers (e.g. patient/record IDs) to keep
                all samples from the same group in the same split (optional).
                Prevents data leakage from correlated samples.
        """
        self.dataset = dataset
        self.val_split = val_split
        self.test_split = test_split
        self.seed = seed
        self.stratify = stratify
        self.groups = groups

        if val_split + test_split >= 1.0:
            raise ValueError("val_split + test_split must be less than 1.0")

    def split(self) -> Tuple[Dataset, Dataset, Dataset]:
        """Split the dataset into train, validation, and test sets.

        Returns:
            Tuple of (train_dataset, val_dataset, test_dataset)
        """
        if self.groups is not None and len(self.groups) > 0:
            return self._split_by_groups()

        total_size = len(self.dataset)

        # Handle small datasets by ensuring at least 1 sample per split if possible
        val_size = max(1, int(total_size * self.val_split)) if self.val_split > 0 else 0
        test_size = max(1, int(total_size * self.test_split)) if self.test_split > 0 else 0

        # Adjust for very small datasets
        if total_size <= 2:
            val_size = 0
            test_size = 0
        elif total_size == 3:
            val_size = 1
            test_size = 0
        elif val_size + test_size >= total_size:
            # Reduce splits for small datasets
            val_size = 1 if self.val_split > 0 else 0
            test_size = 1 if self.test_split > 0 and total_size > val_size + 1 else 0

        train_size = total_size - val_size - test_size

        if self.stratify is not None and total_size > 2:
            try:
                indices = np.arange(total_size)
                if val_size + test_size > 0:
                    train_idx, temp_idx = train_test_split(
                        indices,
                        test_size=(val_size + test_size) / total_size,
                        stratify=self.stratify,
                        random_state=self.seed,
                    )
                    if test_size > 0 and val_size > 0:
                        val_idx, test_idx = train_test_split(
                            temp_idx,
                            test_size=test_size / (val_size + test_size),
                            stratify=self.stratify[temp_idx],
                            random_state=self.seed,
                        )
                    elif test_size > 0:
                        test_idx = temp_idx
                        val_idx = []
                    else:
                        val_idx = temp_idx
                        test_idx = []
                else:
                    train_idx = indices
                    val_idx = []
                    test_idx = []

                train_dataset = Subset(self.dataset, train_idx)
                val_dataset = (
                    Subset(self.dataset, val_idx) if len(val_idx) > 0 else Subset(self.dataset, [])
                )
                test_dataset = (
                    Subset(self.dataset, test_idx)
                    if len(test_idx) > 0
                    else Subset(self.dataset, [])
                )
                return train_dataset, val_dataset, test_dataset
            except ValueError:
                pass  # Fall through to random splitting

        if val_size + test_size == 0:
            train_dataset = self.dataset
            val_dataset = Subset(self.dataset, [])
            test_dataset = Subset(self.dataset, [])
        else:
            train_dataset, val_dataset, test_dataset = random_split(
                self.dataset,
                [train_size, val_size, test_size],
                generator=torch.Generator().manual_seed(self.seed),
            )

        return train_dataset, val_dataset, test_dataset

    def _split_by_groups(self) -> Tuple[Dataset, Dataset, Dataset]:
        """Split by groups so all samples from one group stay in the same split.

        This implements inter-patient splitting for medical datasets where
        segments from the same recording/patient are highly correlated.
        """
        unique_groups = np.unique(self.groups)
        n_groups = len(unique_groups)

        if n_groups < 3:
            raise ValueError(f"Need at least 3 groups for train/val/test split, got {n_groups}")

        val_n = max(1, int(n_groups * self.val_split)) if self.val_split > 0 else 0
        test_n = max(1, int(n_groups * self.test_split)) if self.test_split > 0 else 0

        if val_n + test_n >= n_groups:
            val_n = 1 if self.val_split > 0 else 0
            test_n = 1 if self.test_split > 0 and n_groups > val_n + 1 else 0

        # Compute majority label per group for stratified group splitting
        group_stratify = None
        if self.stratify is not None:
            group_to_labels: dict = {}
            for g, label in zip(self.groups, self.stratify):
                group_to_labels.setdefault(g, []).append(label)
            group_stratify = np.array(
                [Counter(group_to_labels[g]).most_common(1)[0][0] for g in unique_groups]
            )
            _, counts = np.unique(group_stratify, return_counts=True)
            if np.min(counts) < 2:
                group_stratify = None

        group_indices = np.arange(n_groups)

        try:
            train_gi, temp_gi = train_test_split(
                group_indices,
                test_size=(val_n + test_n) / n_groups,
                stratify=group_stratify,
                random_state=self.seed,
            )
            if val_n > 0 and test_n > 0:
                temp_stratify = group_stratify[temp_gi] if group_stratify is not None else None
                if temp_stratify is not None:
                    _, tc = np.unique(temp_stratify, return_counts=True)
                    if np.min(tc) < 2:
                        temp_stratify = None
                val_gi, test_gi = train_test_split(
                    temp_gi,
                    test_size=test_n / (val_n + test_n),
                    stratify=temp_stratify,
                    random_state=self.seed,
                )
            elif test_n > 0:
                test_gi = temp_gi
                val_gi = np.array([], dtype=int)
            else:
                val_gi = temp_gi
                test_gi = np.array([], dtype=int)
        except ValueError:
            # Stratification failed, split without it
            train_gi, temp_gi = train_test_split(
                group_indices,
                test_size=(val_n + test_n) / n_groups,
                random_state=self.seed,
            )
            if val_n > 0 and test_n > 0:
                val_gi, test_gi = train_test_split(
                    temp_gi,
                    test_size=test_n / (val_n + test_n),
                    random_state=self.seed,
                )
            elif test_n > 0:
                test_gi = temp_gi
                val_gi = np.array([], dtype=int)
            else:
                val_gi = temp_gi
                test_gi = np.array([], dtype=int)

        # Map group indices back to sample indices
        train_groups = set(unique_groups[train_gi])
        val_groups = set(unique_groups[val_gi]) if len(val_gi) > 0 else set()
        test_groups = set(unique_groups[test_gi]) if len(test_gi) > 0 else set()

        train_idx = [i for i, g in enumerate(self.groups) if g in train_groups]
        val_idx = [i for i, g in enumerate(self.groups) if g in val_groups]
        test_idx = [i for i, g in enumerate(self.groups) if g in test_groups]

        return (
            Subset(self.dataset, train_idx),
            Subset(self.dataset, val_idx) if val_idx else Subset(self.dataset, []),
            Subset(self.dataset, test_idx) if test_idx else Subset(self.dataset, []),
        )

split

split() -> Tuple[Dataset, Dataset, Dataset]

Split the dataset into train, validation, and test sets.

Returns:

Type Description
Tuple[Dataset, Dataset, Dataset]

Tuple of (train_dataset, val_dataset, test_dataset)

Source code in deepecgkit/datasets/splitting.py
def split(self) -> Tuple[Dataset, Dataset, Dataset]:
    """Split the dataset into train, validation, and test sets.

    Returns:
        Tuple of (train_dataset, val_dataset, test_dataset)
    """
    if self.groups is not None and len(self.groups) > 0:
        return self._split_by_groups()

    total_size = len(self.dataset)

    # Handle small datasets by ensuring at least 1 sample per split if possible
    val_size = max(1, int(total_size * self.val_split)) if self.val_split > 0 else 0
    test_size = max(1, int(total_size * self.test_split)) if self.test_split > 0 else 0

    # Adjust for very small datasets
    if total_size <= 2:
        val_size = 0
        test_size = 0
    elif total_size == 3:
        val_size = 1
        test_size = 0
    elif val_size + test_size >= total_size:
        # Reduce splits for small datasets
        val_size = 1 if self.val_split > 0 else 0
        test_size = 1 if self.test_split > 0 and total_size > val_size + 1 else 0

    train_size = total_size - val_size - test_size

    if self.stratify is not None and total_size > 2:
        try:
            indices = np.arange(total_size)
            if val_size + test_size > 0:
                train_idx, temp_idx = train_test_split(
                    indices,
                    test_size=(val_size + test_size) / total_size,
                    stratify=self.stratify,
                    random_state=self.seed,
                )
                if test_size > 0 and val_size > 0:
                    val_idx, test_idx = train_test_split(
                        temp_idx,
                        test_size=test_size / (val_size + test_size),
                        stratify=self.stratify[temp_idx],
                        random_state=self.seed,
                    )
                elif test_size > 0:
                    test_idx = temp_idx
                    val_idx = []
                else:
                    val_idx = temp_idx
                    test_idx = []
            else:
                train_idx = indices
                val_idx = []
                test_idx = []

            train_dataset = Subset(self.dataset, train_idx)
            val_dataset = (
                Subset(self.dataset, val_idx) if len(val_idx) > 0 else Subset(self.dataset, [])
            )
            test_dataset = (
                Subset(self.dataset, test_idx)
                if len(test_idx) > 0
                else Subset(self.dataset, [])
            )
            return train_dataset, val_dataset, test_dataset
        except ValueError:
            pass  # Fall through to random splitting

    if val_size + test_size == 0:
        train_dataset = self.dataset
        val_dataset = Subset(self.dataset, [])
        test_dataset = Subset(self.dataset, [])
    else:
        train_dataset, val_dataset, test_dataset = random_split(
            self.dataset,
            [train_size, val_size, test_size],
            generator=torch.Generator().manual_seed(self.seed),
        )

    return train_dataset, val_dataset, test_dataset

Preprocessing

ECGStandardizer

Source code in deepecgkit/datasets/preprocessing.py
class ECGStandardizer:
    def __init__(
        self,
        target_sampling_rate: int = 300,
        target_length: Optional[int] = None,
        target_duration_seconds: Optional[float] = None,
        normalization: str = "zscore",
        clip_method: str = "center",
    ):
        self.target_sampling_rate = target_sampling_rate
        self.normalization = normalization
        self.clip_method = clip_method

        if target_length is not None and target_duration_seconds is not None:
            raise ValueError("Cannot specify both target_length and target_duration_seconds")

        if target_duration_seconds is not None:
            self.target_length = int(target_duration_seconds * target_sampling_rate)
        else:
            self.target_length = target_length

    def resample(self, ecg_signal: np.ndarray, original_sampling_rate: int) -> np.ndarray:
        if original_sampling_rate == self.target_sampling_rate:
            return ecg_signal

        num_samples = ecg_signal.shape[-1]
        target_samples = int(num_samples * self.target_sampling_rate / original_sampling_rate)

        if ecg_signal.ndim == 1:
            return signal.resample(ecg_signal, target_samples)
        else:
            resampled = np.zeros((ecg_signal.shape[0], target_samples))
            for i in range(ecg_signal.shape[0]):
                resampled[i] = signal.resample(ecg_signal[i], target_samples)
            return resampled

    def normalize(self, ecg_signal: np.ndarray) -> np.ndarray:
        if self.normalization == "zscore":
            mean = np.mean(ecg_signal, axis=-1, keepdims=True)
            std = np.std(ecg_signal, axis=-1, keepdims=True)
            return (ecg_signal - mean) / (std + 1e-8)

        elif self.normalization == "minmax":
            min_val = np.min(ecg_signal, axis=-1, keepdims=True)
            max_val = np.max(ecg_signal, axis=-1, keepdims=True)
            range_val = max_val - min_val
            return np.where(
                range_val > 1e-8, (ecg_signal - min_val) / range_val, ecg_signal - min_val
            )

        elif self.normalization == "none":
            return ecg_signal

        else:
            raise ValueError(f"Unknown normalization method: {self.normalization}")

    def clip_or_pad(self, ecg_signal: np.ndarray) -> np.ndarray:
        if self.target_length is None or ecg_signal.shape[-1] == self.target_length:
            return ecg_signal

        current_length = ecg_signal.shape[-1]

        if ecg_signal.ndim == 1:
            ecg_signal = ecg_signal[np.newaxis, :]

        if self.clip_method not in ("center", "start", "end"):
            raise ValueError(f"Unknown clip_method: {self.clip_method}")

        if current_length < self.target_length:
            return self._pad_signal(ecg_signal, current_length)
        return self._clip_signal(ecg_signal, current_length)

    def _pad_signal(self, ecg_signal: np.ndarray, current_length: int) -> np.ndarray:
        pad_width = self.target_length - current_length
        if self.clip_method == "center":
            pad_left = pad_width // 2
            pad_right = pad_width - pad_left
            return np.pad(ecg_signal, ((0, 0), (pad_left, pad_right)), mode="constant")
        if self.clip_method == "start":
            return np.pad(ecg_signal, ((0, 0), (0, pad_width)), mode="constant")
        return np.pad(ecg_signal, ((0, 0), (pad_width, 0)), mode="constant")

    def _clip_signal(self, ecg_signal: np.ndarray, current_length: int) -> np.ndarray:
        excess = current_length - self.target_length
        if self.clip_method == "center":
            start = excess // 2
            return ecg_signal[:, start : start + self.target_length]
        if self.clip_method == "start":
            return ecg_signal[:, : self.target_length]
        return ecg_signal[:, -self.target_length :]

    def __call__(self, ecg_signal: np.ndarray, original_sampling_rate: int) -> np.ndarray:
        ecg_signal = self.resample(ecg_signal, original_sampling_rate)
        ecg_signal = self.clip_or_pad(ecg_signal)
        ecg_signal = self.normalize(ecg_signal)
        return ecg_signal

ECGSegmenter

Source code in deepecgkit/datasets/preprocessing.py
class ECGSegmenter:
    def __init__(
        self,
        segment_duration_seconds: float,
        sampling_rate: int,
        overlap: float = 0.0,
    ):
        self.segment_duration_seconds = segment_duration_seconds
        self.sampling_rate = sampling_rate
        self.overlap = overlap
        self.segment_length = int(segment_duration_seconds * sampling_rate)
        self.stride = int(self.segment_length * (1 - overlap))

    def segment(self, ecg_signal: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        if ecg_signal.ndim == 1:
            ecg_signal = ecg_signal[np.newaxis, :]

        _, signal_length = ecg_signal.shape

        if signal_length < self.segment_length:
            return np.array([]), np.array([])

        segments = []
        start_indices = []

        for start in range(0, signal_length - self.segment_length + 1, self.stride):
            segment = ecg_signal[:, start : start + self.segment_length]
            segments.append(segment)
            start_indices.append(start)

        return np.array(segments), np.array(start_indices)

RhythmAnnotationExtractor

Source code in deepecgkit/datasets/preprocessing.py
class RhythmAnnotationExtractor:
    RHYTHM_MAP: ClassVar[Dict[str, int]] = {
        "(AFIB": 1,
        "(AFL": 2,
        "(J": 3,
        "(N": 0,
    }

    def __init__(
        self,
        sampling_rate: int,
        binary_classification: bool = False,
    ):
        self.sampling_rate = sampling_rate
        self.binary_classification = binary_classification

    def extract_labels(
        self,
        annotation,
        signal_length: int,
        original_sampling_rate: Optional[int] = None,
    ) -> np.ndarray:
        labels = np.zeros(signal_length, dtype=np.int64)

        if not hasattr(annotation, "aux_note") or not hasattr(annotation, "sample"):
            return labels

        scale = 1.0
        if original_sampling_rate is not None and original_sampling_rate != self.sampling_rate:
            scale = self.sampling_rate / original_sampling_rate

        for i, (sample_idx, aux_note) in enumerate(zip(annotation.sample, annotation.aux_note)):
            scaled_idx = int(sample_idx * scale)
            if scaled_idx >= signal_length:
                break

            rhythm_code = self.RHYTHM_MAP.get(aux_note.strip(), 0)

            if self.binary_classification:
                rhythm_code = 1 if rhythm_code == 1 else 0

            raw_next = annotation.sample[i + 1] if i + 1 < len(annotation.sample) else signal_length
            next_sample = min(int(raw_next * scale), signal_length)

            labels[scaled_idx:next_sample] = rhythm_code

        return labels

    def segment_with_labels(
        self, labels: np.ndarray, segment_start_indices: np.ndarray, segment_length: int
    ) -> np.ndarray:
        segment_labels = []

        for start_idx in segment_start_indices:
            segment_label_region = labels[start_idx : start_idx + segment_length]

            unique, counts = np.unique(segment_label_region, return_counts=True)
            majority_label = unique[np.argmax(counts)]

            segment_labels.append(majority_label)

        return np.array(segment_labels)

convert_to_tensor

convert_to_tensor(
    data: Union[ndarray, Tensor],
    dtype: dtype = torch.float32,
) -> torch.Tensor
Source code in deepecgkit/datasets/preprocessing.py
def convert_to_tensor(
    data: Union[np.ndarray, torch.Tensor], dtype: torch.dtype = torch.float32
) -> torch.Tensor:
    if isinstance(data, torch.Tensor):
        return data.to(dtype)
    return torch.from_numpy(data).to(dtype)

Dataset Implementations

AFClassificationDataset

Bases: BaseECGDataset

PhysioNet/Computing in Cardiology Challenge 2017 AF Classification Dataset.

This dataset contains over 10,000 single-lead ECG recordings of 30-60 seconds duration for atrial fibrillation (AF) classification. Each recording is labeled as one of four categories: Normal (N), Atrial Fibrillation (A), Other rhythm (O), or Noisy (~).

The recordings are from AliveCor device and represent patient-initiated recordings.

Reference

Clifford GD, Liu C, Moody B, Li-wei HL, Silva I, Li Q, Johnson AE, Mark RG. AF classification from a short single lead ECG recording: The PhysioNet/computing in cardiology challenge 2017. In 2017 Computing in Cardiology (CinC) 2017 Sep 24 (pp. 1-4). IEEE.

URL

https://physionet.org/content/challenge-2017/1.0.0/

Source code in deepecgkit/datasets/af_classification.py
@register_dataset(
    name="af-classification",
    input_channels=1,
    description="PhysioNet 2017 AF Classification (4 classes, single-lead)",
)
class AFClassificationDataset(BaseECGDataset):
    """PhysioNet/Computing in Cardiology Challenge 2017 AF Classification Dataset.

    This dataset contains over 10,000 single-lead ECG recordings of 30-60 seconds duration
    for atrial fibrillation (AF) classification. Each recording is labeled as one of four
    categories: Normal (N), Atrial Fibrillation (A), Other rhythm (O), or Noisy (~).

    The recordings are from AliveCor device and represent patient-initiated recordings.

    Reference:
        Clifford GD, Liu C, Moody B, Li-wei HL, Silva I, Li Q, Johnson AE, Mark RG.
        AF classification from a short single lead ECG recording: The PhysioNet/computing
        in cardiology challenge 2017. In 2017 Computing in Cardiology (CinC) 2017 Sep 24 (pp. 1-4). IEEE.

    URL:
        https://physionet.org/content/challenge-2017/1.0.0/
    """

    CLASS_LABELS: ClassVar[List[str]] = [
        "Normal",
        "AF",
        "Other",
        "Noisy",
    ]

    REFERENCE_FILE: ClassVar[str] = "REFERENCE-v3.csv"

    LABEL_MAPPING: ClassVar[Dict[str, int]] = {
        "N": 0,
        "A": 1,
        "O": 2,
        "~": 3,
    }

    LEADS: ClassVar[List[str]] = ["ECG"]
    DOWNLOAD_URL: ClassVar[str] = "https://physionet.org/content/challenge-2017/get-zip/1.0.0/"
    NATIVE_SAMPLING_RATE: ClassVar[int] = 300

    def __init__(
        self,
        data_dir: Optional[Union[str, Path]] = None,
        sampling_rate: int = 300,
        segment_duration_seconds: float = 10.0,
        segment_overlap: float = 0.0,
        normalization: str = "zscore",
        subset: str = "training",
        transform: Optional[callable] = None,
        target_transform: Optional[callable] = None,
        download: bool = False,
        force_download: bool = False,
        version: str = "1.0.0",
        verbose: bool = True,
    ):
        if subset not in ["training", "validation"]:
            raise ValueError("subset must be either 'training' or 'validation'")

        self.version = version
        self.subset = subset
        self.verbose = verbose
        self.segment_duration_seconds = segment_duration_seconds
        self.segment_overlap = segment_overlap

        self.standardizer = ECGStandardizer(
            target_sampling_rate=sampling_rate,
            normalization=normalization,
        )

        self.segmenter = ECGSegmenter(
            segment_duration_seconds=segment_duration_seconds,
            sampling_rate=sampling_rate,
            overlap=segment_overlap,
        )

        self.signals: List[np.ndarray] = []
        self.labels: List[int] = []
        self.record_names: List[str] = []
        self.reference_data: Optional[pd.DataFrame] = None

        super().__init__(
            data_dir=data_dir,
            sampling_rate=sampling_rate,
            leads=self.LEADS,
            transform=transform,
            target_transform=target_transform,
            download=download,
            force_download=force_download,
        )

    def download(self):
        """Download the AF Classification dataset from PhysioNet as a single ZIP."""
        if self.verbose:
            print(f"Downloading PhysioNet Challenge 2017 dataset to {self.data_dir}")
            print("Note: This is a large dataset (~1.4GB). Download may take a while.")

        os.makedirs(self.data_dir, exist_ok=True)

        zip_path = self.data_dir / "challenge-2017-1.0.0.zip"
        download_file(
            self.DOWNLOAD_URL,
            zip_path,
            desc="Downloading PhysioNet Challenge 2017",
            max_retries=5,
        )

        if self.verbose:
            print("Extracting dataset...")

        with zipfile.ZipFile(zip_path, "r") as zf:
            zf.extractall(self.data_dir)

        self._flatten_nested_dir()
        self._extract_inner_zips()

        zip_path.unlink(missing_ok=True)

        if self.verbose:
            print("Download complete!")

    def _flatten_nested_dir(self):
        """Flatten a nested extraction directory if present from a previous download."""
        for child in self.data_dir.iterdir():
            if not child.is_dir() or child.name in ("__MACOSX", "training2017", "sample2017"):
                continue
            if self.verbose:
                print(f"Flattening nested directory {child.name}...")
            for item in child.iterdir():
                dest = self.data_dir / item.name
                if dest.exists():
                    if dest.is_dir():
                        shutil.rmtree(dest)
                    else:
                        dest.unlink()
                shutil.move(str(item), str(dest))
            child.rmdir()
            return

    def _extract_inner_zips(self):
        """Extract inner zip files like training2017.zip."""
        for zf_path in self.data_dir.glob("*.zip"):
            folder_name = zf_path.stem
            target_dir = self.data_dir / folder_name
            if target_dir.exists():
                continue
            if self.verbose:
                print(f"Extracting {zf_path.name}...")
            with zipfile.ZipFile(zf_path, "r") as zf:
                zf.extractall(self.data_dir)
            zf_path.unlink()

    def _load_data(self):
        """Load the AF Classification dataset into memory."""
        if self.verbose:
            print(f"Loading AF Classification dataset ({self.subset} subset)...")

        self._flatten_nested_dir()
        self._extract_inner_zips()

        reference_path = self.data_dir / self.REFERENCE_FILE
        if not reference_path.exists():
            for ref_file in ["REFERENCE.csv", "REFERENCE-v2.csv", "REFERENCE-v1.csv"]:
                alt_path = self.data_dir / ref_file
                if alt_path.exists():
                    reference_path = alt_path
                    break

        if not reference_path.exists():
            raise FileNotFoundError(f"Reference file not found in {self.data_dir}")

        self.reference_data = pd.read_csv(
            reference_path, header=None, names=["record_name", "label"]
        )

        if self.subset == "training":
            data_folder = self.data_dir / "training2017"
        else:
            data_folder = self.data_dir / "sample2017"

        if not data_folder.exists():
            raise FileNotFoundError(
                f"Data folder {data_folder} not found. Please download the dataset first."
            )

        iterator = (
            tqdm(self.reference_data.iterrows(), desc="Loading records")
            if self.verbose
            else self.reference_data.iterrows()
        )

        for _, row in iterator:
            record_name = row["record_name"]
            label = row["label"]

            mat_file = data_folder / f"{record_name}.mat"

            if mat_file.exists():
                try:
                    mat_data = scipy.io.loadmat(str(mat_file))

                    if "val" in mat_data:
                        signal = mat_data["val"].flatten()
                    elif "ecg" in mat_data:
                        signal = mat_data["ecg"].flatten()
                    elif "data" in mat_data:
                        signal = mat_data["data"].flatten()
                    else:
                        signal_key = [k for k in mat_data.keys() if not k.startswith("__")][0]
                        signal = mat_data[signal_key].flatten()

                    if label not in self.LABEL_MAPPING:
                        if self.verbose:
                            print(f"Warning: Unknown label '{label}' for record {record_name}")
                        continue

                    label_idx = self.LABEL_MAPPING[label]

                    signal = signal[np.newaxis, :]
                    resampled = self.standardizer.resample(signal, self.NATIVE_SAMPLING_RATE)

                    segments, _ = self.segmenter.segment(resampled)
                    if len(segments) == 0:
                        continue

                    for segment in segments:
                        normalized = self.standardizer.normalize(segment)
                        self.signals.append(normalized)
                        self.labels.append(label_idx)
                        self.record_names.append(record_name)

                except Exception as e:
                    if self.verbose:
                        print(f"Error loading record {record_name}: {e}")
                    continue

        if self.verbose:
            print(f"Successfully loaded {len(self.signals)} segments from {self.subset} subset")

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        signal = self.signals[idx]
        label = self.labels[idx]

        signal = convert_to_tensor(signal, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.long)

        if self.transform is not None:
            signal = self.transform(signal)
        if self.target_transform is not None:
            label = self.target_transform(label)

        return signal, label

    def __len__(self) -> int:
        return len(self.labels)

    @property
    def num_classes(self) -> int:
        """Get the number of classes in the dataset."""
        return len(self.CLASS_LABELS)

    @property
    def class_names(self) -> List[str]:
        """Get the names of the classes in the dataset."""
        return self.CLASS_LABELS

    def get_record_info(self, idx: int) -> Dict:
        record_name = self.record_names[idx]
        label_idx = self.labels[idx]

        return {
            "record_name": record_name,
            "label": label_idx,
            "class_name": self.CLASS_LABELS[label_idx],
            "signal_shape": self.signals[idx].shape,
        }

    def get_class_distribution(self) -> Dict[str, int]:
        unique, counts = np.unique(self.labels, return_counts=True)
        return {self.class_names[int(label)]: int(count) for label, count in zip(unique, counts)}

num_classes property

num_classes: int

Get the number of classes in the dataset.

class_names property

class_names: List[str]

Get the names of the classes in the dataset.

download

download()

Download the AF Classification dataset from PhysioNet as a single ZIP.

Source code in deepecgkit/datasets/af_classification.py
def download(self):
    """Download the AF Classification dataset from PhysioNet as a single ZIP."""
    if self.verbose:
        print(f"Downloading PhysioNet Challenge 2017 dataset to {self.data_dir}")
        print("Note: This is a large dataset (~1.4GB). Download may take a while.")

    os.makedirs(self.data_dir, exist_ok=True)

    zip_path = self.data_dir / "challenge-2017-1.0.0.zip"
    download_file(
        self.DOWNLOAD_URL,
        zip_path,
        desc="Downloading PhysioNet Challenge 2017",
        max_retries=5,
    )

    if self.verbose:
        print("Extracting dataset...")

    with zipfile.ZipFile(zip_path, "r") as zf:
        zf.extractall(self.data_dir)

    self._flatten_nested_dir()
    self._extract_inner_zips()

    zip_path.unlink(missing_ok=True)

    if self.verbose:
        print("Download complete!")

LTAFDBDataset

Bases: BaseECGDataset

Long-Term AF Database (LTAFDB) Dataset.

Contains 84 long-term (typically 24-hour) two-lead ECG recordings of subjects with paroxysmal or sustained atrial fibrillation. Rhythm annotations indicate Normal, AF, Atrial Flutter, and Junctional rhythm segments.

Reference

Petrutiu S, Sahakian AV, Swiryn S. Abrupt changes in fibrillatory wave characteristics at the termination of paroxysmal atrial fibrillation in humans. Europace. 2007;9(7):466-470.

URL

https://physionet.org/content/ltafdb/1.0.0/

Source code in deepecgkit/datasets/ltafdb.py
@register_dataset(
    name="ltafdb",
    input_channels=2,
    num_classes=4,
    description="Long-Term AF Database (2-lead, binary or 4-class)",
)
class LTAFDBDataset(BaseECGDataset):
    """Long-Term AF Database (LTAFDB) Dataset.

    Contains 84 long-term (typically 24-hour) two-lead ECG recordings of subjects
    with paroxysmal or sustained atrial fibrillation. Rhythm annotations indicate
    Normal, AF, Atrial Flutter, and Junctional rhythm segments.

    Reference:
        Petrutiu S, Sahakian AV, Swiryn S. Abrupt changes in fibrillatory wave
        characteristics at the termination of paroxysmal atrial fibrillation in humans.
        Europace. 2007;9(7):466-470.

    URL:
        https://physionet.org/content/ltafdb/1.0.0/
    """

    CLASS_LABELS: ClassVar[List[str]] = ["Normal", "AF", "AFL", "J"]
    LABEL_MAPPING: ClassVar[Dict[str, int]] = {
        "(N": 0,
        "(AFIB": 1,
        "(AFL": 2,
        "(J": 3,
    }
    LEADS: ClassVar[List[str]] = ["ECG1", "ECG2"]
    SAMPLING_RATE: ClassVar[int] = 128

    DOWNLOAD_URL: ClassVar[str] = "https://physionet.org/content/ltafdb/get-zip/1.0.0/"

    def __init__(
        self,
        data_dir: Optional[Union[str, Path]] = None,
        sampling_rate: int = 300,
        segment_duration_seconds: float = 10.0,
        segment_overlap: float = 0.0,
        binary_classification: bool = False,
        use_both_leads: bool = False,
        normalization: str = "zscore",
        max_segments_per_record: Optional[int] = None,
        transform: Optional[callable] = None,
        target_transform: Optional[callable] = None,
        download: bool = False,
        force_download: bool = False,
        verbose: bool = True,
    ):
        self.segment_duration_seconds = segment_duration_seconds
        self.segment_overlap = segment_overlap
        self.binary_classification = binary_classification
        self.use_both_leads = use_both_leads
        self.max_segments_per_record = max_segments_per_record
        self.verbose = verbose

        self.standardizer = ECGStandardizer(
            target_sampling_rate=sampling_rate,
            normalization=normalization,
        )

        self.segmenter = ECGSegmenter(
            segment_duration_seconds=segment_duration_seconds,
            sampling_rate=sampling_rate,
            overlap=segment_overlap,
        )

        self.rhythm_extractor = RhythmAnnotationExtractor(
            sampling_rate=sampling_rate, binary_classification=binary_classification
        )

        self.signals: np.ndarray = np.array([])
        self.labels: np.ndarray = np.array([], dtype=np.int64)
        self.record_names: List[str] = []

        super().__init__(
            data_dir=data_dir,
            sampling_rate=sampling_rate,
            leads=self.LEADS if use_both_leads else [self.LEADS[0]],
            transform=transform,
            target_transform=target_transform,
            download=download,
            force_download=force_download,
        )

    def _cache_key(self) -> str:
        params = (
            f"sr{self.sampling_rate}_dur{self.segment_duration_seconds}_"
            f"ovlp{self.segment_overlap}_bin{self.binary_classification}_"
            f"leads{self.use_both_leads}_norm{self.standardizer.normalization}_"
            f"max{self.max_segments_per_record}"
        )
        return hashlib.md5(params.encode()).hexdigest()[:12]

    def _cache_dir(self) -> Path:
        return self.data_dir / f"cache_{self._cache_key()}"

    def download(self):
        if self.verbose:
            print(f"Downloading Long-Term AF Database to {self.data_dir}")
            print("Note: This is a large dataset (~1.7GB). Download may take a while.")

        os.makedirs(self.data_dir, exist_ok=True)

        zip_path = self.data_dir / "ltafdb-1.0.0.zip"
        download_file(
            self.DOWNLOAD_URL,
            zip_path,
            desc="Downloading LTAFDB",
            max_retries=5,
        )

        if self.verbose:
            print("Extracting dataset...")

        with zipfile.ZipFile(zip_path, "r") as zf:
            zf.extractall(self.data_dir)

        nested_dir = self.data_dir / "ltafdb-1.0.0"
        if nested_dir.is_dir():
            for item in nested_dir.iterdir():
                dest = self.data_dir / item.name
                if dest.exists():
                    if dest.is_dir():
                        shutil.rmtree(dest)
                    else:
                        dest.unlink()
                shutil.move(str(item), str(dest))
            nested_dir.rmdir()

        zip_path.unlink(missing_ok=True)

        if self.verbose:
            print("Download complete!")

    def _resolve_record_dir(self) -> Path:
        """Find the directory containing .hea/.atr record files."""
        if any(self.data_dir.glob("*.hea")):
            return self.data_dir
        files_dir = self.data_dir / "files"
        if files_dir.is_dir() and any(files_dir.glob("*.hea")):
            return files_dir
        return self.data_dir

    def _discover_records(self, record_dir: Path) -> List[str]:
        """Discover available records by finding files with both .hea and .atr extensions."""
        hea_stems = {p.stem for p in record_dir.glob("*.hea")}
        atr_stems = {p.stem for p in record_dir.glob("*.atr")}
        return sorted(hea_stems & atr_stems, key=lambda x: (len(x), x))

    def _load_data(self):
        if self.verbose:
            print("Loading Long-Term AF Database...")

        cache_dir = self._cache_dir()
        signals_path = cache_dir / "signals.npy"
        labels_path = cache_dir / "labels.npy"
        record_names_path = cache_dir / "record_names.npy"

        if signals_path.exists() and labels_path.exists():
            if self.verbose:
                print(f"Loading from cache: {cache_dir.name}")
            self.signals = np.load(signals_path, mmap_mode="r")
            self.labels = np.load(labels_path, mmap_mode="r")
            if record_names_path.exists():
                self.record_names = np.load(record_names_path, allow_pickle=True).tolist()
            else:
                self.record_names = ["unknown"] * len(self.labels)
            if self.verbose:
                print(f"Loaded {len(self.labels)} segments from cache (memory-mapped)")
                self._print_class_distribution()
            return

        record_dir = self._resolve_record_dir()
        record_names = self._discover_records(record_dir)
        if not record_names:
            raise FileNotFoundError(
                f"No valid records found in {self.data_dir}. "
                "Expected .hea and .atr files from LTAFDB."
            )

        if self.verbose:
            print(f"Found {len(record_names)} records — processing (first run only)...")

        cache_dir.mkdir(parents=True, exist_ok=True)
        tmp_signals = cache_dir / "signals_tmp.npy"

        num_leads = 2 if self.use_both_leads else 1
        seg_len = self.segmenter.segment_length

        writer = None
        write_offset = 0
        label_chunks: List[np.ndarray] = []
        all_record_names: List[str] = []
        total_segments = 0

        for record_name in record_names:
            record_path = record_dir / record_name

            try:
                record = wfdb.rdrecord(str(record_path))
                annotation = wfdb.rdann(str(record_path), "atr")

                signals = record.p_signal.T

                if not self.use_both_leads:
                    signals = signals[0:1, :]

                standardized_signal = self.standardizer.resample(signals, self.SAMPLING_RATE)

                labels = self.rhythm_extractor.extract_labels(
                    annotation, standardized_signal.shape[-1], self.SAMPLING_RATE
                )

                segments, start_indices = self.segmenter.segment(standardized_signal)

                if len(segments) == 0:
                    continue

                if (
                    self.max_segments_per_record is not None
                    and len(segments) > self.max_segments_per_record
                ):
                    indices = np.random.choice(
                        len(segments),
                        self.max_segments_per_record,
                        replace=False,
                    )
                    segments = segments[indices]
                    start_indices = start_indices[indices]

                segment_labels = self.rhythm_extractor.segment_with_labels(
                    labels, start_indices, self.segmenter.segment_length
                )

                normalized = np.stack([self.standardizer.normalize(seg) for seg in segments])

                if writer is None:
                    estimated_total = len(record_names) * len(segments)
                    writer = np.lib.format.open_memmap(
                        str(tmp_signals),
                        mode="w+",
                        dtype=np.float32,
                        shape=(estimated_total, num_leads, seg_len),
                    )

                needed = write_offset + len(normalized)
                if needed > writer.shape[0]:
                    new_size = max(needed, writer.shape[0] * 2)
                    writer = np.lib.format.open_memmap(
                        str(tmp_signals),
                        mode="r+",
                        dtype=np.float32,
                        shape=(new_size, num_leads, seg_len),
                    )

                writer[write_offset : write_offset + len(normalized)] = normalized
                label_chunks.append(np.array(segment_labels, dtype=np.int64))
                all_record_names.extend([record_name] * len(segments))
                write_offset += len(normalized)
                total_segments += len(segments)

                if self.verbose:
                    print(
                        f"Loaded {record_name}: {len(segments)} segments (total: {total_segments})"
                    )

            except Exception as e:
                if self.verbose:
                    print(f"Error loading {record_name}: {e}")
                continue

        if write_offset == 0:
            shutil.rmtree(cache_dir)
            self.signals = np.array([])
            self.labels = np.array([], dtype=np.int64)
            self.record_names = []
            return

        del writer

        final_signals = np.lib.format.open_memmap(
            str(tmp_signals), mode="r+", shape=(write_offset, num_leads, seg_len)
        )
        np.save(signals_path, final_signals[:write_offset])
        del final_signals
        tmp_signals.unlink(missing_ok=True)

        all_labels = np.concatenate(label_chunks)
        np.save(labels_path, all_labels)
        np.save(record_names_path, np.array(all_record_names))

        self.signals = np.load(signals_path, mmap_mode="r")
        self.labels = np.load(labels_path, mmap_mode="r")
        self.record_names = all_record_names

        if self.verbose:
            print(f"\nCached {len(self.labels)} segments to {cache_dir.name} (memory-mapped)")
            self._print_class_distribution()

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        signal = self.signals[idx]
        label = self.labels[idx]

        signal = convert_to_tensor(signal, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.long)

        if self.transform is not None:
            signal = self.transform(signal)

        if self.target_transform is not None:
            label = self.target_transform(label)

        return signal, label

    def __len__(self) -> int:
        return len(self.labels)

    @property
    def num_classes(self) -> int:
        return 2 if self.binary_classification else len(self.CLASS_LABELS)

    @property
    def class_names(self) -> List[str]:
        if self.binary_classification:
            return ["Non-AF", "AF"]
        return self.CLASS_LABELS

    def get_record_info(self, idx: int) -> Dict:
        return {
            "record_name": self.record_names[idx],
            "label": int(self.labels[idx]),
            "class_name": self.class_names[int(self.labels[idx])],
            "signal_shape": self.signals[idx].shape,
        }

    def get_class_distribution(self) -> Dict[str, int]:
        unique, counts = np.unique(self.labels, return_counts=True)
        return {self.class_names[int(label)]: int(count) for label, count in zip(unique, counts)}

MITBIHAFDBDataset

Bases: BaseECGDataset

MIT-BIH Atrial Fibrillation Database (AFDB) Dataset.

Contains 25 long-term (10-hour) two-lead ECG recordings from subjects with atrial fibrillation (mostly paroxysmal). Rhythm annotations indicate Normal, AF, Atrial Flutter, and Junctional rhythm segments.

Reference

Moody GB, Mark RG. A new method for detecting atrial fibrillation using R-R intervals. Computers in Cardiology. 1983;10:227-230.

URL

https://physionet.org/content/afdb/1.0.0/

Source code in deepecgkit/datasets/mitbih_afdb.py
@register_dataset(
    name="mitbih-afdb",
    input_channels=2,
    num_classes=4,
    description="MIT-BIH AF Database (2-lead, binary or 4-class)",
)
class MITBIHAFDBDataset(BaseECGDataset):
    """MIT-BIH Atrial Fibrillation Database (AFDB) Dataset.

    Contains 25 long-term (10-hour) two-lead ECG recordings from subjects with
    atrial fibrillation (mostly paroxysmal). Rhythm annotations indicate Normal,
    AF, Atrial Flutter, and Junctional rhythm segments.

    Reference:
        Moody GB, Mark RG. A new method for detecting atrial fibrillation using R-R
        intervals. Computers in Cardiology. 1983;10:227-230.

    URL:
        https://physionet.org/content/afdb/1.0.0/
    """

    CLASS_LABELS: ClassVar[List[str]] = ["Normal", "AF", "AFL", "J"]
    LABEL_MAPPING: ClassVar[Dict[str, int]] = {
        "(N": 0,
        "(AFIB": 1,
        "(AFL": 2,
        "(J": 3,
    }
    LEADS: ClassVar[List[str]] = ["ECG1", "ECG2"]
    SAMPLING_RATE: ClassVar[int] = 250
    DOWNLOAD_URL: ClassVar[str] = "https://physionet.org/content/afdb/get-zip/1.0.0/"

    RECORD_NAMES: ClassVar[List[str]] = [
        "04015",
        "04043",
        "04048",
        "04126",
        "04746",
        "04908",
        "04936",
        "05091",
        "05121",
        "05261",
        "06426",
        "06453",
        "06995",
        "07162",
        "07859",
        "07879",
        "07910",
        "08215",
        "08219",
        "08378",
        "08405",
        "08434",
        "08455",
    ]

    def __init__(
        self,
        data_dir: Optional[Union[str, Path]] = None,
        sampling_rate: int = 300,
        segment_duration_seconds: float = 10.0,
        segment_overlap: float = 0.0,
        binary_classification: bool = False,
        use_both_leads: bool = False,
        normalization: str = "zscore",
        transform: Optional[callable] = None,
        target_transform: Optional[callable] = None,
        download: bool = False,
        force_download: bool = False,
        verbose: bool = True,
    ):
        self.segment_duration_seconds = segment_duration_seconds
        self.segment_overlap = segment_overlap
        self.binary_classification = binary_classification
        self.use_both_leads = use_both_leads
        self.verbose = verbose

        self.standardizer = ECGStandardizer(
            target_sampling_rate=sampling_rate,
            normalization=normalization,
        )

        self.segmenter = ECGSegmenter(
            segment_duration_seconds=segment_duration_seconds,
            sampling_rate=sampling_rate,
            overlap=segment_overlap,
        )

        self.rhythm_extractor = RhythmAnnotationExtractor(
            sampling_rate=sampling_rate, binary_classification=binary_classification
        )

        self.signals: List[np.ndarray] = []
        self.labels: List[int] = []
        self.record_names: List[str] = []

        super().__init__(
            data_dir=data_dir,
            sampling_rate=sampling_rate,
            leads=self.LEADS if use_both_leads else [self.LEADS[0]],
            transform=transform,
            target_transform=target_transform,
            download=download,
            force_download=force_download,
        )

    def download(self):
        if self.verbose:
            print(f"Downloading MIT-BIH AFDB to {self.data_dir}")
            print("Note: This is a ~440MB download.")

        os.makedirs(self.data_dir, exist_ok=True)

        zip_path = self.data_dir / "afdb-1.0.0.zip"
        download_file(
            self.DOWNLOAD_URL,
            zip_path,
            desc="Downloading MIT-BIH AFDB",
            max_retries=5,
        )

        if self.verbose:
            print("Extracting dataset...")

        with zipfile.ZipFile(zip_path, "r") as zf:
            zf.extractall(self.data_dir)

        nested_dir = self.data_dir / "afdb-1.0.0"
        if nested_dir.is_dir():
            for item in nested_dir.iterdir():
                dest = self.data_dir / item.name
                if dest.exists():
                    if dest.is_dir():
                        shutil.rmtree(dest)
                    else:
                        dest.unlink()
                shutil.move(str(item), str(dest))
            nested_dir.rmdir()

        # PhysioNet zips may nest records inside a files/ subdirectory
        files_dir = self.data_dir / "files"
        if files_dir.is_dir() and any(files_dir.glob("*.hea")):
            for item in files_dir.iterdir():
                dest = self.data_dir / item.name
                if dest.exists():
                    if dest.is_dir():
                        shutil.rmtree(dest)
                    else:
                        dest.unlink()
                shutil.move(str(item), str(dest))
            files_dir.rmdir()

        zip_path.unlink(missing_ok=True)

        if self.verbose:
            print("Download complete!")

    def _load_data(self):
        if self.verbose:
            print("Loading MIT-BIH AFDB data...")

        # PhysioNet zips may nest records inside a files/ subdirectory
        files_dir = self.data_dir / "files"
        if (
            not any(self.data_dir.glob("*.hea"))
            and files_dir.is_dir()
            and any(files_dir.glob("*.hea"))
        ):
            self.data_dir = files_dir

        if not any(self.data_dir.glob("*.hea")):
            raise FileNotFoundError(
                f"No record files found in {self.data_dir}. Expected .hea files from MIT-BIH AFDB."
            )

        for record_name in self.RECORD_NAMES:
            record_path = self.data_dir / record_name

            if not (self.data_dir / f"{record_name}.hea").exists():
                if self.verbose:
                    print(f"Skipping {record_name} (not found)")
                continue

            try:
                record = wfdb.rdrecord(str(record_path))
                annotation = wfdb.rdann(str(record_path), "atr")

                signals = record.p_signal.T

                if not self.use_both_leads:
                    signals = signals[0:1, :]

                standardized_signal = self.standardizer.resample(signals, self.SAMPLING_RATE)

                labels = self.rhythm_extractor.extract_labels(
                    annotation, standardized_signal.shape[-1], self.SAMPLING_RATE
                )

                segments, start_indices = self.segmenter.segment(standardized_signal)

                if len(segments) == 0:
                    continue

                segment_labels = self.rhythm_extractor.segment_with_labels(
                    labels, start_indices, self.segmenter.segment_length
                )

                for segment, label in zip(segments, segment_labels):
                    normalized_segment = self.standardizer.normalize(segment)
                    self.signals.append(normalized_segment)
                    self.labels.append(label)
                    self.record_names.append(record_name)

                if self.verbose:
                    print(f"Loaded {record_name}: {len(segments)} segments")

            except Exception as e:
                if self.verbose:
                    print(f"Error loading {record_name}: {e}")

        if self.verbose:
            print(f"Total segments loaded: {len(self.signals)}")
            self._print_class_distribution()

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        signal = self.signals[idx]
        label = self.labels[idx]

        signal = convert_to_tensor(signal, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.long)

        if self.transform is not None:
            signal = self.transform(signal)

        if self.target_transform is not None:
            label = self.target_transform(label)

        return signal, label

    def __len__(self) -> int:
        return len(self.labels)

    @property
    def num_classes(self) -> int:
        return 2 if self.binary_classification else len(self.CLASS_LABELS)

    @property
    def class_names(self) -> List[str]:
        if self.binary_classification:
            return ["Non-AF", "AF"]
        return self.CLASS_LABELS

    def get_record_info(self, idx: int) -> Dict:
        return {
            "record_name": self.record_names[idx],
            "label": self.labels[idx],
            "class_name": self.class_names[self.labels[idx]],
            "signal_shape": self.signals[idx].shape,
        }

    def get_class_distribution(self) -> Dict[str, int]:
        unique, counts = np.unique(self.labels, return_counts=True)
        return {self.class_names[int(label)]: int(count) for label, count in zip(unique, counts)}

PTBXLDataset

Bases: BaseECGDataset

PTB-XL ECG Dataset.

PTB-XL is a large publicly available electrocardiography dataset containing 21,837 clinical 12-lead ECGs from 18,885 patients of 10 second length. Each ECG is annotated with up to 71 different diagnostic statements conforming to the SCP-ECG standard.

The dataset supports multiple diagnostic classification tasks: - Superclass: 5 diagnostic superclasses (NORM, MI, STTC, CD, HYP) - Subclass: 23 diagnostic subclasses - Diagnostic: All individual diagnostic SCP codes (~44 statements) - Form: 19 form statements - Rhythm: 12 rhythm statements - All: All SCP statement codes (~71 statements)

Reference

Wagner P, Strodthoff N, Bousseljot RD, Kreiseler D, Lunze FI, Samek W, Schaeffter T. PTB-XL, a large publicly available electrocardiography dataset. Scientific Data. 2020 May 25;7(1):154.

URL

https://physionet.org/content/ptb-xl/1.0.3/

Source code in deepecgkit/datasets/ptbxl.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
@register_dataset(
    name="ptbxl",
    input_channels=12,
    num_classes=5,
    description="PTB-XL 12-lead ECG dataset (multi-task)",
)
class PTBXLDataset(BaseECGDataset):
    """PTB-XL ECG Dataset.

    PTB-XL is a large publicly available electrocardiography dataset containing
    21,837 clinical 12-lead ECGs from 18,885 patients of 10 second length.
    Each ECG is annotated with up to 71 different diagnostic statements conforming
    to the SCP-ECG standard.

    The dataset supports multiple diagnostic classification tasks:
    - Superclass: 5 diagnostic superclasses (NORM, MI, STTC, CD, HYP)
    - Subclass: 23 diagnostic subclasses
    - Diagnostic: All individual diagnostic SCP codes (~44 statements)
    - Form: 19 form statements
    - Rhythm: 12 rhythm statements
    - All: All SCP statement codes (~71 statements)

    Reference:
        Wagner P, Strodthoff N, Bousseljot RD, Kreiseler D, Lunze FI, Samek W, Schaeffter T.
        PTB-XL, a large publicly available electrocardiography dataset.
        Scientific Data. 2020 May 25;7(1):154.

    URL:
        https://physionet.org/content/ptb-xl/1.0.3/
    """

    CLASS_LABELS_SUPERCLASS: ClassVar[List[str]] = ["NORM", "MI", "STTC", "CD", "HYP"]
    CLASS_LABELS_SUBCLASS: ClassVar[List[str]] = [
        "NORM",
        "IMI",
        "ASMI",
        "ILMI",
        "AMI",
        "ALMI",
        "INJAS",
        "LMI",
        "INJAL",
        "ISCAL",
        "ISCAN",
        "INJIN",
        "INJLA",
        "PMI",
        "INJIL",
        "ISCIN",
        "ISCIL",
        "ISCAS",
        "LAFB",
        "IRBBB",
        "LPFB",
        "CRBBB",
        "CLBBB",
    ]
    CLASS_LABELS_FORM: ClassVar[List[str]] = [
        "NDT",
        "NST_",
        "DIG",
        "LNGQT",
        "ABQRS",
        "PVC",
        "STD_",
        "VCLVH",
        "QWAVE",
        "LOWT",
        "NT_",
        "PAC",
        "LPR",
        "INVT",
        "LVOLT",
        "HVOLT",
        "TAB_",
        "STE_",
        "PRC(S)",
    ]
    CLASS_LABELS_RHYTHM: ClassVar[List[str]] = [
        "SR",
        "AFIB",
        "STACH",
        "SARRH",
        "SBRAD",
        "PACE",
        "SVARR",
        "BIGU",
        "AFLT",
        "SVTAC",
        "PSVT",
        "TRIGU",
    ]

    LABEL_MAPPING_SUPERCLASS: ClassVar[Dict[str, int]] = {
        label: i for i, label in enumerate(CLASS_LABELS_SUPERCLASS)
    }

    LEADS: ClassVar[List[str]] = [
        "I",
        "II",
        "III",
        "AVR",
        "AVL",
        "AVF",
        "V1",
        "V2",
        "V3",
        "V4",
        "V5",
        "V6",
    ]
    SAMPLING_RATE_HR: ClassVar[int] = 500
    SAMPLING_RATE_LR: ClassVar[int] = 100

    def __init__(
        self,
        data_dir: Optional[Union[str, Path]] = None,
        sampling_rate: int = 500,
        task: str = "superclass",
        use_high_resolution: bool = True,
        folds: Optional[List[int]] = None,
        leads: Optional[List[str]] = None,
        normalization: str = "zscore",
        multi_label: bool = True,
        transform: Optional[callable] = None,
        target_transform: Optional[callable] = None,
        download: bool = False,
        force_download: bool = False,
        verbose: bool = True,
    ):
        """Initialize the PTB-XL dataset.

        Args:
            data_dir: Directory where the dataset is stored or will be downloaded.
                     If None, uses ~/.deepecgkit/datasets/ptbxldataset
            sampling_rate: Target sampling rate for the ECG signals (Hz)
            task: Classification task - one of "superclass", "subclass",
                  "diagnostic", "form", "rhythm", or "all"
            use_high_resolution: Whether to use 500Hz (True) or 100Hz (False) recordings
            folds: List of folds to include (1-10). None means all folds.
                   Recommended: folds 1-8 for training, 9 for validation, 10 for testing
            leads: List of lead names to use. None means all 12 leads.
            normalization: Normalization method - "zscore", "minmax", or "none"
            multi_label: If True, returns multi-hot encoded labels. If False, returns
                        single label (first/primary diagnosis)
            transform: Optional transform to be applied to the ECG signals
            target_transform: Optional transform to be applied to the labels
            download: Whether to download the dataset if it doesn't exist
            verbose: Whether to print progress information
        """
        valid_tasks = ["superclass", "subclass", "diagnostic", "form", "rhythm", "all"]
        if task not in valid_tasks:
            raise ValueError(f"task must be one of {valid_tasks}, got '{task}'")

        self.task = task
        self.use_high_resolution = use_high_resolution
        self.folds = folds
        self.multi_label = multi_label
        self.verbose = verbose
        self._leads = leads

        source_sampling_rate = (
            self.SAMPLING_RATE_HR if use_high_resolution else self.SAMPLING_RATE_LR
        )

        self.standardizer = ECGStandardizer(
            target_sampling_rate=sampling_rate,
            target_duration_seconds=10.0,
            normalization=normalization,
            clip_method="center",
        )

        self.signals: List[np.ndarray] = []
        self.labels: List[np.ndarray] = []
        self.record_names: List[str] = []
        self.metadata_df: Optional[pd.DataFrame] = None
        self.scp_statements: Optional[pd.DataFrame] = None
        self.source_sampling_rate = source_sampling_rate

        super().__init__(
            data_dir=data_dir,
            sampling_rate=sampling_rate,
            leads=leads if leads else self.LEADS,
            transform=transform,
            target_transform=target_transform,
            download=download,
            force_download=force_download,
        )

    @staticmethod
    def _fix_record_list(records: List[str]) -> List[str]:
        """Fix corrupted RECORDS file from PhysioNet.

        The PTB-XL RECORDS file on PhysioNet has a missing newline that
        concatenates two record paths (e.g. 'records100/...records500/...').
        This splits them back into separate entries.
        """
        fixed = []
        for rec in records:
            if "records100" in rec and "records500" in rec:
                idx = rec.index("records500")
                fixed.append(rec[:idx])
                fixed.append(rec[idx:])
            else:
                fixed.append(rec)
        return fixed

    DOWNLOAD_URL: ClassVar[str] = (
        "https://physionet.org/static/published-projects/ptb-xl/"
        "ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3.zip"
    )

    def download(self):
        """Download the PTB-XL dataset from PhysioNet as a single ZIP."""
        if self.verbose:
            print(f"Downloading PTB-XL dataset to {self.data_dir}")
            print("Note: This is a large dataset (~2.5GB). Download may take a while.")

        os.makedirs(self.data_dir, exist_ok=True)

        zip_path = self.data_dir / "ptb-xl-1.0.3.zip"
        download_file(
            self.DOWNLOAD_URL,
            zip_path,
            desc="Downloading PTB-XL",
            max_retries=5,
        )

        if self.verbose:
            print("Extracting dataset...")

        with zipfile.ZipFile(zip_path, "r") as zf:
            zf.extractall(self.data_dir)

        nested_dir = (
            self.data_dir / "ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3"
        )
        if nested_dir.is_dir():
            for item in nested_dir.iterdir():
                dest = self.data_dir / item.name
                if dest.exists():
                    if dest.is_dir():
                        shutil.rmtree(dest)
                    else:
                        dest.unlink()
                shutil.move(str(item), str(dest))
            nested_dir.rmdir()

        zip_path.unlink(missing_ok=True)

        if self.verbose:
            print("Download complete!")

    def _load_data(self):
        """Load the PTB-XL dataset into memory."""
        if self.verbose:
            print("Loading PTB-XL dataset...")

        metadata_path = self.data_dir / "ptbxl_database.csv"
        if not metadata_path.exists():
            raise FileNotFoundError(
                f"Metadata file not found at {metadata_path}. "
                "Use download=True or manually download from: "
                "https://physionet.org/content/ptb-xl/1.0.3/"
            )

        self.metadata_df = pd.read_csv(metadata_path, index_col="ecg_id")
        self.metadata_df.scp_codes = self.metadata_df.scp_codes.apply(ast.literal_eval)

        scp_path = self.data_dir / "scp_statements.csv"
        if scp_path.exists():
            self.scp_statements = pd.read_csv(scp_path, index_col=0)
        else:
            self.scp_statements = None

        if self.folds is not None:
            self.metadata_df = self.metadata_df[self.metadata_df.strat_fold.isin(self.folds)]
            if self.verbose:
                print(f"Using folds: {self.folds} ({len(self.metadata_df)} records)")

        label_columns = self._get_label_columns()

        if self._leads:
            lead_indices = [self.LEADS.index(lead) for lead in self._leads if lead in self.LEADS]
        else:
            lead_indices = list(range(len(self.LEADS)))

        iterator = (
            tqdm(self.metadata_df.iterrows(), total=len(self.metadata_df), desc="Loading records")
            if self.verbose
            else self.metadata_df.iterrows()
        )

        for ecg_id, row in iterator:
            try:
                record_path = (
                    self.data_dir / row.filename_hr
                    if self.use_high_resolution
                    else self.data_dir / row.filename_lr
                )
                record_path = str(record_path).replace(".hea", "").replace(".dat", "")

                record = wfdb.rdrecord(record_path)
                signal = record.p_signal.T

                signal = signal[lead_indices, :]

                if self.source_sampling_rate != self.sampling_rate:
                    signal = self.standardizer.resample(signal, self.source_sampling_rate)

                signal = self.standardizer.normalize(signal)

                labels = self._extract_labels(row.scp_codes, label_columns)

                self.signals.append(signal.astype(np.float32))
                self.labels.append(labels)
                self.record_names.append(str(ecg_id))

            except Exception as e:
                if self.verbose:
                    print(f"Error loading record {ecg_id}: {e}")
                continue

        if self.verbose:
            print(f"Successfully loaded {len(self.signals)} records")
            self._print_class_distribution()

    def _get_label_columns(self) -> List[str]:
        """Get the label columns based on the task."""
        if self.task == "superclass":
            return self.CLASS_LABELS_SUPERCLASS
        elif self.task == "subclass":
            return self.CLASS_LABELS_SUBCLASS
        elif self.task == "form":
            return self.CLASS_LABELS_FORM
        elif self.task == "rhythm":
            return self.CLASS_LABELS_RHYTHM
        elif self.task == "diagnostic":
            if self.scp_statements is not None:
                diag = self.scp_statements[self.scp_statements.diagnostic == 1.0]
                return sorted(diag.index.tolist())
            raise ValueError(
                "scp_statements.csv is required for the 'diagnostic' task. "
                "Ensure the dataset is fully downloaded."
            )
        else:  # all
            if self.scp_statements is not None:
                return sorted(self.scp_statements.index.tolist())
            # Fallback: union of hardcoded lists
            all_labels = (
                self.CLASS_LABELS_SUPERCLASS
                + self.CLASS_LABELS_SUBCLASS
                + self.CLASS_LABELS_FORM
                + self.CLASS_LABELS_RHYTHM
            )
            return sorted(set(all_labels))

    def _extract_labels(self, scp_codes: Dict[str, float], label_columns: List[str]) -> np.ndarray:
        """Extract labels from SCP codes."""
        if self.multi_label:
            labels = np.zeros(len(label_columns), dtype=np.float32)
            if self.task == "superclass" and self.scp_statements is not None:
                for code, likelihood in scp_codes.items():
                    if likelihood >= 50.0 and code in self.scp_statements.index:
                        stmt = self.scp_statements.loc[code]
                        if pd.notna(stmt.diagnostic_class):
                            superclass = stmt.diagnostic_class
                            if superclass in label_columns:
                                idx = label_columns.index(superclass)
                                labels[idx] = 1.0
            else:
                for code, likelihood in scp_codes.items():
                    if code in label_columns:
                        idx = label_columns.index(code)
                        labels[idx] = 1.0 if likelihood >= 50.0 else 0.0
            return labels
        else:
            if self.task == "superclass" and self.scp_statements is not None:
                superclass_counts = {cls: 0.0 for cls in self.CLASS_LABELS_SUPERCLASS}
                for code, likelihood in scp_codes.items():
                    if likelihood >= 50.0 and code in self.scp_statements.index:
                        stmt = self.scp_statements.loc[code]
                        if pd.notna(stmt.diagnostic_class):
                            superclass = stmt.diagnostic_class
                            if superclass in superclass_counts:
                                superclass_counts[superclass] += likelihood
                if any(v > 0 for v in superclass_counts.values()):
                    primary_class = max(superclass_counts, key=superclass_counts.get)
                    return np.array(
                        self.CLASS_LABELS_SUPERCLASS.index(primary_class), dtype=np.int64
                    )
            for code, likelihood in scp_codes.items():
                if code in label_columns and likelihood >= 50.0:
                    return np.array(label_columns.index(code), dtype=np.int64)
            return np.array(0, dtype=np.int64)

    def _print_class_distribution(self):
        """Print class distribution statistics."""
        if len(self.labels) == 0:
            return

        label_columns = self._get_label_columns()
        print(f"\nClass distribution ({self.task}):")

        if self.multi_label:
            labels_array = np.stack(self.labels)
            for i, class_name in enumerate(label_columns):
                count = int(labels_array[:, i].sum())
                if count > 0:
                    print(f"  {class_name}: {count}")
        else:
            unique, counts = np.unique(self.labels, return_counts=True)
            for label_idx, count in zip(unique, counts):
                if label_idx < len(label_columns):
                    print(f"  {label_columns[label_idx]}: {count}")

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get a sample from the dataset.

        Args:
            idx: Index of the sample to get

        Returns:
            Tuple of (ecg_signal, label) where:
                - ecg_signal: Tensor of shape (num_leads, signal_length)
                - label: Tensor of shape (num_classes,) for multi-label or scalar for single-label
        """
        signal = self.signals[idx]
        label = self.labels[idx]

        signal = convert_to_tensor(signal, dtype=torch.float32)

        if self.multi_label:
            label = torch.tensor(label, dtype=torch.float32)
        else:
            label = torch.tensor(label, dtype=torch.long)

        if self.transform is not None:
            signal = self.transform(signal)

        if self.target_transform is not None:
            label = self.target_transform(label)

        return signal, label

    def __len__(self) -> int:
        return len(self.labels)

    @property
    def num_classes(self) -> int:
        """Get the number of classes in the dataset."""
        return len(self._get_label_columns())

    @property
    def class_names(self) -> List[str]:
        """Get the names of the classes in the dataset."""
        return self._get_label_columns()

    def get_record_info(self, idx: int) -> Dict:
        """Get record information for a specific sample.

        Args:
            idx: Index of the sample

        Returns:
            Dictionary containing record information
        """
        ecg_id = int(self.record_names[idx])
        row = self.metadata_df.loc[ecg_id]

        return {
            "ecg_id": ecg_id,
            "patient_id": row.patient_id,
            "age": row.age,
            "sex": row.sex,
            "signal_shape": self.signals[idx].shape,
            "scp_codes": row.scp_codes,
            "strat_fold": row.strat_fold,
            "labels": self.labels[idx].tolist() if self.multi_label else int(self.labels[idx]),
        }

    def get_class_distribution(self) -> Dict[str, int]:
        """Get the distribution of classes in the dataset.

        Returns:
            Dictionary mapping class names to their counts
        """
        label_columns = self._get_label_columns()
        distribution = {name: 0 for name in label_columns}

        if self.multi_label:
            labels_array = np.stack(self.labels)
            for i, class_name in enumerate(label_columns):
                distribution[class_name] = int(labels_array[:, i].sum())
        else:
            for lbl in self.labels:
                label_idx = int(lbl) if isinstance(lbl, np.ndarray) else lbl
                if label_idx < len(label_columns):
                    distribution[label_columns[label_idx]] += 1

        return distribution

    def get_folds_split(
        self,
        train_folds: Optional[List[int]] = None,
        val_folds: Optional[List[int]] = None,
        test_folds: Optional[List[int]] = None,
    ) -> Dict[str, "PTBXLDataset"]:
        """Create train/val/test splits based on stratified folds.

        The PTB-XL dataset comes with 10 pre-defined stratified folds.
        Recommended split: folds 1-8 for training, 9 for validation, 10 for testing.

        Args:
            train_folds: Folds for training (default: 1-8)
            val_folds: Folds for validation (default: 9)
            test_folds: Folds for testing (default: 10)

        Returns:
            Dictionary with 'train', 'val', 'test' PTBXLDataset instances
        """
        if train_folds is None:
            train_folds = list(range(1, 9))
        if val_folds is None:
            val_folds = [9]
        if test_folds is None:
            test_folds = [10]

        common_kwargs = {
            "data_dir": self.data_dir,
            "sampling_rate": self.sampling_rate,
            "task": self.task,
            "use_high_resolution": self.use_high_resolution,
            "leads": self._leads,
            "normalization": self.standardizer.normalization,
            "multi_label": self.multi_label,
            "transform": self.transform,
            "target_transform": self.target_transform,
            "download": False,
            "verbose": self.verbose,
        }

        return {
            "train": PTBXLDataset(folds=train_folds, **common_kwargs),
            "val": PTBXLDataset(folds=val_folds, **common_kwargs),
            "test": PTBXLDataset(folds=test_folds, **common_kwargs),
        }

num_classes property

num_classes: int

Get the number of classes in the dataset.

class_names property

class_names: List[str]

Get the names of the classes in the dataset.

download

download()

Download the PTB-XL dataset from PhysioNet as a single ZIP.

Source code in deepecgkit/datasets/ptbxl.py
def download(self):
    """Download the PTB-XL dataset from PhysioNet as a single ZIP."""
    if self.verbose:
        print(f"Downloading PTB-XL dataset to {self.data_dir}")
        print("Note: This is a large dataset (~2.5GB). Download may take a while.")

    os.makedirs(self.data_dir, exist_ok=True)

    zip_path = self.data_dir / "ptb-xl-1.0.3.zip"
    download_file(
        self.DOWNLOAD_URL,
        zip_path,
        desc="Downloading PTB-XL",
        max_retries=5,
    )

    if self.verbose:
        print("Extracting dataset...")

    with zipfile.ZipFile(zip_path, "r") as zf:
        zf.extractall(self.data_dir)

    nested_dir = (
        self.data_dir / "ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3"
    )
    if nested_dir.is_dir():
        for item in nested_dir.iterdir():
            dest = self.data_dir / item.name
            if dest.exists():
                if dest.is_dir():
                    shutil.rmtree(dest)
                else:
                    dest.unlink()
            shutil.move(str(item), str(dest))
        nested_dir.rmdir()

    zip_path.unlink(missing_ok=True)

    if self.verbose:
        print("Download complete!")

get_record_info

get_record_info(idx: int) -> Dict

Get record information for a specific sample.

Parameters:

Name Type Description Default
idx int

Index of the sample

required

Returns:

Type Description
Dict

Dictionary containing record information

Source code in deepecgkit/datasets/ptbxl.py
def get_record_info(self, idx: int) -> Dict:
    """Get record information for a specific sample.

    Args:
        idx: Index of the sample

    Returns:
        Dictionary containing record information
    """
    ecg_id = int(self.record_names[idx])
    row = self.metadata_df.loc[ecg_id]

    return {
        "ecg_id": ecg_id,
        "patient_id": row.patient_id,
        "age": row.age,
        "sex": row.sex,
        "signal_shape": self.signals[idx].shape,
        "scp_codes": row.scp_codes,
        "strat_fold": row.strat_fold,
        "labels": self.labels[idx].tolist() if self.multi_label else int(self.labels[idx]),
    }

get_class_distribution

get_class_distribution() -> Dict[str, int]

Get the distribution of classes in the dataset.

Returns:

Type Description
Dict[str, int]

Dictionary mapping class names to their counts

Source code in deepecgkit/datasets/ptbxl.py
def get_class_distribution(self) -> Dict[str, int]:
    """Get the distribution of classes in the dataset.

    Returns:
        Dictionary mapping class names to their counts
    """
    label_columns = self._get_label_columns()
    distribution = {name: 0 for name in label_columns}

    if self.multi_label:
        labels_array = np.stack(self.labels)
        for i, class_name in enumerate(label_columns):
            distribution[class_name] = int(labels_array[:, i].sum())
    else:
        for lbl in self.labels:
            label_idx = int(lbl) if isinstance(lbl, np.ndarray) else lbl
            if label_idx < len(label_columns):
                distribution[label_columns[label_idx]] += 1

    return distribution

get_folds_split

get_folds_split(
    train_folds: Optional[List[int]] = None,
    val_folds: Optional[List[int]] = None,
    test_folds: Optional[List[int]] = None,
) -> Dict[str, PTBXLDataset]

Create train/val/test splits based on stratified folds.

The PTB-XL dataset comes with 10 pre-defined stratified folds. Recommended split: folds 1-8 for training, 9 for validation, 10 for testing.

Parameters:

Name Type Description Default
train_folds Optional[List[int]]

Folds for training (default: 1-8)

None
val_folds Optional[List[int]]

Folds for validation (default: 9)

None
test_folds Optional[List[int]]

Folds for testing (default: 10)

None

Returns:

Type Description
Dict[str, PTBXLDataset]

Dictionary with 'train', 'val', 'test' PTBXLDataset instances

Source code in deepecgkit/datasets/ptbxl.py
def get_folds_split(
    self,
    train_folds: Optional[List[int]] = None,
    val_folds: Optional[List[int]] = None,
    test_folds: Optional[List[int]] = None,
) -> Dict[str, "PTBXLDataset"]:
    """Create train/val/test splits based on stratified folds.

    The PTB-XL dataset comes with 10 pre-defined stratified folds.
    Recommended split: folds 1-8 for training, 9 for validation, 10 for testing.

    Args:
        train_folds: Folds for training (default: 1-8)
        val_folds: Folds for validation (default: 9)
        test_folds: Folds for testing (default: 10)

    Returns:
        Dictionary with 'train', 'val', 'test' PTBXLDataset instances
    """
    if train_folds is None:
        train_folds = list(range(1, 9))
    if val_folds is None:
        val_folds = [9]
    if test_folds is None:
        test_folds = [10]

    common_kwargs = {
        "data_dir": self.data_dir,
        "sampling_rate": self.sampling_rate,
        "task": self.task,
        "use_high_resolution": self.use_high_resolution,
        "leads": self._leads,
        "normalization": self.standardizer.normalization,
        "multi_label": self.multi_label,
        "transform": self.transform,
        "target_transform": self.target_transform,
        "download": False,
        "verbose": self.verbose,
    }

    return {
        "train": PTBXLDataset(folds=train_folds, **common_kwargs),
        "val": PTBXLDataset(folds=val_folds, **common_kwargs),
        "test": PTBXLDataset(folds=test_folds, **common_kwargs),
    }

UnifiedAFDataset

Bases: BaseECGDataset

Unified AF Dataset combining multiple PhysioNet AF databases.

Combines samples from the PhysioNet 2017 Challenge, MIT-BIH AFDB, and LTAFDB into a single dataset for AF classification. Supports both binary (AF vs Non-AF) and 4-class (Normal, AF, AFL, J) classification modes.

PhysioNet 2017 labels are remapped to the unified scheme
  • Normal (N) → Normal, AF (A) → AF
  • Other (O) and Noisy (~) are dropped in 4-class mode
  • Other (O) → Non-AF and Noisy (~) is dropped in binary mode
Source code in deepecgkit/datasets/unified_af.py
@register_dataset(
    name="unified-af",
    input_channels=1,
    num_classes=4,
    description="Unified AF dataset combining PhysioNet 2017, MIT-BIH, LTAFDB",
)
class UnifiedAFDataset(BaseECGDataset):
    """Unified AF Dataset combining multiple PhysioNet AF databases.

    Combines samples from the PhysioNet 2017 Challenge, MIT-BIH AFDB, and LTAFDB
    into a single dataset for AF classification. Supports both binary (AF vs Non-AF)
    and 4-class (Normal, AF, AFL, J) classification modes.

    PhysioNet 2017 labels are remapped to the unified scheme:
      - Normal (N) → Normal, AF (A) → AF
      - Other (O) and Noisy (~) are dropped in 4-class mode
      - Other (O) → Non-AF and Noisy (~) is dropped in binary mode
    """

    CLASS_LABELS: ClassVar[List[str]] = ["Normal", "AF", "AFL", "J"]
    LEADS: ClassVar[List[str]] = ["ECG"]

    AVAILABLE_DATASETS: ClassVar[Dict[str, Type[BaseECGDataset]]] = {
        "physionet2017": AFClassificationDataset,
        "mitbih_afdb": MITBIHAFDBDataset,
        "ltafdb": LTAFDBDataset,
    }

    def _resolve_dataset_dir(self, dataset_name: str) -> Path:
        dataset_class = self.AVAILABLE_DATASETS[dataset_name]
        return dataset_class.get_default_data_dir()

    def __init__(
        self,
        data_dir: Optional[Union[str, Path]] = None,
        sampling_rate: int = 300,
        segment_duration_seconds: float = 10.0,
        datasets: Optional[List[str]] = None,
        binary_classification: bool = False,
        normalization: str = "zscore",
        transform: Optional[callable] = None,
        target_transform: Optional[callable] = None,
        download: bool = False,
        force_download: bool = False,
        verbose: bool = True,
        dataset_kwargs: Optional[Dict[str, Dict]] = None,
    ):
        self.segment_duration_seconds = segment_duration_seconds
        self.binary_classification = binary_classification
        self.verbose = verbose
        self.dataset_kwargs = dataset_kwargs or {}

        if datasets is None:
            datasets = ["physionet2017", "mitbih_afdb", "ltafdb"]

        for dataset_name in datasets:
            if dataset_name not in self.AVAILABLE_DATASETS:
                raise ValueError(
                    f"Unknown dataset: {dataset_name}. "
                    f"Available: {list(self.AVAILABLE_DATASETS.keys())}"
                )

        self.dataset_names = datasets
        self.datasets = []
        self.dataset_sizes = []

        super().__init__(
            data_dir=data_dir,
            sampling_rate=sampling_rate,
            leads=self.LEADS,
            transform=transform,
            target_transform=target_transform,
            download=download,
            force_download=force_download,
        )

    def download(self):
        for dataset_name in self.dataset_names:
            dataset_class = self.AVAILABLE_DATASETS[dataset_name]
            dataset_dir = self._resolve_dataset_dir(dataset_name)

            if dataset_dir.exists() and any(dataset_dir.iterdir()):
                if self.verbose:
                    print(f"\n{'=' * 60}")
                    print(f"Skipping {dataset_name}: already exists at {dataset_dir}")
                    print(f"{'=' * 60}")
                continue

            if self.verbose:
                print(f"\n{'=' * 60}")
                print(f"Downloading dataset: {dataset_name}")
                print(f"{'=' * 60}")

            dataset_dir.mkdir(parents=True, exist_ok=True)

            # Minimal instance for download only — sub-dataset download()
            # methods only require self.data_dir and self.verbose.
            instance = object.__new__(dataset_class)
            instance.data_dir = dataset_dir
            instance.verbose = self.verbose
            instance.download()

    def _get_dataset_kwargs(self, dataset_name: str) -> Dict:
        base_kwargs = {
            "sampling_rate": self.sampling_rate,
            "transform": self.transform,
            "segment_duration_seconds": self.segment_duration_seconds,
            "download": False,
        }

        if dataset_name in ["mitbih_afdb", "ltafdb"]:
            base_kwargs["target_transform"] = self.target_transform
            base_kwargs["binary_classification"] = self.binary_classification
        elif dataset_name == "physionet2017":
            # target_transform is omitted — applied after label remapping
            # in _RemappedDataset (see _load_data)
            pass

        if dataset_name in self.dataset_kwargs:
            base_kwargs.update(self.dataset_kwargs[dataset_name])

        return base_kwargs

    def _load_data(self):
        if self.verbose:
            print("\nLoading unified AF dataset...")

        for dataset_name in self.dataset_names:
            if self.verbose:
                print(f"\n{'=' * 60}")
                print(f"Loading: {dataset_name}")
                print(f"{'=' * 60}")

            dataset_class = self.AVAILABLE_DATASETS[dataset_name]
            dataset_dir = self._resolve_dataset_dir(dataset_name)

            kwargs = self._get_dataset_kwargs(dataset_name)
            kwargs["data_dir"] = dataset_dir
            kwargs["verbose"] = self.verbose

            try:
                dataset = dataset_class(**kwargs)

                # Remap PhysioNet 2017 labels to the unified scheme
                if dataset_name == "physionet2017":
                    remap = (
                        _PHYSIONET2017_BINARY_REMAP
                        if self.binary_classification
                        else _PHYSIONET2017_LABEL_REMAP
                    )
                    original_len = len(dataset)
                    dataset = _RemappedDataset(
                        dataset, remap, target_transform=self.target_transform
                    )
                    dropped = original_len - len(dataset)
                    if dropped > 0 and self.verbose:
                        print(f"  Dropped {dropped} samples with unmappable labels (Other/Noisy)")

                self.datasets.append(dataset)
                self.dataset_sizes.append(len(dataset))

                if self.verbose:
                    print(f"Loaded {len(dataset)} samples from {dataset_name}")
                    if hasattr(dataset, "get_class_distribution"):
                        dist = dataset.get_class_distribution()
                        print(f"Class distribution: {dist}")

            except Exception as e:
                logger.warning("Failed to load %s: %s", dataset_name, e)
                if self.verbose:
                    print(f"Failed to load {dataset_name}: {e}")
                continue

        if len(self.datasets) == 0:
            raise RuntimeError("No datasets were successfully loaded")

        self.concat_dataset = ConcatDataset(self.datasets)

        if self.verbose:
            print(f"\n{'=' * 60}")
            print("Unified dataset created")
            print(f"Total datasets: {len(self.datasets)}")
            print(f"Total samples: {len(self.concat_dataset)}")
            print(f"{'=' * 60}")
            self._print_overall_distribution()

    def _print_overall_distribution(self):
        if len(self.datasets) == 0:
            return

        all_labels = []
        for dataset in self.datasets:
            if hasattr(dataset, "labels"):
                all_labels.extend(dataset.labels)

        if len(all_labels) == 0:
            return

        unique, counts = np.unique(all_labels, return_counts=True)
        print("\nOverall class distribution:")
        for label_idx, count in zip(unique, counts):
            class_name = self.class_names[label_idx]
            percentage = (count / len(all_labels)) * 100
            print(f"  {class_name}: {count} ({percentage:.1f}%)")

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.concat_dataset[idx]

    def __len__(self) -> int:
        return len(self.concat_dataset)

    @property
    def num_classes(self) -> int:
        return 2 if self.binary_classification else len(self.CLASS_LABELS)

    @property
    def class_names(self) -> List[str]:
        if self.binary_classification:
            return ["Non-AF", "AF"]
        return self.CLASS_LABELS

    def get_dataset_info(self) -> Dict:
        return {
            "num_datasets": len(self.datasets),
            "dataset_names": self.dataset_names,
            "dataset_sizes": self.dataset_sizes,
            "total_samples": len(self),
            "num_classes": self.num_classes,
            "class_names": self.class_names,
        }

    def get_class_distribution(self) -> Dict[str, int]:
        all_labels = []
        for dataset in self.datasets:
            if hasattr(dataset, "labels"):
                all_labels.extend(dataset.labels)

        if len(all_labels) == 0:
            return {}

        unique, counts = np.unique(all_labels, return_counts=True)
        return {self.class_names[int(label)]: int(count) for label, count in zip(unique, counts)}