利用 TensorFlow 和 MNIST 数据集演示 GAN 的构建

自打关注深度学习这个领域就不时的看到和 Generative Adversarial Network, GAN 相关的东西,也一直非常好奇这个被 LeCun 称为深度学习近年来最大的突破的东西到底是什么样子的。正好在 Udacity 的课堂里遇到了,在完成了通过 GAN 来完成人脸生成的项目后,在这里做一个总结,加深一下对于 GAN 这个网络的理解。为了便于本地试验,这里展示的是利用 MNIST 数据集来训练一个简单的 GAN 来生成手写数字的过程。注意文中代码和示例图片来自 Udacity 深度学习纳米学位课程,版权归 Udacity 所有。

深度神经网络最令人诟病一点就在于其决策过程的不可解释性,你无从知道网络中的单元提取了哪些特征来完成了一项分类或识别任务。比如在图片识别任务中,即便你可以提取隐藏层的 feature map 来可视化出来相应层的情况,其图像在人类看来是抽象而诡异甚至有些惊悚的。这一点其实在我看来是十分正常的,也不应该像很多媒体的解读方式那样过分的夸大,事实上,人脑的加工过程有谁可以可视化出来呢?只不过我们对于人类行为的可预测性是有把握的,所以不像对于新生技术那样容易催生恐惧。

而 GAN 最为聪明之处在于既然人类无法理解网络内部的生成过程,索性不用人脑和人类对于图像的理解方式去理解中间过程,而是用另一个类似结构的神经网络,二者的相互理解过程也就是对抗 Adversarial 的过程。其实现的大致思路是:

  • 作为生成器的一个典型代表,GAN 的一个典型应用是通过模型来生成类似已有数据集的图片来实现数据扩增,因此可以首先建立一个通过多层神经网络实现的生成器,其主要作用是通过对于符合一定分布规律的原始数据进行处理,进而得到一个符合另一特定分布情况的结果图像。这里要求这个网络至少包含一个隐藏层,否则网络就不具有足够的学习和泛化能力,这个网络在 GAN 中被称为生成器 Generator。例如在下面的示例图片中,生成器的输入是符合某个分布特征的随机数字:在后续的代码示例中采用的是 (-1, 1) 之间的均匀分布

  • 在获得了生成器之后,还要建立一个类似结构的可以完成图像识别任务的分类器,其特殊之处在于这个网络的输出层只对输入是来自原始数据集还是由生成器网络生成的结果做一个真假判断,这个网络在 GAN 中称为识别器 Discriminator

High level overview of GAN with MNIST

在看到代码之前我一直以为 GAN 的实现会比较复杂,但真正看到代码之后就像看到 E = mc2 一样,发现其是如此的简洁,优雅,直观,不得不佩服 Ian Goodfellow 强大的思路。闲话到此为止,网络架构和实现代码如下:

Network Architecture
%matplotlib inline
import pickle as pkl
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

# load data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data')

# define the model input for both Generator and Discirminator
def model_inputs(real_dim, z_dim):
    inputs_real = tf.placeholder(tf.float32, (None, real_dim), name='input_real') 
    inputs_z = tf.placeholder(tf.float32, (None, z_dim), name='input_z')
    
    return inputs_real, inputs_z

# define the Generator
def generator(z, out_dim, n_units=128, reuse=False, alpha=0.01):
    with tf.variable_scope('generator', reuse=reuse):
        # Hidden layer
        h1 = tf.layers.dense(z, n_units, activation=None)
        # Leaky ReLU
        h1 = tf.maximum(alpha * h1, h1)
        
        # Logits and tanh output
        logits = tf.layers.dense(h1, out_dim, activation=None)
        out = tf.tanh(logits)
        
        return out

# define the Discriminator
def discriminator(x, n_units=128, reuse=False, alpha=0.01):
    with tf.variable_scope('discriminator', reuse=reuse):
        # Hidden layer
        h1 = tf.layers.dense(x, n_units, activation=None)
        # Leaky ReLU
        h1 = tf.maximum(alpha * h1, h1)
        
        logits = tf.layers.dense(h1, 1, activation=None)
        out = tf.sigmoid(logits)
        
        return out, logits

这里之所以要定义这个 variable_scope 是由于在后续的训练中,需要分别更新生成器和判别器的参数,为了提取参数而特别设置的。另外值得注意的是,激活函数需要采用 Leaky ReLU 来保证梯度可以从判别器传回到生成器。

# build the network
tf.reset_default_graph()
# Create our input placeholders
input_real, input_z = model_inputs(input_size, z_size)

# Build the model
g_model = generator(input_z, input_size, n_units=g_hidden_size, alpha=alpha)
# g_model is the generator output

d_model_real, d_logits_real = discriminator(input_real, n_units=d_hidden_size, alpha=alpha)
d_model_fake, d_logits_fake = discriminator(g_model, reuse=True, n_units=d_hidden_size, alpha=alpha)

# Calculate losses
d_loss_real = tf.reduce_mean(
                  tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, 
                                                          labels=tf.ones_like(d_logits_real) * (1 - smooth)))
d_loss_fake = tf.reduce_mean(
                  tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, 
                                                          labels=tf.zeros_like(d_logits_real)))
d_loss = d_loss_real + d_loss_fake

g_loss = tf.reduce_mean(
             tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                     labels=tf.ones_like(d_logits_fake)))

在这里新引入的一个操作是 label smoothing,其目的在于适度的放低要求以促进收敛。而针对损失函数这部分,由于希望判别器将真实数据识别为 1, 而将生成器生成的数据识别为 0,因此需要分别计算这两部分的损失函数。

# Optimizers
learning_rate = 0.002

# Get the trainable_variables, split into G and D parts
t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if var.name.startswith('generator')]
d_vars = [var for var in t_vars if var.name.startswith('discriminator')]

d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)

这一段代码非常重要,正式因为选择了间歇性的训练才使得网络的对抗得以实现。

# Size of input image to discriminator
input_size = 784
# Size of latent vector to generator
z_size = 100
# Sizes of hidden layers in generator and discriminator
g_hidden_size = 128
d_hidden_size = 128
# Leak factor for leaky ReLU
alpha = 0.01
# Smoothing 
smooth = 0.1

下面代码部分为比较常见的训练代码结构:

batch_size = 100
epochs = 100
samples = []
losses = []
# Only save generator variables
saver = tf.train.Saver(var_list=g_vars)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for e in range(epochs):
        for ii in range(mnist.train.num_examples//batch_size):
            batch = mnist.train.next_batch(batch_size)
            
            # Get images, reshape and rescale to pass to D
            batch_images = batch[0].reshape((batch_size, 784))
            batch_images = batch_images*2 - 1
            
            # Sample random noise for G
            batch_z = np.random.uniform(-1, 1, size=(batch_size, z_size))
            
            # Run optimizers
            _ = sess.run(d_train_opt, feed_dict={input_real: batch_images, input_z: batch_z})
            _ = sess.run(g_train_opt, feed_dict={input_z: batch_z})
        
        # At the end of each epoch, get the losses and print them out
        train_loss_d = sess.run(d_loss, {input_z: batch_z, input_real: batch_images})
        train_loss_g = g_loss.eval({input_z: batch_z})
            
        print("Epoch {}/{}...".format(e+1, epochs),
              "Discriminator Loss: {:.4f}...".format(train_loss_d),
              "Generator Loss: {:.4f}".format(train_loss_g))    
        # Save losses to view after training
        losses.append((train_loss_d, train_loss_g))
        
        # Sample from generator as we're training for viewing afterwards
        sample_z = np.random.uniform(-1, 1, size=(16, z_size))
        gen_samples = sess.run(
                       generator(input_z, input_size, n_units=g_hidden_size, reuse=True, alpha=alpha),
                       feed_dict={input_z: sample_z})
        samples.append(gen_samples)
        saver.save(sess, './checkpoints/generator.ckpt')

# Save training generator samples
with open('train_samples.pkl', 'wb') as f:
    pkl.dump(samples, f)

为了监控训练,可以提取训练过程中的参数来识别训练结果。实际上在学习过程中可以发现 GAN 的训练对于超参数的选择十分敏感,并且在后续的 DCGAN 学习中,作者们甚至通过调整 Adam 中的指数加权平均参数 beta1 来实现较好的训练效果。Ian Goodfellow 在 Andrew Ng 的访谈里也提到自己现在 40% 的时间话在研究如何 Stablize GAN,当时没理解是什么意思,直到自己训练了 DCGAN 之后才知道原来 GAN 的训练对于超参数是如此的敏感。

def view_samples(epoch, samples):
    fig, axes = plt.subplots(figsize=(7,7), nrows=4, ncols=4, sharey=True, sharex=True)
    for ax, img in zip(axes.flatten(), samples[epoch]):
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')
    
    return fig, axes

rows, cols = 10, 6
fig, axes = plt.subplots(figsize=(7,12), nrows=rows, ncols=cols, sharex=True, sharey=True)

for sample, ax_row in zip(samples[::int(len(samples)/rows)], axes):
    for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):
        ax.imshow(img.reshape((28,28)), cmap='Greys_r')
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
Generated result as the training goes

参考阅读

  1. Tips and tricks to make GANs work

  2. Generative Adversarial Networks for beginners

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 206,378评论 6 481
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 88,356评论 2 382
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 152,702评论 0 342
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 55,259评论 1 279
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 64,263评论 5 371
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,036评论 1 285
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,349评论 3 400
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,979评论 0 259
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,469评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,938评论 2 323
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,059评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,703评论 4 323
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,257评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,262评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,485评论 1 262
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,501评论 2 354
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,792评论 2 345

推荐阅读更多精彩内容