1.写在前面
最近组里有个项目与目标识别有关,去网上找了一下,发现目前SOTA的目标识别算法基本都是one-stage的,比如SSD、DSSD、RetinaNet、YOLO等,但是速度上YOLO是最快的。而且看了下YOLO主页,作者的风格我很喜欢。所以仔细研究了一下。本文的内容基于GluonCV、OpenCV和YoloV3,运行平台为Ubuntu16.04版本。ps:因为组里采购的服务器还没到,目前只能在我自己笔记本的虚拟机上跑,而虚拟机的显卡是模拟出来的,无法安装CUDA和CUDNN(这个坑也是我安装CUDA遇到了各种坑后发现的),各位有条件的还是使用CUDA+CUDNN环境,速度会快不少。
2.环境搭建
2.1 GluonCV
GuonCV是一个计算机视觉深度学习的工具箱,功能非常强大,包含了图像分类,目标识别,语义分割,实例分割等。GluonCV的安装在他们主页上面有介绍,安装很简单,python2和python3都可以,但是你的pip版本要大于9.0,同时还要安装一个mxnet框架。同时他们主页还提供了一些简单的demo教你使用,还可以查询API的源代码。
2.2 OpenCV
OpenCV是一个用于图像处理、分析、机器视觉方面的开源函数库. 无论你是做科学研究,还是商业应用,OpenCV都可以作为你理想的工具库,因为,对于这两者,它完全是免费的。该库采用C及C++语言编写,可以在windows, linux, mac OSX系统上面运行。该库的所有代码都经过优化,计算效率很高,因为,它更专注于设计成为一种用于实时系统的开源库。opencv采用C语言进行优化,而且,在多核机器上面,其运行速度会更快。它的一个目标是提供友好的机器视觉接口函数,从而使得复杂的机器视觉产品可以加速面世。该库包含了横跨工业产品检测、医学图像处理、安防、用户界面、摄像头标定、三维成像、机器视觉等领域的超过500个接口函数。
OpenCV安装很简单,直接pip install opencv-python
即可。你也可以使用源代码安装,官网的下载速度很痛苦,我给个OpenCV3.4.7版本的链接,需要的朋友可以自取:
https://pan.baidu.com/s/1Zts9WR7VtH-2L0e9fIaNHw
提取码:498k
源码的安装教程网上很多,我贴一个别人https://jingyan.baidu.com/article/a3761b2be162951576f9aace.html,需要安装cmake工具,没有安装的直接apt install cmake
就可以了。
2.3 YoloV3
YoloV3在他们主页有很详细的教程(基于darknet),有兴趣可以去看下他们的论文,写的很有趣,传统的识别方法是当做一个分类问题,而作者当做一个回归问题来处理,同时并不像传统算法那样需要很多滑动窗口,他是end to end直接输出结果,这也是他们的名字YOLO(you only look once)的由来。同时推荐新手使用darknet,他是一个很轻量级的框架,但是内容很多,且易于上手。
3.代码
代码主要分为三个模块,utils模块,detection模块和main模块。
3.1 utils模块
utils模块包括data_preset.py,yolov3.py,bbox.py等文件
[图片上传失败...(image-4f99d3-1569487002238)]
3.2 detection模块
detection模块包括model,mobilefacedetnet.py等文件
[图片上传失败...(image-7077ae-1569487002238)]
3.3 main模块
main模块包括cap.py函数,其实就是执行函数。使用python3 cap.py
执行就行。ps:我设置了一些命令行参数,比如--video
选择本地视频,--camera
选择摄像头,--gpu
选择是否使用GPU。大家可以使用python3 cap.py -h
查看使用方法,比如
[图片上传失败...(image-4ec4d7-1569487002238)]
cap.py代码如下:
from mxnet import nd
import gluoncv as gcv
from mxnet.gluon.nn import BatchNorm
from gluoncv.data.transforms import presets
from matplotlib import pyplot as plt
sys.path.append(os.path.abspath(os.path.dirname(__file__)) + os.sep + '../MobileFace_Detection/utils/')
from data_presets import data_trans
sys.path.append(os.path.abspath(os.path.dirname(__file__)) + os.sep + '../MobileFace_Detection/')
from mobilefacedetnet import mobilefacedetnet_v2
sys.path.append(os.path.abspath(os.path.dirname(__file__)) + os.sep + '../MobileFace_Tracking/')
from mobileface_sort_v1 import Sort
def parse_args():
parser = argparse.ArgumentParser(description='Test with YOLO networks.')
parser.add_argument('--model', type=str,
default='../MobileFace_Detection/model/mobilefacedet_v2_gluoncv.params',
help='Pretrained model path.')
parser.add_argument('--video', type=str, default='friends1.mp4',
help='Test video path.')
parser.add_argument('--camera', type=int, default=None,
help='Camera select')
parser.add_argument('--gpus', type=str, default='',
help='Default is cpu , you can specify 1,3 for example with GPUs.')
parser.add_argument('--pretrained', type=str, default='True',
help='Load weights from previously saved parameters.')
parser.add_argument('--thresh', type=float, default=0.5,
help='Threshold of object score when visualize the bboxes.')
parser.add_argument('--sort_max_age', type=int, default=10,
help='Threshold of object score when visualize the bboxes.')
parser.add_argument('--sort_min_hits', type=int, default=3,
help='Threshold of object score when visualize the bboxes.')
parser.add_argument('--output', type=str,
default='./tracking_result/result_friends1_tracking.avi',
help='Output video path and name.')
args = parser.parse_args()
return args
def main():
args = parse_args()
# context list
ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
ctx = [mx.cpu()] if not ctx else ctx
net = mobilefacedetnet_v2(args.model)
net.set_nms(0.45, 200)
net.collect_params().reset_ctx(ctx = ctx)
mot_tracker = Sort(args.sort_max_age, args.sort_min_hits)
img_short = 256
colors = np.random.rand(32, 3) * 255
winName = 'MobileFace for face detection and tracking'
cv2.namedWindow(winName, cv2.WINDOW_NORMAL)
if args.camera == None:
cap = cv2.VideoCapture(args.video)
else:
cap = cv2.VideoCapture(args.camera)
output_video = args.output
# video_writer = cv2.VideoWriter(output_video, cv2.VideoWriter_fourcc('M','J','P','G'), 30, (round(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),round(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))))
video_writer = cv2.VideoWriter(output_video, cv2.VideoWriter_fourcc('M','J','P','G'), 30, (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))))
# while(cap.isOpened()):
while cv2.waitKey(1) < 0:
ret, frame = cap.read()
if not ret:
print("Done processing !!!")
print("Output file is stored as ", output_video)
cv2.waitKey(3000)
break
dets = []
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_nd = nd.array(frame_rgb)
x, img = data_trans(frame_nd, short=img_short)
x = x.as_in_context(ctx[0])
# ids, scores, bboxes = [xx[0].asnumpy() for xx in net(x)]
tic = time.time()
result = net(x)
toc = time.time() - tic
#print('Detection inference time:%fms' % (toc*1000))
ids, scores, bboxes = [xx[0].asnumpy() for xx in result]
h, w, c = frame.shape
scale = float(img_short) / float(min(h, w))
for i, bbox in enumerate(bboxes):
if scores[i]< args.thresh:
continue
xmin, ymin, xmax, ymax = [int(x/scale) for x in bbox]
# result = [xmin, ymin, xmax, ymax, ids[i], scores[i]]
result = [xmin, ymin, xmax, ymax, ids[i]]
dets.append(result)
dets = np.array(dets)
tic = time.time()
trackers = mot_tracker.update(dets)
toc = time.time() - tic
#print('Tracking time:%fms' % (toc*1000))
for d in trackers:
color = (int(colors[int(d[4]) % 32, 0]), int(colors[int(d[4]) % 32,1]), int(colors[int(d[4]) % 32, 2]))
cv2.rectangle(frame, (int(d[0]), int(d[1])), (int(d[2]), int(d[3])), color, 3)
# cv2.putText(frame, str('%s%0.2f' % (net.classes[int(d[4])], d[5])),
# (d[0], d[1] - 5), cv2.FONT_HERSHEY_COMPLEX , 0.8, color, 2)
cv2.putText(frame, str('%s%d' % ('face', d[4])),
(int(d[0]), int(d[1]) - 5), cv2.FONT_HERSHEY_COMPLEX , 0.8, color, 2)
video_writer.write(frame.astype(np.uint8))
cv2.imshow(winName, frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
warnings.filterwarnings("ignore")
main()
4.后续
项目我会放到我的GitHub上,更新了会告诉大家,如果有想要的可以联系我maplect@sina.com,我看到会发给你。