和 layer 一样,这也是深度学习框架的重要数据结构。
类型 | 名称 | 意义 |
---|
int n;//
int batch;//
int *seen;
float epoch;//
int subdivisions;
float momentum;//
float decay;//
layer *layers;//
float *output;
learning_rate_policy policy;//
float learning_rate;//
float gamma;//
float scale;//
float power;//
int time_steps;
int step;//
int max_batches;
float *scales;
int *steps;
int num_steps;
int burn_in;
int adam;//
float B1;
float B2;
float eps;
int inputs;
int outputs;
int truths;
int notruth;
int h, w, c;
int max_crop;
int min_crop;
int center;
float angle;
float aspect;
float exposure;
float saturation;
float hue;
int gpu_index;
tree *hierarchy;
float *input;
float *truth;
float *delta;
float *workspace;
int train;
int index;
float *cost;
#ifdef GPU
float *input_gpu;
float *truth_gpu;
float *delta_gpu;
float *output_gpu;
#endif
这里面提供了许多重要函数
float get_current_rate(network net);
int get_current_batch(network net);
void free_network(network net);
void compare_networks(network n1, network n2, data d);
char *get_layer_string(LAYER_TYPE a);
network make_network(int n);
void forward_network(network net);
void backward_network(network net);
void update_network(network net);
float train_network(network net, data d);
float train_network_sgd(network net, data d, int n);
float train_network_datum(network net);
matrix network_predict_data(network net, data test);
float *network_predict(network net, float *input);
float network_accuracy(network net, data d);
float *network_accuracies(network net, data d, int n);
float network_accuracy_multi(network net, data d, int n);
void top_predictions(network net, int n, int *index);
image get_network_image(network net);
image get_network_image_layer(network net, int i);
layer get_network_output_layer(network net);
int get_predicted_class_network(network net);
void print_network(network net);
void visualize_network(network net);
int resize_network(network *net, int w, int h);
void set_batch_network(network *net, int b);
network load_network(char *cfg, char *weights, int clear);
load_args get_base_args(network net);
void calc_network_cost(network net);
重点先看下面三个,有助于代码理解
forward_network(...)##
前向
void forward_network(network net)
{
int i;
for(i = 0; i < net.n; ++i){
net.index = i;
layer l = net.layers[i];
if(l.delta){
fill_cpu(l.outputs * l.batch, 0, l.delta, 1);
}
l.forward(l, net);
net.input = l.output;
if(l.truth) {
net.truth = l.output;
}
}
calc_network_cost(net);
}
backward_network(...)##
BP 算梯度
void backward_network(network net)
{
int i;
network orig = net;
for(i = net.n-1; i >= 0; --i){
layer l = net.layers[i];
if(l.stopbackward) break;
if(i == 0){
net = orig;
}else{
layer prev = net.layers[i-1];
net.input = prev.output;
net.delta = prev.delta;
}
net.index = i;
l.backward(l, net);
}
}
update_network(...)##
更新 parameters
void update_network(network net)
{
int i;
int update_batch = net.batch*net.subdivisions;
float rate = get_current_rate(net);
for(i = 0; i < net.n; ++i){
layer l = net.layers[i];
if(l.update){
l.update(l, update_batch, rate*l.learning_rate_scale, net.momentum, net.decay);
}
}
}
网友评论