美文网首页Caffe
(Caffe)基本类Solver、Caffe、Batch(二)

(Caffe)基本类Solver、Caffe、Batch(二)

作者: 沤江一流 | 来源:发表于2016-04-07 18:28 被阅读912次

    本文从CSDN上转移过来:
    http://blog.csdn.net/mounty_fsc/article/details/51088173

    1 Solver

    1.1 简介

    其对网络进行求解,其作用有:

    1. 提供优化日志支持、创建用于学习的训练网络、创建用于评估的测试网络
    2. 通过调用forward / backward迭代地优化,更新权值
    3. 周期性地评估测试网络
    4. 通过优化了解model及solver的状态

    1.2 源代码

    /**
     * @brief An interface for classes that perform optimization on Net%s.
     *
     * Requires implementation of ApplyUpdate to compute a parameter update
     * given the current state of the Net parameters.
     */
    template <typename Dtype>
    class Solver {
     public:
      explicit Solver(const SolverParameter& param,
          const Solver* root_solver = NULL);
      explicit Solver(const string& param_file, const Solver* root_solver = NULL);
      void Init(const SolverParameter& param);
      void InitTrainNet();
      void InitTestNets();
     ...
      // The main entry of the solver function. In default, iter will be zero. Pass
      // in a non-zero iter number to resume training for a pre-trained net.
      virtual void Solve(const char* resume_file = NULL);
      inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
      void Step(int iters);
    ...
    
     protected:
      // Make and apply the update value for the current iteration.
      virtual void ApplyUpdate() = 0;
      ...
    
      SolverParameter param_;
      int iter_;
      int current_step_;
      shared_ptr<Net<Dtype> > net_;
      vector<shared_ptr<Net<Dtype> > > test_nets_;
      vector<Callback*> callbacks_;
      vector<Dtype> losses_;
      Dtype smoothed_loss_;
    
      // The root solver that holds root nets (actually containing shared layers)
      // in data parallelism
      const Solver* const root_solver_;
    ...
    };
    

    说明:

    1. shared_ptr<Net<Dtype>> net_为训练网络的指针,vector<shared_ptr<Net<Dtype>>> test_nets为测试网络的指针组,可见测试网络可以有多个

    2. 一般来说训练网络跟测试网络在实现上会有区别,但是绝大部分网络层是相同的。

    3. 不同的模型训练方法通过重载函数ComputeUpdateValue( )实现计算update参数的核心功能

    4. caffe.cpp中的train( )函数训练模型,在这里实例化一个Solver对象,初始化后调用了Solver中的Solve( )方法。而这个Solve( )函数主要就是在迭代运行下面这两个函数。ComputeUpdateValue();
      net_->Update();

    1.3 Solver的方法

    • Stochastic Gradient Descent (type: "SGD")
    • AdaDelta (type: "AdaDelta")
    • Adaptive Gradient (type: "AdaGrad")
    • Adam (type: "Adam")
    • Nesterov’s Accelerated Gradient (type: "Nesterov")
    • RMSprop (type: "RMSProp")

    详细参见引用1

    2 Caffe类

    Caffe类为一个包含常用的caffe成员的单例类。如caffe使用的cuda库cublas,curand的句柄等,以及生成Caffe中的随机数等。

    
    // common.hpp
    // A singleton class to hold common caffe stuff, such as the handler that
    // caffe is going to use for cublas, curand, etc.
    class Caffe {
     public:
      ~Caffe();
    
      // Thread local context for Caffe. Moved to common.cpp instead of
      // including boost/thread.hpp to avoid a boost/NVCC issues (#1009, #1010)
      // on OSX. Also fails on Linux with CUDA 7.0.18.
      static Caffe& Get();
    
      enum Brew { CPU, GPU };
    ...
    
    protected:
    #ifndef CPU_ONLY
      cublasHandle_t cublas_handle_;
      curandGenerator_t curand_generator_;
    #endif
      shared_ptr<RNG> random_generator_;
    
      Brew mode_;
      int solver_count_;
      bool root_solver_;
    
     private:
      // The private constructor to avoid duplicate instantiation.
      Caffe();
      DISABLE_COPY_AND_ASSIGN(Caffe);
    };
    
    //common.cpp
    
    namespace caffe {
    
    // Make sure each thread can have different values.
    static boost::thread_specific_ptr<Caffe> thread_instance_;
    
    Caffe& Caffe::Get() {
      if (!thread_instance_.get()) {
        thread_instance_.reset(new Caffe());
      }
      return *(thread_instance_.get());
    }
    
    ...
    Caffe::Caffe()
        : cublas_handle_(NULL), curand_generator_(NULL), random_generator_(),
        mode_(Caffe::CPU), solver_count_(1), root_solver_(true) {
      // Try to create a cublas handler, and report an error if failed (but we will
      // keep the program running as one might just want to run CPU code).
      if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) {
        LOG(ERROR) << "Cannot create Cublas handle. Cublas won't be available.";
      }
      // Try to create a curand handler.
      if (curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT)
          != CURAND_STATUS_SUCCESS ||
          curandSetPseudoRandomGeneratorSeed(curand_generator_, cluster_seedgen())
          != CURAND_STATUS_SUCCESS) {
        LOG(ERROR) << "Cannot create Curand generator. Curand won't be available.";
      }
    }
    ...
    
    }  // namespace caffe
    
    

    说明:

    1. Caffe类为一个单例类,构造方法私有
    2. 该单例由static boost::thread_specific_ptr<Caffe> thread_instance_维护,确保多线程环境下,不同的线程有不同的Caffe单例版本
    3. 获取该单例由Get()方法执行,即Caffe::Get()方法返回thread_instance_维护的单例,
    4. thread_instance_的初值为NULL,若是第一次访问,则new Caffe()
    5. new Caffe()执行构造方法,其实只是创建了cublas,curand的句柄
    6. 单步调试可发现cublasCreate()创建cublas的句柄,生成了额外的两个线程

    3 Batch

    template <typename Dtype>
    class Batch {
     public:
      Blob<Dtype> data_, label_;
    };
    

    说明:

    • Batch是对一个样本的封装,与Datum不同,Datum是面向数据库的,且一个Datum对应一个样本(图像、标签);而Batch是面向网络的,一个Batch对应一批样本

    [1].http://caffe.berkeleyvision.org/tutorial/solver.html

    相关文章

      网友评论

        本文标题:(Caffe)基本类Solver、Caffe、Batch(二)

        本文链接:https://www.haomeiwen.com/subject/iytvlttx.html