美文网首页
一些有用的代码

一些有用的代码

作者: 英文名字叫dawntown | 来源:发表于2020-02-14 21:26 被阅读0次

    以下代码来自公益课,深度学习中常用的模板,可以使用在自己平时的项目中,自己懒得写就搬运过来了

    mnist_train = torchvision.datasets.FashionMNIST(root='/home/kesci/input/FashionMNIST2065', train=True, download=True, transform=transforms.ToTensor())
    
    def get_fashion_mnist_labels(labels):
        text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                       'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
        return [text_labels[int(i)] for i in labels]
    
    def show_fashion_mnist(images, labels):
        d2l.use_svg_display()
        # 这里的_表示我们忽略(不使用)的变量
        _, figs = plt.subplots(1, len(images), figsize=(12, 12))
        for f, img, lbl in zip(figs, images, labels):
            f.imshow(img.view((28, 28)).numpy())
            f.set_title(lbl)
            f.axes.get_xaxis().set_visible(False)
            f.axes.get_yaxis().set_visible(False)
        plt.show()
    X, y = [], []
    for i in range(10):
        X.append(mnist_train[i][0]) # 将第i个feature加到X中
        y.append(mnist_train[i][1]) # 将第i个label加到y中
    show_fashion_mnist(X, get_fashion_mnist_labels(y))
    
    def softmax(X):
        X_exp = X.exp()
        partition = X_exp.sum(dim=1, keepdim=True)
        # print("X size is ", X_exp.size())
        # print("partition size is ", partition, partition.size())
        return X_exp / partition  # 这里应用了广播机制
    
    def net(X):
        return softmax(torch.mm(X.view((-1, num_inputs)), W) + b)
    
    def cross_entropy(y_hat, y):
        return - torch.log(y_hat.gather(1, y.view(-1, 1)))
    
    def accuracy(y_hat, y):
        return (y_hat.argmax(dim=1) == y).float().mean().item()
    
    def evaluate_accuracy(data_iter, net):
        acc_sum, n = 0.0, 0
        for X, y in data_iter:
            acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
            n += y.shape[0]
        return acc_sum / n
    
    num_epochs, lr = 5, 0.1
    
    def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,
                  params=None, lr=None, optimizer=None):
        for epoch in range(num_epochs):
            train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
            for X, y in train_iter:
                y_hat = net(X)
                l = loss(y_hat, y).sum()
                
                # 梯度清零
                if optimizer is not None:
                    optimizer.zero_grad()
                elif params is not None and params[0].grad is not None:
                    for param in params:
                        param.grad.data.zero_()
                
                l.backward()
                if optimizer is None:
                    d2l.sgd(params, lr, batch_size)
                else:
                    optimizer.step() 
                
                
                train_l_sum += l.item()
                train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
                n += y.shape[0]
            test_acc = evaluate_accuracy(test_iter, net)
            print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
                  % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))
    
    train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, batch_size, [W, b], lr)
    

    相关文章

      网友评论

          本文标题:一些有用的代码

          本文链接:https://www.haomeiwen.com/subject/dwgdfhtx.html