前言
前面写了一篇利用TF在iOS实现类Prisma软件的文章后,收到很多网上朋友交流实现思想与求Demo回复,可见大家对于这个功能实现有浓厚的兴趣。
上一篇文章并未深入详解Google的实现的原理,仅仅是简单将参数和计算图在iOS设备上跑起来,且由于TF本身编译搭建工程就很复杂,所以未托管源码供下载。
本次主要是通过剖析Google论文中实现逻辑后,利用iOS新的Metal框架以及设备的GPU部分加速运算后实现了图片的类Prisma渲染(当然网络参数还是Google训练出来的)。期间跨过了很多坑,这里分享出来,希望对大家研究有所帮助😅。
结构
看过《A LEARNED REPRESENTATION FOR ARTISTIC STYLE》这篇论文的应该都还是比较熟悉Google提出的整体网络结构,如下图:
后面虚线部分是VGG网络,这部分训练方法和2015年《A Neural Algorithm of Artistic Style》这篇论文中一样,并非google此次优化的重点。
重点在前面Style transfer network这部分,我们的到的训练参数也都是这一层的参数。这是一个前向生成图片的网络,有了这么一个前向网络层,我们只需要训练好参数,生成图片就只需前向运算一遍就可以的到,这样相对于直接利用VGG网络来回学习生成图片节省了很大一部分时间。并且由于合成图片时间短,所以在移动设备上也是可以本地运行。
下面是Style transfer network的网络结构:
网络一共3个卷积+5个Residual Block+2个upsampling+1个卷积,其实Residual Block就是两次卷积,然后将输入与输出相加,upsampling是先用Nearest-Neighbor放大图片,然后卷积。所以一共有16层卷积操作,并且每次卷积操作后先Batch-normalization,再接激活函数(这里开始因为苹果MPSCNN库直接就可以卷积后带激活函数,所以在实现的时候我把BN放到了激活之后,生成图片就一直错误😭)。
Padding Mode 论文这里写的是Reflect,苹果Metal不支持这个padding方式,我自己写了一个😢,但最后发现其实用Zero Padding反而才是对的。不知道是否是苹果的卷积实现有不同,还是说这里Padding模式仅仅是针对训练的时候?这里后面有时间会再研究一下。
以上就是整个实现的核心网络结构,理论上我们有了参数,知道了网络实现,不用TF计算图,自己实现也是可以的。这样可以免去繁琐的TF集成,编译,并且自己的网络调试、控制内存等等都要方便很多。
但是,并非如此简单,苹果Metal框架很多深度神经网络的kernel都还没有,仅仅对卷积操作有部分封装。下面就分享实现过程中几个比较重要算法的实现。
Batch-Normalization
BN(Batch-Normalization)其实是这个网络的核心部分,每种不同的style图片就是在这里进行区分的,当你选择不同的style的时候,每层卷积操作是相同的,但是BN不同,就是改变最后生成图片的样式。
开始我期望Metal有BN的实现,但找了一圈没发现,考虑过写kernel这样和卷积操作都能在GPU上运算,但是最后发现kernel编码从头学太复杂,于是就在CPU上实现一个,每次卷积完过后,图片Copy出来在CPU上运算BN,然后再接激活函数(还是期待苹果后面能提供支持BN的kernel😊)。
实现代码如下:
- (void)batch_norm:(MPSImage *)image styles:(float *)styles shift:(float *)shift
{
NSUInteger w = image.texture.width;
NSUInteger h = image.texture.height;
NSUInteger featureNum = image.featureChannels;
float *gamma = calloc(featureNum, sizeof(float));
float *beta = calloc(featureNum, sizeof(float));
// float gamma[featureNum], beta[featureNum];
vDSP_mmul(styles, 1, shift, 1, beta, 1, 1, featureNum, styleNum);
vDSP_mmul(styles, 1, shift+featureNum*styleNum, 1, gamma, 1, 1, featureNum, styleNum);
// for (int i = 0; i < featureNum; i++) {
// printf("%f,%f ",gamma[i],beta[i]);
// }
// NSLog(@"%@",image);
//
NSUInteger numSlices = (featureNum + 3) / 4;
NSUInteger numComponents = featureNum < 3 ? featureNum : 4;
NSUInteger channels = featureNum < 3 ? featureNum : numSlices * 4;
float16_t *htemp = calloc(w*h*channels, sizeof(float16_t));
for (int i = 0; i < numSlices; i++) {
[image.texture getBytes:htemp+w*h*numComponents*i bytesPerRow:w*numComponents*2 bytesPerImage:0 fromRegion:MTLRegionMake3D(0, 0, 0, w, h, 1) mipmapLevel:0 slice:i];
}
float *temp = calloc(w*h*channels, sizeof(float));
[self halfTofloat:htemp floatp:temp width:w height:h channel:channels];
float mean, var;
for (int i = 0; i < featureNum; i++) {
int slice = i / 4;
int stride = i % 4;
vDSP_normalize(temp+slice*w*h*numComponents+stride, numComponents, temp+slice*w*h*numComponents+stride, numComponents, &mean, &var, w*h);
if (var == 0) {
vDSP_vfill(&var, temp+slice*w*h*numComponents+stride, numComponents, w*h);
}
vDSP_vsmul(temp+slice*w*h*numComponents+stride, numComponents, &gamma[i], temp+slice*w*h*numComponents+stride, numComponents, w*h);
vDSP_vsadd(temp+slice*w*h*numComponents+stride, numComponents, &beta[i], temp+slice*w*h*numComponents+stride, numComponents, w*h);
}
[self floatToHalf:temp halfp:htemp width:w height:h channel:channels];
for (int i = 0; i < numSlices; i++) {
[image.texture replaceRegion:MTLRegionMake3D(0, 0, 0, w, h, 1) mipmapLevel:0 slice:i withBytes:htemp+w*h*numComponents*i bytesPerRow:w*numComponents*2 bytesPerImage:0];
}
free(temp);
free(htemp);
free(gamma);
free(beta);
}
Nearest-Neighbor
这个填充算法,苹果也没有直接提供,BlitCommandEncoder里面有相关的方法,但是我感觉使用有点麻烦,本来是个很简单的填充算法,再加上前面BN都已经在CPU上实现了,这个也就调用2次,于是我也直接在CPU上去实现运算了。
原理很简单,就是放大图片像素点周围用这一个色值去填充,
实现代码:
- (void)ResizeNearestNeighbor:(MPSImage *)source destinationImage:(MPSImage *)destinationImage
{
NSUInteger w = source.texture.width;
NSUInteger h = source.texture.height;
NSUInteger w2 = destinationImage.texture.width;
NSUInteger h2 = destinationImage.texture.height;
NSUInteger featureNum = source.featureChannels;
NSUInteger numSlices = (featureNum + 3) / 4;
NSUInteger numComponents = featureNum < 3 ? featureNum : 4;
NSUInteger channels = featureNum < 3 ? featureNum : numSlices * 4;
float16_t *htemp1 = calloc(w*h*channels, sizeof(float16_t));
float16_t *htemp2 = calloc(w2*h2*channels, sizeof(float16_t));
for (int i = 0; i < numSlices; i++) {
[source.texture getBytes:htemp1+w*h*numComponents*i bytesPerRow:w*numComponents*2 bytesPerImage:0 fromRegion:MTLRegionMake3D(0, 0, 0, w, h, 1) mipmapLevel:0 slice:i];
}
int x_ratio = (int)((w<<16)/w2) +1;
int y_ratio = (int)((h<<16)/h2) +1;
int x2, y2 ;
for (int k = 0; k < featureNum; k++) {
int slice = k / 4;
int stride = k % 4;
for (int i=0;i<h2;i++) {
for (int j=0;j<w2;j++) {
x2 = ((j*x_ratio)>>16) ;
y2 = ((i*y_ratio)>>16) ;
htemp2[slice*w2*h2*numComponents+(i*w2+j)*numComponents+stride] = htemp1[slice*w*h*numComponents+((y2*w)+x2)*numComponents+stride];
}
}
}
for (int i = 0; i < numSlices; i++) {
[destinationImage.texture replaceRegion:MTLRegionMake3D(0, 0, 0, w2, h2, 1) mipmapLevel:0 slice:i withBytes:htemp2+w2*h2*numComponents*i bytesPerRow:w2*numComponents*2 bytesPerImage:0];
}
free(htemp1);
free(htemp2);
}
整个网络实现
最后,整个网络的实现,参照论文里的结构和链接顺序,其中所有的卷积都是继承的MPSCNNConvolution对象,代码有点长如下:
- (MPSImage *)forward:(CGImageRef)srcImage width:(int)width height:(int)height styles:(float *)styles
{
id<MTLCommandBuffer> commandbuffer = [commandQueue commandBuffer];
int w = width;
int h = height;
MTKTextureLoader *loader = [[MTKTextureLoader alloc] initWithDevice:mtDevice];
id<MTLTexture> srcTexture = [loader newTextureWithCGImage:srcImage options:nil error:nil];
MPSImage *cc1Image = [[MPSImage alloc] initWithTexture:srcTexture featureChannels:3];
// MPSImage *tImage = [[MPSImage alloc] initWithTexture:srcTexture featureChannels:3];
// MPSImageDescriptor *cc1Des = [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 width:w height:h featureChannels:3];
// MPSImage *cc1Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:cc1Des];
// [cc1Image.texture replaceRegion:MTLRegionMake3D(0, 0, 0, w, h, 1) mipmapLevel:0 withBytes:srcImage bytesPerRow:w*4*2];
// contract
MPSImageDescriptor *cc2Des = [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 width:w height:h featureChannels:32];
MPSImage *cc2Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:cc2Des];
[contractConv1 encodeToCommandBuffer:commandbuffer sourceImage:cc1Image destinationImage:cc2Image device:mtDevice];
[commandbuffer commit];
[commandbuffer waitUntilCompleted];
[self batch_norm:cc2Image styles:styles shift:cc1Shift];
commandbuffer = [commandQueue commandBuffer];
[relu encodeToCommandBuffer:commandbuffer sourceImage:cc2Image destinationImage:cc2Image];
w /= 2;
h /= 2;
MPSImageDescriptor *cc3Des = [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 width:w height:h featureChannels:64];
MPSImage *cc3Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:cc3Des];
[contractConv2 encodeToCommandBuffer:commandbuffer sourceImage:cc2Image destinationImage:cc3Image device:mtDevice];
[commandbuffer commit];
[commandbuffer waitUntilCompleted];
[self batch_norm:cc3Image styles:styles shift:cc2Shift];
commandbuffer = [commandQueue commandBuffer];
[relu encodeToCommandBuffer:commandbuffer sourceImage:cc3Image destinationImage:cc3Image];
w /= 2;
h /= 2;
MPSImageDescriptor *rcDes = [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 width:w height:h featureChannels:128];
MPSImage *rc11Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:rcDes];
[contractConv3 encodeToCommandBuffer:commandbuffer sourceImage:cc3Image destinationImage:rc11Image device:mtDevice];
[commandbuffer commit];
[commandbuffer waitUntilCompleted];
[self batch_norm:rc11Image styles:styles shift:cc3Shift];
// residual
commandbuffer = [commandQueue commandBuffer];
[relu encodeToCommandBuffer:commandbuffer sourceImage:rc11Image destinationImage:rc11Image];
MPSImage *rc12Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:rcDes];
[residual1Conv1 encodeToCommandBuffer:commandbuffer sourceImage:rc11Image destinationImage:rc12Image device:mtDevice];
[commandbuffer commit];
[commandbuffer waitUntilCompleted];
[self batch_norm:rc12Image styles:styles shift:rc11Shift];
commandbuffer = [commandQueue commandBuffer];
[relu encodeToCommandBuffer:commandbuffer sourceImage:rc12Image destinationImage:rc12Image];
MPSImage *rc21Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:rcDes];
[residual1Conv2 encodeToCommandBuffer:commandbuffer sourceImage:rc12Image destinationImage:rc21Image device:mtDevice];
[commandbuffer commit];
[commandbuffer waitUntilCompleted];
[self batch_norm:rc21Image styles:styles shift:rc12Shift];
[self addImage:rc11Image B:rc21Image C:rc21Image];
commandbuffer = [commandQueue commandBuffer];
MPSImage *rc22Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:rcDes];
[residual2Conv1 encodeToCommandBuffer:commandbuffer sourceImage:rc21Image destinationImage:rc22Image device:mtDevice];
[commandbuffer commit];
[commandbuffer waitUntilCompleted];
[self batch_norm:rc22Image styles:styles shift:rc21Shift];
commandbuffer = [commandQueue commandBuffer];
[relu encodeToCommandBuffer:commandbuffer sourceImage:rc22Image destinationImage:rc22Image];
MPSImage *rc31Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:rcDes];
[residual2Conv2 encodeToCommandBuffer:commandbuffer sourceImage:rc22Image destinationImage:rc31Image device:mtDevice];
[commandbuffer commit];
[commandbuffer waitUntilCompleted];
[self batch_norm:rc31Image styles:styles shift:rc22Shift];
[self addImage:rc21Image B:rc31Image C:rc31Image];
commandbuffer = [commandQueue commandBuffer];
MPSImage *rc32Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:rcDes];
[residual3Conv1 encodeToCommandBuffer:commandbuffer sourceImage:rc31Image destinationImage:rc32Image device:mtDevice];
[commandbuffer commit];
[commandbuffer waitUntilCompleted];
[self batch_norm:rc32Image styles:styles shift:rc31Shift];
commandbuffer = [commandQueue commandBuffer];
[relu encodeToCommandBuffer:commandbuffer sourceImage:rc32Image destinationImage:rc32Image];
MPSImage *rc41Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:rcDes];
[residual3Conv2 encodeToCommandBuffer:commandbuffer sourceImage:rc32Image destinationImage:rc41Image device:mtDevice];
[commandbuffer commit];
[commandbuffer waitUntilCompleted];
[self batch_norm:rc41Image styles:styles shift:rc32Shift];
[self addImage:rc31Image B:rc41Image C:rc41Image];
commandbuffer = [commandQueue commandBuffer];
MPSImage *rc42Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:rcDes];
[residual4Conv1 encodeToCommandBuffer:commandbuffer sourceImage:rc41Image destinationImage:rc42Image device:mtDevice];
[commandbuffer commit];
[commandbuffer waitUntilCompleted];
[self batch_norm:rc42Image styles:styles shift:rc41Shift];
commandbuffer = [commandQueue commandBuffer];
[relu encodeToCommandBuffer:commandbuffer sourceImage:rc42Image destinationImage:rc42Image];
MPSImage *rc51Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:rcDes];
[residual4Conv2 encodeToCommandBuffer:commandbuffer sourceImage:rc42Image destinationImage:rc51Image device:mtDevice];
[commandbuffer commit];
[commandbuffer waitUntilCompleted];
[self batch_norm:rc51Image styles:styles shift:rc42Shift];
[self addImage:rc41Image B:rc51Image C:rc51Image];
commandbuffer = [commandQueue commandBuffer];
MPSImage *rc52Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:rcDes];
[residual5Conv1 encodeToCommandBuffer:commandbuffer sourceImage:rc51Image destinationImage:rc52Image device:mtDevice];
[commandbuffer commit];
[commandbuffer waitUntilCompleted];
[self batch_norm:rc52Image styles:styles shift:rc51Shift];
commandbuffer = [commandQueue commandBuffer];
[relu encodeToCommandBuffer:commandbuffer sourceImage:rc52Image destinationImage:rc52Image];
MPSImage *temp = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:rcDes];
[residual5Conv2 encodeToCommandBuffer:commandbuffer sourceImage:rc52Image destinationImage:temp device:mtDevice];
[commandbuffer commit];
[commandbuffer waitUntilCompleted];
[self batch_norm:temp styles:styles shift:rc52Shift];
[self addImage:rc51Image B:temp C:temp];
// unsampling
commandbuffer = [commandQueue commandBuffer];
w *= 2;
h *= 2;
MPSImageDescriptor *ec1Des = [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 width:w height:h featureChannels:128];
MPSImage *ec1Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:ec1Des];
[self ResizeNearestNeighbor:temp destinationImage:ec1Image];
MPSImageDescriptor *temp2Des = [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 width:w height:h featureChannels:64];
MPSImage *temp2 = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:temp2Des];
[expandConv1 encodeToCommandBuffer:commandbuffer sourceImage:ec1Image destinationImage:temp2 device:mtDevice];
[commandbuffer commit];
[commandbuffer waitUntilCompleted];
[self batch_norm:temp2 styles:styles shift:ec1Shift];
commandbuffer = [commandQueue commandBuffer];
[relu encodeToCommandBuffer:commandbuffer sourceImage:temp2 destinationImage:temp2];
w *= 2;
h *= 2;
MPSImageDescriptor *ec2Des = [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 width:w height:h featureChannels:64];
MPSImage *ec2Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:ec2Des];
[self ResizeNearestNeighbor:temp2 destinationImage:ec2Image];
MPSImageDescriptor *ec3Des = [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 width:w height:h featureChannels:32];
MPSImage *ec3Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:ec3Des];
[expandConv2 encodeToCommandBuffer:commandbuffer sourceImage:ec2Image destinationImage:ec3Image device:mtDevice];
[commandbuffer commit];
[commandbuffer waitUntilCompleted];
[self batch_norm:ec3Image styles:styles shift:ec2Shift];
commandbuffer = [commandQueue commandBuffer];
[relu encodeToCommandBuffer:commandbuffer sourceImage:ec3Image destinationImage:ec3Image];
MPSImageDescriptor *destDes = [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 width:w height:h featureChannels:3];
MPSImage *destImage = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:destDes];
[expandConv3 encodeToCommandBuffer:commandbuffer sourceImage:ec3Image destinationImage:destImage device:mtDevice];
[commandbuffer commit];
[commandbuffer waitUntilCompleted];
[self batch_norm:destImage styles:styles shift:ec3Shift];
commandbuffer = [commandQueue commandBuffer];
[sigmoid encodeToCommandBuffer:commandbuffer sourceImage:destImage destinationImage:destImage];
[commandbuffer commit];
[commandbuffer waitUntilCompleted];
return destImage;
}
结语
什么话都不想留下了☠️,放几张程序运行图吧😊。