Caffe代码导读(3):caffe的模型训练流程

caffe train 整体流程图

程序入口:main()

int main(int argc, char** argv) {
     .....
     return GetBrewFunction(caffe::string(argv[1]))();
     ....
}

g_brew_map实现过程,首先通过 typedef定义函数指针 typedef int (*BrewFunction)(); 这个是用typedef定义函数指针方法。这个程序定义一个BrewFunction函数指针类型,在caffe.cpp 中 BrewFunction 作为GetBrewFunction()函数的返回类型,可以是 train(),test(),device_query(),time() 这四个函数指针的其中一个。在train(),test(),中可以调用solver类的函数,从而进入到net,进入到每一层,运行整个caffe程序。然后对每个函数注册。

1. RegisterBrewFunction(train)
2. RegisterBrewFunction(test)
3. RegisterBrewFunction(device_query)
4. RegisterBrewFunction(time)
    train: 训练或者调整一个模型
    test : 在测试集上测试一个模型
    device_query : 打印GPU的调试信息
    time: 压测一个模型的执行时间

如果需要,可以增加其他的方式,然后通过RegisterBrewFunction()函数注册一下即可。

调用train()函数

接着调用train()函数,train函数中主要有三个方法ReadSolverParamsFromTextFileOrDie、CreateSolver、Solve。

// Train / Finetune a model.
int train() {
  ......
  caffe::SolverParameter solver_param;
  caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);//从-solver参数读取solver_param
  ......
  shared_ptr<caffe::Solver<float> >
      solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));//从参数创建solver,同样采用string到函数指针的映射实现,用到了工厂模式

  if (FLAGS_snapshot.size()) {//迭代snapshot次后保存模型一次
    LOG(INFO) << "Resuming from " << FLAGS_snapshot;
    solver->Restore(FLAGS_snapshot.c_str());
  } else if (FLAGS_weights.size()) {//若采用finetuning,则拷贝weight到指定模型
    CopyLayers(solver.get(), FLAGS_weights);
  }

  if (gpus.size() > 1) {
    caffe::P2PSync<float> sync(solver, NULL, solver->param());
    sync.Run(gpus);
  } else {
    LOG(INFO) << "Starting Optimization";
    solver->Solve();//开始训练网络
  }
  LOG(INFO) << "Optimization Done.";
  return 0;
}

ReadSolverParamsFromTextFileOrDie

caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param)解析-solver指定的solver.prototxt的文件内容到solver_param中

CreateSolver

CreateSolver函数构建solver和net,该函数是初始化的入口,会通过执行Solver的构造函数,调用 void Solver<Dtype>::Init(const SolverParameter& param),该函数内有InitTrainNet()、InitTestNets()。对于InitTrainNet函数:

......
net_.reset(new Net<Dtype>(net_param));

调用Net类的构造函数,然后执行Init()操作,该函数具体的内容如下图和源码所示:

template <typename Dtype>
void Net<Dtype>::Init(const NetParameter& in_param) {
  ........//过滤校验参数FilterNet
  FilterNet(in_param, &filtered_param);
  .........//插入Splits层
  InsertSplits(filtered_param, &param);
  .......// 构建网络中输入输出存储结构
  bottom_vecs_.resize(param.layer_size());
  top_vecs_.resize(param.layer_size());
  bottom_id_vecs_.resize(param.layer_size());
  param_id_vecs_.resize(param.layer_size());
  top_id_vecs_.resize(param.layer_size());
  bottom_need_backward_.resize(param.layer_size());

  for (int layer_id = 0; layer_id < param.layer_size(); ++layer_id) {
   ...//创建层
 layers_.push_back(LayerRegistry<Dtype>::CreateLayer(layer_param));
    layer_names_.push_back(layer_param.name());
    LOG_IF(INFO, Caffe::root_solver())
        << "Creating Layer " << layer_param.name();
    bool need_backward = false;

    // Figure out this layer's input and output
    for (int bottom_id = 0; bottom_id < layer_param.bottom_size();
         ++bottom_id) {
      const int blob_id = AppendBottom(param, layer_id, bottom_id,
                                       &available_blobs, &blob_name_to_idx);


   ........//创建相关blob
    // If the layer specifies that AutoTopBlobs() -> true and the LayerParameter
    // specified fewer than the required number (as specified by
    // ExactNumTopBlobs() or MinTopBlobs()), allocate them here.
    Layer<Dtype>* layer = layers_[layer_id].get();
    if (layer->AutoTopBlobs()) {
      const int needed_num_top =
          std::max(layer->MinTopBlobs(), layer->ExactNumTopBlobs());
      for (; num_top < needed_num_top; ++num_top) {
        // Add "anonymous" top blobs -- do not modify available_blobs or
        // blob_name_to_idx as we don't want these blobs to be usable as input
        // to other layers.
        AppendTop(param, layer_id, num_top, NULL, NULL);
      }
    }


    .....//执行SetUp()
    // After this layer is connected, set it up.
    layers_[layer_id]->SetUp(bottom_vecs_[layer_id], top_vecs_[layer_id]);
    LOG_IF(INFO, Caffe::root_solver())
        << "Setting up " << layer_names_[layer_id];
    for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) {
      if (blob_loss_weights_.size() <= top_id_vecs_[layer_id][top_id]) {
        blob_loss_weights_.resize(top_id_vecs_[layer_id][top_id] + 1, Dtype(0));
      }
      blob_loss_weights_[top_id_vecs_[layer_id][top_id]] = layer->loss(top_id);
      LOG_IF(INFO, Caffe::root_solver())
          << "Top shape: " << top_vecs_[layer_id][top_id]->shape_string();
      if (layer->loss(top_id)) {
        LOG_IF(INFO, Caffe::root_solver())
            << "    with loss weight " << layer->loss(top_id);
      }
      memory_used_ += top_vecs_[layer_id][top_id]->count();
    }
    LOG_IF(INFO, Caffe::root_solver())
        << "Memory required for data: " << memory_used_ * sizeof(Dtype);
    const int param_size = layer_param.param_size();
    const int num_param_blobs = layers_[layer_id]->blobs().size();
    CHECK_LE(param_size, num_param_blobs)
        << "Too many params specified for layer " <<

Net::Init()

SetUp是怎么构建的呢?

virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top) {}

 void SetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top) {
    InitMutex();
    CheckBlobCounts(bottom, top);
    LayerSetUp(bottom, top);
    Reshape(bottom, top);
    SetLossWeights(top);
  }

初始化的总体流程大概就是新建一个Solver对象,然后调用Solver类的构造函数,然后在Solver的构造函数中又会新建Net类实例,在Net类的构造函数中又会新建各个layer的实例,一直具体到设置每个Blob,大概就完成了网络初始化的工作了。

Solve

train函数中CreateSolver()执行完成后,接下来是具体训练过程,执行Solve()函数---->Step()--->结束

Solve的具体内容和代码:

template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
  CHECK(Caffe::root_solver());
  LOG(INFO) << "Solving " << net_->name();
  LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();

  // For a network that is trained by the solver, no bottom or top vecs
  // should be given, and we will just provide dummy vecs.
  int start_iter = iter_;
  Step(param_.max_iter() - iter_);

  // overridden by setting snapshot_after_train := false
  if (param_.snapshot_after_train()
      && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
    Snapshot();
  }

  // display loss
  if (param_.display() && iter_ % param_.display() == 0) {
    int average_loss = this->param_.average_loss();
    Dtype loss;
    net_->Forward(&loss);

    UpdateSmoothedLoss(loss, start_iter, average_loss);


  if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
    TestAll();
  }
}

然后开始执行Step函数,具体内容和代码:

template <typename Dtype>  
void Solver<Dtype>::Step(int iters)  
{  
    // 起始迭代步数  
    const int start_iter = iter_;  
    // 终止迭代步数  
    const int stop_iter = iter_ + iters;  

    // 判断是否已经完成设定步数  
    while (iter_ < stop_iter)  
    {  
        // 将net_中的Bolb梯度参数置为零  
        net_->ClearParamDiffs();  

        ...  

        // accumulate the loss and gradient  
        Dtype loss = 0;  
        for (int i = 0; i < param_.iter_size(); ++i)  
        {  
            // 正向传导和反向传导,并计算loss  
            loss += net_->ForwardBackward();  
        }  
        loss /= param_.iter_size();  

        // 为了输出结果平滑,将临近的average_loss个loss数值进行平均,存储在成员变量smoothed_loss_中  
        UpdateSmoothedLoss(loss, start_iter, average_loss);  

        // BP算法更新权重  
        ApplyUpdate();  

        // Increment the internal iter_ counter -- its value should always indicate  
        // the number of times the weights have been updated.  
        ++iter_;  
    }  
}

while循环中先调用了网络类Net::ForwardBackward()成员函数进行正向传导和反向传导,并计算loss

Dtype ForwardBackward() {
    Dtype loss;
    //正向传导
    Forward(&loss);
    //反向传导
    Backward();
    return loss;
  }

而Fordward函数中调用了ForwardFromTo,而FordwardFromTo又调用了每个layer的Fordward。反向传导函数Backward()调用了BackwardFromTo(int start, int end)函数。正向传导和反向传导结束后,再调用SGDSolver::ApplyUpdate()成员函数进行权重更新。

  • ForwardBackward:按顺序调用了Forward和Backward。
  • ForwardFromTo(int start, int end):执行从start层到end层的前向传递,采用简单的for循环调用。,forward只要计算损失loss
  • BackwardFromTo(int start, int end):和前面的ForwardFromTo函数类似,调用从start层到end层的反向传递。backward主要根据loss来计算梯度,caffe通过自动求导并反向组合每一层的梯度来计算整个网络的梯度。
  • ToProto函数完成网络的序列化到文件,循环调用了每个层的ToProto函数
template <typename Dtype>  
void SGDSolver<Dtype>::ApplyUpdate()  
{  
    // 获取当前学习速率  
    Dtype rate = GetLearningRate();  
    if (this->param_.display() && this->iter_ % this->param_.display() == 0)  
    {  
        LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;  
    }  

    // 在计算当前梯度的时候,如果该值超过了阈值clip_gradients,则将梯度直接设置为该阈值  
    // 此处阈值设为-1,即不起作用  
    ClipGradients();  

    // 逐层更新网络中的可学习层  
    for (int param_id = 0; param_id < this->net_->learnable_params().size();  
       ++param_id)  
    {  
        // 归一化  
        Normalize(param_id);  
        // L2范数正则化添加衰减权重  
        Regularize(param_id);  
        // 随机梯度下降法计算更新值  
        ComputeUpdateValue(param_id, rate);  
    }  
    // 更新权重  
    this->net_->Update();  
}

ApplyUpdate

最后将迭代次数++iter_,继续while循环,直到迭代次数完成。 这就是整个网络的训练过程。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,445评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,889评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,047评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,760评论 1 276
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,745评论 5 367
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,638评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,011评论 3 398
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,669评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,923评论 1 299
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,655评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,740评论 1 330
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,406评论 4 320
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,995评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,961评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,197评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,023评论 2 350
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,483评论 2 342

推荐阅读更多精彩内容