Dataloader重要参数与内部机制

@[TOC]

一、pytorch数据输入

Dataset负责生产数据,DataLoader负责数据的分批(batch_size)、采样(sampler)、传输
Pytorch版本:1.0.1

1. Dataset

继承torch.utils.data.Dataset,实现两个函数即可:

  • def len(self) 数据总数
  • def getitem(self, index) 根据下标获取其中一条数据

2. DataLoader

将Dataset作为参数,构造一个torch.utils.data.DataLoader对象即可。
DataLoader其他参数见下文。

二、Dataloader参数汇总

  • dataset(Dataset):
    传入的数据集

  • batch_size(int, optional):
    每个batch有多少个样本

  • shuffle(bool, optional):
    在每个epoch开始的时候,对数据进行重新打乱

  • sampler(Sampler, optional):
    自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False

  • batch_sampler(Sampler, optional):
    与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)

  • num_workers (int, optional):
    这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)

  • collate_fn (callable, optional):
    将一个list的sample组成一个mini-batch的函数

  • pin_memory (bool, optional):
    如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.

  • drop_last (bool, optional):
    如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了…
    如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。

  • timeout(numeric, optional):
    如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0

  • worker_init_fn (callable, optional):
    每个worker初始化函数 If not None, this will be called on each
    worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

2.1 sampler:分布式训练需DistributedSampler

train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)

DataLoader构造函数中相关代码:

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)  ##如果shuffer就随机  
                else:
                    sampler = SequentialSampler(dataset)  ##否则顺序采样  
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler  

batch_sampler是sampler的封装,可自定义批次数据的构造。默认BatchSampler相关源码:

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)      ##遍历sampler获取数据,满batch_size就yield  
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

2.2 collate_fn:将batch的数据重新组装

例如cirtorch中将数据拆成input_data和target两个数据。
因Dataset中get_item返回input_data和target两个值,如果不用该函数,每个batch的数据应该是[batch_size,2(先input_data再target),,,],经过该函数将变成([batch_size,,,],[batch_size,,]),第一个数据全是input_data,第二个数据全是target。

2.3 pin_memory=True:提高数据从cpu到gpu传输效率

pin_memory可在cpu主存(内存)中分配不可交换到swap(缓存)的内存。。默认内存分配中的数据都可交换到swap中,那CUDA驱动会通过DRAM机制将数据从内存传到GPU显存时会复制2次(先复制到一临时不可见pinned固定内存,再往显存中复制),因此pin_memory=True可提高约2倍cpu到gpu传输效率(.cuda()或 .to(device)的时候)。相见CPU和GPU内存交互

【拓展】Elasticsearch中的Memlock(内存锁定)可申请固定大小且不可交换内存空间。

三、DataLoader的并行

# Our data model looks like this (queues are indicated with curly brackets):
    #
    #                main process                              ||
    #                     |                                    ||
    #               {index_queue}                              ||
    #                     |                                    ||
    #              worker processes                            ||     DATA
    #                     |                                    ||
    #            {worker_result_queue}                         ||     FLOW
    #                     |                                    ||
    #      pin_memory_thread of main process                   ||   DIRECTION
    #                     |                                    ||
    #               {data_queue}                               ||
    #                     |                                    ||
    #                data output                               \/
    #
    # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
    #      `pin_memory=False`.
    
  • 基于multiprocessing多进程
  • 每个子进程的输入输出,通过两个主要的队列(multiprocessing.Queue()): index_queue要处理的下标、worker_result_queue要返回的下标。
  • 每个worker一次产生一个batch的数据
  • 返回batch数据前放入下一个批次数据下标
  • 构造函数子进程初始化:
            self.index_queues = []
            self.workers = []
            for i in range(self.num_workers):
                index_queue = multiprocessing.Queue() # 1.每个子进程一个队列放要处理的下标
                index_queue.cancel_join_thread()
                w = multiprocessing.Process(
                    target=_utils.worker._worker_loop, # 每个子进程循环执行的函数  
                    args=(self.dataset, index_queue,
                          self.worker_result_queue, self.done_event, #2.self.worker_result_queue 多子进程公用要返回batch数据的队列  
                          self.collate_fn, base_seed + i,
                          self.worker_init_fn, i))
                w.daemon = True
                # NB: Process.start() actually take some time as it needs to
                #     start a process and pass the arguments over via a pipe.
                #     Therefore, we only add a worker to self.workers list after
                #     it started, so that we do not call .join() if program dies
                #     before it starts, and __del__ tries to join but will get:
                #     AssertionError: can only join a started process.
                w.start()
                self.index_queues.append(index_queue)
                self.workers.append(w)

3.1 index_queue 要处理的数据下标

每个worker有一个index_queue dataloader.py#L544-L552
每个worker从index_queue取要处理的下标 dataloader.py#L124
dataloader输出一次数据前先往index_queue中放一次下标, _process_next_batch函数:

    def _process_next_batch(self, batch):
        self.rcvd_idx += 1
        self._put_indices()  ## 先放下一批数据下标
        if isinstance(batch, ExceptionWrapper):
            raise batch.exc_type(batch.exc_msg)
        return batch         ## 再返回该批数据

_put_indices依次往不同worker所属的index_queue中放 dataloader.py#L644-L652

完整的dataloader next函数:

    def __next__(self):
        if self.num_workers == 0:  # same-process loading
            indices = next(self.sample_iter)  # may raise StopIteration
            batch = self.collate_fn([self.dataset[i] for i in indices])
            if self.pin_memory:
                batch = pin_memory_batch(batch)
            return batch

        # check if the next sample has already been generated
        if self.rcvd_idx in self.reorder_dict:
            batch = self.reorder_dict.pop(self.rcvd_idx)
            return self._process_next_batch(batch) ## 5. 之前以及取出来该下标数据,直接返回

        if self.batches_outstanding == 0:
            self._shutdown_workers()
            raise StopIteration

        while True:  ## 1.直到取的数据下标正确才return
            assert (not self.shutdown and self.batches_outstanding > 0)
            idx, batch = self._get_batch()  ## 2.从worker_result_queue中获取数据
            self.batches_outstanding -= 1
            if idx != self.rcvd_idx:
                # store out-of-order samples
                self.reorder_dict[idx] = batch  ## 3.下标不对先存一下
                continue
            return self._process_next_batch(batch) ## 4.内部先放下一批数据下标再返回batch数据  

3.2 worker_result_queue 返回结果

每个worker一直在执行的循环_worker_loop,其中worker_result_queue作为_worker_loop函数的data_queue传入(dataloader.py#L544-L552),相见:

def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id):
    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
    # logic of this function.

    try:
        global _use_shared_memory
        _use_shared_memory = True

        # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
        # module's handlers are executed after Python returns from C low-level
        # handlers, likely when the same fatal signal happened again already.
        # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
        _set_worker_signal_handlers()

        torch.set_num_threads(1)
        random.seed(seed)
        torch.manual_seed(seed)

        data_queue.cancel_join_thread()

        if init_fn is not None:
            init_fn(worker_id)

        watchdog = ManagerWatchdog()

        while watchdog.is_alive():
            try:
                r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) ##从index_queue中获取要处理的下标
            except queue.Empty:
                continue
            if r is None:
                # Received the final signal
                assert done_event.is_set()
                return
            elif done_event.is_set():
                # Done event is set. But I haven't received the final signal
                # (None) yet. I will keep continuing until get it, and skip the
                # processing steps.
                continue
            idx, batch_indices = r
            try:
                samples = collate_fn([dataset[i] for i in batch_indices]) ##1.根据下标取样本数据  
            except Exception:
                # It is important that we don't store exc_info in a variable,
                # see NOTE [ Python Traceback Reference Cycle Problem ]
                data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
            else: ## 2. 没有抛异常就将样本数据放入结果返回队列  
                data_queue.put((idx, samples))
                del samples
    except KeyboardInterrupt:
        # Main process will raise KeyboardInterrupt anyways.
        pass

参考文献

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,189评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,577评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,857评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,703评论 1 276
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,705评论 5 366
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,620评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,995评论 3 396
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,656评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,898评论 1 298
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,639评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,720评论 1 330
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,395评论 4 319
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,982评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,953评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,195评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 44,907评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,472评论 2 342

推荐阅读更多精彩内容