torch.utils.data.DataLoader
主要用于数据读取的一个接口,一般在Pytorch中训练模型时用到,详细见torch.utils.data — PyTorch master documentation
DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None)
dataset是这个主要的组成部分,表示数据加载的对象
batch_size表示每次加载多少个样本,默认为1
shuffle表示是否打乱顺序,默认为否
drop_last=False表示不舍去所有数据量除以每批次个数多余的部分
后面的几个参数没用过,看不懂
torch.nn.Sequential()
表示一种容器,主要用于神经网络模块,保证网络的流动顺序,自带forward方法
torch.nn.ModuleDict()
nn.ModuleDict同样也是一种容器
name=nn.ModuleDict({ 可供选择的网络/可供选择的激活函数})对应的forward:x=self.name[选择的序号](x)
class Modeldict(nn.Module):
def __init__(self):
super(Modeldict,self).__init__()
self.choices = nn.ModuleDict({
"conv1": nn.Conv2d(10,10,3),
"pool": nn.MaxPool2d(3) })
self.activations = nn.ModuleDict({
"relu": nn.ReLU(),
"prelu": nn.PReLU() })
def forward(self,x,choice,act):
x = self.choices[choice](x)
x = self.activations[act](x) return x
#input = img
model = Modeldict()
out = model(input,"pool","prelu")
优化函数Adam自适应优化算法,
torch.optim.Adam(model.parameters()) 需要对Model中所有生成的参数进行优化
网友评论