要使用纯 PyTorch 训练模型,而不是依赖于 PyTorch Lightning,需要手动实现训练循环、验证、测试步骤以及优化器的配置。
- 准备数据集和数据加载器(Data Loaders)
使用 scDataset 类来创建训练、验证和测试数据集。创建 PyTorch 的 DataLoader 实例,用于加载数据。 - 构建模型
初始化模型(如 Encoder 和 Decoder)。 - 定义优化器和损失函数
设置优化器,例如 Adam 或 SGD。定义损失函数,例如三元组损失和均方误差损失。 - 训练循环
对数据进行迭代,执行正向传播、计算损失、进行反向传播和优化器步骤。 - 验证和测试
在训练过程中或之后,对验证和测试数据集进行评估。
以下是如何在 PyTorch 中实现这些步骤的示例代码:
import torch
from torch.utils.data import DataLoader
# 1. 准备数据集和数据加载器
train_dataset = scDataset(...) # 使用适当的参数填充
train_loader = DataLoader(train_dataset, batch_size=..., num_workers=..., sampler=...)
val_dataset = scDataset(...)
val_loader = DataLoader(val_dataset, batch_size=..., num_workers=...)
# 2. 构建模型
encoder = Encoder(...)
decoder = Decoder(...)
# 3. 定义优化器和损失函数
optimizer = torch.optim.Adam([...], lr=...)
triplet_loss_fn = TripletLoss(...)
mse_loss_fn = torch.nn.MSELoss()
# 4. 训练循环
for epoch in range(num_epochs):
for batch in train_loader:
cells, labels, _ = batch
optimizer.zero_grad()
embeddings = encoder(cells)
reconstructions = decoder(embeddings)
triplet_loss = triplet_loss_fn(embeddings, labels, ...)
reconstruction_loss = mse_loss_fn(cells, reconstructions)
loss = ... # 根据需要组合损失
loss.backward()
optimizer.step()
# 5. 验证步骤
with torch.no_grad():
for batch in val_loader:
cells, labels, _ = batch
embeddings = encoder(cells)
reconstructions = decoder(embeddings)
# 计算和记录验证损失
...
# 测试步骤类似于验证步骤,只是使用测试数据集
网友评论