想从Tensorflow循环生成对抗网络开始。但是发现从最难的内容入手还是?太复杂了所以搜索了一下他的始祖也就是深度卷积生成对抗网络。
https://mlnotebook.github.io/post/GAN2/
推荐这个给力的博客我们的派送程序分为四个部分。
gantut_imgfuncs.py: 一个部分是图像数据的读入处理
gantut_datafuncs.py: 第二个部分实现卷积神经网络
gantut_gan.py:第三个部分实现生成对抗网络
gantut_trainer.py:第四个部分,把所有程序。合到一块儿
数据
数据处理部分我们需要先下载数据
将图像数据处理到模型能处理的大小
使用数据集 CelebA databse. 这里是我们提供的下载链接: https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg.你需要进入文件夹下载 “img_align_celeba.zip” . 直接下载连接是
Download and extract this folder into ~/GAN/raw_images to find it contains 200,000+ examples of celebrity faces. Even though the .zip says ‘align’ in the name, we still need to resize the images and thus may need to realign them too.
数据预处理
为了处理这个量级的图片我们需要一个自动化的方法来重新改变他的图片大小。和裁剪图片你说我们使用 OpenFace.
打开终端进入你现在的工作目录。或者创建新的工作目录。我使用~/GAN,并且按照下面的指示 git clone Openface :
cd~/GAN
git clone https://github.com/cmusatyalab/openface.git openface
进入openface 文件夹 , 安装要求(其实requirements.txt中), so do this:
cd./openface
sudo pip install -r requirements.txt
安装完成,现在我们安装python模块
./models/get-models.sh
搞定以后
sudo pip install --upgrade scipy
现在处理图像的工具装好了。对图像的处理是非常关键的。他保证我们我们处理的图像,都是一个型号,而且网络可以好好的训练它好好的。如果一张图片的眼睛在最底部,另外一张图片的脸又在最上方。那这样的话就不方便训练了。
在工作目录`~/GAN’:
./openface/util/align-dlib.py ./raw_images align innerEyesAndBottomLip ./aligned --size 64
这将会整理好所有的图片,并把它们放在 ./raw_images。中然后将他们裁剪到 64 x 64大小 并且在./aligned中。这会花很长的时间(200,000+ 图片!).
现在我们来写一下函数进行图像处理。和数据的正则化。
之前我们已经下载了数据,进行了数据的预处理在进行下一步之前,我们的文件应该长成这样。
~/GAN
|- raw
|-- 00001.jpg
|-- ...
|- aligned
|-- 00001.jpg
|-- ...
|- gantut_imgfuncs.py
|- gantut_datafuncs.py
|- gantut_gan.py
|- gantut_trainer.py
图像函数
我们需要读入一个图像集。我们也需要能够输出产生的图像。另外我们希望能够安全的做图像的转换。和图像裁剪保证我们有正确的输入格式。
读入图像
这些函数时使得我们能够从硬盘读取我们所需要的数据:
1. get_image which calls
2. imread and
3. transform which calls
4. center_crop
imread() 函数
我们处理的是标准格式的图像。我们的网络支持.jpg, .jpeg , .png 作为输入
我们使用scipy.misc 库的 scipy.misc.imread 函数
Inputs
path: 图像的路径
Returns
图像
""" 读取图像 (part of get_image function)
"""
def imread(path):
return scipy.misc.imread(path, mode='RGB').astype(np.float)
transform() [to top][100]
我们这个函数是为了让图像数据都能够变成同样大小。所以我们应该读取图像(image)和我们想要它变成的大小(desired width),还有他需不需要被裁减whether to perform the cropping or not。也许我们已经裁减了我们的图像。因为我们已经做过注册/整理( registration/alignment) etc.
我们检查需不需要裁剪图像。就是看看我们有没有调用center_crop函数。没有的话我们就原来的图像了
在返回我们裁剪的图像之前我们做正则化处理。一般来说,像素的每个通道值得范围都在[0 255][0 255] 之间。最好不要让这个特点破坏你的数据。所以我们要把数据进行正则化处理。把他弄到[−1 1][−1 1] 之间。 方法很简单,就是减去 255/2=127.5, 再减去一。
接下来我们定义裁剪图像的函数 cropping function, 注意返回的图象只是一个数组。
Inputs
image: 要转换的图像数据
npx: 要转换的图像数据的大小。 [npx x npx]
is_crop: 要不要裁剪图像 [True or False]
Returns
裁剪以及正则化了的图像。
""" Transforms the image by cropping and resizing and
normalises intensity values between -1 and 1
"""
def transform(image, npx=64, is_crop=True):
if is_crop:
cropped_image = center_crop(image, npx)
else:
cropped_image = image
return np.array(cropped_image)/127.5-1.
center_crop()
如果requested,我们便进行图像的裁剪处理。因为我们处理的是正方形图像,[64×64][64×64]大小。 不如加入一个快速的选择。We can add a quick option to change that with short if statements looking at the crop_w argument to this function. 我们的输入是现在的图像和图像的宽和高
为了找到正方形的中心。我们取宽和高的一半记住,我们必须要对这个结果取整,才能得到一个确定的像素值。虽然这并不能保证我们得到一个完美的[64×64][64×64]图像我们后面再来讲这个事情吧。
我们还是用这个包scipy , 这次用scipy.misc.imresize takes in an image array and the desired size and outputs a resized image. 我们让他读书,我们的数组(虽然由于原图像的大小不一,他不一定是一个完美的正方形图片)然后 imresize 会进行推测保证我们的新图片最后变得还挺不错得。
Inputs
x, crop_h, crop_w, resize_w
Returns
the cropped image
"""
Crops the input image at the centre pixel
"""
def center_crop(x, crop_h, crop_w=None, resize_w=64):
if crop_w is None:
crop_w = crop_h
h, w = x.shape[:2]
j = int(round((h - crop_h)/2.))
i = int(round((w - crop_w)/2.))
return scipy.misc.imresize(x[j:j+crop_h, i:i+crop_w], [resize_w, resize_w])
get_image()
这个写了的话就可以一步到位找到我的图像而不用分开调用数个函数了。他会调用 imread 和 transform 函数。
Parameters
is_crop:
Inputs
image_path, image_size:(输出图像的大小)
Returns
the cropped image
"""
Loads the image and crops it to 'image_size'
"""
def get_image(image_path, image_size, is_crop=True):
return transform(imread(image_path), image_size, is_crop)
Saving Functions
在我们训练网络的时候。我们希望看到一些结果。之前都不住都是把图像从硬盘里,放到网络中去。现在我们想要从网络里拿出来一些深沉的图像。函数大概有这些:
1. save_images which calls
2. inverse_transform and
3. imsave which calls
4. merge
inverse_transform()
首先我们要把像素变回图片。先从 [−1 1][−1 1] 变 [0 1][0 1]
Inputs
images: the image to be transformed
Returns
the transformed image
"""
This turns the intensities back to a normal range
"""
def inverse_transform(images):
return(images+1.)/2.
merge()
我们创建一个例子图像数组这样我们能够时不时地从网络中看到一些图像例子。就能知道整个训练过程进行的怎么样了。我们需要输入图像并且告诉电脑我们需要多少个输出图像
首先我们要知道图像的宽和高。 注意图像是一个图像集合。他们用同样的宽和高。
我们把定义为最后的图像数组。并且把它初始化成全零的矩阵。对于RGB图像他有三列
然后我们遍历图像集中的每一个图像。%能够让我们得到她的子集。 // 这个符号能够对除法结果取整。
Inputs
images: the set of input images
size: [height, width] of the array
Returns
an array of images as a single image
"""
Takes a set of 'images' and creates an array from them.
"""
def merge(images, size):
h, w = images.shape[1], images.shape[2]
img = np.zeros((int(h * size[0]), int(w * size[1]),3))
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
img[j*h:j*h+h, i*w:i*w+w, :] = image
return img
imsave()
现在我们把像素变到正常的范围 [0 255][0 255] 之后再用这个scipy.misc.imsave变回图像
Inputs
images: the set of input images
size: [height, width] of the array
path: the save location
Returns
an image saved to disk
""" Takes a set of `images` and calls the merge function. Converts
the array to image data and saves to disk.
"""
def imsave(images, size, path):
img = merge(images, size)
return scipy.misc.imsave(path, (255*img).astype(np.uint8))
save_images()
Finally, let’s create the wrapper to pull this together:
Inputs
images: the images to be saves
size: the size of the img array [width height]
image_path: where the array is to be stored on disk
"""
takes an image and saves it to disk. Redistributes
intensity values [-1 1] from [0 255]
"""
def save_images(images, size, image_path):
return imsave(inverse_transform(images), size, image_path)
import numpy as np
import scipy.misc
THE GAN: 下面有三个部分
dataset_files(),
GAN Class( _init_(), discriminator(), generator(),build_model(),save(),load(),train() ),
Data Functions ( batch_norm(), conv2d(), relu(),linear(), conv2d_transpose())
from __future__ import division
import os
import time
import math
import itertools
from glob import glob
import tensorflow as tf
import numpy as np
from six.moves import xrange
#IMPORT OUR IMAGE AND DATA FUNCTIONS
from gantut_datafuncs import *
from gantut_imgfuncs import *
dataset_files()
这个文件主要就是一些housekeeping - 保证我们只处理支持的文件类型。这种办法,我比较欣赏B. Amos blog 的做法。我们定义的支持的文件类型。然后我们就返回可用于训练的文件清单
itertools.chain.from_iterable 函数可以用来在文件夹中产生一个含所有文件的清单
SUPPORTED_EXTENSIONS = ["png", "jpg", "jpeg"]
"""
Returns the list of all SUPPORTED image files in the directory
"""
def dataset_files(root):
return list(itertools.chain.from_iterable(
glob(os.path.join(root, "*.{}".format(ext))) for ext in SUPPORTED_EXTENSIONS))
DCGAN()
这里先做一个 DCGAN 类 (i.e. Deep Convolutional Generative Adversarial Network). 我们需要create:
__init__: to initialise the model and set parameters
build_model: creates the model (or ‘graph’ in TensorFlow-speak) by calling…
generator: defines the generator network
discriminator: defines the discriminator network
train: is called to begin the training of the network with data
save: saves the TensorFlow checkpoints of the GAN
load: loads the TensorFlow checkpoints of the GAN
We create an instance of our GAN class with DCGAN(args) and be returned a DCGAN object with the above methods. Let’s code.
__init__()
初始化 GAN对象, 我们需要一些初始化参数 :
def__init__(self, sess, image_size=64, is_crop=False, batch_size=64, sample_size=64, z_dim=100, gf_dim=64, df_dim=64, gfc_dim=1024, dfc_dim=1024, c_dim=3, checkpoint_dir=None, lam=0.1):
参数有:
sess: the TensorFlow session to run in
image_size: the width of the images, which should be the same as the height as we like square inputs
is_crop: whether to crop the images or leave them as they are
batch_size: number of images to use in each run
sample_size: number of z samples to take on each run, should be equal to batch_size
z_dim: number of samples to take for each z (什么是z sample)
gf_dim: dimension of generator filters in first conv layer
df_dim: dimenstion of discriminator filters in first conv layer
gfc_dim: dimension of generator units for fully-connected layer
dfc_gim: dimension of discriminator units for fully-connected layer
c_dim: number of image cannels (gray=1, RGB=3)
checkpoint_dir: where to store the TensorFlow checkpoints
lam: small constant weight for the sum of contextual and perceptual loss
这些都是GAN的可控变量。我们需要把这些输入传送到类的self里, 这样才能access到他们。我们还需要添加两行:
1.添加一个验证环节确保image_size 是2的次方。这里采用位运算方法 。 另外验证一个图像大小是不是大于 [8×8][8×8]
2. 取 image_shape (图像宽和高还有通道数 (gray or RBG).
#image_size must be power of 2 and 8+
assert(image_size & (image_size - 1) == 0 and image_size >= 8)
self.sess = sess
self.is_crop = is_crop
self.batch_size = batch_size
self.image_size = image_size
self.sample_size = sample_size
self.image_shape = [image_size, image_size, c_dim]
self.z_dim = z_dim
self.gf_dim = gf_dim
self.df_dim = df_dim
self.gfc_dim = gfc_dim
self.dfc_dim = dfc_dim
self.lam = lam
self.c_dim = c_dim
必须要做‘batch normalisation’ 以保证我们的图像不会相互之间有特别大的区别。我们需要对每一个生成器和鉴别器的卷积层做‘batch normalisation。我们在这里先给他做初始化。但是在gantut_datafuncs.py 里面定义它
#batchnorm (from funcs.py)
self.d_bns = [batch_norm(name='d_bn{}'.format(i,)) for i in range(4)]
log_size = int(math.log(image_size) / math.log(2))
self.g_bns = [batch_norm(name='g_bn{}'.format(i,)) for i in range(log_size)]
从这里看出,鉴别器有四层。对生成器我们需要更多的层次。生成器是从一个简单的向量开始但是需要上采样到 image_size大小。 通过log(image size)/log(2) 可以算出我们需要多少次。上采样log(image size)/log(2) 。2的num of layers次方=64。所以是八层 。 注意我们使用迭代器创造了一些网络层次对象他们的名字是g_bn1, g_bn2 等。
T为了完成这个__init__() 函数我们设置目录checkpoint 以便TensorFlow保存,指导他创建模型并命名为‘DCGAN.model’.
self.checkpoint_dir = checkpoint_dir
self.build_model()
self.model_name="DCGAN.model"
discriminator()
我们需要鉴别器鉴别一个真实的图像,保存变量,然后使用相同的变量来鉴别一个假的图像。这样的话如果一个图像是伪造的,却欺骗了鉴别器,说明生成器工作的不错。
这显然是一个分类任务。我们将通过使用sigmoid 返回一个概率。完整的输出也会被return。
"""输入是图像, kernel (filter)的大小(维度), 还有name(方便以后识别) """
def discriminator(self, image, reuse=False):
with tf.variable_scope("discriminator") as scope:
"""添加 tf.variable_scope() 方便 TensorBoard 可视化"""
if reuse:
scope.reuse_variables()
"""discriminator()方法中使用变量 reuse- 当使用fake images的时候就设置为True """
h0 = lrelu(conv2d(image, self.df_dim, name='d_h00_conv'))
h1 = lrelu(self.d_bns[0](conv2d(h0, self.df_dim*2, name='d_h1_conv'), self.is_training))
h2 = lrelu(self.d_bns[1](conv2d(h1, self.df_dim*4, name='d_h2_conv'), self.is_training))
h3 = lrelu(self.d_bns[2](conv2d(h2, self.df_dim*8, name='d_h3_conv'), self.is_training))
"""4层网络。 注意d_bns objects. 这个必须在non-linear lrelu function 之前"""
h4 = linear(tf.reshape(h3, [-1,8192]),1,'d_h4_lin')
""" 最后一层是线性层"""
return tf.nn.sigmoid(h4), h4
generator()
一个简单的输入-从已知分布pzpz随机采样的vector zz
生成器是一个反向的鉴别器是一个反卷积。从特定的值开始必须采用线性变换才能准备好将它们放入其他的神经网络层 。一开始我们并不知道权重和偏置项。我们需要保证我们从线性层输出这些并添加了 with_w=True."""
def generator(self, z):
with tf.variable_scope("generator") as scope:
self.z_, self.h0_w, self.h0_b = linear(z, self.gf_dim*8*4*4, 'g_h0_lin', with_w=True)
"""第一个隐藏层 hs[0] 需要 reshape 成很小的图像形状的数组这样我们才能把她送入神经网络进行上采样。最后变成[64×64][64×64]的图像。 """
"""所以我们拿已经线性变化过了的z-values ,reshape 到 [4x4xnum kernels][4x4xnum kernels]. 别忘了 -1 to do this for all images in the batch. 和之前一样必须batch-norm 再传入非线性层 """
hs = [None]
hs[0] = tf.reshape(self.z_, [-1, 4, 4, self.gf_dim * 8])
hs[0] = tf.nn.relu(self.g_bns[0](hs[0], self.is_training))
i=1 #iteration number
depth_mul = 8 #depth decreases as spatial component increases
size=8 #size increases as depth decreases
"""在每一层我们做的是:
给layer一个名字 使用 inverse convolution
采用non-linearity"""
""" inverse convolution函数会读入很小的正方形图像进行上采样使用一些我们将学习的权重把它变成更大的图像
我从第一层layer i=1开始这里我没希望他的size=8 而 第0层size=4。所以每一层增加一倍This will increase by a factor of 2 at each layer.
和卷积神经网络中一样。对于更大的图像我们使用更少的过滤器。所以每经过一层 depth_mul 减少一倍
当图像的大小和输入图像一样时,我们终止循环。“”“
while size < self.image_size:
hs.append(None)
name='g_h{}'.format(i)
hs[i], _, _ = conv2d_transpose(hs[i-1], [self.batch_size, size, size, self.gf_dim*depth_mul], name=name, with_w=True)
hs[i] = tf.nn.relu(self.g_bns[i](hs[i], self.is_training))
i += 1
depth_mul //= 2
size *= 2
"""最后一层把最后的output拿来进行反卷积,生成最终的fake image ,然后进入鉴别器测试"""
hs.append(None)
name = 'g_h{}'.format(i)
hs[i], _, _ = conv2d_transpose(hs[i-1], [self.batch_size, size, size, 3], name=name, with_w=True)
return tf.nn.tanh(hs[i])
--------
build_model()
把图像数据生成器鉴别器弄到一起。它包括了一些我们以后会用到的tf.placeholder部分.
我们需要知道模型是在training 还是 inference 。所以设置一个placeholder self.training.
对于图像数据自己也需要一个placeholder 。因为每个epoch会有不同的data batch导入。这些是对于 real_images.
当我们把 z vectors放进GAN (served by another placeholder) 我们也需要使用Tensorboard检测输出。通过加入 tf.summary.histogram()我们可以看到z vectors再每个epoch有多不同
def build_model(self):
self.is_training = tf.placeholder(tf.bool, name='is_training')
self.images = tf.placeholder(tf.float32, [None] + self.image_shape,
name='real_images')
# Real Image self.lowres_images
self.lowres_images = tf.reduce_mean(tf.reshape(self.images, [self.batch_size, self.lowres_size, self.lowres, self.lowres_size, self.lowres, self.c_dim]), [2,4])
self.z = tf.placeholder(tf.float32, [None, self.z_dim], name='z')
self.z_sum = tf.summary.histogram("z", self.z)
接下来我们告诉这个图。怎么使用生成器把z变成图像。我们依然会产生一个lowres 版本。现在把真是的图像放进鉴别器。取出概率和左后一层数据 然后重新用一样的鉴别器参数测试生成器产生的fake image。输出‘real_image’ 和 fake image概率的直方图和现有的生成器产生fake image
self.G = self.generator(self.z)
#fake image self.lowres_G
self.lowres_G = tf.reduce_mean
(tf.reshape(self.G, [self.batch_size, self.lowres_size, self.lowres,
self.lowres_size, self.lowres, self.c_dim]), [2,4])
# real images D, d
self.D, self.D_logits = self.discriminator(self.images)
# fake images D_, d_
self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True)
# real images
self.d_sum = tf.summary.histogram("d", self.D)
# fake images
self.d__sum = tf.summary.histogram("d_", self.D_)
self.G_sum = tf.summary.image("G", self.G)
Now for some of the necessary calculations needed to be able to update the network. Let’s find the ‘loss’ on the current outputs. We will utilise a very efficient loss function here the tf.nn.sigmoid_cross_entropy_with_logits. We want to calculate a few things:
how well did the discriminator do at letting true images through (i.e. comparing D to 1)
how often was the discriminator fooled by the generator (i.e. comparing D_ to 1)
how often did the generator fail at making realistic images (i.e. comparing D_to 0).
We’ll add the discriminator losses up (1 + 2) and create a TensorBoard summary statistic (a scalar value) for the discriminator and generator losses in this epoch. These are what we will optimise during training.
To keep everything tidy, we’ll group the discriminator and generator variables into d_vars and g_vars respectively.
self.d_loss_real = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits, labels=tf.ones_like(self.D)))
self.d_loss_fake = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.zeros_like(self.D_)))
self.g_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.ones_like(self.D_)))
self.d_loss_real_sum = tf.summary.scalar("d_loss_real", self.d_loss_real) self.d_loss_fake_sum = tf.summary.scalar("d_loss_fake", self.d_loss_fake)
self.d_loss = self.d_loss_real + self.d_loss_fake
self.g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)
self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
t_vars = tf.trainable_variables()
self.d_vars = [var forvarint_varsif'd_'invar.name]
self.g_vars = [varforvarint_varsif'g_'invar.name]
我们不想丢掉我们的进度啊。所以每次我们使用tf.Saver() 函数来保存最近的变量。
self.saver = tf.train.Saver(max_to_keep=1)
save()
我们想保存一个 checkpoint (他保存了我们所学习的所有权重。) 我们调用这个函数检查输出目录是否存在,如果不存在的话,就生成一个输出目录。然后我们调用tf.train.Saver.save() , 他会take现有的session sess, the save directory, model name 和 track the number of steps that’ve been done.
def save(self, checkpoint_dir, step):
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name), global_step=step)
load()
我们已经花了很长的时间去练拳种。我们不想每一次我先开始所以这个函数可以帮我们load the most recent checkpoint in the save directory. 如果没有checkpoints的话就返回false。
def load(self, checkpoint_dir):
print(" [*] Reading checkpoints...")
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path: self.saver.restore(self.sess, ckpt.model_checkpoint_path)
return True
else:
return False
train()
下面这个方法是见证奇迹的时刻。我们调用DCGAN.train(config)让网络他们开始博弈。我们把config argument放到后面讲,不过简单来说其实也就是一堆tf用的超参数。train()是这样的:
我们先把数据(using our dataset_files function)给他并保证他被randomly shuffled. 我们需要保证相邻的图像没有共同之处。我们有一个检查的函数assert(len(data) > 0)来保证我们没有把空目录传进去。
def train(self, config):
data = dataset_files(config.dataset)
np.random.shuffle(data)
assert(len(data) >0)
我们用adam 来训练( tf.train.AdamOptimizer() from Kingma et al (2014))在我们设置我们的鉴别器和生成器。 the discriminator (d_optim) and the generator (g_optim).
d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1).minimize(self.d_loss, var_list=self.d_vars)
g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1).minimize(self.g_loss, var_list=self.g_vars)
接下来我们初始化所有变量。Next we will initialize all variables in the network (depending on TensorFlow version) 为TensorBoard产生一些tf.summary变量, 把我们想要track的tf summary组合在一起。
try:
tf.global_variables_initializer().run()
except:
tf.initialize_all_variables().run()
self.g_sum = tf.summary.merge([self.z_sum, self.d__sum, self.G_sum, self.d_loss_fake_sum, self.g_loss_sum])
self.d_sum = tf.summary.merge([self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum])
self.writer = tf.summary.FileWriter("./logs", self.sess.graph)
现在我们sample已知分布pzpz 来得到向量zz。 使用 np.random.uniform distribution。再tf board上要注意看这个,告诉GAN类要输出 zz的直方图,他看起来应该像uniform 分布
sample已经shuffle了的输入的real image files, 把 sample_size量级的 images 带进training process. 我们将在输出一些例子,调用损失函数的时候经常使用到这些
用 get_image() load数据,确保他们在一个 np.array里.
sample_z = np.random.uniform(-1,1, size=(self.sample_size, self.z_dim))
sample_files = data[0:self.sample_size]
sample = [get_image(sample_file, self.image_size, is_crop=self.is_crop) for sample_file in sample_files]
sample_images = np.array(sample).astype(np.float32)
设置epoch counter , 获得 start time 。确保我们load 先前的TensorFlow checkpoints from TensorFlow before we start again from scratch.
counter =1
start_time = time.time()
if self.load(self.checkpoint_dir):
print(""" An existing model was found - delete the directory or specify a new one with --checkpoint_dir """)
else:
print(""" No model found - initializing a new one""")
现在我们要真正开始训练了。对于每一个epoch we’ve assigned in config我们产生两个 minibatches: 一批真图像一批生成器生成的假图像。然后在更新生成器之前更新鉴别器。这里也把loss values 写到TensorBoard summary里. 注意两点:
1.通过使用具体的变量调用sess.run() with specified variables in the first (or fetch attribute) 生成器可以在更新鉴别器的同时保持稳定,反之亦然。
2.鉴别器被更新了两次。这是为了保证鉴别器的损失函数不会很快的收敛到0。
for epoch in xrange(config.epoch):
data = dataset_files(config.dataset)
batch_idxs = min(len(data), config.train_size) // self.batch_size
for idx in xrange(0, batch_idxs):
batch_files = data[idx*config.batch_size:(idx+1)*config.batch_size]
batch = [get_image(batch_file, self.image_size, is_crop=self.is_crop) for batch_file in batch_files]
batch_images = np.array(batch).astype(np.float32)
batch_z = np.random.uniform(-1,1, [config.batch_size, self.z_dim]).astype(np.float32)
#update D network
_, summary_str = self.sess.run([d_optim, self.d_sum], feed_dict={self.images: batch_images, self.z: batch_z, self.is_training:True})
self.writer.add_summary(summary_str, counter)
#update G network
_, summary_str = self.sess.run([g_optim, self.g_sum], feed_dict={self.z: batch_z, self.is_training:True})
self.writer.add_summary(summary_str, counter)
#run g_optim twice to make sure that d_loss does not go to zero
_, summary_str = self.sess.run([g_optim, self.g_sum], feed_dict={self.z: batch_z, self.is_training:True})self.writer.add_summary(summary_str, counter)
为了得到反向传播需要的errors 我们计算d_loss_fake, d_loss_real and g_loss.
zz 向量穿过graph 得到 fake loss 和 generator loss,
对于real loss使用 real batch_images .
errD_fake = self.d_loss_fake.eval({self.z: batch_z, self.is_training:False})
errD_real = self.d_loss_real.eval({self.images: batch_images, self.is_training:False})
errG = self.g_loss.eval({self.z: batch_z, self.is_training:False})
每100 个 minibatches 评估现在的生成器self.G , 计算一个我们之前sample过的小的图像集的loss。输出生成器的结果,使用save_images() 函数来产生 image array。
counter +=1
print("Epoch [{:2d}] [{:4d}/{:4d}] time: {:4.4f}, d_loss: {:.8f}".format(epoch, idx, batch_idxs, time.time() - start_time, errD_fake + errD_real, errG))
if np.mod(counter,100) ==1:
samples, d_loss, g_loss = self.sess.run([self.G, self.d_loss, self.g_loss], feed_dict={self.z: sample_z, self.images: sample_images, self.is_training:False})
save_images(samples, [8,8],'./samples/train_{:02d}-{:04d}.png'.format(epoch, idx))
print("[Sample] d_loss: {:.8f}, g_loss: {:.8f}".format(d_loss, g_loss))
保存来自网络的现有weights
if np.mod(counter,500) ==2: self.save(config.checkpoint_dir, counter)