美文网首页
numpy与图片与tensor的互相转换

numpy与图片与tensor的互相转换

作者: 老枪打天下 | 来源:发表于2020-10-25 20:17 被阅读0次

    我想自己试试手写resnet18但是输入图片的尺寸必须是224* 224
    Mnist的尺寸是2828而且是单通道所以需要先变成三通道(用了tile)然后再resize成224224,还要transpose
    注意:transpose是针对numpy对象的!!!

    def resizeimg(rawimgs):
         #这里是整一个占位符
        res = torch.empty(20,3,224,224)
        for ix,img in enumerate(rawimgs):
    
            img = img.view(28,28,-1)
            img = img.numpy()
            img = np.tile(img,(3))
            img = cv2.resize(img,(224,224))
            img = torch.from_numpy(img.transpose(2,0,1))
            res[ix] = img
            # print(img.shape)
        return res
    

    全部代码如下:

    # @Time : 2020/10/25 8:02 
    
    # @Author : xx
    
    # @File : architectures.py 
    
    # @Software: PyCharm
    
    # @description=''
    import sys
    import math
    import itertools
    from  torch.optim import Adam
    import torch
    from torch.nn import functional as F
    from torch.nn import CrossEntropyLoss
    from torch.autograd import Variable,Function
    from torchvision.models import resnet18
    import torch.nn as nn
    from torch.utils.data import DataLoader,random_split
    from torchvision.datasets import MNIST,ImageFolder
    from torchvision.transforms  import Compose,ToTensor,RandomHorizontalFlip,Resize
    import numpy as np
    import cv2
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    def load_mnist(root = 'D:/big4data/project/code/day005/data'):
        whole_set = MNIST(root=root,download=False,train=True,transform=Compose([ToTensor()]))
        length = len(whole_set)
        rate = 0.02
        train_size, validate_size = int(rate * length), int((1-rate) * length)
        # first param is data set to be saperated, the second is list stating how many sets we want it to be.
        train_set, validate_set = torch.utils.data.random_split(whole_set, [train_size, validate_size])
    
        loader_train = DataLoader(dataset=train_set,shuffle=True,batch_size=20)
        return loader_train
    def make_layers(in_channels,out_channels,block_num,stride = 1):
        shortcut = nn.Sequential(nn.Conv2d(in_channels,out_channels,1,stride),
        nn.BatchNorm2d(out_channels))
        layers = list()
        layers.append(ResBlock(in_channels,out_channels,stride,shortcut))
        for i in range(1,block_num):
            layers.append(ResBlock(out_channels,out_channels))
        return nn.Sequential(*layers)
    class ResBlock(nn.Module):
        def __init__(self,in_channel,out_channel,stride = 1,shortcut = None):
            super(ResBlock,self).__init__()
            self.left = nn.Sequential(
                nn.Conv2d(in_channel,out_channel,3,stride,1,bias=False),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(True),
                nn.Conv2d(out_channel,out_channel,3,1,1,bias=False),
                nn.BatchNorm2d(out_channel)
            )
            self.right = shortcut
        def forward(self,x):
            out = self.left(x)
            residual = x if self.right is None else self.right(x)
            out+=residual
            return F.relu(out)
    def resizeimg(rawimgs):
        res = torch.empty(20,3,224,224)
        for ix,img in enumerate(rawimgs):
    
            img = img.view(28,28,-1)
            img = img.numpy()
            img = np.tile(img,(3))
            img = cv2.resize(img,(224,224))
            img = torch.from_numpy(img.transpose(2,0,1))
            res[ix] = img
            # print(img.shape)
        return res
    
    class Resnet(nn.Module):
        def __init__(self):
            super(Resnet, self).__init__()
            self.pre = nn.Sequential(
    
                nn.Conv2d(3,64,7,2,3,bias=False),
                nn.BatchNorm2d(64),
                nn.ReLU(True),
                nn.MaxPool2d(3,2,1))
            self.layer1 = make_layers(64,64,2)
            self.layer2 = make_layers(64,128,2,stride=2)
            self.layer3 = make_layers(128,256,2,stride=2)
            self.layer4 = make_layers(256,512,2,stride=2)
            self.avg = nn.AvgPool2d(7)
            self.classifier = nn.Sequential(nn.Linear(512,10))
        def forward(self,x):
            x = self.pre(x)
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            x = self.layer4(x)
            x = self.avg(x)
            x = x.view(x.size(0),-1)
            out = self.classifier(x)
            return out
    # def net_train()
    
    class NetTrainer:
        def __init__(self):
            self.CUDA = torch.cuda.is_available()
            self.epoch = 100
            self.lr = 5e-4
            self.interval = 5
            self.datatrain = load_mnist()
            self.loss_f = CrossEntropyLoss()
            self.net = Resnet()
            self.loss = 0.0
            if self.CUDA:
                self.net = self.net.cuda()
            self.optimizer = Adam(self.net.parameters(),lr = self.lr)
        def train_one(self):
            index = 0
    
            for x,y in self.datatrain:
                index+=1
                x = resizeimg(x)
                if self.CUDA:
                    x = x.cuda()
                    y = y.cuda()
    
                y_ = self.net(x)
                loss = self.loss_f(y_,y)
                self.loss+=loss
                loss.backward()
                self.optimizer.step()
                # print('index %d,loss:%8.6f'%(index,loss))
                self.optimizer.zero_grad()
        def train(self,epoch = 100,lr = 1e-3):
            self.epoch = epoch
            self.lr = lr
            for n in range(self.epoch):
                print('训练轮数%2d:'%(n))
                self.train_one()
                if (n+1)%self.interval==0:
                    print('loss:  ',self.loss/self.interval)
                    self.loss = 0
    
    
    
    
    
    
    
    if __name__=='__main__':
        # make_layers(64,64,2)
        net = Resnet().to(device)
        nettrainer = NetTrainer()
        nettrainer.train()
    

    如果没有数据集,就把download改成true!!!

    相关文章

      网友评论

          本文标题:numpy与图片与tensor的互相转换

          本文链接:https://www.haomeiwen.com/subject/cqgnmktx.html