美文网首页
(六)TensorFlow.js的Iris数据集示例

(六)TensorFlow.js的Iris数据集示例

作者: zqyadam | 来源:发表于2020-01-30 22:43 被阅读0次

这是一个使用Iris数据集进行神经网络分类训练的示例。

Iris数据集介绍

Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据样本,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。 ----百度百科

基本流程

基本流程

文件说明

该示例共包含三个文件:

  • index.html: 主页面
  • data.js: 用来生成数据
  • script.js: 主脚本

index.html

编写一个简陋的表格,用于在预测时进行数据输入

<!DOCTYPE html>
<html lang="en">
  <head>
    <meta charset="UTF-8" />
    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    <meta http-equiv="X-UA-Compatible" content="ie=edge" />
    <title>IRIS</title>
    <script src="script.js"></script>
  </head>
  <body>
    <form action="" onsubmit="predict(this);return false;">
      花萼长度:<input type="text" name="a" /><br />
      花萼宽度:<input type="text" name="b" /><br />
      花瓣长度:<input type="text" name="c" /><br />
      花瓣宽度:<input type="text" name="d" /><br />
      <button type="submit">预测</button>
    </form>
  </body>
</html>

data.js

代码太长了,摘取一部分:

import * as tf from "@tensorflow/tfjs";

export const IRIS_CLASSES = ["山鸢尾", "变色鸢尾", "维吉尼亚鸢尾"];
export const IRIS_NUM_CLASSES = IRIS_CLASSES.length;

// Iris flowers data. Source:
//   https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
const IRIS_DATA = [
  [5.1, 3.5, 1.4, 0.2, 0],
  [4.9, 3.0, 1.4, 0.2, 0],
  [4.7, 3.2, 1.3, 0.2, 0],
  [4.6, 3.1, 1.5, 0.2, 0],
  [5.0, 3.6, 1.4, 0.2, 0],
  [5.4, 3.9, 1.7, 0.4, 0],
  [4.6, 3.4, 1.4, 0.3, 0],
  ...
  [6.8, 3.2, 5.9, 2.3, 2],
  [6.7, 3.3, 5.7, 2.5, 2],
  [6.7, 3.0, 5.2, 2.3, 2],
  [6.3, 2.5, 5.0, 1.9, 2],
  [6.5, 3.0, 5.2, 2.0, 2],
  [6.2, 3.4, 5.4, 2.3, 2],
  [5.9, 3.0, 5.1, 1.8, 2]
];


function convertToTensors(data, targets, testSplit) {
  const numExamples = data.length;
  if (numExamples !== targets.length) {
    throw new Error("data and split have different numbers of examples");
  }

  // Randomly shuffle `data` and `targets`.
  const indices = [];
  for (let i = 0; i < numExamples; ++i) {
    indices.push(i);
  }
  tf.util.shuffle(indices);

  const shuffledData = [];
  const shuffledTargets = [];
  for (let i = 0; i < numExamples; ++i) {
    shuffledData.push(data[indices[i]]);
    shuffledTargets.push(targets[indices[i]]);
  }

  // Split the data into a training set and a tet set, based on `testSplit`.
  const numTestExamples = Math.round(numExamples * testSplit);
  const numTrainExamples = numExamples - numTestExamples;

  const xDims = shuffledData[0].length;

  // Create a 2D `tf.Tensor` to hold the feature data.
  const xs = tf.tensor2d(shuffledData, [numExamples, xDims]);

  // Create a 1D `tf.Tensor` to hold the labels, and convert the number label
  // from the set {0, 1, 2} into one-hot encoding (.e.g., 0 --> [1, 0, 0]).
  const ys = tf.oneHot(tf.tensor1d(shuffledTargets).toInt(), IRIS_NUM_CLASSES);

  // Split the data into training and test sets, using `slice`.
  const xTrain = xs.slice([0, 0], [numTrainExamples, xDims]);
  const xTest = xs.slice([numTrainExamples, 0], [numTestExamples, xDims]);
  const yTrain = ys.slice([0, 0], [numTrainExamples, IRIS_NUM_CLASSES]);
  const yTest = ys.slice([0, 0], [numTestExamples, IRIS_NUM_CLASSES]);
  return [xTrain, yTrain, xTest, yTest];
}


export function getIrisData(testSplit) {
  return tf.tidy(() => {
    const dataByClass = [];
    const targetsByClass = [];
    for (let i = 0; i < IRIS_CLASSES.length; ++i) {
      dataByClass.push([]);
      targetsByClass.push([]);
    }
    for (const example of IRIS_DATA) {
      const target = example[example.length - 1];
      const data = example.slice(0, example.length - 1);
      dataByClass[target].push(data);
      targetsByClass[target].push(target);
    }

    const xTrains = [];
    const yTrains = [];
    const xTests = [];
    const yTests = [];
    for (let i = 0; i < IRIS_CLASSES.length; ++i) {
      const [xTrain, yTrain, xTest, yTest] = convertToTensors(
        dataByClass[i],
        targetsByClass[i],
        testSplit
      );
      xTrains.push(xTrain);
      yTrains.push(yTrain);
      xTests.push(xTest);
      yTests.push(yTest);
    }

    const concatAxis = 0;
    return [
      tf.concat(xTrains, concatAxis),
      tf.concat(yTrains, concatAxis),
      tf.concat(xTests, concatAxis),
      tf.concat(yTests, concatAxis)
    ];
  });
}

script.js

在这个文件中,实现了上面流程图中的大部分内容。

获取数据

window.onload函数中,调用getIrisData函数,输入验证集所占比例,返回训练集的数据、标签,验证集的数据、标签,并且都是已经转换成Tensor格式的。

import * as tfvis from "@tensorflow/tfjs-vis";
import * as tf from "@tensorflow/tfjs";
import { getIrisData, IRIS_CLASSES } from "./data";

window.onload = async () => {
  const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15);
  xTrain.print();
};

这里打印了一下xTrain,看下格式是什么样的。


xTrain格式

搭建神经网络

搭建一个拥有2个全连接层的模型,使用的依然是sequential,这里的inputShape是4,与数据集的输入长度相同。

  let model = tf.sequential();
  model.add(
    tf.layers.dense({ units: 15, inputShape: [4], activation: "relu" })
  );
  model.add(tf.layers.dense({ units: 3, activation: "softmax" }));

编译模型

使用categoricalCrossentropy作为损失函数,adam作为优化器,衡量标准为准确度。

  model.compile({
    loss: "categoricalCrossentropy",
    optimizer: tf.train.adam(0.1),
    metrics: ["accuracy"]
  });

训练模型

validataionData里面输入的是验证集的数据和标签。使用tfjs-vis库对训练过程进行可视化,查看的内容包括训练集的lossaccuracy,验证集的lossaccuracy,在每个epoch结束后更新训练效果。

  await model.fit(xTrain, yTrain, {
    epochs: 100,
    validationData: [xTest, yTest],
    callbacks: tfvis.show.fitCallbacks(
      {
        name: "训练效果"
      },
      ["loss", "val_loss", "acc", "val_acc"],
      { callbacks: ["onEpochEnd"] }
    )
  });
训练效果

使用神经网络进行推断

编写一个predict函数,放在window.onload中,输入为表单,将输入的数据转化为Tensor后,输入到模型中,并得出预测的向量,通过寻找最大值的位置,找到对应的标签。

  window.predict = form => {
    const input = tf.tensor([
      [form.a.value * 1, form.b.value * 1, form.c.value * 1, form.d.value * 1]
    ]);

    const pred = model.predict(input);
    alert(
      `预测结果:${
        IRIS_CLASSES[pred.argMax(1).dataSync(0)]
      },概率:${pred.max().dataSync()}`
    );
  };
输入表单 预测结果

代码在这👇

链接:https://pan.baidu.com/s/1eHAwsZwQTTplYzC0X38lkw
提取码:8lvh

相关文章

网友评论

      本文标题:(六)TensorFlow.js的Iris数据集示例

      本文链接:https://www.haomeiwen.com/subject/bbhbthtx.html