美文网首页
(五)TensorFlow.js的XOR示例

(五)TensorFlow.js的XOR示例

作者: zqyadam | 来源:发表于2020-01-30 00:05 被阅读0次

这是一个使用TensorFlow.js进行异或(XOR)分类的示例,使用Parcel作为打包工具.

这个例子主要需要以下三个文件:

  • index.html:主页面
  • script.js:主脚本
  • data.js:用来生成数据

index.html文件

在这个文件中,主要是引入script.js文件及编写一个简单的界面,代码如下:

<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  <meta http-equiv="X-UA-Compatible" content="ie=edge">
  <title>XOR</title>
  <script src="script.js"></script>
</head>
<body>
  <form action="" onsubmit="predict(this);return false;">
    x: <input type="text" name="x">
    y: <input type="text" name="y">
    <button type="submit">预测</button>
  </form>
</body>
</html>

data.js文件

这个文件是用来准备异或数据,生成训练及测试样本

export function getData(numSamples) {
  let points = [];

  function genGauss(cx, cy, label) {
    for (let i = 0; i < numSamples / 2; i++) {
      let x = normalRandom(cx);
      let y = normalRandom(cy);
      points.push({ x, y, label });
    }
  }

  genGauss(2, 2, 0);
  genGauss(-2, -2, 0);
  genGauss(-2, 2, 1);
  genGauss(2, -2, 1);
  return points;
}

/**
 * Samples from a normal distribution. Uses the seedrandom library as the
 * random generator.
 *
 * @param mean The mean. Default is 0.
 * @param variance The variance. Default is 1.
 */
function normalRandom(mean = 0, variance = 1) {
  let v1, v2, s;
  do {
    v1 = 2 * Math.random() - 1;
    v2 = 2 * Math.random() - 1;
    s = v1 * v1 + v2 * v2;
  } while (s > 1);

  let result = Math.sqrt((-2 * Math.log(s)) / s) * v1;
  return mean + Math.sqrt(variance) * result;
}

数据格式如下图:


前10个训练数据

返回的是一个数组,里面每一项是一个对象,其中包含了x、y、label三个键

script.js文件

该文件为这个示例的核心文件,包含了获取数据、搭建模型、训练模型及使用模型的几个步骤,下面来分步说明一下:

获取并查看训练数据

首先需要在文件中加载data.js文件,这样才可以使用getData()方法获取数据,其参数为获取训练数据的数量,这里我们获取了400个训练数据。

import { getData } from "./data";

const data = getData(400);

同时,在这里使用了tfjs-vis这个工具,用来帮助我们对数据进行可视化,不过可视化的过程需要在window.onload函数中进行。

import * as tfvis from "@tensorflow/tfjs-vis";
import * as tf from "@tensorflow/tfjs";
import { getData } from "./data";

window.onload = async () => {
  const data = getData(400);

  tfvis.render.scatterplot(
    { name: "训练数据" },
    {
      values: [data.filter(p => p.label === 1), data.filter(p => p.label === 0)]
    }
  );
};
可视化训练数据

将数据转化成Tensor

通过下面两行代码将训练数据转化为Tensor

  const inputs = tf.tensor(data.map(p => [p.x, p.y]));
  const labels = tf.tensor(data.map(p => p.label));

搭建模型

搭建模型的使用的是Layers API,这样可以快速的搭建一个2层全连接层的模型,代码如下:

  // 创建一个模型
  const model = tf.sequential();   
  // 添加第一个全连接层,有10个神经元,激活函数为relu,由于这里是第一层,所以必须制定inputShape
  model.add(tf.layers.dense({ inputShape: [2], units: 10, activation: "relu" })); 
  // 添加第二个全连接层,也是输出层,有1个神经元,使用sigmoid函数作为激活函数,用于二分类
  model.add(tf.layers.dense({ units: 1, activation: "sigmoid" }));
  // 编译模型,需要给出损失函数和优化器,这里使用的是logloss损失函数,使用adam优化器
  model.compile({ loss: tf.losses.logLoss, optimizer: tf.train.adam(0.1) });

这里的4行代码首先创建了一个对象model,是一个sequential类型的Model,然后向这个对象里添加了2个全连接层。

训练模型

训练模型使用model中的fit方法,还有一种fitDataset方法,这里不说。

fit方法接收训练数据相应的标签一个对象参数,具体请看这里

  await model.fit(inputs, labels, {
    epochs: 10,
    callbacks: tfvis.show.fitCallbacks({ name: "训练效果" }, ["loss"],{callbacks: ['onEpochEnd']})
  });

这里设置epochs为10,并使用tfjs-vis工具来对训练过程进行可视化,查看了loss的训练过程,每个epoch结束时更新显示。

可视化训练过程

loss虽然不是很低,但是已经足够了~

使用模型进行推断

模型训练好了,下一步就是使用了。之前在index.html文件中编写了一个简单的界面,如下:

简单的界面

点击预测按钮时,将调用predict方法,并将表单传递进去。
下面编写predict方法

  // 这个方法放在window.onload方法里面
  window.predict = async form => {
    const pred = await model.predict(
      tf.tensor([[form.x.value * 1, form.y.value * 1]])
    );
    alert(`预测结果:${pred.dataSync()}`);
  };

看看效果吧~

代码写完了,该训练的也训练完了,来看看效果吧
先来看看整体的界面~

整体界面

输入个数据看看~

输入个数据 大概在这里 预测结果

看样子,预测结果还是挺准的么~~

相关文章

网友评论

      本文标题:(五)TensorFlow.js的XOR示例

      本文链接:https://www.haomeiwen.com/subject/gmhbthtx.html