用tensorflow实现拉普拉斯金字塔(Laplacian pyramid)的上采样

前言

之前在写代码的时候遇到了需要拉普拉斯金字塔(Laplacian pyramid)。图片生成金字塔的时候是python代码,但是生成金字塔以后需要用tensor处理,也就是需要在tensor状态下一层一层恢复金字塔。看这篇简书的你应该已经知道什么是拉普拉斯金字塔。他的upsampling不只是简单的tf.image.resize()或者dilation。这个就有点费劲了。当时卡了我很久,stackoverflow上也找不到有人问这个问题。后来做出来了一个东西倒是能正确的upsampling,但是在bp的时候不能返传gradient。。。那还训练个P啊。。。终于经过研究,勤学苦练的我(手动滑稽)终于找到了用tensor方程实现upsampling的方法。在这里分享出来希望对别人有帮助。

背景介绍

我用的是python平台的tensorflow。tensorflow建立graph之后seas.run()的时候是不能运行python代码的。如果想把python代码写入tensor,硬来的话有一个方法,就是用tf.py_func()(现在tf.py_func已经被deprecated了,我当时用的时候还有)。他可以把python方程转换成tensorflow方程。但是有个弊端。我们知道tensorflow是用自己定义的方程tf.啥()运作的。他自己定义的方程,当然他自己也知道怎么返传gradient。如果自己定义的python方程转换成tensorflow方程,系统肯定不知道你定义的东西怎么返传呀。所以tf.py_func上游的variable们都没法更新了。所以我的办法是霸王硬上弓,生用tensorflow的方程一步步实现upsampling。

步骤

先回忆一下laplacian upsampling downsampling是怎么弄的。我的downsampling是用opencv里的方程cv2.pyrdown()。现在目的是用tensorflow实现cv2.pyrup()

  1. 将原图行列,和外面一圈填入0,如下图所示。


    Intuition of dilation
  2. 用4倍于下采样的高斯滤波器进行一次卷积。

Opencv下采样用的高斯滤波器长这样:

Gaussian kernel in Opencv cv2.pyrDown()

那么乘以4,再卷积就好了。tf.nn.conv2d()就可以卷积。现在问题是什么方法能方便的把一张图像行列填0。
由于图片的大小是不确定的,用数学上矩阵相乘的方式是不推荐的(也可能有时候那样的矩阵就不存在,不知道我的线性代数不太好。。)。那么直接一点的方法就是用写代码的思维来做。这里找到了一个可以把图像的列填入0的方法。

    | ? ? ? |        | ? 0 ? 0 ? |
A = | ? ? ? |  --->  | ? 0 ? 0 ? |
    | ? ? ? |        | ? 0 ? 0 ? |

代码是这样:

# Input
a = tf.constant(np.arange(9).reshape(3,3), tf.float32)
#[[0. 1. 2.]
# [3. 4. 5.]
# [6. 7. 8.]]

# 创建一个和a一样大小的全是0的矩阵
b = tf.zeros_like(a)
c = tf.reshape(tf.stack([a,b], 2),
               [-1, tf.shape(a)[1]+tf.shape(b)[1]])[:,:-1]


with tf.Session() as sess:
   print(sess.run(c))
#[[0. 0. 1. 0. 2.]
# [3. 0. 4. 0. 5.]
# [6. 0. 7. 0. 8.]]

a和b都是二维矩阵,这里tf.stack([a,b], 2)会拓展出第三维,并把0矩阵沿着第三维放到a后面。

Explain tf.stack

[-1, tf.shape(a)[1]+tf.shape(b)[1]]这个是[-1, 6]。也就是把stack的3d矩阵3x3x2=18个元素reshape成6列。3x6=18,自然也就是3行了。reshape的时候会自动把第三维的0矩阵穿插到第一个矩阵之间。

Reshape

最后[:,:-1]意思是去掉最后一列。
Discard the last column

有了这个就好办了。虽然不知道怎么穿插0到行里面,但是tensorflow里有转置矩阵的方法tf.transpose()。这样行变成列以后再干一遍,不就都有0了嘛。
Insert 0s between rows

最后,最外面再padding一圈就可以了,我用的reflect padding,最后最后恢复的效果比padding 0好一些。
Reflect padding

这是上面所有的可执行代码:

def dilatezeros(imgs):
    zeros = tf.zeros_like(imgs)
    column_zeros = tf.reshape(tf.stack([imgs, zeros], 2), [-1, tf.shape(imgs)[1] + tf.shape(zeros)[1]])[:,:-1]

    row_zeros = tf.transpose(column_zeros)

    zeros = tf.zeros_like(row_zeros)
    dilated = tf.reshape(tf.stack([row_zeros, zeros], 2), [-1, tf.shape(row_zeros)[1] + tf.shape(zeros)[1]])[:,:-1]
    dilated = tf.transpose(dilated)

    paddings = tf.constant([[0, 1], [0, 1]])
    dilated = tf.pad(dilated, paddings, "REFLECT")

    dilated = tf.expand_dims(dilated, axis=0)
    dilated = tf.expand_dims(dilated, axis=3)
    return dilated

第二步卷积高斯kernel就简单了。先创建出上面那个downsampling用的kernel:

def call2dtensorgaussfilter():
    return tf.constant([[1./256., 4./256., 6./256., 4./256., 1./256.],
                        [4./256., 16./256., 24./256., 16./256., 4./256.],
                        [6./256., 24./256., 36./256., 24./256., 6./256.],
                        [4./256., 16./256., 24./256., 16./256., 4./256.],
                        [1./256., 4./256., 6./256., 4./256., 1./256.]])

再应用上去:

def applygaussian(imgs):
    gauss_f = call2dtensorgaussfilter()
    gauss_f = tf.expand_dims(gauss_f, axis=2)
    gauss_f = tf.expand_dims(gauss_f, axis=3)

    result = tf.nn.conv2d(imgs, gauss_f * 4, strides=[1, 1, 1, 1], padding="VALID")
    result = tf.squeeze(result, axis=0)
    result = tf.squeeze(result, axis=2)
    return result

padding='VALID'意思是卷积时候如果矩阵长度不整除stride的话会丢掉矩阵剩余的部分。padding还有一个parameter是SAME。意思是不丢掉,不整除的话会自动padding到整除为止,再做卷积。这里有关于padding更详细的讲解。咱们stride是1,所以不存在不整除。
按照上面的做法做一次就恢复了一层。但是laplacian pyramid一般都不会只分两层。如果多余2层怎么办呢?其实只要有一个循环重复上面的动作就可以了。tensorflow里还真有循环操作tf.while_loop。这个怎么用就不详细讲了,官方document里有介绍。CSDN里也有介绍的帖子,这个这个stackoverflow里也有例子。总之是要写两个方程,一个是condition,一个是loop的body。condition很简单:

def cond(output_bot, i, n):
    return tf.less(i, n)

loop里叫上面的dilatezerosapplygaussian就行了。只是要注意因为我在卷积的时候用的padding='VALID',所以卷积结束会小两圈,所以loop里要再padding一下,才能保证结果的大小不变。

# funcs for tf.while_loop ====================================
def body(output_bot, i, n):
    paddings = tf.constant([[0, 0], [2, 2], [2, 2], [0, 0]])
    output_bot = dilatezeros(output_bot)
    output_bot = tf.pad(output_bot, paddings, "REFLECT")
    output_bot = applygaussian(output_bot)
    return output_bot, tf.add(i, 1), n

这样第二步也完成了。 下面是完整的代码:

上采样完整可执行代码

def call2dtensorgaussfilter():
    return tf.constant([[1./256., 4./256., 6./256., 4./256., 1./256.],
                        [4./256., 16./256., 24./256., 16./256., 4./256.],
                        [6./256., 24./256., 36./256., 24./256., 6./256.],
                        [4./256., 16./256., 24./256., 16./256., 4./256.],
                        [1./256., 4./256., 6./256., 4./256., 1./256.]])

def applygaussian(imgs):
    gauss_f = call2dtensorgaussfilter()
    gauss_f = tf.expand_dims(gauss_f, axis=2)
    gauss_f = tf.expand_dims(gauss_f, axis=3)

    result = tf.nn.conv2d(imgs, gauss_f * 4, strides=[1, 1, 1, 1], padding="VALID")
    result = tf.squeeze(result, axis=0)
    result = tf.squeeze(result, axis=2)
    return result

def dilatezeros(imgs):
    zeros = tf.zeros_like(imgs)
    column_zeros = tf.reshape(tf.stack([imgs, zeros], 2), [-1, tf.shape(imgs)[1] + tf.shape(zeros)[1]])[:,:-1]

    row_zeros = tf.transpose(column_zeros)

    zeros = tf.zeros_like(row_zeros)
    dilated = tf.reshape(tf.stack([row_zeros, zeros], 2), [-1, tf.shape(row_zeros)[1] + tf.shape(zeros)[1]])[:,:-1]
    dilated = tf.transpose(dilated)

    paddings = tf.constant([[0, 1], [0, 1]])
    dilated = tf.pad(dilated, paddings, "REFLECT")

    dilated = tf.expand_dims(dilated, axis=0)
    dilated = tf.expand_dims(dilated, axis=3)
    return dilated

# funcs for tf.while_loop ====================================
def body(bottom, i, n):
    paddings = tf.constant([[0, 0], [2, 2], [2, 2], [0, 0]])
    bottom = dilatezeros(bottom)
    bottom = tf.pad(bottom, paddings, "REFLECT")
    bottom = applygaussian(bottom)
    return bottom, tf.add(i, 1), n

def cond(bottom, i, n):
    return tf.less(i, n)

用法

这几个代码怎么用呢?首先要设定一个循环次数,比如要upsampling3次,就n=tf.constant(3), i=tf.constant(0)。然后叫

# 注意这里是tensor操作,所以n和i不能是scaler,也得是tensor才行
bottom, i, n = tf.while_loop(cond, body, [bottom, i, n],  shape_invariants=[tf.TensorShape([None, None]), i.get_shape(),n.get_shape()])

还有一个问题。一般在训练数据的时候,tensor默认的格式是[batchsize, height, width, channel]。上面做的是batch里的一张图的恢复,如果要batch里每张图都恢复怎么做呢?很好办,用python代码写一个loop,把每张图从batch里slice出来,再传给tf.while_loop()就好了。这里有我的一篇专门理解tf.slice()的文章。下面是可执行的slice循环代码:

# 计算得到bottom的size,用你的自己方法去找h和w。lev_scale是我的金字塔的层数
h, w = calshape(height, width, lev_scale)
# dynamic shape到static shape的转换
tfbot_upsampling = tf.reshape(bottom, [config.train.batch_size_ft, h, w])

new_bottom = 0
for index in range(config.train.batch_size_ft):
    # 切出一张图
    fullsize_bottom = tf.squeeze(tf.slice(tfbot_upsampling, [index, 0, 0], [1, -1, -1]))

    i = tf.constant(0)
    n = tf.constant(int(lev_scale))
    fullsize_bottom, i, n = tf.while_loop(cond, body, [fullsize_bottom,i,n], shape_invariants=[tf.TensorShape([None, None]), i.get_shape(),n.get_shape()])
    fullsize_bottom = tf.expand_dims(fullsize_bottom, axis=0)
    if index == 0:
        new_bottom = fullsize_bottom 
    else:
       # 恢复tensor的shape,按batch再concat回去
        new_bottom = tf.concat([new_bot, fullsize_bottom], axis=0)

注意

有两点要注意:

  1. Opencv里laplacian pyramid的恢复是要在灰度图下进行的,也就是channel=1。所以彩色图需要用cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)变成灰度图,再进行金字塔运算。
  2. 这个代码本意只是把最底层的高斯层恢复到最大,而不是把金子塔恢复到原图。所以中间没有➕恢复一次以后上一层的laplacian特征。如果想恢复整个金字塔的盆友可以在这个代码上改一下,在body里加上对应的laplacian层就可以了,应该很简单。
  3. 在进行upsampling操作之前,tensorflow必须知道你传入的图片的具体size,不能是dynamic shape(就是不能是[bs,?,?,1])。所以要把size数值管理好,传入之前reshape成相应的大小。
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 194,390评论 5 459
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 81,821评论 2 371
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 141,632评论 0 319
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 52,170评论 1 263
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 61,033评论 4 355
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 46,098评论 1 272
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 36,511评论 3 381
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 35,204评论 0 253
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 39,479评论 1 290
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 34,572评论 2 309
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 36,341评论 1 326
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 32,213评论 3 312
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 37,576评论 3 298
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 28,893评论 0 17
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,171评论 1 250
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 41,486评论 2 341
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 40,676评论 2 335

推荐阅读更多精彩内容