network

作者: 陈继科 | 来源:发表于2017-05-04 15:08 被阅读214次

    和 layer 一样,这也是深度学习框架的重要数据结构。

    类型 名称 意义

    int n;//
    int batch;//
    int *seen;
    float epoch;//
    int subdivisions;
    float momentum;//
    float decay;//
    layer *layers;//
    float *output;
    learning_rate_policy policy;//

    float learning_rate;//
    float gamma;//
    float scale;//
    float power;//
    int time_steps;
    int step;//
    int max_batches;
    float *scales;
    int   *steps;
    int num_steps;
    int burn_in;
    
    int adam;//
    float B1;
    float B2;
    float eps;
    
    int inputs;
    int outputs;
    int truths;
    int notruth;
    int h, w, c;
    int max_crop;
    int min_crop;
    int center;
    float angle;
    float aspect;
    float exposure;
    float saturation;
    float hue;
    
    int gpu_index;
    tree *hierarchy;
    
    
    
    float *input;
    float *truth;
    float *delta;
    float *workspace;
    int train;
    int index;
    float *cost;
    
    #ifdef GPU
    float *input_gpu;
    float *truth_gpu;
    float *delta_gpu;
    float *output_gpu;
    #endif
    

    这里面提供了许多重要函数
    float get_current_rate(network net);
    int get_current_batch(network net);
    void free_network(network net);
    void compare_networks(network n1, network n2, data d);
    char *get_layer_string(LAYER_TYPE a);

    network make_network(int n);
    void forward_network(network net);
    void backward_network(network net);
    void update_network(network net);

    float train_network(network net, data d);
    float train_network_sgd(network net, data d, int n);
    float train_network_datum(network net);

    matrix network_predict_data(network net, data test);
    float *network_predict(network net, float *input);
    float network_accuracy(network net, data d);
    float *network_accuracies(network net, data d, int n);
    float network_accuracy_multi(network net, data d, int n);
    void top_predictions(network net, int n, int *index);
    image get_network_image(network net);
    image get_network_image_layer(network net, int i);
    layer get_network_output_layer(network net);
    int get_predicted_class_network(network net);
    void print_network(network net);
    void visualize_network(network net);
    int resize_network(network *net, int w, int h);
    void set_batch_network(network *net, int b);
    network load_network(char *cfg, char *weights, int clear);
    load_args get_base_args(network net);
    void calc_network_cost(network net);
    重点先看下面三个,有助于代码理解

    forward_network(...)##

    前向

    void forward_network(network net)
    {
        int i;
        for(i = 0; i < net.n; ++i){
            net.index = i;
            layer l = net.layers[i];
            if(l.delta){
                fill_cpu(l.outputs * l.batch, 0, l.delta, 1);
            }
            l.forward(l, net);
            net.input = l.output;
            if(l.truth) {
                net.truth = l.output;
            }
        }
        calc_network_cost(net);
    }
    
    

    backward_network(...)##

    BP 算梯度

    void backward_network(network net)
    {
        int i;
        network orig = net;
        for(i = net.n-1; i >= 0; --i){
            layer l = net.layers[i];
            if(l.stopbackward) break;
            if(i == 0){
                net = orig;
            }else{
                layer prev = net.layers[i-1];
                net.input = prev.output;
                net.delta = prev.delta;
            }
            net.index = i;
            l.backward(l, net);
        }
    }
    

    update_network(...)##

    更新 parameters

    void update_network(network net)
    {
        int i;
        int update_batch = net.batch*net.subdivisions;
        float rate = get_current_rate(net);
        for(i = 0; i < net.n; ++i){
            layer l = net.layers[i];
            if(l.update){
                l.update(l, update_batch, rate*l.learning_rate_scale, net.momentum, net.decay);
            }
        }
    }
    

    相关文章

      网友评论

          本文标题:network

          本文链接:https://www.haomeiwen.com/subject/zhgrtxtx.html