这个 demo 展示了如何将图片转换成印象派风格, 非常有趣, 本文将从数据流的角度分析作者是如何做的.
Demo 的具体步骤和效果可见 GitHub.
从图片生成训练数据
读取每个像素 RGB 的值, 并归一化.
下图是由(0,0)坐标
像素生成的 Pixel
对象, RGB 的值都是由原值除以 255 得到的.
Pixel -> Example
接着通过 pixelToExample
方法, 每个像素会被转换为一个 Example
对象.
RGB 分别对应3个 FeatureVector
, 其中 C
是 channel 的缩写, 存储了对应的颜色, $target
的值会作为输出参与训练.
$target
这个键名是在 image_impressionism.conf 中配置的, 该文件还记录了数据路径/模型参数/特征转换方法和参数
等其他配置.
下图是一个 Example
实例.
这是完全展开后转成 json 格式的示例.
{
"example": [
{
"stringFeatures": {
"C": [
"Red"
]
},
"floatFeatures": {
"$target": {
"": 0.5529411764705883
}
}
},
{
"stringFeatures": {
"C": [
"Green"
]
},
"floatFeatures": {
"$target": {
"": 0.5725490196078431
}
}
},
{
"stringFeatures": {
"C": [
"Blue"
]
},
"floatFeatures": {
"$target": {
"": 0.5568627450980392
}
}
}
],
"context": {
"floatFeatures": {
"LOC": {
"X": 0,
"Y": 0
}
}
}
}
存储
最后用 thrift 将特征对象序列化后压缩存储.
sc.parallelize(pixels)
.map(x => pixelToExample(x, true))
.map(Util.encode)
.saveAsTextFile(output, classOf[GzipCodec])
训练模型
特征转换
所有特征会按顺序依次应用3类 transform: context_transform
, item_transform
, combined_transform
.
每类 transform 的名称及相关参数都需要在 image_impressionism.conf
中配置.
Context Transform
Context transform 会转换 Example
中的 context
属性, 并将新特征存入 context.stringFeatures
.
quantize_pixel_location {
transform: multiscale_grid_quantize
# Grid up the pixels into squares of the following sizes.
# Use relatively prime grids to create jitter.
buckets : [ 3.0, 7.0, 17.0, 31.0, 47.0, 67.0, 79.0, 89.0, 97.0 ]
field1: "LOC"
value1: "Y"
value2: "X"
output: "QLOC"
}
此处用的是 multiscale_grid_quantize
方法, 实现该方法的类为 MultiscaleGridQuantizeTransform.
该方法将二维平面划分成不同大小的正方形格子, 然后将每个格子里的点都映射到该格子.
即用更大的粒度来描述平面, 借此消除局部差异, 提取共同特征.
格子的 ID 由其边长和左上角点的坐标拼接生成.
例如 [3.0]=(0.0,0.0)
包含了 (0,0),(0,1),(0,2),(1,0),(1,1),(1,2),(2,0),(2,1),(2,2)
9个点.
buckets
配置了会用哪些边长的格子, field1/value1/value2
可结合之前的特征示例体会含义, output
是新特征的名字.
下图是坐标(0,0)
转换后得到的9个新特征.
Item Transform
Item transform 将转换 Example.example
中每个 FeatureVector
, 转换后的新特征会存入 stringFeatures
.
identity_transform {
transform: list
transforms: []
}
list 表示将逐个应用 transforms 列表中的变换, 空列表意味着不做任何转换.
Combined Transform
Combined transform 将会把 context
和 example 中每个 FeatureVector
结合起来, 并存入后者的 stringFeatures
中.
代码实现分为两步:
- 拷贝
context.stringFeatures
至FeatureVector.stringFeatures
. - 对
example
中每个FeatureVector
的stringFeatures
应用配置中指定的 transform.
C_X_QLOC {
transform: cross
field1: "C" // Color channel
field2: "QLOC" // Quantized location
output: "C_x_QLOC"
}
combined_transform {
transform: list
transforms: [
C_X_QLOC
]
}
cross 对应的类为 CrossTransform, 它会把 field1/field2
的值拼接起来作为 output
的值.
需要注意的是:
- 转换前每个
Example
对象中example
属性有3个FeatureVector
. - 转换后3个
FeatureVector
将分别转换为1个Example
对象, 每个Example
对象的example
属性只有1个FeatureVector
.
下图为转换后的一个 Example
对象:
训练
Aerosolve 的训练算法都是基于 Spark 实现的, 所以和训练有关的代码都放在一个独立的子项目 training 中.
这个 demo 用的是线性回归模型, 训练方法为 SGD (Stochastic Gradient Descent), 代码实现在 LinearRankerTrainer 中.
每次迭代计算权重前, 会从 FeatureVector
取出相关的目标值和特征, 如下图:
权重训练完后的形式为 ((feature family, feature), weigth)
, 如下图:
保存模型
模型由两部分组成: 一个 ModelHeader
和 若干个 ModelRecord
(ModelHeader
实际上会保存为成一个特殊的 ModelRecord
).
线性模型的 ModelHeader
只用到了两个属性: modelType
的值会设置为 linear
, numRecords
会设为 weights.size
.
模型也是用 thrift 序列化后存储.
应用模型输出印象派图片
宏观视角
前面从微观角度观察了整个数据流, 接着我们从宏观的角度看看它和线性模型是怎么对接的.
线性模型实质是一个方程组, 训练权重的过程即求解自变量系数的过程.
该 demo 的方程组共有 num_pixels * 3
个方程, 每个像素会对应3个方程,
这是因为每个像素有3个 color channel (red/green/blue).
方程的因变量 y 即 r/g/b 归一化后的值.
自变量的个数 = 不同特征的总个数
, 自变量只有0和1两种取值, 0表示该方程中不含此特征, 1表示包含.
特征有3类:
- Red, Green, Blue
- 离散化后的坐标, 例如:
[3.0]=(0.0,0.0)
- 前两类的交叉组合, 例如:
Red^[3.0]=(0.0,0.0)
第2类特征总数 num_loc
和像素个数有关.
num_loc = sum([(int(image_width / b) + 1) * (int(image_height / b) + 1) for b in buckets])
总的特征个数 = 3 + num_loc + num_loc * 3 = num_loc * 4 + 3
每个方程只有少数特征对应的自变量取值为1, 所以不同类特征的影响范围是不同的.
- 第1类会影响 1/3 的方程.
- 第2, 3类特征只会影响和格子中像素有关的方程.
- 第2类会影响
边长^2 * 3
个, 第3类为边长^2
个.
单张图片是怎么生成的
输入数据是每个点的坐标, 带入模型后按 color channel 输出预测值, 合并后就得到了该点的 RGB 值.
其过程就像在大方格上摞小方格, 最终的高度即预测值.
这样的预测结果肯定好于只做单一划分的方法, 即包含了整体信息, 也包含了局部差异.
我觉得 airbnb 预测房价也应该是类似的想法.
动图是怎么生成的
生成每一帧的方法和单张图片一样, 只是每帧用到的权重个数不一样.
假设总共有 N
个权重, 第 i
帧只会用前 i/(N-1)
个来绘制图像, i∈{0, 1, ..., N-1}
.
这样图片就会渐渐的由模糊变清晰.
其他
- Readme 中对最红, 最蓝的解释不太恰当, 详见 Google Group.
- 项目还在发展阶段, 从新旧代码的质量就能看出来. 例如会看到一些复制粘贴的实现.
- 代码中有一些重复计算的问题. 例如理论上 context 只会计算一次, 但实际会计算多次, 不过我觉得影响不大.
- 还未成为性能瓶颈. 我测试了修改后的速度, 并未提升多少. 该demo的时间多花在文件读取上, 每轮迭代约5分钟, 约一半时间是在读文件.
- 新代码中重复计算的情况有所改善, 说明作者知道这个事情.
- 项目还在发展期, 不需要过早优化.
Debug 的小技巧
- 换一个像素较少的图片, 这样会大大节省每个步骤的时间.
- 换图后需修改
image_impressionism.conf
中make_impression
的宽高, 让生成的图片大小更合适. - 将
build.gradle
中 spark 的依赖从 provided 改为 compile. - 将
JobRunner.scala
作为 debug 的入口, 记得加上.setMaster("local")
. -
com.airbnb.aerosolve:training
的版本有点低, 可更换为最新版. - Ubuntu 用户如不想编译安装 thrift, 可用 docker thrift. Mac 用户可用 brew 安装.