FCN在VOC数据集的实践

本次参考《动手学深度学习》(此书用的是MXNet,本次实践使用的是pytorch框架)进行FCN在VOC2012数据集上的实践。
首先需要下载VOC数据集,书中已经讲下载已经封装好了,直接调用一下函数即可。当然不想装MXNet也可以很容易通过百度下载。

from mxnet.gluon import data as gdata, utils as gutils
import os
import sys
import tarfile

def download_voc_pascal(data_dir='../data'):
    voc_dir = os.path.join(data_dir, 'VOCdevkit/VOC2012')
    url = ('http://host.robots.ox.ac.uk/pascal/VOC/voc2012'
           '/VOCtrainval_11-May-2012.tar')
    sha1 = '4e443f8a2eca6b1dac8a6c57641b67dd40621a49'
    fname = gutils.download(url, data_dir, sha1_hash=sha1)
    with tarfile.open(fname, 'r') as f:
        f.extractall(data_dir)
    return voc_dir

voc_dir = download_voc_pascal()

数据集将会放置在../data/VOCdevkit/VOC2012路径下。进入../data/VOCdevkit/VOC2012路径后,我们可以获取数据集的不同组成部分。其中ImageSets/Segmentation路径包含了指定训练和测试样本的文本文件,而JPEGImages和SegmentationClass路径下分别包含了样本的输入图像和标签。这里的标签也是图像格式,其尺寸和它所标注的输入图像的尺寸相同。标签中颜色相同的像素属于同一个语义类别。下面定义read_images函数将输入图像和标签全部读进内存。

voc_root = '../data/VOCdevkit/VOC2012'

# 读取图片和标签路径成为列表
def read_images(root=voc_root, train=True):
    txt_fname = root + '/ImageSets/Segmentation/' + ('train.txt' if train else 'val.txt')
    with open(txt_fname, 'r') as f:
        images = f.read().split()
    data = [os.path.join(root, 'JPEGImages', i+'.jpg') for i in images]
    label = [os.path.join(root, 'SegmentationClass', i+'.png') for i in images]
    return data, label

在标签图像中,白色和黑色分别代表边框和背景,而其他不同的颜色则对应不同的类别。
接下来,我们列出标签中每个RGB颜色的值及其标注的类别。

classes = ['background','aeroplane','bicycle','bird','boat',
           'bottle','bus','car','cat','chair','cow','diningtable',
           'dog','horse','motorbike','person','potted plant',
           'sheep','sofa','train','tv/monitor']
colormap = [[0,0,0],[128,0,0],[0,128,0], [128,128,0], [0,0,128],
            [128,0,128],[0,128,128],[128,128,128],[64,0,0],[192,0,0],
            [64,128,0],[192,128,0],[64,0,128],[192,0,128],
            [64,128,128],[192,128,128],[0,64,0],[128,64,0],
            [0,192,0],[128,192,0],[0,64,128]]

有了上面定义的两个常量以后,我们可以很容易地查找标签中每个像素的类别索引。

cm2lbl = np.zeros(256**3) # 256**3个色彩
for i,cm in enumerate(colormap): # 将其中21个色彩对应索引
    cm2lbl[(cm[0]*256+cm[1])*256+cm[2]] = i

# 图像转索引矩阵
def image2label(im):
    data = np.array(im, dtype='int32')
    idx = (data[:, :, 0] * 256 + data[:, :, 1]) * 256 + data[:, :, 2]
    return np.array(cm2lbl[idx], dtype='int64')

我们可以测试一下:


可以看到像素[128,0,0]对应的索引为(128256+0)256+0,索引出的值为1,然后再通过1对classes进行索引,得到‘aeroplane’。
函数image2label()将一个图像标签转换成一个二维矩阵,其中每个值的取值范围为0-20,代表着它是属于21个类别中的哪一类像素。

在之前图像分类中,我们通过缩放图像使其符合模型的输入形状。然而在语义分割里,这样做需要将预测的像素类别重新映射回原始尺寸的输入图像。这样的映射难以做到精确,尤其在不同语义的分割区域。为了避免这个问题,我们将图像裁剪成固定尺寸而不是缩放。具体来说,我们使用随机裁剪,并对输入图像和标签裁剪相同区域。
貌似pytorch里没有这个实现方法,这里用下面的函数进行图像和标签的随机裁剪。

import random
# 随机裁剪图像和标签
def rand_crop(data,label, height, width):
    
    x1 = random.randint(0, data.size[0] - width)
    y1 = random.randint(0, data.size[1] - height)
    x2 = x1 + width
    y2 = y1 + height
            
    data=data.crop((x1, y1, x2, y2))
    label=label.crop((x1, y1, x2, y2))
 
    return data,label

运行此代码两次查看效果:

data = Image.open('../data/VOCdevkit/VOC2012/JPEGImages/2007_000032.jpg')
label = Image.open('../data/VOCdevkit/VOC2012/SegmentationClass/2007_000032.png').convert('RGB')
data, label = rand_crop(data, label, 200, 300)
plt.subplot(1,2,1)
plt.imshow(data)
plt.subplot(1,2,2)
plt.imshow(label)

自定义数据集

def img_transforms(im, label, crop_size):
    im, label = rand_crop(im, label, *crop_size)
    im_tfs = tfs.Compose([
        tfs.ToTensor(),
        tfs.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    im = im_tfs(im)
    label = image2label(label)
    label = torch.from_numpy(label)
    return im, label

class VOCSegDataset(Dataset):
    '''
    voc dataset
    '''
    def __init__(self, train, crop_size, transforms):
        self.crop_size = crop_size
        self.transforms = transforms
        data_list, label_list = read_images(train=train)
        self.data_list = self._filter(data_list)
        self.label_list = self._filter(label_list)
        print('Read ' + str(len(self.data_list)) + ' images')
        
    def _filter(self, images): # 过滤掉图片大小小于 crop 大小的图片
        return [im for im in images if (Image.open(im).size[1] >= self.crop_size[0] and 
                                        Image.open(im).size[0] >= self.crop_size[1])]
        
    def __getitem__(self, idx):
        img = self.data_list[idx]
        label = self.label_list[idx]
        img = Image.open(img)
        label = Image.open(label).convert('RGB')
        img, label = self.transforms(img, label, self.crop_size)
        return img, label
    
    def __len__(self):
        return len(self.data_list)

读取数据集:

input_shape = (320, 480)
voc_train = VOCSegDataset(True, input_shape, img_transforms)
voc_test = VOCSegDataset(False, input_shape, img_transforms)

train_data = DataLoader(voc_train, 16, shuffle=True, num_workers=1)
valid_data = DataLoader(voc_test, 16, num_workers=1)

FCN模型在此笔记已经实现了。
我们只需要将网络实例化,在定义谢谢参数即可进行训练了。

device = torch.device('cuda')
net = fcn(num_classes)
criterion = nn.NLLLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=1e-3, weight_decay=1e-4)

好了,开始训练!

net = net.to(device)

for e in range(80):
    print('epoch:{}/80'.format(e))
    print('-' * 20)
    train_loss = 0
    train_acc = 0
    train_acc_cls = 0
    train_mean_iu = 0
    train_fwavacc = 0
    
    prev_time = datetime.now()
    net = net.train()
    for data, label in train_data:
        im = data.to(device)
        label = label.to(device)
#         print(im.shape, label.shape)
        # forward
        out = net(im)
#         print(out.shape, type(out))
        loss = criterion(nn.LogSoftmax(dim=1)(out), label)
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        
        label_pred = out.max(dim=1)[1].data.cpu().numpy()
        label_true = label.data.cpu().numpy()
        for lbt, lbp in zip(label_true, label_pred):
            acc, acc_cls, mean_iu, fwavacc = label_accuracy_score(lbt, lbp, num_classes)
            train_acc += acc
            train_acc_cls += acc_cls
            train_mean_iu += mean_iu
            train_fwavacc += fwavacc
    print('Train Loss: {:.5f}, Train Acc: {:.5f}, Train Mean IU: {:.5f}'.format(train_loss / len(train_data), train_acc / len(voc_train), train_mean_iu / len(voc_train)))   
    net = net.eval()
    eval_loss = 0
    eval_acc = 0
    eval_acc_cls = 0
    eval_mean_iu = 0
    eval_fwavacc = 0
    for data, label in valid_data:
        im = data.to(device)
        label = label.to(device)
        # forward
        out = net(im)
        loss = criterion(nn.LogSoftmax(dim=1)(out), label)
        eval_loss += loss.item()
        
        label_pred = out.max(dim=1)[1].data.cpu().numpy()
        label_true = label.data.cpu().numpy()
        for lbt, lbp in zip(label_true, label_pred):
            acc, acc_cls, mean_iu, fwavacc = label_accuracy_score(lbt, lbp, num_classes)
            eval_acc += acc
            eval_acc_cls += acc_cls
            eval_mean_iu += mean_iu
            eval_fwavacc += fwavacc                                                              
    print('Valid Loss: {:.5f}, Valid Acc: {:.5f}, Valid Mean IU: {:.5f} '.format(eval_loss / len(valid_data), eval_acc / len(voc_test), eval_mean_iu / len(voc_test)))    
    cur_time = datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = 'Time: {:.0f}:{:.0f}:{:.0f}'.format(h, m, s)
    print(time_str)
    print()

部分训练日志如下:

epoch:76/80
--------------------
Train Loss: 0.31527, Train Acc: 0.89842, Train Mean IU: 0.54611
Valid Loss: 0.42469, Valid Acc: 0.86795, Valid Mean IU: 0.50423 
Time: 0:0:45

epoch:77/80
--------------------
Train Loss: 0.31699, Train Acc: 0.89788, Train Mean IU: 0.55095
Valid Loss: 0.42091, Valid Acc: 0.87019, Valid Mean IU: 0.51235 
Time: 0:0:45

epoch:78/80
--------------------
Train Loss: 0.32034, Train Acc: 0.89675, Train Mean IU: 0.54414
Valid Loss: 0.42173, Valid Acc: 0.86962, Valid Mean IU: 0.51024 
Time: 0:0:45

可以看到经过80次迭代,验证集的meanIOU为51%左右。
最后可视化一下结果。第一列为原始图像,第二列为图像标签,第三列为预测的结果。


©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,590评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 86,808评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,151评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,779评论 1 277
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,773评论 5 367
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,656评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,022评论 3 398
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,678评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 41,038评论 1 299
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,659评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,756评论 1 330
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,411评论 4 321
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,005评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,973评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,203评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,053评论 2 350
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,495评论 2 343

推荐阅读更多精彩内容