TensorFlow是深度学习中使用人数最多的框架,本文快速尝试一下其能力,方便入门
添加依赖
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.13.1</version>
</dependency>
定义图模型
示例完成一个简单的函数:
f(x, y) = z = a*x + b*y
其中a, b是常量,x, y是变量
- 定义Graph
Graph graph = new Graph()
- 定义常量
Operation a = graph.opBuilder("Const", "a")
.setAttr("dtype", DataType.fromClass(Double.class))
.setAttr("value", Tensor.<Double>create(3.0, Double.class))
.build();
Operation b = graph.opBuilder("Const", "b")
.setAttr("dtype", DataType.fromClass(Double.class))
.setAttr("value", Tensor.<Double>create(2.0, Double.class))
.build()
- 定义变量
Operation x = graph.opBuilder("Placeholder", "x")
.setAttr("dtype", DataType.fromClass(Double.class))
.build();
Operation y = graph.opBuilder("Placeholder", "y")
.setAttr("dtype", DataType.fromClass(Double.class))
.build();
- 定义函数
Operation ax = graph.opBuilder("Mul", "ax")
.addInput(a.output(0))
.addInput(x.output(0))
.build();
Operation by = graph.opBuilder("Mul", "by")
.addInput(b.output(0))
.addInput(y.output(0))
.build();
Operation z = graph.opBuilder("Add", "z")
.addInput(ax.output(0))
.addInput(by.output(0))
.build();
可以看出来,用Java定义图模型比较麻烦,但是使用Python会简单很多
执行
Session session = new Session(graph);
Tensor<Double> tensor = session.runner().fetch("z")
.feed("x", Tensor.create(3.0, Double.class))
.feed("y", Tensor.create(6.0, Double.class))
.run().get(0).expect(Double.class);
System.out.println(tensor.doubleValue());
图模型保存及加载
- 保存模型
Path path = Paths.get("tensor.model");
byte[] bytes = graph.toGraphDef();
Files.write(path, bytes);
- 加载模型
Graph graph = new Graph();
byte[] bytes = Files.readAllBytes(path);
graph.importGraphDef(bytes);
ps: 模型可以在不同语言通用,所以可以使用python训练模型,然后提供给其他语言使用,比如Java
结果
最后输出结果:21.0
网友评论