本文从CSDN上转移过来:
http://blog.csdn.net/mounty_fsc/article/details/51088173
1 Solver
1.1 简介
其对网络进行求解,其作用有:
- 提供优化日志支持、创建用于学习的训练网络、创建用于评估的测试网络
- 通过调用forward / backward迭代地优化,更新权值
- 周期性地评估测试网络
- 通过优化了解model及solver的状态
1.2 源代码
/**
* @brief An interface for classes that perform optimization on Net%s.
*
* Requires implementation of ApplyUpdate to compute a parameter update
* given the current state of the Net parameters.
*/
template <typename Dtype>
class Solver {
public:
explicit Solver(const SolverParameter& param,
const Solver* root_solver = NULL);
explicit Solver(const string& param_file, const Solver* root_solver = NULL);
void Init(const SolverParameter& param);
void InitTrainNet();
void InitTestNets();
...
// The main entry of the solver function. In default, iter will be zero. Pass
// in a non-zero iter number to resume training for a pre-trained net.
virtual void Solve(const char* resume_file = NULL);
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
void Step(int iters);
...
protected:
// Make and apply the update value for the current iteration.
virtual void ApplyUpdate() = 0;
...
SolverParameter param_;
int iter_;
int current_step_;
shared_ptr<Net<Dtype> > net_;
vector<shared_ptr<Net<Dtype> > > test_nets_;
vector<Callback*> callbacks_;
vector<Dtype> losses_;
Dtype smoothed_loss_;
// The root solver that holds root nets (actually containing shared layers)
// in data parallelism
const Solver* const root_solver_;
...
};
说明:
-
shared_ptr<Net<Dtype>> net_
为训练网络的指针,vector<shared_ptr<Net<Dtype>>> test_nets
为测试网络的指针组,可见测试网络可以有多个 -
一般来说训练网络跟测试网络在实现上会有区别,但是绝大部分网络层是相同的。
-
不同的模型训练方法通过重载函数
ComputeUpdateValue( )
实现计算update参数的核心功能 -
caffe.cpp中的
train( )
函数训练模型,在这里实例化一个Solver
对象,初始化后调用了Solver
中的Solve( )
方法。而这个Solve( )函数主要就是在迭代运行下面这两个函数。ComputeUpdateValue();
net_->Update();
1.3 Solver的方法
- Stochastic Gradient Descent (type: "SGD")
- AdaDelta (type: "AdaDelta")
- Adaptive Gradient (type: "AdaGrad")
- Adam (type: "Adam")
- Nesterov’s Accelerated Gradient (type: "Nesterov")
- RMSprop (type: "RMSProp")
详细参见引用1
2 Caffe类
Caffe类为一个包含常用的caffe成员的单例类。如caffe使用的cuda库cublas,curand的句柄等,以及生成Caffe中的随机数等。
// common.hpp
// A singleton class to hold common caffe stuff, such as the handler that
// caffe is going to use for cublas, curand, etc.
class Caffe {
public:
~Caffe();
// Thread local context for Caffe. Moved to common.cpp instead of
// including boost/thread.hpp to avoid a boost/NVCC issues (#1009, #1010)
// on OSX. Also fails on Linux with CUDA 7.0.18.
static Caffe& Get();
enum Brew { CPU, GPU };
...
protected:
#ifndef CPU_ONLY
cublasHandle_t cublas_handle_;
curandGenerator_t curand_generator_;
#endif
shared_ptr<RNG> random_generator_;
Brew mode_;
int solver_count_;
bool root_solver_;
private:
// The private constructor to avoid duplicate instantiation.
Caffe();
DISABLE_COPY_AND_ASSIGN(Caffe);
};
//common.cpp
namespace caffe {
// Make sure each thread can have different values.
static boost::thread_specific_ptr<Caffe> thread_instance_;
Caffe& Caffe::Get() {
if (!thread_instance_.get()) {
thread_instance_.reset(new Caffe());
}
return *(thread_instance_.get());
}
...
Caffe::Caffe()
: cublas_handle_(NULL), curand_generator_(NULL), random_generator_(),
mode_(Caffe::CPU), solver_count_(1), root_solver_(true) {
// Try to create a cublas handler, and report an error if failed (but we will
// keep the program running as one might just want to run CPU code).
if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) {
LOG(ERROR) << "Cannot create Cublas handle. Cublas won't be available.";
}
// Try to create a curand handler.
if (curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT)
!= CURAND_STATUS_SUCCESS ||
curandSetPseudoRandomGeneratorSeed(curand_generator_, cluster_seedgen())
!= CURAND_STATUS_SUCCESS) {
LOG(ERROR) << "Cannot create Curand generator. Curand won't be available.";
}
}
...
} // namespace caffe
说明:
-
Caffe
类为一个单例类,构造方法私有 - 该单例由
static boost::thread_specific_ptr<Caffe> thread_instance_
维护,确保多线程环境下,不同的线程有不同的Caffe
单例版本 - 获取该单例由
Get()
方法执行,即Caffe::Get()
方法返回thread_instance_
维护的单例, -
thread_instance_
的初值为NULL
,若是第一次访问,则new Caffe()
-
new Caffe()
执行构造方法,其实只是创建了cublas
,curand
的句柄 - 单步调试可发现
cublasCreate()
创建cublas
的句柄,生成了额外的两个线程
3 Batch
template <typename Dtype>
class Batch {
public:
Blob<Dtype> data_, label_;
};
说明:
- Batch是对一个样本的封装,与Datum不同,Datum是面向数据库的,且一个Datum对应一个样本(图像、标签);而Batch是面向网络的,一个Batch对应一批样本
网友评论