常用生成模型 -- GAN

1、GAN简介

GAN,对抗生成网络, 包括两个网络,即生成器G与判别器D。它们的功能分别是:

(1)G是一个生成图片的网络,它接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)。相当于decoder。

(2)D是一个判别网络,判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。相当于encoder。

在训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D。而D的目标就是尽量把G生成的图片和真实的图片分别开来。这样,G和D构成了一个动态的“博弈过程”。在最理想的状态下,G可以生成足以“以假乱真”的图片G(z)。对于D来说,它难以判定G生成的图片究竟是不是真实的,因此D(G(z)) = 0.5。这样我们就得到了一个生成式的模型G,它可以用来生成图片。

 

2、GAN博弈过程

第1代生成器输入随机噪声,输出一张生成的图片,和真实图片一起喂入第1代鉴别器,鉴别器可以判断出这个样本是来自生成器还是真实图片,然后第2代生成器调节参数,生成一张更为真实的图片,继续喂给第1代鉴别器,第1代鉴别器无法判断这个样本是来自生成器还是真实图片,于是第1代鉴别器调节参数,有了第2代鉴别器,以此类推,直到鉴别器无法判断这个样本是来自生成器还是真实图片。

鉴别器工作:输入一张图像,判断图像和真实图像是否属于同一个数据分布,输出real或fake。

生成器工作:输入随机噪声,生成图像,喂给鉴别器判断真伪,鉴别器判断出来是fake,于是生成器继续生成更真实的图片,直到鉴别器判断结果为1,也就是认为生成器生成的图像是真实图像。

 

3、GAN原理

(1)极大似然估计 -- 已知结果,倒推出现这种结果最可能的参数。

(2)生成器通过函数G尽可能学习一个分布,使得它非常接近数据真实分布,鉴别器尽可能学习到之间的差异,GAN的目标就是求得,D的训练目标是使两个分布之间的差异最大,G的训练目标是使得两个分布之间的差异最小。所以我们的优化目标就是:

(3)生成器可以迭代很多种,这里假设我们只有三种:G1、G2、G3,得到每种生成器G下,D的变化曲线如图,就是在3种生成器G下,找到一个D使得真实数据和生成数据的分布差异最大(图中红色点位置的D),就是找到一个G,使得差异最小,也就是G3。

(4)数据之间的分布差异是什么?

(交叉熵损失函数)

 

4、损失函数

这个公式看似复杂,其实只要我们理解了GAN的博弈过程,就可以很清楚的了解这个公式的含义了,我们直到,GAN是单独交替迭代训练的,所以这个目标函数也是分别对判别器和生成器进行优化的,首先对判别器进行优化,表达形式如下:

其中D(X)表示对真实的样本进行判别,这里,我们希望它的判别结果越接近于1越好,所以损失函数为log(D(x)),而z是随机的输入,G(z)表示生成的样本,对于生成的样本,我们希望判别器的判别结果D(G(z))越接近于0越好,也就是让总数值最大,所以总体表达形式如上所示。

在完成对判别模型的优化之后,便是对生成模型进行优化,在这里,生成模型的优化很简单,只需要让判别的结果D(G(z))接近于1就可以了,也就是让总数值最小。

 

5、实践中的GAN

在每次迭代中:

(1)训练D:从真实数据分布中采样m个样本,从学习到数据分布中采样m个噪声样本,用噪声样本放入生成器中生成m个图像,梯度下降更新参数,使得两个分布之间的距离最大。

(2)训练G:从学习到数据分布中采样m个噪声样本,梯度下降更新参数,使得两个分布之间的距离最小。

(3)D训练多次,G只训练一次。,因为G不一定产生的图片总是朝着更有利的方向,G的频繁更新会导致模型不稳定。

实际训练过程中,我们会用代替。因为训练初始时,梯度大有利于加快训练速度,而训练快结束时,我们希望梯度小,使得模型不要发生震荡。有图可知,的图像刚好符合我们的预期。

 

6、GAN不稳定的原因

(1)很难使得G和D同时收敛,更新参数使得G梯度下降的同时,可能引起D的模型更差。

(2)生成器G发生模式崩溃,对于不同的输入生成类似的图片。

(3)生成器梯度消失问题,如果D很快达到最优,而G生成的图像和真实分布相差很远,此时G很难继续更新到最优。

 

7、GAN为什么会产生模式崩溃现象?

目前DNN只能预测连续分布,而源数据分布往往是具有间断点的非连续分布,所以在训练过程中,DNN无法学习到具有间断点的非连续分布。如果目标概率测度的支集具有多个联通分支,GAN训练得到的又是连续映射,则有可能连续映射的值域集中在某一个连通分支上,这就是模式崩溃(mode collapse),如果强行用一个连续映射来覆盖所有的连通分支,那么这一连续映射的值域必然会覆盖之外的一些区域,即GAN会生成一些没有现实意义的图片。

模式崩溃更多内容参见:https://blog.csdn.net/qq_32172681/article/details/99676858

 

8、代码解读

代码地址:https://github.com/ikostrikov/TensorFlow-VAE-GAN-DRAW

(1)生成器(decoder)

 

(2)鉴别器(encoder)

 

(3)损失函数计算

生成器的目标是愚弄辨别器蒙混过关,需要达到的目标是对于生成的图片,输出为1(正好和鉴别器相反)。

辨别器对假数据的损失原理相同,最终达到的目标是对于所有的真实图片,输出为1;对于所有的假图片,输出为0。 

代码如下:

def __get_discrinator_loss(self, D1, D2):'''Loss for the discriminator networkArgs:D1: logits computed with a discriminator networks from real imagesD2: logits computed with a discriminator networks from generated imagesReturns:Cross entropy loss, positive samples have implicit labels 1, negative 0s'''return (losses.sigmoid_cross_entropy(D1, tf.ones(tf.shape(D1))) +losses.sigmoid_cross_entropy(D2, tf.zeros(tf.shape(D1))))def __get_generator_loss(self, D2):'''Loss for the genetor. Maximize probability of generating images thatdiscrimator cannot differentiate.Returns:see the paper'''return losses.sigmoid_cross_entropy(D2, tf.ones(tf.shape(D2)))

(4)整体结构

 

 

参考文章地址:

https://www.cnblogs.com/baiting/p/8314936.html

https://www.jianshu.com/p/40feb1aa642a

https://blog.csdn.net/wang2008start/article/details/76443576

 


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部