美文网首页
学习写python包——data_models.py

学习写python包——data_models.py

作者: KK_f2d5 | 来源:发表于2023-12-25 11:25 被阅读0次

scDataset 类

from collections import Counter
from typing import Optional

import anndata
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import scanpy
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

class scDataset(Dataset):
    """A class that represent a single cell dataset."""

    def __init__(self, X, Y, study=None):
        self.X = X
        self.Y = Y
        self.study = study
        self.classes = set(self.Y)

    def __len__(self):
        return len(self.Y)

    def __getitem__(self, idx):
        # data, label, study
        return self.X[idx].A, self.Y[idx], self.study[idx]

这段代码定义了 scDataset 类,它是一个用于表示单细胞数据集的 Python 类。这个类继承自 torch.utils.data.Dataset,使其兼容于 PyTorch 的数据加载和处理机制。下面我将详细解释这个类及其方法的功能:

导入的模块
collections.Counter: 用于计数不同元素的出现次数。
typing.Optional: 用于类型注解,表示参数可以为 None。
anndata: 用于处理单细胞数据的库。
pytorch_lightning: 简化 PyTorch 模型训练的库。
scanpy: 用于单细胞数据分析的库。

构造函数 init(self, X, Y, study=None)
参数:

X: 数据矩阵,通常是一个稀疏矩阵,包含细胞的基因表达数据。
Y: 标签数组,包含与 X 中每个样本相对应的标签(例如,细胞类型)。
study: 可选参数,包含与 X 中每个样本相对应的研究或实验信息。
属性初始化:

self.X: 存储传入的数据矩阵 X。
self.Y: 存储传入的标签 Y。
self.study: 存储传入的研究信息 study。
self.classes: 从 Y 中提取的唯一标签集合,表示数据中包含的不同类别。
len(self) 方法
返回数据集中样本的总数。这是通过计算 Y(标签数组)的长度来实现的。
getitem(self, idx) 方法
参数: idx - 请求的样本索引。
返回: 三元组 (data, label, study),其中:
data: 索引 idx 处的样本数据(从 X 中提取)。.A 用于将稀疏矩阵转换为常规数组。
label: 索引 idx 处的样本标签(从 Y 中提取)。
study: 索引 idx 处的样本对应的研究信息(从 study 中提取)。

DataLoader

在 PyTorch Lightning 框架中,MetricLearningDataModule 类继承自 pl.LightningDataModule。

DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True,
            sampler=self.get_sampler_weights(self.train_dataset),
            collate_fn=self.collate)

DataLoader 是一个极其重要的工具,用于批量加载数据并为训练提供必要的输入。为了使数据能够被 DataLoader 正确读取和处理,需要遵循特定的格式和协议。这就是为什么 scDataset 类按照特定方式实现的原因。
如何 scDataset 被 DataLoader 读取:

  1. 继承自 Dataset:

scDataset 继承自 PyTorch 的 Dataset 类,这意味着它需要实现两个方法:lengetitem
len 返回数据集中的样本数。
getitem 根据给定的索引返回相应的样本。

  1. 数据格式:

getitem 返回一个包含数据、标签和研究信息的元组。这是标准的 PyTorch 数据格式,允许 DataLoader 以一致的方式处理不同类型的数据集。

get_sampler_weights

其中, get_sampler_weights(self, dataset)
这个函数用于根据数据集生成加权随机采样器。加权随机采样器在数据不平衡的情况下非常有用,它可以确保在训练过程中各类别被均等地表示。

   def get_sampler_weights(self, dataset: scDataset) -> WeightedRandomSampler:
        """Get weighted random sampler.
        WeightedRandomSampler
            A WeightedRandomSampler object.
        """
        if dataset.study is None:
            class_sample_count = Counter(dataset.Y)
            sample_weights = torch.Tensor(
                [1.0 / class_sample_count[t] for t in dataset.Y]
            )
        else:
            class_sample_count = Counter(dataset.Y)
            study_sample_count = Counter(dataset.study)
            sample_weights = torch.Tensor(
                [
                    1.0
                    / class_sample_count[dataset.Y[i]]
                    / np.log(study_sample_count[dataset.study[i]])
                    for i in range(len(dataset.Y))
                ]
            )
        return WeightedRandomSampler(sample_weights, len(sample_weights))

实现逻辑
参数:dataset 是一个 scDataset 类的实例,包含数据集的特征、标签和其他信息。

处理:

如果 dataset 没有提供 study 信息,则根据类别标签 Y 计算每个类别的样本计数。然后,为每个样本计算权重,权重为类别的倒数。
如果提供了 study 信息,则同时考虑类别和研究的影响。在这种情况下,样本的权重是类别和研究的频率的对数的倒数。
返回:返回一个 WeightedRandomSampler 对象,用于在数据加载过程中按照计算的权重随机选择样本。

collate

def collate(self, batch):
        """Collate tensors.

        Parameters
        ----------
        batch:
            Batch to collate.

        Returns
        -------
        tuple
            A Tuple[torch.Tensor, torch.Tensor, list] containing information
            on the collated tensors.
        """
        profiles, labels, studies = tuple(
            map(list, zip(*batch))
        )  # tuple([list(t) for t in zip(*batch)])
        return (
            torch.squeeze(torch.Tensor(np.vstack(profiles))),
            torch.Tensor(labels),
            studies)

DataLoader 通过 collate_fn 参数接收一个函数,该函数定义了如何将多个样本组合成一个批次。这在处理不规则大小或不同类型的数据时尤其重要。

collate 函数的工作流程如下:

输入:batch,一个包含多个从 scDataset.getitem 返回的元组的列表。
处理:
使用 zip(*batch) 将批次中的元素分解为单独的列表(profiles, labels, studies)。
将每个列表转换为适当的 PyTorch 张量或保持为列表(如研究信息)。
对于数据 profiles,使用 np.vstack 将它们垂直堆叠成一个 NumPy 数组,然后转换为一个 PyTorch 张量。
返回:一个包含处理后的数据张量、标签张量和研究信息列表的元组

为什么这里需要一个collate_fn?

在 PyTorch 中,collate_fn 用于在数据加载过程中将多个样本组合成一个批次。通常,如果你的数据集返回的每个样本是一个简单的张量(比如图片或标签),你不需要提供一个自定义的 collate_fn,因为 PyTorch 的默认 collate_fn 已经可以处理这种情况。

然而,如果你的数据集返回的是复杂的数据结构或需要特殊处理(比如不同的数据类型组合、不规则的张量形状等),那么你可能需要提供一个自定义的 collate_fn 来正确地处理这些数据。

为什么 scDataset 类需要 collate_fn?
scDataset 类返回三种不同类型的数据:self.X(数据矩阵),self.Y(标签),和 self.study(研究)。这些数据可能需要特殊处理才能合并为一个批次,尤其是当它们包含不同类型的数据时。例如:

数据转换:self.X 可能是一个稀疏矩阵,需要转换为密集张量。
数据维度对齐:如果 self.X 中的样本有不同的形状,可能需要进行填充或裁剪以确保它们可以合并。
额外信息合并:self.Y 和 self.study 可能需要特殊处理才能与 self.X 正确对应。
什么时候不需要写 collate_fn?
如果你的数据集返回的每个样本已经是一个规则的张量,且不需要任何特殊的预处理或后处理,那么就不需要提供自定义的 collate_fn。在这种情况下,PyTorch 的默认 collate_fn 足以应对大多数情况,它会自动将多个样本堆叠成一个批次。例如,如果你的数据集只返回一组图片和对应的标签,而且所有图片都有相同的形状,那么默认的 collate_fn 就足够了。

相关文章

网友评论

      本文标题:学习写python包——data_models.py

      本文链接:https://www.haomeiwen.com/subject/zcdhndtx.html