这是一个使用TensorFlow.js进行异或(XOR)分类的示例,使用Parcel作为打包工具.
这个例子主要需要以下三个文件:
- index.html:主页面
- script.js:主脚本
- data.js:用来生成数据
index.html文件
在这个文件中,主要是引入script.js
文件及编写一个简单的界面,代码如下:
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta http-equiv="X-UA-Compatible" content="ie=edge">
<title>XOR</title>
<script src="script.js"></script>
</head>
<body>
<form action="" onsubmit="predict(this);return false;">
x: <input type="text" name="x">
y: <input type="text" name="y">
<button type="submit">预测</button>
</form>
</body>
</html>
data.js文件
这个文件是用来准备异或数据,生成训练及测试样本
export function getData(numSamples) {
let points = [];
function genGauss(cx, cy, label) {
for (let i = 0; i < numSamples / 2; i++) {
let x = normalRandom(cx);
let y = normalRandom(cy);
points.push({ x, y, label });
}
}
genGauss(2, 2, 0);
genGauss(-2, -2, 0);
genGauss(-2, 2, 1);
genGauss(2, -2, 1);
return points;
}
/**
* Samples from a normal distribution. Uses the seedrandom library as the
* random generator.
*
* @param mean The mean. Default is 0.
* @param variance The variance. Default is 1.
*/
function normalRandom(mean = 0, variance = 1) {
let v1, v2, s;
do {
v1 = 2 * Math.random() - 1;
v2 = 2 * Math.random() - 1;
s = v1 * v1 + v2 * v2;
} while (s > 1);
let result = Math.sqrt((-2 * Math.log(s)) / s) * v1;
return mean + Math.sqrt(variance) * result;
}
数据格式如下图:
前10个训练数据
返回的是一个数组,里面每一项是一个对象,其中包含了x、y、label三个键
script.js文件
该文件为这个示例的核心文件,包含了获取数据、搭建模型、训练模型及使用模型的几个步骤,下面来分步说明一下:
获取并查看训练数据
首先需要在文件中加载data.js文件,这样才可以使用getData()方法获取数据,其参数为获取训练数据的数量,这里我们获取了400个训练数据。
import { getData } from "./data";
const data = getData(400);
同时,在这里使用了tfjs-vis这个工具,用来帮助我们对数据进行可视化,不过可视化的过程需要在window.onload函数中进行。
import * as tfvis from "@tensorflow/tfjs-vis";
import * as tf from "@tensorflow/tfjs";
import { getData } from "./data";
window.onload = async () => {
const data = getData(400);
tfvis.render.scatterplot(
{ name: "训练数据" },
{
values: [data.filter(p => p.label === 1), data.filter(p => p.label === 0)]
}
);
};
可视化训练数据
将数据转化成Tensor
通过下面两行代码将训练数据转化为Tensor
const inputs = tf.tensor(data.map(p => [p.x, p.y]));
const labels = tf.tensor(data.map(p => p.label));
搭建模型
搭建模型的使用的是Layers API,这样可以快速的搭建一个2层全连接层的模型,代码如下:
// 创建一个模型
const model = tf.sequential();
// 添加第一个全连接层,有10个神经元,激活函数为relu,由于这里是第一层,所以必须制定inputShape
model.add(tf.layers.dense({ inputShape: [2], units: 10, activation: "relu" }));
// 添加第二个全连接层,也是输出层,有1个神经元,使用sigmoid函数作为激活函数,用于二分类
model.add(tf.layers.dense({ units: 1, activation: "sigmoid" }));
// 编译模型,需要给出损失函数和优化器,这里使用的是logloss损失函数,使用adam优化器
model.compile({ loss: tf.losses.logLoss, optimizer: tf.train.adam(0.1) });
这里的4行代码首先创建了一个对象model,是一个sequential类型的Model,然后向这个对象里添加了2个全连接层。
训练模型
训练模型使用model
中的fit
方法,还有一种fitDataset
方法,这里不说。
fit
方法接收训练数据、相应的标签、一个对象参数,具体请看这里。
await model.fit(inputs, labels, {
epochs: 10,
callbacks: tfvis.show.fitCallbacks({ name: "训练效果" }, ["loss"],{callbacks: ['onEpochEnd']})
});
这里设置epochs
为10,并使用tfjs-vis
工具来对训练过程进行可视化,查看了loss
的训练过程,每个epoch
结束时更新显示。
loss虽然不是很低,但是已经足够了~
使用模型进行推断
模型训练好了,下一步就是使用了。之前在index.html
文件中编写了一个简单的界面,如下:
点击预测按钮时,将调用predict
方法,并将表单传递进去。
下面编写predict
方法
// 这个方法放在window.onload方法里面
window.predict = async form => {
const pred = await model.predict(
tf.tensor([[form.x.value * 1, form.y.value * 1]])
);
alert(`预测结果:${pred.dataSync()}`);
};
看看效果吧~
代码写完了,该训练的也训练完了,来看看效果吧
先来看看整体的界面~
输入个数据看看~
看样子,预测结果还是挺准的么~~
网友评论