详解GANs中9种主流损失函数,包括原始GAN、WGAN-GP等,并提供PyTorch代码实现。探索不同损失函数对GAN训练的影响。
原文标题:9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
原文作者:数据派THU
冷月清谈:
怜星夜思:
2、文章中WGAN通过weight clipping实现Lipschitz约束,但提到了这种方式的缺点。那么,除了gradient penalty,还有没有其他方法可以实现或者近似Lipschitz约束呢?
3、文章中提到了InfoGAN通过最大化互信息来实现对生成样本特定属性的控制。那么,这种方法在实际应用中会遇到什么挑战?如何解决这些挑战?
原文内容
本文共4500字,建议阅读5分钟
本文通过详细分析GAN的经典损失函数及其多种变体,揭示了不同类型损失函数各自的优势。
生成对抗网络(GANs)的训练效果很大程度上取决于其损失函数的选择。本研究首先介绍经典GAN损失函数的理论基础,随后使用PyTorch实现包括原始GAN、最小二乘GAN(LS-GAN)、Wasserstein GAN(WGAN)及带梯度惩罚的WGAN(WGAN-GP)在内的多种损失函数。
生成对抗网络(GANs)的工作原理堪比一场精妙的艺术创作过程——生成器(Generator)扮演创作者角色,不断生成作品;判别器(Discriminator)则如同严苛的评论家,持续提供改进建议。这种对抗学习机制促使两个网络在竞争中共同进步。判别器向生成器提供反馈的方式——即损失函数的设计——对整个网络的学习表现有着决定性影响。
GAN的基本原理与经典损失函数
1、原始GAN

-
D(x)表示判别器对输入x判定为真实样本的概率
-
G(z)表示生成器将随机噪声z转换为合成图像的函数
-
p_{data}(x)表示真实数据分布
-
p_z(z)表示噪声先验分布,通常为标准正态分布
import torch import torch.nn as nn
原始GAN损失函数实现
class OriginalGANLoss:
def init(self, device):
self.device = device
self.criterion = nn.BCELoss()def discriminator_loss(self, real_output, fake_output):
真实样本的目标标签为1.0
real_labels = torch.ones_like(real_output, device=self.device)
生成样本的目标标签为0.0
fake_labels = torch.zeros_like(fake_output, device=self.device)
计算判别器对真实样本的损失
real_loss = self.criterion(real_output, real_labels)
计算判别器对生成样本的损失
fake_loss = self.criterion(fake_output, fake_labels)
总损失为两部分之和
d_loss = real_loss + fake_loss
return d_lossdef generator_loss(self, fake_output):
生成器希望判别器将生成样本判断为真实样本
target_labels = torch.ones_like(fake_output, device=self.device)
g_loss = self.criterion(fake_output, target_labels)
return g_loss
2、非饱和损失函数(Non-Saturating Loss)

class NonSaturatingGANLoss: def __init__(self, device): self.device = device self.criterion = nn.BCELoss()
def discriminator_loss(self, real_output, fake_output):
与原始GAN相同
real_labels = torch.ones_like(real_output, device=self.device)
fake_labels = torch.zeros_like(fake_output, device=self.device)real_loss = self.criterion(real_output, real_labels)
fake_loss = self.criterion(fake_output, fake_labels)d_loss = real_loss + fake_loss
return d_lossdef generator_loss(self, fake_output):
非饱和损失:直接最大化log(D(G(z)))
target_labels = torch.ones_like(fake_output, device=self.device)
注意这里使用的是相同的BCE损失,但目标是让D将G(z)判断为真
g_loss = self.criterion(fake_output, target_labels)
return g_loss
GAN变体实现与原理分析
3、最小二乘GAN(LS-GAN)
class LSGANLoss: def __init__(self, device): self.device = device # LS-GAN使用MSE损失而非BCE损失 self.criterion = nn.MSELoss()
def discriminator_loss(self, real_output, fake_output):
真实样本的目标值为1.0
real_labels = torch.ones_like(real_output, device=self.device)
生成样本的目标值为0.0
fake_labels = torch.zeros_like(fake_output, device=self.device)
计算真实样本的MSE损失
real_loss = self.criterion(real_output, real_labels)
计算生成样本的MSE损失
fake_loss = self.criterion(fake_output, fake_labels)
d_loss = real_loss + fake_loss
return d_lossdef generator_loss(self, fake_output):
生成器希望生成的样本被判别为真实样本
target_labels = torch.ones_like(fake_output, device=self.device)
g_loss = self.criterion(fake_output, target_labels)
return g_loss
4、Wasserstein GAN(WGAN)

class WGANLoss: def __init__(self, device, clip_value=0.01): self.device = device self.clip_value = clip_value
def discriminator_loss(self, real_output, fake_output):
WGAN的判别器(称为critic)直接最大化真实样本和生成样本输出的差值
注意这里没有使用sigmoid激活
d_loss = -torch.mean(real_output) + torch.mean(fake_output)
return d_lossdef generator_loss(self, fake_output):
生成器希望最大化critic对生成样本的评分
g_loss = -torch.mean(fake_output)
return g_lossdef weight_clipping(self, critic):
权重裁剪,限制critic参数范围
for p in critic.parameters():
p.data.clamp_(-self.clip_value, self.clip_value)

class WGANGP: def __init__(self, device, lambda_gp=10): self.device = device self.lambda_gp = lambda_gp
def discriminator_loss(self, real_output, fake_output, real_samples, fake_samples, discriminator):
基本的Wasserstein距离
d_loss = -torch.mean(real_output) + torch.mean(fake_output)
计算梯度惩罚
在真实和生成样本之间随机插值
alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=self.device)
interpolates = alpha * real_samples + (1 - alpha) * fake_samples
interpolates.requires_grad_(True)计算判别器对插值样本的输出
d_interpolates = discriminator(interpolates)
计算梯度
fake_outputs = torch.ones_like(d_interpolates, device=self.device, requires_grad=False)
gradients = torch.autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=fake_outputs,
create_graph=True,
retain_graph=True,
only_inputs=True
)[0]计算梯度L2范数
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()添加梯度惩罚项
d_loss = d_loss + self.lambda_gp * gradient_penalty
return d_loss
def generator_loss(self, fake_output):
与WGAN相同
g_loss = -torch.mean(fake_output)
return g_loss
6、条件生成对抗网络(CGAN)

class CGANLoss: def __init__(self, device): self.device = device self.criterion = nn.BCELoss()
def discriminator_loss(self, real_output, fake_output):
条件GAN的判别器损失与原始GAN相似,只是输入增加了条件信息
real_labels = torch.ones_like(real_output, device=self.device)
fake_labels = torch.zeros_like(fake_output, device=self.device)real_loss = self.criterion(real_output, real_labels)
fake_loss = self.criterion(fake_output, fake_labels)d_loss = real_loss + fake_loss
return d_lossdef generator_loss(self, fake_output):
与原始GAN相似
target_labels = torch.ones_like(fake_output, device=self.device)
g_loss = self.criterion(fake_output, target_labels)
return g_lossCGAN的网络结构示例
class ConditionalGenerator(nn.Module):
def init(self, latent_dim, n_classes, img_shape):
super(ConditionalGenerator, self).init()
self.img_shape = img_shape
self.label_emb = nn.Embedding(n_classes, n_classes)self.model = nn.Sequential(
输入是噪声向量与条件拼接后的向量
nn.Linear(latent_dim + n_classes, 128),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(128, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, int(np.prod(img_shape))),
nn.Tanh()
)def forward(self, z, labels):
条件嵌入
c = self.label_emb(labels)
拼接噪声和条件
x = torch.cat([z, c], 1)
生成图像
img = self.model(x)
img = img.view(img.size(0), *self.img_shape)
return img
7、信息最大化GAN(InfoGAN)

class InfoGANLoss: def __init__(self, device, lambda_info=1.0): self.device = device self.criterion = nn.BCELoss() self.lambda_info = lambda_info # 对于离散潜变量使用交叉熵损失 self.discrete_criterion = nn.CrossEntropyLoss() # 对于连续潜变量使用高斯分布负对数似然 self.continuous_criterion = nn.MSELoss()
def discriminator_loss(self, real_output, fake_output):
判别器损失与原始GAN相同
real_labels = torch.ones_like(real_output, device=self.device)
fake_labels = torch.zeros_like(fake_output, device=self.device)real_loss = self.criterion(real_output, real_labels)
fake_loss = self.criterion(fake_output, fake_labels)d_loss = real_loss + fake_loss
return d_lossdef generator_info_loss(self, fake_output, q_discrete, q_continuous, c_discrete, c_continuous):
生成器损失部分(欺骗判别器)
target_labels = torch.ones_like(fake_output, device=self.device)
g_loss = self.criterion(fake_output, target_labels)互信息损失部分
离散潜变量的互信息损失
info_disc_loss = self.discrete_criterion(q_discrete, c_discrete)
连续潜变量的互信息损失
info_cont_loss = self.continuous_criterion(q_continuous, c_continuous)
总损失
total_loss = g_loss + self.lambda_info * (info_disc_loss + info_cont_loss)
return total_loss, info_disc_loss, info_cont_loss
8、能量基础GAN(EBGAN)
class EBGANLoss: def __init__(self, device, margin=10.0): self.device = device self.margin = margin
def discriminator_loss(self, real_energy, fake_energy):
判别器的目标是降低真实样本的能量,提高生成样本的能量(直到边界值)
对生成样本的损失使用hinge loss
hinge_loss = torch.mean(torch.clamp(self.margin - fake_energy, min=0))
总损失
d_loss = torch.mean(real_energy) + hinge_loss
return d_lossdef generator_loss(self, fake_energy):
生成器的目标是降低生成样本的能量
g_loss = torch.mean(fake_energy)
return g_loss
9、f-GAN

class FGANLoss: def __init__(self, device, divergence_type='kl'): self.device = device self.divergence_type = divergence_type
def activation_function(self, x):
不同散度对应的激活函数
if self.divergence_type == ‘kl’: # KL散度
return x
elif self.divergence_type == ‘js’: # JS散度
return torch.log(1 + torch.exp(x))
elif self.divergence_type == ‘hellinger’: # Hellinger距离
return 1 - torch.exp(-x)
elif self.divergence_type == ‘total_variation’: # 总变差距离
return 0.5 * torch.tanh(x)
else:
return x # 默认为KL散度def conjugate_function(self, x):
不同散度的Fenchel共轭
if self.divergence_type == ‘kl’:
return torch.exp(x - 1)
elif self.divergence_type == ‘js’:
return -torch.log(2 - torch.exp(x))
elif self.divergence_type == ‘hellinger’:
return x / (1 - x)
elif self.divergence_type == ‘total_variation’:
return x
else:
return torch.exp(x - 1) # 默认为KL散度def discriminator_loss(self, real_output, fake_output):
判别器损失
注意:在f-GAN中,通常D的输出需要经过激活函数处理
activated_real = self.activation_function(real_output)
d_loss = -torch.mean(activated_real) + torch.mean(self.conjugate_function(fake_output))
return d_lossdef generator_loss(self, fake_output):
生成器损失
activated_fake = self.activation_function(fake_output)
g_loss = -torch.mean(activated_fake)
return g_loss