import tensorflow as tf
# tf.config.run_functions_eagerly(True) # 调试时使用
class GAN(tf.keras.Model):
def __init__(self, data_dim, latent_dim, hidden_dims1, hidden_dims2):
super(GAN, self).__init__()
self.data_dim = data_dim
self.latent_dim = latent_dim
self.hidden_dims1 = hidden_dims1
self.hidden_dims2 = hidden_dims2
# 定义判别器
input_d = tf.keras.layers.Input(shape=(data_dim,))
if len(hidden_dims1) != 0:
discriminator = tf.keras.layers.Dense(hidden_dims1[0], activation='relu')(input_d)
for dim in hidden_dims1[1:]:
discriminator = tf.keras.layers.Dense(dim, activation='relu')(discriminator)
discriminator = tf.keras.layers.Dense(1, activation='sigmoid')(discriminator)
else:
discriminator = tf.keras.layers.Dense(1, activation='sigmoid')(input_d)
self.discriminator = tf.keras.Model(input_d, discriminator, name='discriminator')
self.discriminator.summary()
# 定义生成器
input_g = tf.keras.layers.Input(shape=(latent_dim,)) # 生成
if len(hidden_dims2) != 0:
generator = tf.keras.layers.Dense(hidden_dims2[0], activation='relu')(input_g)
for dim in hidden_dims2[1:]:
generator = tf.keras.layers.Dense(dim, activation='relu')(generator)
generator = tf.keras.layers.Dense(data_dim, activation='sigmoid')(generator)
else:
generator = tf.keras.layers.Dense(data_dim, activation='sigmoid')(input_g)
self.generator = tf.keras.Model(input_g, generator, name='generator')
self.generator.summary()
# ----优化器和记录d_loss和g_loss
self.d_optimizer = None
self.g_optimizer = None
self.loss_func = None
self.d_loss = None
self.g_loss = None
def compile(self, d_optimizer, g_optimizer, loss_func):
super().compile()
# 优化器和损失函数
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.loss_func = loss_func
self.d_loss = tf.keras.metrics.Mean(name="d_loss")
self.g_loss = tf.keras.metrics.Mean(name="g_loss")
def train_step(self, real_datas):
# 从潜在空间进行采样
batch_size = tf.shape(real_datas)[0]
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
# 通过生成器产生假样本
generated_datas = self.generator(random_latent_vectors)
# 组合真实样本和假样本
combined_datas = tf.concat([generated_datas, real_datas], axis=0)
labels = tf.concat(
[tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
) # 标签0和1,0:真实样本, 1:fake(keras官网是这样)
# Add random noise to the labels - important trick!(keras官网例子这么做)
labels += 0.05 * tf.random.uniform(tf.shape(labels))
# 训练判别器,训练5次判别器
for i in range(5):
with tf.GradientTape() as tape:
predictions = self.discriminator(combined_datas)
d_loss = self.loss_func(labels, predictions)
self.d_loss.update_state(d_loss)
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(
zip(grads, self.discriminator.trainable_weights)
)
# 从潜在空间进行采样
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
# 假定这些样本的都是正常的
misleading_labels = tf.zeros((batch_size, 1))
# 训练生成器,训练1次
with tf.GradientTape() as tape:
predictions = self.discriminator(self.generator(random_latent_vectors))
g_loss = self.loss_func(misleading_labels, predictions)
self.g_loss.update_state(g_loss)
grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
# 一次迭代结束
return {
"d_loss": self.d_loss.result(),
"g_loss": self.g_loss.result(),
}
一维变量的GAN
最后编辑于 :
©著作权归作者所有,转载或内容合作请联系作者
- 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
- 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
- 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...