美文网首页
临时代码

临时代码

作者: BG大龍 | 来源:发表于2019-08-23 20:02 被阅读0次

nn_train代码块

import matplotlib
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from keras.models import Sequential
from keras.layers import Dropout
from keras.layers.core import Dense
from keras.optimizers import SGD
from keras import initializers
from keras import regularizers
from my_utils import utils_paths
import matplotlib.pyplot as plt
import numpy as np
import argparse
import random
import pickle
import cv2
import os


#--dataset、 --model、 --label-bin、 --plot
# 1 输入参数
ap = argparse.ArgumentParser()
ap.add_argument("-d", "--dataset", required=True,
    help="path to input dataset of images")
ap.add_argument("-m", "--model", required=True,
    help="path to output trained model")
ap.add_argument("-l", "--label-bin", required=True,
    help="path to output label binarizer")
ap.add_argument("-p", "--plot", required=True,
    help="path to output accuracy/loss plot")
args = vars(ap.parse_args())

print("[INFO] 开始读取数据")
data = []
labels = []


# 2 拿到图像数据路径,方便后续读取
imagePaths = sorted(list(utils_paths.list_images(args["dataset"])))
random.seed(42)   #随机种子
random.shuffle(imagePaths)    #洗牌操作


# 3 遍历读取数据
for imagePath in imagePaths:
    # 3.1 读取图像数据,由于使用神经网络,需要给定成一维
    image = cv2.imread(imagePath)   #读取图像
    image = cv2.resize(image, (32, 32)).flatten()   #重新定义大小,并且进行flatten扁平化
    data.append(image)   #在data的末尾,追加image数据

    # 3.2 读取标签
    label = imagePath.split(os.path.sep)[-2]
    labels.append(label)


# 4 scale图像数据
data = np.array(data, dtype="float") / 255.0
labels = np.array(labels)


# 5 数据集切分
(trainX, testX, trainY, testY) = train_test_split(data,
    labels, test_size=0.25, random_state=42)


# 6 转换标签,one-hot格式
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)


# 7 网络模型结构:3072-512-256-3
model = Sequential()
# kernel_regularizer=regularizers.l2(0.01)
# keras.initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None)
# initializers.random_normal
# #model.add(Dropout(0.8))
model.add(Dense(512, input_shape=(3072,), activation="relu" ,kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),kernel_regularizer=regularizers.l2(0.01)))
model.add(Dropout(0.5))
model.add(Dense(256, activation="relu",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),kernel_regularizer=regularizers.l2(0.01)))
model.add(Dropout(0.5))
model.add(Dense(len(lb.classes_), activation="softmax",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),kernel_regularizer=regularizers.l2(0.01)))


# 8 初始化超参数
INIT_LR = 0.001
EPOCHS = 2000


# 9 给定损失函数和评估方法
print("[INFO] 准备训练网络...")
opt = SGD(lr=INIT_LR)
model.compile(loss="categorical_crossentropy", optimizer=opt,
    metrics=["accuracy"])


# 10 训练网络模型
H = model.fit(trainX, trainY, validation_data=(testX, testY),
    epochs=EPOCHS, batch_size=32)


# 11 测试网络模型
print("[INFO] 正在评估模型")
predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1),
    predictions.argmax(axis=1), target_names=lb.classes_))


# 12 当训练完成时,绘制结果曲线
N = np.arange(0, EPOCHS)
plt.style.use("ggplot")
plt.figure()
#plt.plot(N[150:], H.history["loss"][150:], label="train_loss")
#plt.plot(N[150:], H.history["val_loss"][150:], label="val_loss")
plt.plot(N[150:], H.history["acc"][150:], label="train_acc")
plt.plot(N[150:], H.history["val_acc"][150:], label="val_acc")
plt.title("Training Loss and Accuracy (Simple NN)")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig(args["plot"])


# 13 保存模型到本地
print("[INFO] 正在保存模型")
model.save(args["model"])
f = open(args["label_bin"], "wb")
f.write(pickle.dumps(lb))
f.close()

相关文章

  • 临时代码

    nn_train代码块

  • Git临时保存代码

    在日常开发过程中,有时需要临时性修复bug或者配合同事提交其依赖的代码,这时你在本地开发的代码由于未开发完成就不能...

  • 备忘录-响应状态码

    http状态返回代码 1xx(临时响应) 表示临时响应并需要请求者继续执行操作的状态代码。 http状态返回代码 ...

  • Java 常用返回码整理

    http状态返回代码 1xx(临时响应) 表示临时响应并需要请求者继续执行操作的状态代码。 http状态返回代码 ...

  • 用Java发起HTTP请求返回状态码(含状态码列表)

    http状态返回代码 1xx(临时响应)表示临时响应并需要请求者继续执行操作的状态代码。 http状态返回代码 代...

  • 2019-12-26 http状态返回代码

    http状态返回代码 1xx(临时响应)表示临时响应并需要请求者继续执行操作的状态代码。 http状态返回代码 代...

  • http响应状态码大全

    http状态返回代码 1xx(临时响应)表示临时响应并需要请求者继续执行操作的状态代码。 http状态返回代码 代...

  • Http状态码汇总

    Http状态码汇总: 1xx(临时响应)表示临时响应并需要请求者继续执行操作的状态代码。 代码 说明100 ...

  • 常用状态码记录

    1xx(临时响应) 表示临时响应并需要请求者继续执行操作的状态代码。 代码 说明 100 (继续) 请求者应当...

  • http常见错误码总结

    1xx(临时响应) 表示临时响应并需要请求者继续执行操作的状态代码。 代码 说明 100 (继续) 请求者应当继续...

网友评论

      本文标题:临时代码

      本文链接:https://www.haomeiwen.com/subject/hecxectx.html