本文是 GAN 在 MNIST 数据集上生成假的手写数字图片的一个实例,具体是用 pytorch 实现的。
先来看下训练的结果,下面两张图中,上面的是真实手写数字图片,下面的是在训练了30个 epoch 之后 GAN 生成的假图片。整体来说,效果还是蛮不错的。
GAN 由一个生成器和一个对抗器组成,在该任务中,生成器的输入是一堆随机生成的噪音,其输出为生成的假图片,其目标是让生成的假图片尽可能的像真实图片;而判别器的输入是一张图片,其输出是这张图片是真实图片的概率。
在训练时,需要先训练判别器再训练生成器,如果判别器的好坏决定着生成器的效果。在生成器训练时,先将一堆噪音输入到生成器并得到假图片,然后再将假图片输入到判别器进行判别,然后将判别结果与真实标签(注意不是假标签,因为生成器的目标是尽可能的模拟真实图片)进行比对形成损失函数。判别器的训练分为两个部分,一是对真实图片进行判别,二是对假图片进行判别,得到的判别结果分别与真实标签和假标签对比形成真实图片的损失和假图片的损失,两者相加就是判别器的总损失。
这个代码有个蛋疼的地方是,如果用 keras 引入 MNIST 数据集,则在训练时损失函数会很快就趋近于 0 了,训练效果很差,而用 torchvision 时则没这个问题。
1 | import os |
- 本文作者: 俎志昂
- 本文链接: zuzhiang.cn/2020/02/21/GAN-example/
- 版权声明: 本博客所有文章除特别声明外,均采用 Apache License 2.0 许可协议。转载请注明出处!