介绍
本章中我们继续讲解图构建当中非常重要的一个部分即具体节点的构建。我们明白真正的计算图是由一个个节点组成的,为了构建出最终的图来,我们首先要构建出一个个的节点,然后再将这些节点组织,拼装,行成最终的计算图。
节点的构建主要通过class NodeBuilder来完成。这个class我们在前几篇中也做过介绍,本质上它是由更底层的NodeDefBuilder来协助完成的Node节点的构建。
详细代码可见于:tensorflow/core/graph/node_builder.h。
NodeBuilder
- NodeBuilder的使用
下面为NodeBuilder class在用户端API层面的使用。同样的代码结构我们在上一篇中也有见过。
// This is a helper for creating a Node and adding it to a Graph.
// Internally, it uses a NodeDefBuilder to automatically set attrs
// that can be inferred from the inputs, and use default values
// (where they exist) for unspecified attrs. Example usage:
//
// Node* node;
// Status status = NodeBuilder(node_name, op_name)
// .Input(...)
// .Attr(...)
// .Finalize(&graph, &node);
// if (!status.ok()) return status;
// // Use node here.
- NodeOut
在class NodeBuilder里面有一个内含的struct NodeOut。它主要用于描述Node的输出Tensor。其代码如下,并不复杂。
class NodeBuilder {
public:
// For specifying the output of a Node to provide to one of the Input()
// functions below. It supports both regular inputs (where you are
// connecting to an existing Node*), and inputs from outside the graph
// (or haven't been added to the graph yet, like back edges, where
// you don't have a Node*). Both types can be mixed, e.g. in an
// ArraySlice.
struct NodeOut {
// For referencing an existing Node.
NodeOut(Node* n, int32 i = 0);
// For referencing Nodes not in the graph being built. It is
// useful when preparing a graph for ExtendSession or creating a
// back edge to a node that hasn't been added to the graph yet,
// but will be.
NodeOut(StringPiece name, int32 i, DataType t);
// Default constructor for std::vector<NodeOut>.
NodeOut();
Node* node;
// error is set to true if:
// * the NodeOut was default constructed and never overwritten,
// * a nullptr Node* was passed to the NodeOut constructor, or
// * an out-of-range index was passed to the NodeOut constructor.
bool error;
string name;
int32 index;
DataType dt;
};
};
- NodeBuilder的功能详解
NodeBuilder里面有个叫NodeDefBuilder的成员。基本上NodeBuilder的所有有用的功能成员函数都是外包它的成员的成员函数来完成。具体的功能函数有Input, Attr, Finalize等等。其实现多为通过调用NodeDefBuilder与Graph的成员函数来完成。。等将来我们过Lib中的核心class与Function时再具体分析吧。
class NodeBuilder {
public:
// Specify the name and the Op (either via an OpDef or the name of
// the Op plus a registry) for the Node. Other fields are
// specified by calling the methods below.
// REQUIRES: The OpDef must satisfy ValidateOpDef().
NodeBuilder(StringPiece name, StringPiece op_name,
const OpRegistryInterface* op_registry = OpRegistry::Global());
NodeBuilder(StringPiece name, const OpDef* op_def);
//由以下注释可见我们的一切Node构建都要围绕具体OpDef的定义来进行
// You must call one Input() function per input_arg in the Op,
// *and in the same order as the input_args appear in the OpDef.*
// For inputs that take a single tensor.
NodeBuilder& Input(Node* src_node, int src_index = 0);
NodeBuilder& Input(NodeOut src);
// For inputs that take a list of tensors.
NodeBuilder& Input(gtl::ArraySlice<NodeOut> src_list);
// Require that this node run after src_node(s).
NodeBuilder& ControlInput(Node* src_node);
NodeBuilder& ControlInputs(gtl::ArraySlice<Node*> src_nodes);
// Sets the "requested device spec" in the NodeDef (not the
// "assigned device" in the Node).
NodeBuilder& Device(StringPiece device_spec);
// Set the value of an attr. attr_name must match the name of one of
// attrs defined by the Op, and value must have the corresponding type
// (see SetAttrValue() in ../framework/attr_value_util.h for legal
// types for value). Note that attrs will be set automatically if
// they can be determined by the inputs.
template <class T>
NodeBuilder& Attr(StringPiece attr_name, T&& value);
template <class T>
NodeBuilder& Attr(StringPiece attr_name, std::initializer_list<T> value);
// Validates the described node and adds it to *graph, adding edges
// for all (non-back) inputs. If created_node is not nullptr,
// *created_node will be set to the new node (or nullptr on error).
Status Finalize(Graph* graph, Node** created_node) const;
// Accessors for the values set in the constructor.
const string& node_name() const { return def_builder_.node_name(); }
const OpDef& op_def() const { return def_builder_.op_def(); }
private:
NodeDefBuilder def_builder_;
std::vector<NodeOut> inputs_;
std::vector<Node*> control_inputs_;
std::vector<string> errors_;
};
网友评论