object_detectionAPI源码阅读笔记(3-train.py)

本文是以Faster RCNN为脉络进行分析。 SDD等类似吧!!! 我还没看。作为一个菜鸟,阅读代码一般是从第一个文件开始看。在我的思维里,Faster RCNN是从CNN等基层框架中抽取feature map进行检测,所以就想在train.py和trainer.py中想找到loss和输出等函数,好像没有。看到显性的损失函数,或者输出。
So.....,还是先看train.py

在train.py的导入文件有如下:

train.py

#train.py
import functools
import json
import os
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import trainer
from object_detection.builders import input_reader_builder
from object_detection.builders import model_builder
from object_detection.utils import config_util

这里的model_builder引起了我的注意,那就跳到model_builder.py文件吧。model_builder.py这里导入了很多内容

model_builder.py

#model_builder.py
from object_detection.builders import anchor_generator_builder
from object_detection.builders import box_coder_builder
from object_detection.builders import box_predictor_builder
from object_detection.builders import hyperparams_builder
from object_detection.builders import image_resizer_builder
from object_detection.builders import losses_builder
from object_detection.builders import matcher_builder
from object_detection.builders import post_processing_builder
from object_detection.builders import region_similarity_calculator_builder as sim_calc
from object_detection.core import box_predictor
from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.meta_architectures import rfcn_meta_arch
from object_detection.meta_architectures import ssd_meta_arch
from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extractor as frcnn_inc_res
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
from object_detection.models.ssd_mobilenet_v1_feature_extractor import SSDMobileNetV1FeatureExtractor
from object_detection.protos import model_pb2

这里的faster_rcnn_inception_resnet_v2_feature_extractor在说明文档里有提到的。请看object_detectionAPI源码阅读笔记
So...faster_rcnn_inception_resnet_v2_feature_extractor.py就是我要的啊,在配置文档里有提到,这个是进行特这提取的。也是DetectionModels (object_detection/core/model.py)的子类。为什么是子类接下来会有说明。

#model_builder.py
from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extractor as frcnn_inc_res
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
from object_detection.models.ssd_mobilenet_v1_feature_extractor import SSDMobileNetV1FeatureExtractor

那么就从faster_rcnn_inception_v2_feature_extractor.py看导入的文件吧。

faster_rcnn_inception_v2_feature_extractor.py

##faster_rcnn_inception_v2_feature_extractor.py

import tensorflow as tf

from object_detection.meta_architectures import faster_rcnn_meta_arch
from nets import inception_resnet_v2

这里导入了inception_resnet_v2模型,我感觉老祖宗找到了,这里的基本CNN模型就是inception_resnet_v2模型,这是一个基本的CNN框架,nets还有很多基本的网络框架,包括vgg,alexnet等。

所以faster_rcnn_inception_v2_feature_extractor.py看样子就是对基本框架进行提取的文件了。

但是这里导入了faster_rcnn_meta_arch
我们看看faster_rcnn_meta_arch.py

faster_rcnn_meta_arch.py

#faster_rcnn_meta_arch.py
from abc import abstractmethod
from functools import partial
import tensorflow as tf

from object_detection.anchor_generators import grid_anchor_generator
from object_detection.core import balanced_positive_negative_sampler as sampler
from object_detection.core import box_list
from object_detection.core import box_list_ops
from object_detection.core import box_predictor
from object_detection.core import losses
from object_detection.core import model
from object_detection.core import post_processing
from object_detection.core import standard_fields as fields
from object_detection.core import target_assigner
from object_detection.utils import ops
from object_detection.utils import shape_utils

这里找到了model.py这可是所有检测模型的基类

model.py

#model.py
from abc import ABCMeta
from abc import abstractmethod

from object_detection.core import standard_fields as fields

发现果然是基类,没导入什么内容,这也算是检测模型的老祖宗了,简陋到不行啊。

class DetectionModel(object):
  """Abstract base class for detection models."""
  __metaclass__ = ABCMeta

  def __init__(self, num_classes):
    """Constructor.

发现基类DetectionModel(object)就是在这里个文件实现的,
基类的功能就是:如下
inputs (images tensor) -> preprocess -> predict -> loss ->outputs (loss tensor)

1.目录脉络 train.py:
1.model.py -> faster_rcnn_meta_arch.py ->faster_rcnn_inception_v2_feature_extractor.py
2.inception_resnet_v2 ->faster_rcnn_inception_v2_feature_extractor.py ->model_builder.py
3.model_builder.py -> train.py

2.目录脉络 eval.py:
1.model.py -> faster_rcnn_meta_arch.py ->faster_rcnn_inception_v2_feature_extractor.py
2.inception_resnet_v2 ->faster_rcnn_inception_v2_feature_extractor.py ->model_builder.py
3.model_builder.py -> eval.py
这里有一张haixwang的图

trianer.py

def _create_losses(input_queue, create_model_fn, train_config):
    """Creates loss function for a DetectionModel.

    Args:
    input_queue: BatchQueue object holding enqueued tensor_dicts.
    create_model_fn: A function to create the DetectionModel.
    train_config: a train_pb2.TrainConfig protobuf.
    """
    
    # 创建一个检测模型
    detection_model = create_model_fn()
    
    # 读入数据 使用get_inputs()
    (images, _, groundtruth_boxes_list, groundtruth_classes_list,
    groundtruth_masks_list, groundtruth_keypoints_list) = get_inputs(
       input_queue,
       detection_model.num_classes,
       train_config.merge_multiple_label_boxes)
    
    # 对数据进行归一化
    images = [detection_model.preprocess(image) for image in images]
    images = tf.concat(images, 0)
    if any(mask is None for mask in groundtruth_masks_list):
    groundtruth_masks_list = None
    if any(keypoints is None for keypoints in groundtruth_keypoints_list):
    groundtruth_keypoints_list = None
    
    # 获取真实标签数据
    detection_model.provide_groundtruth(groundtruth_boxes_list,
                                      groundtruth_classes_list,
                                      groundtruth_masks_list,
                                      groundtruth_keypoints_list)
                                     
    # 进行预测吧 
    prediction_dict = detection_model.predict(images)

    # 产生损失
    losses_dict = detection_model.loss(prediction_dict)
    for loss_tensor in losses_dict.values():
    tf.losses.add_loss(loss_tensor)

这里的_create_losses()产生了loss会被送入到train进行优化训练。

  • train()

def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
          num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name,
          is_chief, train_dir):
  """Training function for detection models.

  Args:
    create_tensor_dict_fn: 创建输入张量函数
    create_model_fn:a function that creates a DetectionModel and generates losses.(创建一个损失函数) 
    train_config: 训练配置文件
    master: 分布式训练设别的名字
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: 训练文件的保存目录
  """

到这里,估计差不多了训练流程走的差不都了,这里实现的trainer.train()是最后的配置。我也看到loss了(参数create_model_fn:a function that creates a DetectionModel and generates losses)

eval.py 中用到DetectionModel

预测总体过程:inputs (images tensor) -> preprocess -> predict -> postprocess -> outputs (boxes tensor, scores tensor, classes tensor, num_detections tensor)
eval.py 中导入如下包

# eval.py 
import functools
import os
import tensorflow as tf

import evaluator
from object_detection.builders import input_reader_builder
from object_detection.builders import model_builder
from object_detection.utils import config_util
from object_detection.utils import label_map_util

与train类似,其中的evaluator才是DetectionModel真正使用者。

  • _extract_prediction_tensors()

def _extract_prediction_tensors(model,
                                create_input_dict_fn,
                                ignore_groundtruth=False):
  """Restores the model in a tensorflow session.

  Args:
    model: model to perform predictions with.
    create_input_dict_fn: function to create input tensor dictionaries.
    ignore_groundtruth: whether groundtruth should be ignored.

  Returns:
    tensor_dict: A tensor dictionary with evaluations.
  """
  # 创建数据输入队列
  input_dict = create_input_dict_fn()
  prefetch_queue = prefetcher.prefetch(input_dict, capacity=500)
  input_dict = prefetch_queue.dequeue()
  original_image = tf.expand_dims(input_dict[fields.InputDataFields.image], 0)
  
  # 创建检测模型
  preprocessed_image = model.preprocess(tf.to_float(original_image))
  
  # 进行预测
  prediction_dict = model.predict(preprocessed_image)
  
  # 进行后处理
  detections = model.postprocess(prediction_dict)

  # 获取这是标签
  groundtruth = None
  if not ignore_groundtruth:
    groundtruth = {
        fields.InputDataFields.groundtruth_boxes:
            input_dict[fields.InputDataFields.groundtruth_boxes],
        fields.InputDataFields.groundtruth_classes:
            input_dict[fields.InputDataFields.groundtruth_classes],
        fields.InputDataFields.groundtruth_area:
            input_dict[fields.InputDataFields.groundtruth_area],
        fields.InputDataFields.groundtruth_is_crowd:
            input_dict[fields.InputDataFields.groundtruth_is_crowd],
        fields.InputDataFields.groundtruth_difficult:
            input_dict[fields.InputDataFields.groundtruth_difficult]
    }
    if fields.InputDataFields.groundtruth_group_of in input_dict:
      groundtruth[fields.InputDataFields.groundtruth_group_of] = (
          input_dict[fields.InputDataFields.groundtruth_group_of])
    if fields.DetectionResultFields.detection_masks in detections:
      groundtruth[fields.InputDataFields.groundtruth_instance_masks] = (
          input_dict[fields.InputDataFields.groundtruth_instance_masks])

  return eval_util.result_dict_for_single_example(
      original_image,
      input_dict[fields.InputDataFields.source_id],
      detections,
      groundtruth,
      class_agnostic=(
          fields.DetectionResultFields.detection_classes not in detections),
      scale_to_absolute=True)

这里关于检测模型的详细内容请继续阅读吧!!!

参考:
TensorFlow Object Detection API 源码(1) DetectionModel

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 199,902评论 5 468
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 84,037评论 2 377
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 146,978评论 0 332
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 53,867评论 1 272
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 62,763评论 5 360
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,104评论 1 277
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,565评论 3 390
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,236评论 0 254
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,379评论 1 294
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,313评论 2 317
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,363评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,034评论 3 315
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,637评论 3 303
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,719评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,952评论 1 255
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,371评论 2 346
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 41,948评论 2 341

推荐阅读更多精彩内容