Tensorflow核心代码解析之计算图篇其四:图的构建之二

作者: manofmountain | 来源:发表于2018-06-25 20:27 被阅读9次

    介绍

    在上篇“图的构建”中,我们有讲过TF里面是如何将一个GraphDef转换为真正执行所用的Graph的。照此理解的话,我们使用TF 上层API构建出的深度学习模型最终在底层生成了GraphDef;然后再进一步由上篇讲过的两个函数(ConvertGraphDefToGraph和ImportGraphDef)之一生成Graph用于执行训练或推理。

    但是实际上,TF内部的真正实现并非如此即"High level API 写的程序 -> 生成一张GraphDef表示的图 -> 转换为最终执行计算所用的Graph"。

    它实际的转换步骤有些小Tricky。先是High level API写程序时直接构成Graph -> 转换为由GraphDef表示的图 -> 再变换为最终执行所需的Graph。 注意在此步骤序列当中,第三步最终生成出来的Graph与第一步直接由我们用API所描述的Graph是不相同的,它有经过ConvertGraphDefToGraph时,多加了一些功能部分像BackEdges,Source/Sink节点以及这两个特殊节点与原图之间的边,还有一些完全检查、约束等等。

    本章当中,我们主要分析下TF里面如何在前端使用那些High level API在底层构建Graph,然后再转换为GraphDef的。它的主要实现可见于class GraphDefBuilder 当中。详细内容可见于tensorflow/core/graph/graph_constructor.h与tensorflow/core/graph/graph_constructor.cc。

    Graph的初步构建

    GraphDefBuilder类实现了大部分由前端API构建出底层的初步Graph这一工作。它在完成这一工作时具体又使用了像NodeBuilder这样的负责某一节点构建的类的功能。至于从初步生成的Graph到GraphDef这一步骤转换则主要由我们此系列文章中第一篇所介绍过的class Graph来完成。

    • GraphDefBuilder

    以下一段注释代码中,我们能看出GraphDefBuilder的使用。

    //   GraphDefBuilder b;
    //   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    //   Node* na = Const(7, b.opts());
    //   // Note: WithName() returns a copy, opts is unchanged.
    //   Node* nb = Const(5, b.opts().WithName("control-input"));
    //   Node* nc = Identity(na, b.opts().WithControlInput(nb));
    //   GraphDef graph_def;
    //   Status status = b.ToGraphDef(&graph_def);
    //   if (!status.ok()) { /* Handle error */ }
    //
    // In tests you can skip the status handling via:
    //   GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
    //   ...
    //   b.ToGraphDef(&graph_def);
    
    

    以下则为class GraphDefBuilder的具体定义。我们会发现它有一个主要的内部class定义,Options。这个选项类提供了大部分的构建Node的功能。几乎GraphDefBuilder内部的每个成员函数都会使用一个Options的引用参数来负责具体实现成员函数的功能。

    class GraphDefBuilder {
     public:
      // Options for adding a Node to a Graph.
      class Options {
       public:
        // Sets the Graph (that Nodes will be added to) and the status.  The
        // status may be set to nullptr, in which case errors cause CHECK
        // failures.  The graph and status must outlive *this.
        Options(Graph* graph, Status* status);
        ~Options();
    
        // Methods for setting options.  These are const methods: they
        // return a copy of *this with the option set.
        Options WithName(StringPiece name) const;
        Options WithDevice(StringPiece device) const;
        Options WithControlInput(Node* control_input) const;
        Options WithControlInputs(gtl::ArraySlice<Node*> control_inputs) const;
        // Given the Op type name, return a name for a node of that type.
        // Uses the value set in WithName() if that has been called.  Otherwise,
        // returns a name built out of the Op type name.
        string GetNameForOp(StringPiece op) const;
    
        // Sets the device, adds control inputs, adds attrs, and calls Finalize().
        // If Finalize returns an error, it is saved and this function returns
        // nullptr.
        Node* FinalizeBuilder(NodeBuilder* builder) const;
    
        // Updates the associated status, if any, or calls TF_CHECK_OK if none.
        void UpdateStatus(const Status& status) const;
    
        // Accessor
        const OpRegistryInterface* op_registry() const {
          return graph_->op_registry();
        }
    
       private:
        Options WithNameImpl(StringPiece name);
        Options WithDeviceImpl(StringPiece device);
        Options WithControlInputImpl(Node* control_input);
        Options WithControlInputsImpl(gtl::ArraySlice<Node*> control_inputs);
        template <class T>
        Options WithAttrImpl(StringPiece name, T&& value) {
          attrs_.emplace_back(std::string(name), AttrValue());
          SetAttrValue(std::forward<T>(value), &attrs_.back().second);
          return *this;
        }
    
        Graph* const graph_;
        Status* const status_;
        string name_;
        string device_;
        std::vector<Node*> control_inputs_;
        std::vector<std::pair<string, AttrValue>> attrs_;
      };
    };
    

    我们下面分开讲述GraphDefBuilder内部的主要成员变量及函数。可以看得到几乎每个函数都有一个为Options的参数。

    
     // Start building a new graph.
      explicit GraphDefBuilder(
          const OpRegistryInterface* op_registry = OpRegistry::Global())
          : graph_(op_registry), opts_(&graph_, &status_) {}
    
      // Gets the Options with the associated Graph and Status.
      const Options& opts() const { return opts_; }
    
      // Once all the nodes have been added, call this to get whether it was
      // successful, and if so fill *graph_def.
      Status ToGraphDef(GraphDef* graph_def) const;
      // Adds the function and gradient definitions in `fdef_lib` to this graph's op
      // registry. Ignores duplicate functions, and returns a bad status if an
      // imported function differs from an existing function or op with the same
      // name.
      Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
        return graph_.AddFunctionLibrary(fdef_lib);
      }
    
      // Returns whether a user-defined function with `name` already exists in the
      // graph.
      bool HasFunction(const string& name) {
        return graph_.flib_def().Find(name) != nullptr;
      }
    
     private:
      Graph graph_;
      Status status_;
      Options opts_;
    };
    

    以下为Options中FinalizeBuilder的实现,我们会发现基本Graph之上的每个Node都是由NodeBuilder来构建的(它会参考使用Graph_提供给它的OpLibraries)。最终可以使用NodeBuilder的Finalize函数来将生成的接点插入到Graph当中。

    Node* GraphDefBuilder::Options::FinalizeBuilder(NodeBuilder* builder) const {
      builder->ControlInputs(control_inputs_);
      if (!device_.empty()) builder->Device(device_);
      for (const auto& attr : attrs_) {
        builder->Attr(attr.first, attr.second);
      }
    
      Node* returned_node;
      UpdateStatus(builder->Finalize(graph_, &returned_node));
      return returned_node;
    }
    
    • GraphDefBuilderToGraph

    由以上GraphDefBuilder的定义,我们知道它里面已经包含了一个Graph成员,而且在具体由NodeBuilder一点点构建Graph时正是在这个成员变量上构建的。那么我们实现GraphDefBuilderToGraph似乎不成问题,简单说就是直接返回Graph成员变量就好了。

    可实际上,当下TF中会先将此Graph成员变量转变为一个GraphDef,然后再进一步调用ConvertGraphDefToGraph将它塑造为完整可用的最终Graph。此一转换步骤,我们在上一篇blog中有过提及,它会涉及到许多功能的完善、添加像Source/Sink nodes与图的连接,完整、安全性检查及控制节点的加入等等。

    // Converts the `GraphDef` being built by `builder` to a `Graph` and
    // stores it in `*graph`.
    // TODO(josh11b): Make this faster; right now it converts
    // Graph->GraphDef->Graph.  This cleans up the graph (e.g. adds
    // edges from the source and to the sink node, resolves back edges
    // by name), and makes sure the resulting graph is valid.
    Status GraphDefBuilderToGraph(const GraphDefBuilder& builder, Graph* graph);
    

    参考文献

    相关文章

      网友评论

        本文标题:Tensorflow核心代码解析之计算图篇其四:图的构建之二

        本文链接:https://www.haomeiwen.com/subject/wqluyftx.html