不知道大家在日常生活中有没有朋友或是妹子看见一辆车,突然问道:这是什么车?
或者我们逛论坛时经常会看见有人传一张照片然后问:这是什么车?
就像下图这样
其他人还好,如果是妹子问,你答不出来,咱面子上挂不住啊~
但是除了老司机,我们不可能每辆车都能说出名称来,尤其是仅仅凭借汽车内饰或者一些局部特写。
那么我们能不能制作一个汽车识别程序,让程序自动识别车型呢。
今天我决定通过卷积神经网络,从头实现一个汽车识别分类器。
大体流程如下
- 爬取数据、清洗数据
- 搭建模型、训练模型
- 模型识别汽车图像
先表明一下操作环境:
操作系统:ubuntu18.04
python版本:3.7
pytorch版本:1.2
GPU:1060
CUDA:10.1
另外:ubuntu自带的输入法真难用,如果有错别字请自行联想:)
爬取数据篇
这里爬取的目标是汽车之家图片板块下关注度排行的各类型(如微型车、小型车、SUV等)下的汽车各9种。合集应该是81种(但是我爬取了90种,可能是HTML里多了一类汽车的链接)
点击其中一个汽车(比如奔驰AMG GT),是此汽车的分类目录,如下图
将此分类下的图像全部爬取下来。以便网络可以识别汽车的不同角度以及各种局部信息。
爬取的结果如以下图展示:
可以看到图片数据全部放在了all_cars这个文件夹下,并且按车类放在了各自的文件夹里。一共16万多张图片,占用3.5G。
另外值得一提的是,为了节省抓取时间和训练时间以及储存空间~。~
我抓取的是展示用的缩略图(图像尺寸为200 * 180),并非为1024 * 768的原图。
下面贴出爬虫代码(car.py)
import requests
from lxml import etree
from os.path import join
import os
import time
base_url = 'https://car.autohome.com.cn'
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/69.0.3497.92 Safari/537.36'
}
# 获取html文档,方便下面的函数调用
def get_html(url):
r = requests.get(url, headers=headers)
r.encoding = r.apparent_encoding
html = etree.HTML(r.text)
return html
# 获取汽车的名称和图册链接
def get_car(url):
html = get_html(url)
links = html.xpath('//div[@id="levelContent"]/div/ul/li/a/@href')
for link in links:
link = base_url + link
get_calsses(link)
# 获取汽车的各分类的图册链接
def get_calsses(url):
html = get_html(url)
links = html.xpath('//ul[@class="search-pic-sortul"]/li/a/@href')
for link in links:
link = base_url+link
get_img_url(link)
# 获取图片地址
def get_img_url(url):
html = get_html(url)
links = html.xpath('//div[@class="uibox-con carpic-list03 border-b-solid"]/ul/li/a/img/@src')
name = html.xpath('//h2[@class="fn-left cartab-title-name"]/a/text()')[0]
for link in links:
link = 'https:' + link
# print(link)
save_img(name, link)
try:
next_page = html.xpath('//div[@class="page"]/a[last()]/@href')[0]
next_page = base_url + next_page
if 'html' in next_page:
get_img_url(next_page)
time.sleep(0.5)
except:
pass
# 保存图片
def save_img(name, url):
img_name = url.split('_')[-1]
img = requests.get(url, headers=headers)
root = join('cars', name)
filename = join(root, img_name)
if not os.path.exists(root):
os.makedirs(root)
with open(filename, 'wb') as f:
f.write(img.content)
print('正在下载:', filename)
if __name__ == '__main__':
url = 'https://car.autohome.com.cn/pic/index.html'
get_car(url)
print('下在完成!')
最后说一下,爬虫爬取下来的可能会有错误的,会给训练带来麻烦,这里再贴出清理错误的图片的代码。
import imghdr
import os
from torchvision.datasets import ImageFolder
data_set = ImageFolder('./all_cars')
print(data_set.imgs)
for img, _ in data_set.imgs:
img_type = imghdr.what(img)
if img_type == None:
os.remove(img)
print('已删除无效文件:', img)
OK, 数据爬取和清洗已经介绍完了,改天我换个搜狗输入法,继续~