基础

作者: 三点水_787a | 来源:发表于2019-03-13 14:56 被阅读0次

    一、数据导入部分

    torch.utils.data.Dataset,这是一个抽象类,在pytorch中所有和数据相关的类都要继承这个类来实现。

    torchvision.datasets.ImageFolder接口实现数据导入。torchvision.datasets.ImageFolder会返回一个列表(比如image_datasets[‘train’]或者image_datasets[‘val]),列表中的每个值都是一个tuple,每个tuple包含图像和标签信息。列表list是不能作为模型输入的,因此在PyTorch中需要用另一个类来封装list,那就是:

    torch.utils.data.DataLoader,它可以将list类型的输入数据封装成Tensor数据格式,以备模型使用。这里是对图像和标签分别封装成一个Tensor。

    data_transforms是一个字典。主要是进行一些图像预处理,比如resize、crop等。实现的时候采用的是torchvision.transforms模块,比如torchvision.transforms.Compose是用来管理所有transforms操作的,torchvision.transforms.RandomSizedCrop是做crop的。需要注意的是对于torchvision.transforms.RandomSizedCrop和transforms.RandomHorizontalFlip()等,输入对象都是PILImage,也就是用python的PIL库读进来的图像内容,而transforms.Normalize([0.5, 0.5,0.4], [0.2, 0.2,0.5])的作用对象需要是一个Tensor,因此在transforms.Normalize([0.5, 0.5, 0.4],[0.2, 0.2, 0.5])之前有一个

    transforms.ToTensor()就是用来生成Tensor的。另外transforms.Scale(256)其实就是resize操作,目前已经被transforms.Resize类取代了。

    将Tensor数据类型封装成Variable数据类型后就可以作为模型的输入了,用torch.autograd.Variable将Tensor封装成模型真正可以用的Variable数据类型。Variable可以看成是tensor的一种包装,其不仅包含了tensor的内容,还包含了梯度等信息。

    二、模块导入

    torchvision.models用来导入模块

    torch.nn模块来定义网络的所有层,比如卷积、降采样、损失层等等

    torch.optim模块定义优化函数

    三、训练

    在每个epoch开始时都要更新学习率:scheduler.step()

    设置模型状态为训练状态:model.train(True)

    先将网络中的所有梯度置0:optimizer.zero_grad()

    网络的前向传播:outputs =

    model(inputs)

    然后将输出的outputs和原来导入的labels作为loss函数的输入就可以得到损失了:loss =

    criterion(outputs, labels)

    输出的outputs也是torch.autograd.Variable格式,得到输出后(网络的全连接层的输出)还希望能到到模型预测该样本属于哪个类别的信息,这里采用torch.max。torch.max()的第一个输入是tensor格式,所以用outputs.data而不是outputs作为输入;第二个参数1是代表dim的意思,也就是取每一行的最大值,其实就是我们常见的取概率最大的那个index;第三个参数loss也是torch.autograd.Variable格式。

     _, preds =

    torch.max(outputs.data, 1)

    计算得到loss后就要回传损失。要注意的是这是在训练的时候才会有的操作,测试时候只有forward过程。

    loss.backward()

    回传损失过程中会计算梯度,然后需要根据这些梯度更新参数,optimizer.step()就是用来更新参数的。optimizer.step()后,你就可以从optimizer.param_groups[0][‘params’]里面看到各个层的梯度和权值信息。

    optimizer.step()

    这样一个batch数据的训练就结束了!不断重复这样的训练过程。

    相关文章

      网友评论

          本文标题:基础

          本文链接:https://www.haomeiwen.com/subject/lrttmqtx.html