这是一个使用Iris数据集进行神经网络分类训练的示例。
Iris数据集介绍
Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据样本,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。 ----百度百科
基本流程
基本流程文件说明
该示例共包含三个文件:
- index.html: 主页面
- data.js: 用来生成数据
- script.js: 主脚本
index.html
编写一个简陋的表格,用于在预测时进行数据输入
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<meta http-equiv="X-UA-Compatible" content="ie=edge" />
<title>IRIS</title>
<script src="script.js"></script>
</head>
<body>
<form action="" onsubmit="predict(this);return false;">
花萼长度:<input type="text" name="a" /><br />
花萼宽度:<input type="text" name="b" /><br />
花瓣长度:<input type="text" name="c" /><br />
花瓣宽度:<input type="text" name="d" /><br />
<button type="submit">预测</button>
</form>
</body>
</html>
data.js
代码太长了,摘取一部分:
import * as tf from "@tensorflow/tfjs";
export const IRIS_CLASSES = ["山鸢尾", "变色鸢尾", "维吉尼亚鸢尾"];
export const IRIS_NUM_CLASSES = IRIS_CLASSES.length;
// Iris flowers data. Source:
// https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
const IRIS_DATA = [
[5.1, 3.5, 1.4, 0.2, 0],
[4.9, 3.0, 1.4, 0.2, 0],
[4.7, 3.2, 1.3, 0.2, 0],
[4.6, 3.1, 1.5, 0.2, 0],
[5.0, 3.6, 1.4, 0.2, 0],
[5.4, 3.9, 1.7, 0.4, 0],
[4.6, 3.4, 1.4, 0.3, 0],
...
[6.8, 3.2, 5.9, 2.3, 2],
[6.7, 3.3, 5.7, 2.5, 2],
[6.7, 3.0, 5.2, 2.3, 2],
[6.3, 2.5, 5.0, 1.9, 2],
[6.5, 3.0, 5.2, 2.0, 2],
[6.2, 3.4, 5.4, 2.3, 2],
[5.9, 3.0, 5.1, 1.8, 2]
];
function convertToTensors(data, targets, testSplit) {
const numExamples = data.length;
if (numExamples !== targets.length) {
throw new Error("data and split have different numbers of examples");
}
// Randomly shuffle `data` and `targets`.
const indices = [];
for (let i = 0; i < numExamples; ++i) {
indices.push(i);
}
tf.util.shuffle(indices);
const shuffledData = [];
const shuffledTargets = [];
for (let i = 0; i < numExamples; ++i) {
shuffledData.push(data[indices[i]]);
shuffledTargets.push(targets[indices[i]]);
}
// Split the data into a training set and a tet set, based on `testSplit`.
const numTestExamples = Math.round(numExamples * testSplit);
const numTrainExamples = numExamples - numTestExamples;
const xDims = shuffledData[0].length;
// Create a 2D `tf.Tensor` to hold the feature data.
const xs = tf.tensor2d(shuffledData, [numExamples, xDims]);
// Create a 1D `tf.Tensor` to hold the labels, and convert the number label
// from the set {0, 1, 2} into one-hot encoding (.e.g., 0 --> [1, 0, 0]).
const ys = tf.oneHot(tf.tensor1d(shuffledTargets).toInt(), IRIS_NUM_CLASSES);
// Split the data into training and test sets, using `slice`.
const xTrain = xs.slice([0, 0], [numTrainExamples, xDims]);
const xTest = xs.slice([numTrainExamples, 0], [numTestExamples, xDims]);
const yTrain = ys.slice([0, 0], [numTrainExamples, IRIS_NUM_CLASSES]);
const yTest = ys.slice([0, 0], [numTestExamples, IRIS_NUM_CLASSES]);
return [xTrain, yTrain, xTest, yTest];
}
export function getIrisData(testSplit) {
return tf.tidy(() => {
const dataByClass = [];
const targetsByClass = [];
for (let i = 0; i < IRIS_CLASSES.length; ++i) {
dataByClass.push([]);
targetsByClass.push([]);
}
for (const example of IRIS_DATA) {
const target = example[example.length - 1];
const data = example.slice(0, example.length - 1);
dataByClass[target].push(data);
targetsByClass[target].push(target);
}
const xTrains = [];
const yTrains = [];
const xTests = [];
const yTests = [];
for (let i = 0; i < IRIS_CLASSES.length; ++i) {
const [xTrain, yTrain, xTest, yTest] = convertToTensors(
dataByClass[i],
targetsByClass[i],
testSplit
);
xTrains.push(xTrain);
yTrains.push(yTrain);
xTests.push(xTest);
yTests.push(yTest);
}
const concatAxis = 0;
return [
tf.concat(xTrains, concatAxis),
tf.concat(yTrains, concatAxis),
tf.concat(xTests, concatAxis),
tf.concat(yTests, concatAxis)
];
});
}
script.js
在这个文件中,实现了上面流程图中的大部分内容。
获取数据
在window.onload
函数中,调用getIrisData
函数,输入验证集所占比例,返回训练集的数据、标签,验证集的数据、标签,并且都是已经转换成Tensor格式的。
import * as tfvis from "@tensorflow/tfjs-vis";
import * as tf from "@tensorflow/tfjs";
import { getIrisData, IRIS_CLASSES } from "./data";
window.onload = async () => {
const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15);
xTrain.print();
};
这里打印了一下xTrain,看下格式是什么样的。
xTrain格式
搭建神经网络
搭建一个拥有2个全连接层的模型,使用的依然是sequential
,这里的inputShape
是4,与数据集的输入长度相同。
let model = tf.sequential();
model.add(
tf.layers.dense({ units: 15, inputShape: [4], activation: "relu" })
);
model.add(tf.layers.dense({ units: 3, activation: "softmax" }));
编译模型
使用categoricalCrossentropy
作为损失函数,adam作为优化器,衡量标准为准确度。
model.compile({
loss: "categoricalCrossentropy",
optimizer: tf.train.adam(0.1),
metrics: ["accuracy"]
});
训练模型
validataionData
里面输入的是验证集的数据和标签。使用tfjs-vis
库对训练过程进行可视化,查看的内容包括训练集的loss
、accuracy
,验证集的loss
、accuracy
,在每个epoch
结束后更新训练效果。
await model.fit(xTrain, yTrain, {
epochs: 100,
validationData: [xTest, yTest],
callbacks: tfvis.show.fitCallbacks(
{
name: "训练效果"
},
["loss", "val_loss", "acc", "val_acc"],
{ callbacks: ["onEpochEnd"] }
)
});
训练效果
使用神经网络进行推断
编写一个predict
函数,放在window.onload
中,输入为表单,将输入的数据转化为Tensor
后,输入到模型中,并得出预测的向量,通过寻找最大值的位置,找到对应的标签。
window.predict = form => {
const input = tf.tensor([
[form.a.value * 1, form.b.value * 1, form.c.value * 1, form.d.value * 1]
]);
const pred = model.predict(input);
alert(
`预测结果:${
IRIS_CLASSES[pred.argMax(1).dataSync(0)]
},概率:${pred.max().dataSync()}`
);
};
输入表单
预测结果
代码在这👇
链接:https://pan.baidu.com/s/1eHAwsZwQTTplYzC0X38lkw
提取码:8lvh
网友评论