PyTorch 是一个开源的深度学习框架,由 Facebook 的人工智能研究团队(FAIR)开发和维护。pytorch下面有很多子模块,常用的子模块如下:
torch.nn
-
torch.nn
提供了构建神经网络所需的各种模块、损失函数和容器类。用户可以通过继承nn.Module
来定义自己的模型。 -
torch.nn.functional
提供了许多函数式的接口,这些接口可以在定义神经网络的前向传播过程中使用。与 torch.nn 提供的模块化接口(如 nn.Linear, nn.Conv2d 等)不同,torch.nn.functional 提供的是函数式的接口,这使得定义神经网络的结构更加灵活
torch.optim
torch.optim
包含了多种优化算法,如 SGD、Adam、RMSprop 等,可以轻松地应用于模型参数的更新。
torch.utils.data
torch.utils.data
提供了强大的数据加载和处理工具,包括Dataset
和 DataLoader
类,支持批处理、打乱和并行加载。torch.utils.data的其他常用函数包括:Sampler
示例代码
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
网友评论