深度学习训练中的数据增强

数据增强(Data Augmentation),也叫数据增广,是机器学习和深度学习中的一种技术,它通过转换数据来扩充训练样本,提高训练样本的多样性和数量。数据增强的主要目的是创建一个更加丰富和多样化的训练数据集,这有助于模型学习到更加泛化的的特征,提高模型的鲁棒性。在计算机视觉领域,主要的数据增强方法有随机裁剪、旋转、翻转、缩放、平移、亮度调整、对比度调整、添加噪声等。这些方法能够有效的扩充训练数据集,使得模型对于不同的尺度、角度和光照条件下的图像具有更好的识别能力。下图就是一副遥感影像数据增加旋转和翻转等操作进行数据增强的示例。



为了比较数据增强的效果,我做了一个试验,把一个遥感影像样本库中的数据用最简单语义分割网络FCN来尝试做语义分割。原始训练样本库如下:



其中,红色的为道路,白色的为建筑,绿色为耕地,天蓝色为水系,黄色为草地,蓝色为园地。默认的样本库共有1673幅遥感影像以及预期对应的1673幅标签图像数据。
首先来看一下数据增强的实现方法。Pytorch为我们提供了很多数据增强的实现方法,而且借助torchvision的transforms方法,可以很方便的实现数据增强。一般来说,在transforms中定义需要做的数据增强方法,再用Compse方法组合起来,在数据的Dataset中进行数据增强的处理。如Pytorch官方文档的介绍:
import torch
from torchvision.transforms import v2

H, W = 32, 32
img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)

transforms = v2.Compose([
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
img = transforms(img)

这段代码定义了一个数据增强方法,对数据进行随机裁剪,并在水平方向随机翻转,转成浮点型的数据类型,并做了一个归一化的操作,再由Compose方法组合在一起。当然这个是transforms的新版本,如果用老版本的也是可以的,取决你按照的torch和torchvision版本。
可以参考如下的官方链接,我这里就不逐个解读了:
新版:Transforming and augmenting images — Torchvision 0.20 documentation (pytorch.org)
老版:torchvision.transforms — Torchvision 0.11.0 documentation (pytorch.org)
通过数据增强的方法,可以在Dataset的getitem方法中进行数据增强处理,完整的Dataset代码如下:

import os  
from PIL import Image  
from torch.utils.data import Dataset, DataLoader  
from torchvision import transforms  
import torch
import cv2
import numpy as np
import random

seed = np.random.randint(2147483647)  # make a seed with numpy generator
random.seed(seed) 

def read_samples(sample_dir, is_train=True):
    # 读取样本库图像并标注
    feature_list, label_list = [], []
    for file_name in os.listdir(sample_dir+'//img2_1024'):
        feature_list.append((os.path.join(sample_dir, 'img2_1024', f'{file_name}')))
        label_list.append((os.path.join(sample_dir, 'mask_1024', f'{file_name}')))
    return feature_list, label_list

class Js08mDataset(torch.utils.data.Dataset):
    def __init__(self, is_train, sample_dir, transform_img=None, transform_mask=None):
        self.transform_img = transform_img
        self.transform_mask = transform_mask
        self.feature_list, self.label_list = read_samples(sample_dir)
        print('read ' + str(len(self.feature_list)) + ' examples')

    def normalize_image(self, img):
        return img.float() / 255

    def __getitem__(self, idx):
        img = cv2.imread(self.feature_list[idx], -1)
        mask = cv2.imread(self.label_list[idx], -1)
        img = torch.tensor(img)
        img = img.permute(2, 0, 1)
        mask = torch.tensor(mask)
        mask = mask.unsqueeze(0)
        img = self.normalize_image(img)
        if self.transform_img:
            torch.random.manual_seed(seed)
            img = self.transform_img(img)
        if self.transform_mask:
            torch.random.manual_seed(seed)
            mask = self.transform_mask(mask)
        mask = mask.squeeze(0)
        return img, mask

    def __len__(self):
        return len(self.feature_list)
        
def load_data_voc(batch_size):
    sample_dir = 'D:\\zj\\sample\\js08m'
    rot_degree = random.choice([0, 90, 180, 270])
    print("rot_degree:{}".format(rot_degree))
    transform_img = transforms.Compose([  
        transforms.RandomCrop(384),  # 随机裁剪
        transforms.RandomRotation((rot_degree, rot_degree))
    ])
    transform_mask = transforms.Compose([  
        transforms.RandomCrop(384),  # 随机裁剪  
        transforms.RandomRotation((rot_degree, rot_degree))
    ])
    num_workers = 1
    train_iter = torch.utils.data.DataLoader(Js08mDataset(True, sample_dir, transform_img, transform_mask), batch_size, shuffle=True,
                                             drop_last=True, num_workers=num_workers)
    test_iter = torch.utils.data.DataLoader(Js08mDataset(False, sample_dir, transform_img, transform_mask), batch_size, drop_last=True,
                                            num_workers=num_workers)
    return train_iter, test_iter

if __name__ == "__main__":
    import matplotlib.pyplot as plt
    train_iter, test_iter = load_data_voc(batch_size=4)
    for index, (image, label) in enumerate(train_iter):
        image = image.numpy().transpose((0, 2, 3, 1))
        print(image[0].shape)
        img = image[0]
        img_uint8 = (img * 255).astype(np.uint8)  
        cv2.imwrite("1.png", img_uint8)
        mask = label.numpy()
        print(mask.shape)
        cv2.imwrite("1.tif", mask[0])
        break
#输出
rot_degree:270
read 1673 examples
read 1673 examples
(384, 384, 3)
(4, 384, 384)

代码本身不难,要注意的是,程序开始要设置一个统一的随机种子,如此,在数据增强处理中,影像和标签的随机裁剪和随机旋转会做同样的处理,否则可能影像的裁剪范围和旋转角度和标签图像的裁剪范围和旋转角度不同,这样就无法一一对应,训练时候当然就会出问题,从而导致无法好好的训练,因此一定要先设置一个随机种子,使得影像和标签同步增强变换。当然,如果是做图像分类的话则不需要这么做,因为图像分类的标签仅仅是一个数字,用来判断是哪个类别,那么需要做数据增强就是影像本身,不过如果是语义分割和实例分割等应用,那么标签图像也要做同步变换才可以。
数据增强后的结果如下,随机从原始数据中裁剪384*384大小的区域,同时在[0,90,180,270]四个角度中随机选择一个角度进行旋转。



之后,我们可以利用做过数据增强的样本库进行训练,为了方便比较,我用同样的模型分别在未做数据增强和做了数据增强的数据集上进行训练。
先来看未作数据增强的样本的训练结果:

2024-10-12 09:53:19,837 - __main__ - DEBUG - epoch0----> train loss = 1.075703, Time 00:02:23
2024-10-12 09:53:19,838 - __main__ - DEBUG - epoch0----> acc = 0.651488
2024-10-12 09:55:41,606 - __main__ - DEBUG - epoch1----> train loss = 0.852746, Time 00:02:21
2024-10-12 09:55:41,613 - __main__ - DEBUG - epoch1----> acc = 0.731762
2024-10-12 09:58:02,680 - __main__ - DEBUG - epoch2----> train loss = 0.795102, Time 00:02:21
2024-10-12 09:58:02,687 - __main__ - DEBUG - epoch2----> acc = 0.751905
2024-10-12 10:00:22,949 - __main__ - DEBUG - epoch3----> train loss = 0.748344, Time 00:02:20
2024-10-12 10:00:22,955 - __main__ - DEBUG - epoch3----> acc = 0.765419
2024-10-12 10:02:43,191 - __main__ - DEBUG - epoch4----> train loss = 0.705971, Time 00:02:20
2024-10-12 10:02:43,197 - __main__ - DEBUG - epoch4----> acc = 0.776413
2024-10-12 10:05:02,869 - __main__ - DEBUG - epoch5----> train loss = 0.707136, Time 00:02:19
2024-10-12 10:05:02,875 - __main__ - DEBUG - epoch5----> acc = 0.775926
......
2024-10-12 17:34:02,862 - __main__ - DEBUG - epoch199----> train loss = 0.348782, Time 00:02:21
2024-10-12 17:34:02,865 - __main__ - DEBUG - epoch199----> acc = 0.883190

可以看到,训练200个epoch,最终总体精度达到88.3%。
下面看看做了数据增强的样本库的训练结果:

2024-10-17 08:54:33,525 - __main__ - DEBUG - epoch0----> train loss = 1.262130, Time 00:01:08
2024-10-17 08:54:33,526 - __main__ - DEBUG - epoch0----> acc = 0.553920
2024-10-17 08:55:43,016 - __main__ - DEBUG - epoch1----> train loss = 0.932994, Time 00:01:09
2024-10-17 08:55:43,016 - __main__ - DEBUG - epoch1----> acc = 0.703861
2024-10-17 08:56:50,672 - __main__ - DEBUG - epoch2----> train loss = 0.840231, Time 00:01:07
2024-10-17 08:56:50,672 - __main__ - DEBUG - epoch2----> acc = 0.737803
2024-10-17 08:57:58,084 - __main__ - DEBUG - epoch3----> train loss = 0.786046, Time 00:01:07
2024-10-17 08:57:58,084 - __main__ - DEBUG - epoch3----> acc = 0.754506
2024-10-17 08:59:05,808 - __main__ - DEBUG - epoch4----> train loss = 0.753338, Time 00:01:07
2024-10-17 08:59:05,809 - __main__ - DEBUG - epoch4----> acc = 0.762777
2024-10-17 09:00:14,201 - __main__ - DEBUG - epoch5----> train loss = 0.736403, Time 00:01:08
2024-10-17 09:00:14,202 - __main__ - DEBUG - epoch5----> acc = 0.767858
......
2024-10-17 12:40:56,003 - __main__ - DEBUG - epoch199----> train loss = 0.051287, Time 00:01:07
2024-10-17 12:40:56,004 - __main__ - DEBUG - epoch199----> acc = 0.978977

同样训练200个epoch,最终总体精度达到97.9%,当然这个我觉得有点太高了,可能训练迭代太多了,有点过拟合了,但总的来说,进行了数据增强后的训练效果是有明显提升的。
测试结果图如下,这里灰色是道路,深红色是建筑,绿色是耕地,蓝色是水系,粉红色是园地,黄色是草地,由于园地和草地样本较少,所以测试效果略差。不过总体效果还行,细节部分不清晰,有些是由于有点过拟合,有些是由于FCN模型比较基础,对于特征提取能力还不够强,如果换成其他效果比较好的模型,测试的效果还会更好。


©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 206,723评论 6 481
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 88,485评论 2 382
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 152,998评论 0 344
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 55,323评论 1 279
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 64,355评论 5 374
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,079评论 1 285
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,389评论 3 400
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,019评论 0 259
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,519评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,971评论 2 325
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,100评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,738评论 4 324
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,293评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,289评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,517评论 1 262
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,547评论 2 354
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,834评论 2 345

推荐阅读更多精彩内容