经过了数据准备阶段,得到roidb和 imdb,下面利用得到的数据roidb进入网络的训练阶段:
model_paths = train_net(solver, roidb, output_dir,pretrained_model=init_model,max_iters=max_iters)
# 假设:net_name=‘ZF’,那么solver:ZF_faster_rcnn_alt_opt_stage1_rpn_solver60k80k.pt,
# max_iters:80000
进入train_net函数:
首先,使用filter_roidb函数对roidb进行过滤,过滤掉无效的图片数据。
然后,创建 SolverWrapper类:
sw = SolverWrapper(solver_prototxt, roidb, output_dir,pretrained_model=pretrained_model)
进入 SolverWrapper类的初始化函数:
注意,cfg.TRAIN.BBOX_REG =False。所以,这两个if语句都不执行。
self.solver = caffe.SGDSolver(solver_prototxt):利用stage1_rpn_solver60k80k.pt文件初始化网络的优化器
self.solver.net.layers[0].set_roidb(roidb):根据roi_data_layer.layer文件中的set_roidb函数来对roidb进行随机打乱
回到train_net函数,执行:model_paths = sw.train_model(max_iters),返回值是一个列表。
进入 sw.train_model函数:
每200次打印一次结果。
每10000次保存一次网络,注意,在stage1_rpn_solver60k80k.pt文件中,snapshot:0,因此,这里没用使用caffe自带的snapshot来保存网络结果,而是用的自己定义的snapshot。
进入snapshot,发现这个函数返回值是一个:filename文件(保存的网络的绝对路径),因此model_paths 返回的结果是一个filename文件的列表。这也是train_net函数的返回结果。
回到train_rpn函数中:
将得到的model_paths 列表中的元组只保留最后一个,其余的全部移除,也就是只保留最新的那个网络结果,然后把这个结果以字典的形式推入进程队列中。
最后,回到:p = mp.Process(target=train_rpn,kwargs=mp_kwargs),注意,这里的p只是创建进程,接下来,我们启动进程:p.start(),从进程队列中取出刚才的字典的value:rpn_stage1_out = mp_queue.get(),然后等待进程结束:p.join()
这样,我们就得到了训练的rnp网络:rpn_stage1_out,然后把它用在下一步中。