介绍
在上篇“图的构建”中,我们有讲过TF里面是如何将一个GraphDef转换为真正执行所用的Graph的。照此理解的话,我们使用TF 上层API构建出的深度学习模型最终在底层生成了GraphDef;然后再进一步由上篇讲过的两个函数(ConvertGraphDefToGraph和ImportGraphDef)之一生成Graph用于执行训练或推理。
但是实际上,TF内部的真正实现并非如此即"High level API 写的程序 -> 生成一张GraphDef表示的图 -> 转换为最终执行计算所用的Graph"。
它实际的转换步骤有些小Tricky。先是High level API写程序时直接构成Graph -> 转换为由GraphDef表示的图 -> 再变换为最终执行所需的Graph。 注意在此步骤序列当中,第三步最终生成出来的Graph与第一步直接由我们用API所描述的Graph是不相同的,它有经过ConvertGraphDefToGraph时,多加了一些功能部分像BackEdges,Source/Sink节点以及这两个特殊节点与原图之间的边,还有一些完全检查、约束等等。
本章当中,我们主要分析下TF里面如何在前端使用那些High level API在底层构建Graph,然后再转换为GraphDef的。它的主要实现可见于class GraphDefBuilder 当中。详细内容可见于tensorflow/core/graph/graph_constructor.h与tensorflow/core/graph/graph_constructor.cc。
Graph的初步构建
GraphDefBuilder类实现了大部分由前端API构建出底层的初步Graph这一工作。它在完成这一工作时具体又使用了像NodeBuilder这样的负责某一节点构建的类的功能。至于从初步生成的Graph到GraphDef这一步骤转换则主要由我们此系列文章中第一篇所介绍过的class Graph来完成。
- GraphDefBuilder
以下一段注释代码中,我们能看出GraphDefBuilder的使用。
// GraphDefBuilder b;
// using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
// Node* na = Const(7, b.opts());
// // Note: WithName() returns a copy, opts is unchanged.
// Node* nb = Const(5, b.opts().WithName("control-input"));
// Node* nc = Identity(na, b.opts().WithControlInput(nb));
// GraphDef graph_def;
// Status status = b.ToGraphDef(&graph_def);
// if (!status.ok()) { /* Handle error */ }
//
// In tests you can skip the status handling via:
// GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
// ...
// b.ToGraphDef(&graph_def);
以下则为class GraphDefBuilder的具体定义。我们会发现它有一个主要的内部class定义,Options。这个选项类提供了大部分的构建Node的功能。几乎GraphDefBuilder内部的每个成员函数都会使用一个Options的引用参数来负责具体实现成员函数的功能。
class GraphDefBuilder {
public:
// Options for adding a Node to a Graph.
class Options {
public:
// Sets the Graph (that Nodes will be added to) and the status. The
// status may be set to nullptr, in which case errors cause CHECK
// failures. The graph and status must outlive *this.
Options(Graph* graph, Status* status);
~Options();
// Methods for setting options. These are const methods: they
// return a copy of *this with the option set.
Options WithName(StringPiece name) const;
Options WithDevice(StringPiece device) const;
Options WithControlInput(Node* control_input) const;
Options WithControlInputs(gtl::ArraySlice<Node*> control_inputs) const;
// Given the Op type name, return a name for a node of that type.
// Uses the value set in WithName() if that has been called. Otherwise,
// returns a name built out of the Op type name.
string GetNameForOp(StringPiece op) const;
// Sets the device, adds control inputs, adds attrs, and calls Finalize().
// If Finalize returns an error, it is saved and this function returns
// nullptr.
Node* FinalizeBuilder(NodeBuilder* builder) const;
// Updates the associated status, if any, or calls TF_CHECK_OK if none.
void UpdateStatus(const Status& status) const;
// Accessor
const OpRegistryInterface* op_registry() const {
return graph_->op_registry();
}
private:
Options WithNameImpl(StringPiece name);
Options WithDeviceImpl(StringPiece device);
Options WithControlInputImpl(Node* control_input);
Options WithControlInputsImpl(gtl::ArraySlice<Node*> control_inputs);
template <class T>
Options WithAttrImpl(StringPiece name, T&& value) {
attrs_.emplace_back(std::string(name), AttrValue());
SetAttrValue(std::forward<T>(value), &attrs_.back().second);
return *this;
}
Graph* const graph_;
Status* const status_;
string name_;
string device_;
std::vector<Node*> control_inputs_;
std::vector<std::pair<string, AttrValue>> attrs_;
};
};
我们下面分开讲述GraphDefBuilder内部的主要成员变量及函数。可以看得到几乎每个函数都有一个为Options的参数。
// Start building a new graph.
explicit GraphDefBuilder(
const OpRegistryInterface* op_registry = OpRegistry::Global())
: graph_(op_registry), opts_(&graph_, &status_) {}
// Gets the Options with the associated Graph and Status.
const Options& opts() const { return opts_; }
// Once all the nodes have been added, call this to get whether it was
// successful, and if so fill *graph_def.
Status ToGraphDef(GraphDef* graph_def) const;
// Adds the function and gradient definitions in `fdef_lib` to this graph's op
// registry. Ignores duplicate functions, and returns a bad status if an
// imported function differs from an existing function or op with the same
// name.
Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
return graph_.AddFunctionLibrary(fdef_lib);
}
// Returns whether a user-defined function with `name` already exists in the
// graph.
bool HasFunction(const string& name) {
return graph_.flib_def().Find(name) != nullptr;
}
private:
Graph graph_;
Status status_;
Options opts_;
};
以下为Options中FinalizeBuilder的实现,我们会发现基本Graph之上的每个Node都是由NodeBuilder来构建的(它会参考使用Graph_提供给它的OpLibraries)。最终可以使用NodeBuilder的Finalize函数来将生成的接点插入到Graph当中。
Node* GraphDefBuilder::Options::FinalizeBuilder(NodeBuilder* builder) const {
builder->ControlInputs(control_inputs_);
if (!device_.empty()) builder->Device(device_);
for (const auto& attr : attrs_) {
builder->Attr(attr.first, attr.second);
}
Node* returned_node;
UpdateStatus(builder->Finalize(graph_, &returned_node));
return returned_node;
}
- GraphDefBuilderToGraph
由以上GraphDefBuilder的定义,我们知道它里面已经包含了一个Graph成员,而且在具体由NodeBuilder一点点构建Graph时正是在这个成员变量上构建的。那么我们实现GraphDefBuilderToGraph似乎不成问题,简单说就是直接返回Graph成员变量就好了。
可实际上,当下TF中会先将此Graph成员变量转变为一个GraphDef,然后再进一步调用ConvertGraphDefToGraph将它塑造为完整可用的最终Graph。此一转换步骤,我们在上一篇blog中有过提及,它会涉及到许多功能的完善、添加像Source/Sink nodes与图的连接,完整、安全性检查及控制节点的加入等等。
// Converts the `GraphDef` being built by `builder` to a `Graph` and
// stores it in `*graph`.
// TODO(josh11b): Make this faster; right now it converts
// Graph->GraphDef->Graph. This cleans up the graph (e.g. adds
// edges from the source and to the sink node, resolves back edges
// by name), and makes sure the resulting graph is valid.
Status GraphDefBuilderToGraph(const GraphDefBuilder& builder, Graph* graph);
网友评论