浏览器安装
使用脚本标签(script tags)
将以下脚本标签添加到主HTML文件中:
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"></script>
示例
//定义一个线性回归模型。
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
// 为训练生成一些合成数据
const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);
// 使用数据训练模型
model.fit(xs, ys, {epochs: 10}).then(() => {
// 在该模型从未看到过的数据点上使用模型进行推理
model.predict(tf.tensor2d([5], [1, 1])).print();
// 打开浏览器开发工具查看输出
});
使用NPM安装
您可以使用 npm cli工具或是yarn安装TensorFlow.js。
从NPM安装可以使用Parcel, WebPack或是 Rollup这样的构建工具进行打包。
yarn add @tensorflow/tfjs
或
npm install @tensorflow/tfjs
示例
import * as tf from '@tensorflow/tfjs';
//定义一个线性回归模型。
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
// 为训练生成一些合成数据
const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);
// 使用数据训练模型
model.fit(xs, ys, {epochs: 10}).then(() => {
// 在该模型从未看到过的数据点上使用模型进行推理
model.predict(tf.tensor2d([5], [1, 1])).print();
// 打开浏览器开发工具查看输出
});
Node.js 安装
Node.js版本的tensorflow.js可以分为CPU版、GPU版、纯JavaScript版本,可以使用 npm cli工具或是yarn安装TensorFlow.js。
CPU版
yarn add @tensorflow/tfjs-node
或
npm install @tensorflow/tfjs-node
GPU版(仅限Linux)
需要系统具有支持CUDA的NVIDIA®GPU
yarn add @tensorflow/tfjs-node-gpu
或
npm install @tensorflow/tfjs-node-gpu
纯JavaScript版本
yarn add @tensorflow/tfjs
或
npm install @tensorflow/tfjs
实例
const tf = require('@tensorflow/tfjs');
// 可选加载绑定:
// 如果使用GPU运行,请使用'@tensorflow/tfjs-node-gpu'
require('@tensorflow/tfjs-node');
// 训练一个简单模型:
const model = tf.sequential();
model.add(tf.layers.dense({units: 100, activation: 'relu', inputShape: [10]}));
model.add(tf.layers.dense({units: 1, activation: 'linear'}));
model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
const xs = tf.randomNormal([100, 10]);
const ys = tf.randomNormal([100, 1]);
model.fit(xs, ys, {
epochs: 100,
callbacks: {
onEpochEnd: (epoch, log) => console.log(`Epoch ${epoch}: loss = ${log.loss}`)
}
});
网友评论