一、跑通 COCO 数据集整套流程
程序入口:
单 GPU 训练:
python tools/train.py ${CONFIG_FILE}
train.py
核心如下:
def main():
# 1. 读取配置文件
cfg = Config.fromfile(args.config)
# 2. 创建模型
model = build_detector(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
# 3. 创建数据集
datasets = [build_dataset(cfg.data.train)]
# 4. 开始训练
train_detector(
model,
datasets,
cfg,
distributed=distributed,
validate=args.validate,
timestamp=timestamp,
meta=meta)
1. 读取配置文件
2. 创建模型
3. 创建数据集
cfg.data.train
提供 4 个参数:
data = dict(
...
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
...)
build_dataset()
位于 mmdet/datasets/builder.py
,它的定义如下:
def build_dataset(cfg, default_args=None):
if isinstance(cfg, (list, tuple)):
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
elif cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(
build_dataset(cfg['dataset'], default_args), cfg['times'])
elif isinstance(cfg.get('ann_file'), (list, tuple)):
dataset = _concat_dataset(cfg, default_args)
else:
dataset = build_from_cfg(cfg, DATASETS, default_args) # 假设执行这一行
return dataset
这里的 DATASETS
是一个 Registry()
实例:
DATASETS = Registry('dataset')
build_from_cfg()
位于 mmdet/utils/registry.py
,它的定义如下:
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
obj: The constructed object.
"""
assert isinstance(cfg, dict) and 'type' in cfg
assert isinstance(default_args, dict) or default_args is None
args = cfg.copy()
obj_type = args.pop('type')
if mmcv.is_str(obj_type):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError('{} is not in the {} registry'.format(
obj_type, registry.name))
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_cls(**args)
最终返回的是一个 class
,而且已经传入了参数 **args
。
上述代码之所以能够跑通,是因为在 mmdet/datasets/coco.py
中,CocoDataset
这个类已经事先被注册了:
import logging
import os.path as osp
import tempfile
import mmcv
import numpy as np
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from mmdet.core import eval_recalls
from mmdet.utils import print_log
from .custom import CustomDataset
from .registry import DATASETS
@DATASETS.register_module
class CocoDataset(CustomDataset):
CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic_light', 'fire_hydrant',
'stop_sign', 'parking_meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports_ball', 'kite', 'baseball_bat',
'baseball_glove', 'skateboard', 'surfboard', 'tennis_racket',
'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'hot_dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
'vase', 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush')
...
CocoDataset
类中没有定义 __init__()
方法,该方法的定义是在 CocoDataset
继承的类——CustomDataset
中完成的:
import os.path as osp
import mmcv
import numpy as np
from torch.utils.data import Dataset
from mmdet.core import eval_map, eval_recalls
from .pipelines import Compose
from .registry import DATASETS
@DATASETS.register_module
class CustomDataset(Dataset):
"""Custom dataset for detection.
Annotation format:
[
{
'filename': 'a.jpg',
'width': 1280,
'height': 720,
'ann': {
'bboxes': <np.ndarray> (n, 4),
'labels': <np.ndarray> (n, ),
'bboxes_ignore': <np.ndarray> (k, 4), (optional field)
'labels_ignore': <np.ndarray> (k, 4) (optional field)
}
},
...
]
The `ann` field is optional for testing.
"""
CLASSES = None
def __init__(self,
ann_file,
pipeline,
data_root=None,
img_prefix='',
seg_prefix=None,
proposal_file=None,
test_mode=False,
filter_empty_gt=True):
self.ann_file = ann_file
self.data_root = data_root
self.img_prefix = img_prefix
self.seg_prefix = seg_prefix
self.proposal_file = proposal_file
self.test_mode = test_mode
self.filter_empty_gt = filter_empty_gt
...
CustomDataset
类事先也已被注册。因此 return obj_cls(**args)
中的参数最终其实是传到了 CustomDataset
中。但是对 coco 数据集的各种处理方法基本都在 mmdet/datasets/coco.py
中被重写了,因此涉及到处理数据集的各种操作还是要看 mmdet/datasets/coco.py
。
4. 开始训练
train_detector()
是在 mmdet/apis/train.py
中定义的,它的源码如下:
def train_detector(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
logger = get_root_logger(cfg.log_level)
# start training
if distributed:
_dist_train(
model,
dataset,
cfg,
validate=validate,
logger=logger,
timestamp=timestamp,
meta=meta)
else:
_non_dist_train(
model,
dataset,
cfg,
validate=validate,
logger=logger,
timestamp=timestamp,
meta=meta)
可以看到,这里是将训练过程分成了分布式训练和非分布式训练两部分。
非分布式训练的源码如下:
def _non_dist_train(model,
dataset,
cfg,
validate=False,
logger=None,
timestamp=None,
meta=None):
if validate:
raise NotImplementedError('Built-in validation is not implemented '
'yet in not-distributed training. Use '
'distributed training or test.py and '
'*eval.py scripts instead.')
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
data_loaders = [
build_dataloader(
ds,
cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu,
cfg.gpus,
dist=False,
seed=cfg.seed) for ds in dataset
]
# put model on gpus
model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
runner = Runner(
model,
batch_processor,
optimizer,
cfg.work_dir,
logger=logger,
meta=meta)
# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
optimizer_config = Fp16OptimizerHook(
**cfg.optimizer_config, **fp16_cfg, distributed=False)
else:
optimizer_config = cfg.optimizer_config
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config)
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
这里首先要根据 dataset
创建 data_loader
,build_dataloader()
位于 mmdet/datasets/loader/build_loader.py
,它里面的内容如下:
def build_dataloader(dataset,
imgs_per_gpu,
workers_per_gpu,
num_gpus=1,
dist=True,
shuffle=True,
seed=None,
**kwargs):
"""Build PyTorch DataLoader.
In distributed training, each GPU/process has a dataloader.
In non-distributed training, there is only one dataloader for all GPUs.
Args:
dataset (Dataset): A PyTorch dataset.
imgs_per_gpu (int): Number of images on each GPU, i.e., batch size of
each GPU.
workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU.
num_gpus (int): Number of GPUs. Only used in non-distributed training.
dist (bool): Distributed training/test or not. Default: True.
shuffle (bool): Whether to shuffle the data at every epoch.
Default: True.
kwargs: any keyword argument to be used to initialize DataLoader
Returns:
DataLoader: A PyTorch dataloader.
"""
rank, world_size = get_dist_info()
if dist:
# DistributedGroupSampler will definitely shuffle the data to satisfy
# that images on each GPU are in the same group
if shuffle:
sampler = DistributedGroupSampler(dataset, imgs_per_gpu,
world_size, rank)
else:
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=False)
batch_size = imgs_per_gpu
num_workers = workers_per_gpu
else:
sampler = GroupSampler(dataset, imgs_per_gpu) if shuffle else None
batch_size = num_gpus * imgs_per_gpu
num_workers = num_gpus * workers_per_gpu
init_fn = partial(
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None
data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=imgs_per_gpu),
pin_memory=False,
worker_init_fn=init_fn,
**kwargs)
return data_loader
最终返回的是一个 Pytorch DataLoader
。
训练过程是借助 mmcv 库中的 Runner()
类来实现的,它位于 mmcv/mmcv/runner/runner.py
。Runner()
类接收初始参数如下:
class Runner(object):
"""A training helper for PyTorch.
Args:
model (:obj:`torch.nn.Module`): The model to be run.
batch_processor (callable): A callable method that process a data
batch. The interface of this method should be
`batch_processor(model, data, train_mode) -> dict`
optimizer (dict or :obj:`torch.optim.Optimizer`): If it is a dict,
runner will construct an optimizer according to it.
work_dir (str, optional): The working directory to save checkpoints
and logs.
log_level (int): Logging level.
logger (:obj:`logging.Logger`): Custom logger. If `None`, use the
default logger.
meta (dict | None): A dict records some import information such as
environment info and seed, which will be logged in logger hook.
"""
def __init__(self,
model,
batch_processor,
optimizer=None,
work_dir=None,
log_level=logging.INFO,
logger=None,
meta=None):
其中,batch_processor
负责模型的正向传播并将 loss 包装成合适的输出。
Runner()
在进行实例化之后,最终是通过调用 run()
方法来启动训练过程的:
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
run()
方法的源码如下:
def run(self, data_loaders, workflow, max_epochs, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
max_epochs (int): Total training epochs.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
self._max_epochs = max_epochs
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
self.call_hook('before_run')
while self.epoch < max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if isinstance(mode, str): # self.train()
if not hasattr(self, mode):
raise ValueError(
'runner has no method named "{}" to run an epoch'.
format(mode))
epoch_runner = getattr(self, mode)
elif callable(mode): # custom train()
epoch_runner = mode
else:
raise TypeError('mode in workflow must be a str or '
'callable function, not {}'.format(
type(mode)))
for _ in range(epochs):
if mode == 'train' and self.epoch >= max_epochs:
return
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
epoch_runner
等于 train
或 val
,因此
epoch_runner(data_loaders[i], **kwargs)
等价于
train(data_loaders[i], **kwargs) 或 val(data_loaders[i], **kwargs)
train()
和 val()
的源码如下:
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(data_loader)
self.call_hook('before_train_epoch')
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
outputs = self.batch_processor(
self.model, data_batch, train_mode=True, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('batch_processor() must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
with torch.no_grad():
outputs = self.batch_processor(
self.model, data_batch, train_mode=False, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('batch_processor() must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
从这里可以看到,最终确实是通过 batch_processor()
来完成模型的正向传播的。
至此,整个训练流程的代码基本都已经走通了。