本文主要讲解有关生成对抗网络(GAN)的相关知识。
一、判别模型和生成模型
机器学习中的模型一般有两种:1. 决策函数 Y=f(X);2. 条件概率分布 P(Y|X)
根据通过学习数据来获取这两种模型的方法,可以分为判别方法和生成方法。判别方法是由数据直接学习决策函数或条件概率分布作为预测模型,即判别模型;而生成模型是由数据学习联合概率分布 P(X,Y),然后由 P(Y|X)=p(X,Y)/P(X) 求出概率分布 P(Y|X) 作为预测模型,即生成模型。
二、生成对抗网络(GAN)
生成对抗网络(GAN)启发自博弈论中的零和博弈(即两人的利益之和为零,一方的所得正式另一方的所失),GAN 模型由生成模型(generative model)和对抗模型(discriminative model)组成。生成模型 G 捕捉样本数据的分布,用服从某一分布的噪声 z 来生成一个类似真是训练数据的样本,追求效果越像真实样本越好;判别模型 D 是一个二分类器,估计一个样本来自训练数据(而非生成数据)的概率。生成器就类似于造假币的人,而判别器就类似于验钞机,生成器的目的就是其造的假币要骗过验钞机。
GAN 的目标函数:
$$
\min_G\max_DV(D,G)=E_{x-p_{data}(x)}[\log D(x)]+E_{z-p_z(x)}[\log(1-D(G(z)))]
$$
其中 D(x) 表示真实数据通过判别器 D 的输出结果,而 D(G(z)) 是噪声 z 通过生成器 G 生成的假数据通过判别器 D 的输出结果。所以对于判别器 D 来说,需要最大化目标函数,即通过判别器让真币为真的概率越大越好,而假币为真的概率越小越好。对于生成器 G 来说,需要最小化目标函数,即让假币越真越好。
在训练的时候需要先训练判别器,再训练生成器。
令 $C(G)=\max_DV(G,D)$,则可以通过推导得到:$C(G)=-\log(4)+2\cdot JSD(p_{data}||p_g)$,其中 JSD 表示 Jensen-Shannon divergence,即 JS 散度。$p_{data}$ 是真实数据的分布,而 $p_g$ 是生成的假数据的分布。
GAN 存在训练过程不稳定的问题,这一方面是因为 GAN 自身的缺陷,另一方面是因为生成器和判别器的能力不匹配;此外生成器只会生成一两种类别的样本。
GAN 的一个改进是 WGAN。当真实数据的分布和假数据的分布互不重叠时,JS 散度值会趋近于一个常数,其导数接近于0,这就导致了梯度消失。所以重新定义了一种 Wasserstein-1 距离来代替原来的 JS 散度:
$$
W(P_r,P_g)=\inf_{\gamma-\Pi(P_r,P_g)}E_{(x,y)}[||x-y||]
$$
即使 $P_r$ 和 $P_g$ 互不重叠,wasserstein 距离依旧可以清楚的反应两个分布的距离。
目标函数也变为了:
$$
\max_{f_w}E_{x-P_r}[f_w(x)]-E_{z-P_z}[f_w(G(z))]
$$
$$
\min_G-E_{z-P_z}[f_w(G(z))]
$$
WGAN 很好的解决了训练不稳定和模式崩溃的问题。
GAN 只能随机产生一个类别,CGAN 可以指定类别来生成。
- 本文作者: 俎志昂
- 本文链接: zuzhiang.cn/2020/02/14/GAN/
- 版权声明: 本博客所有文章除特别声明外,均采用 Apache License 2.0 许可协议。转载请注明出处!