美文网首页
第三课 关于AlexNet内存爆炸问题

第三课 关于AlexNet内存爆炸问题

作者: chenyihang | 来源:发表于2018-01-17 16:22 被阅读0次

    1.课程地址

    http://zh.gluon.ai/chapter_convolutional-neural-networks/alexnet-gluon.html

    2.解决

    原因:来自https://discuss.gluon.ai/t/topic/3792 xiaoming

    内存炸裂是因为’load_data_fashion_mnist‘函数的原因,这个函数会把fashionMNIST数据集的所有图片都先resize,然后存储到内存里面。 你这里resize = 224,然后它会把整个数据集的60000张图片一起resize,这时候数据集的数据就有60000 * 224 * 224 * 3。这个用float32存储需要30多g的内存。

    办法:来自https://discuss.gluon.ai/t/topic/1258/49 xiaoming

    删除了,然后把相应的功能放到class DataLoader里了。
    提醒一下:原来的transform是作为gluon.data.vision.FashionMNIST的参数的。而我将transform的操作放到class DataLoader内部,而在外部只是多加了一个resize的参数。
    其实我这样写少了很多功能,万一tranform的操作需要更改的话,就要去改class DataLoader的定义了。 所以如果想实现跟gluon.data.vision.FashionMNIST的参数transform一样多的功能的话,最好把整个transform函数作为class DataLoader的一个参数,然后可以在 yield里调用这个transform。
    如下修改:

    class DataLoader(object):
        """similiar to gluon.data.DataLoader, but might be faster.
    
        The main difference this data loader tries to read more exmaples each
        time. But the limits are 1) all examples in dataset have the same shape, 2)
        data transfomer needs to process multiple examples at each time
        """
        def __init__(self, dataset, batch_size, shuffle, transform):
            self.dataset = dataset
            self.batch_size = batch_size
            self.shuffle = shuffle
            self.transform = transform
    
        def __iter__(self):
            data = self.dataset[:]
            X = data[0]
            y = nd.array(data[1])
            n = X.shape[0]
            if self.shuffle:
                idx = np.arange(n)
                np.random.shuffle(idx)
                X = nd.array(X.asnumpy()[idx])
                y = nd.array(y.asnumpy()[idx])
    
            for i in range(n//self.batch_size):
                if self.transform is not None:
                    yield self.transform(X[i*self.batch_size:(i+1)*self.batch_size], 
                                         y[i*self.batch_size:(i+1)*self.batch_size])
                else:
                    yield (X[i*self.batch_size:(i+1)*self.batch_size],
                           y[i*self.batch_size:(i+1)*self.batch_size])
    
        def __len__(self):
            return len(self.dataset)//self.batch_size
    
    def load_data_fashion_mnist(batch_size, resize=None, root="~/.mxnet/datasets/fashion-mnist"):
        """download the fashion mnist dataest and then load into memory"""
        def transform_mnist(data, label):
            # transform a batch of examples
            if resize:
                n = data.shape[0]
                new_data = nd.zeros((n, resize, resize, data.shape[3]))
                for i in range(n):
                    new_data[i] = image.imresize(data[i], resize, resize)
                data = new_data
            # change data from batch x height x weight x channel to batch x channel x height x weight
            return nd.transpose(data.astype('float32'), (0,3,1,2))/255, label.astype('float32')
        
        mnist_train = gluon.data.vision.FashionMNIST(root=root, train=True, transform=None)
        mnist_test = gluon.data.vision.FashionMNIST(root=root, train=False, transform=None)
        train_data = DataLoader(mnist_train, batch_size, shuffle=True, transform = transform_mnist)
        test_data = DataLoader(mnist_test, batch_size, shuffle=False, transform = transform_mnist)
        return (train_data, test_data)
    

    参考地址:
    https://discuss.gluon.ai/t/topic/3792
    https://discuss.gluon.ai/t/topic/1258/45

    相关文章

      网友评论

          本文标题:第三课 关于AlexNet内存爆炸问题

          本文链接:https://www.haomeiwen.com/subject/xejuoxtx.html