1.5 神经网络入门-神经元实现

作者: 9c0ddf06559c | 来源:发表于2018-09-24 18:23 被阅读5次

    1.5 神经元实现

    • 分拆数据集

      def load_data(filename):
          """read data from data file."""
          with open(filename, 'rb') as f:
              data = pickle.load(f, encoding='bytes')
              return data[b'data'], data[b'labels']
      
      # trensorflow.DataSet
      class CifarData:
          def __init__(self, filenames, need_shuffle):
              all_data = []
              all_labels = []
              for filename in filenames:
                  data,labels = load_data(filename)
                  for item,label in zip(data,labels):
                      if label in [0,1]:
                          all_data.append(item)
                          all_labels.append(label)
              self._data = np.vstack(all_data)
              # 归一化,将0-255的数归一成0-1直接的数
              self._data = self._data / 127.5 - 1 
              self._labels = np.hstack(all_labels)
              self._num_examples = self._data.shape[0]
              self._need_shuffle = need_shuffle
              self._indicator = 0
              if self._need_shuffle:
                  self._shuffle_data()
              
          def _shuffle_data(self):
              # 混排 [0,1,2,3,4,5] -> [2,1,4,0,3,5]
              p = np.random.permutation(self._num_examples)
              self._data = self._data[p]
              self._labels = self._labels[p]
          
          def next_batch(self, batch_size):
              """return batch_size examples as a batch."""
              end_indicator = self._indicator + batch_size
              if end_indicator > self._num_examples:
                  if self._need_shuffle:
                      self._shuffle_data()
                      self._indicator = 0
                      end_indicator = batch_size
                  else:
                      raise Exception("have no more examples")
              if end_indicator > self._num_examples:
                  raise Exception("batch size is lager then all examples")
              batch_data = self._data[self._indicator:end_indicator]
              batch_labels = self._labels[self._indicator:end_indicator]
              self._indicator = end_indicator
              return batch_data, batch_labels
              
      train_filename = [os.path.join(CIFAR_DIR,'data_batch_%d' % i) for i in range(1,6)]
      test_filenames = [os.path.join(CIFAR_DIR, 'test_batch')]
      
      train_data = CifarData(train_filename, True)
      test_data = CifarData(test_filenames, False)
      
      batch_data,batch_labels = train_data.next_batch(10)
      
    • 测试算法准确率

      init = tf.global_variables_initializer()
      batch_size = 20
      train_steps = 100000
      test_steps = 100
      
      with tf.Session() as sess:
          sess.run(init)
          for i in range(train_steps):
              batch_data, batch_labels = train_data.next_batch(batch_size)
              loss_val, acc_val, _ = sess.run(
                  [loss, accuracy, train_op],
                  feed_dict={
                      x: batch_data,
                      y: batch_labels})
              if (i+1) % 500 == 0:
                  print ('[Train] Step: %d, loss: %4.5f, acc: %4.5f' \
                      % (i+1, loss_val, acc_val))
                      
              if (i+1) % 5000 == 0:
                  test_data = CifarData(test_filenames, False)
                  all_test_acc_val = []
                  for j in range(test_steps):
                      test_batch_data, test_batch_labels \
                          = test_data.next_batch(batch_size)
                      test_acc_val = sess.run(
                          [accuracy],
                          feed_dict = {
                              x: test_batch_data, 
                              y: test_batch_labels
                          })
                      all_test_acc_val.append(test_acc_val)
                  test_acc = np.mean(all_test_acc_val)
                  print('[Test ] Step: %d, acc: %4.5f' % (i+1, test_acc))
      

    相关文章

      网友评论

        本文标题:1.5 神经网络入门-神经元实现

        本文链接:https://www.haomeiwen.com/subject/tslhoftx.html