美文网首页
2019-03-04深度学习——从头搭建一个简单的NN-clas

2019-03-04深度学习——从头搭建一个简单的NN-clas

作者: Hie_9e55 | 来源:发表于2019-03-05 16:23 被阅读0次

正在学习斯坦福的cs231n课程,该课程使用的是CIFAR-10数据集

该数据集可在管网下载
http://www.cs.toronto.edu/~kriz/cifar.html

下载并解压,得到


image.png

如何导入数据

CIFAR-10数据集由pickle产生,因此也由pickle导入

import pickle
    def load_file(filename):
        with open(filename, 'rb') as fo:
            data = pickle.load(fo, encoding='latin1')
        return data

    filename = 'D:/Download/cifar-10-batches-py/data_batch_1'
    data = load_file(filename)
    print(data.keys())//得到当前文件的一些基本信息

当前文件的一些基本信息
dict_keys(['batch_label', 'labels', 'data', 'filenames'])

NN分类的思想

NN分类并不需要训练,只需要将要判断的图和已有数据进行比较即可

比较时计算目标图与每一个数据图的范数一,范数一最小的数据图所属类别即为目标图类别

关于范数一与范数二 image.png

代码如下

import numpy as np
import pickle
filename = 'xxx'
filename_test = 'xxx'

class NearestNeighbor:
    """docstring for NearestNeighbor"""
    def __init__(self):
        pass

# 导入数据
    def load_file(self, filename):
        with open(filename, 'rb') as fo:
            data = pickle.load(fo, encoding='latin1')
        return data

# 训练模型,NN只是简单的导入即可,X是数据,n*3072,Y是数据标签,n*1
    def train(self, X, y):
        self.Xtr = X
        self.ytr = y

# 使用模型进行预测,X是test集的数据
    def predict(self, X):
        num_test = X.shape[0]# test数据个数
        Ypred = np.zeros(num_test)# 初始化预测结果
        
        for i in range(num_test):
            distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)# 计算范数一
            min_index = np.argmin(distances)# 寻范数一最小的数据
            Ypred[i] = self.ytr[min_index]# 得到预测结果

        return Ypred

net = NearestNeighbor()

data = net.load_file(filename)
test_batch = net.load_file(filename_test)

net.train(data['data'], data['labels'])
result = net.predict(test_batch['data'])

print(result)

相关文章

网友评论

      本文标题:2019-03-04深度学习——从头搭建一个简单的NN-clas

      本文链接:https://www.haomeiwen.com/subject/tziwuqtx.html