VQ-VAE:探索离散化特征学习的新方法

本文探讨VQ-VAE及其在特征学习中的应用,揭示离散表示如何提升模型性能。

原文标题:VQ-VAE:矢量量化变分自编码器,离散化特征学习模型(附代码)

原文作者:数据派THU

冷月清谈:

本文深入分析了矢量量化变分自编码器(VQ-VAE)的工作原理及其优势,并与传统变分自编码器(VAE)进行了比较。VQ-VAE通过矢量量化步骤解决了VAE中常见的后验崩溃问题,从而提升了模型学习有效表示的能力。文章介绍了概率基础、VAE架构以及VQ-VAE的实现细节,包括重构损失及其与其他损失项的关系,最后展示了在PyTorch中的代码实现。VQ-VAE的离散表示提供了一种高效的表示学习方式,特别适合处理人类语言与图像等离散信号问题。整体而言,VQ-VAE展示了在机器学习领域中离散表示的重要性和潜力。

怜星夜思:

1、VQ-VAE与传统VAE的主要区别是什么?
2、实际应用中,VQ-VAE可以应用于哪些场景?
3、VQ-VAE的梯度回传机制有什么特别之处?

原文内容

来源:DeepHub IMBA
本文约1800字,建议阅读5分钟
本文将研究概率基础、VAE 架构和 VQ-VAE。


VQ-VAE 是变分自编码器(VAE)的一种改进。这些模型可以用来学习有效的表示。本文将深入研究 VQ-VAE,不过,在这之前我们先讨论一些概率基础和 VAE 架构。



后验和先验分布



证据下界(ELBO)


在机器学习模型中,大多数后验分布都相当复杂。我们使用变分推理这一基于优化的方法来近似这些分布。ELBO 是变分推理中一个至关重要的目标函数。其推导方式如下。




重构项用于评估解码器从潜在变量重构输入的能力。KL散度项则充当正则化机制。


变分自编码器(VAE)


标准的自编码器将输入映射到潜在空间中的单个点。然而,VAE的编码器输出概率分布的参数(均值和方差)。模型从这个分布中采样一个点,然后将其输入到解码器中。



我们使用ELBO作为损失函数。


VAE存在后验崩溃的问题:模型中的正则化项开始主导损失函数,后验分布变得与先验分布相似。解码器变得过于强大,忽略了潜在表示。因此后验分布将不包含有关潜在变量的信息。


在VQ-VAE中,通过矢量量化步骤避免了后验崩溃。


矢量量化变分自编码器(VQ-VAE)


离散表示可以有效地用来提高机器学习模型的性能。人类语言本质上是离散的,使用符号表示。我们可以使用语言来解释图像。因此在机器学习中使用潜在空间的离散表示是一个自然的选择。



首先,编码器生成嵌入。然后从码本中为给定嵌入选择最佳近似。码本由离散向量组成。使用L2距离进行最近邻查找。


在反向传播过程中,通过嵌入选择步骤的梯度流动并非易事。编码器的输出嵌入和解码器的输入嵌入具有相同的维度。所以直接将解码器输入的梯度复制到编码器输出(红色箭头)。这样可以产生一个良好的梯度近似。



在训练过程中,梯度可以推动编码器嵌入(绿色圆圈)靠近不同的离散表示(紫色圆圈)。


优化编码器、解码器和嵌入(即码本)。损失函数可以用以下方式表达。


图片


第一个术语是重构损失(类似于标准的VAE)。它衡量解码器在生成与输入分布相似的输出方面的表现。如果输入是正态分布的,这一项将是简单的均方误差。


sg 是停止梯度操作符,用来停止参数学习。由于从解码器到编码器的直接路径,重构损失项不会向嵌入提供学习信号。所以使用第二项来优化码本,将嵌入推向编码器表示。


第三项是commitment损失。它防止嵌入任意增长。


解码器仅由第一项优化。第一项和第三项优化编码器。第二项优化码本。


在训练期间,先验保持均匀。因此,ELBO的KL散度项是恒定的。



Pytorch实现


矢量量化器可以通过以下方式实现。


class VectorQuantizer(nn.Module):
def __init__(self, num_embeddings, embedding_dim, commitment_cost):
super(VectorQuantizer, self).__init__()

self._embedding_dim = embedding_dim
self._num_embeddings = num_embeddings

self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
self.embedding.weight.data.uniform(-1/self._num_embeddings, 1/self._num_embeddings)
self._commitment_cost = commitment_cost

def forward(self, inputs):

convert inputs from BCHW -> BHWC

inputs = inputs.permute(0, 2, 3, 1).contiguous()
input_shape = inputs.shape

Flatten input

flat_input = inputs.view(-1, self._embedding_dim)

Calculate distances

distances = (torch.sum(flat_input**2, dim=1, keepdim=True)

  • torch.sum(self._embedding.weight**2, dim=1)
  • 2 * torch.matmul(flat_input, self._embedding.weight.t()))

Encoding

encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=inputs.device)
encodings.scatter
(1, encoding_indices, 1)

Quantize and unflatten

quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)

Loss

e_latent_loss = F.mse_loss(quantized.detach(), inputs)
q_latent_loss = F.mse_loss(quantized, inputs.detach())
loss = q_latent_loss + self._commitment_cost * e_latent_loss

quantized = inputs + (quantized - inputs).detach()
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

convert quantized from BHWC -> BCHW

return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings


我们将输入扁平化,并保持嵌入空间的维数为_embedding_dim。假设输入为 16,32,32,64 BHWC/ batch, height, width, channels 。被压扁成[16384,64]。


# Flatten input
flat_input = inputs.view(-1, self._embedding_dim)


然后计算从每个嵌入向量到每个码本向量的距离的平方。假设(N, D)是编码器的输出,(K, D)是码本。得到(N, K)大小的结果。


distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
+ torch.sum(self._embedding.weight**2, dim=1)
- 2 * torch.matmul(flat_input, self._embedding.weight.t()))


接下来,我们跨dim = 1(跨码本)执行简单的argmin,获得与编码器输出距离最小的嵌入。我们生成N个大小为K的一元向量。


encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)


将嵌入表与这个独热向量相乘以提取最接近的码本向量。这就是量化过程。


接下来定义损失项(重建损失除外)。Mse代表均方误差,.detach作为停止梯度操作。


e_latent_loss = F.mse_loss(quantized.detach(), inputs)
q_latent_loss = F.mse_loss(quantized, inputs.detach())
loss = q_latent_loss + self._commitment_cost * e_latent_loss


最后确保梯度可以直接从解码器流向编码器。


quantized = inputs + (quantized - inputs).detach()


从数学上讲,左右两边是相等的(+输入和-输入将相互抵消)。在反向传播过程中,.detach部分将被忽略。


以上就是VQ VAE的完整实现,原始的完整代码可以在这里找到:

https://colab.research.google.com/github/zalandoresearch/pytorch-vq-vae/blob/master/vq-vae.ipynb#scrollTo=JscoOyZ3ddge


最后论文:

ArXiv. /abs/1711.00937


编辑:于腾凯
校对:林亦霖



关于我们

数据派THU作为数据科学类公众号,背靠清华大学大数据研究中心,分享前沿数据科学与大数据技术创新研究动态、持续传播数据科学知识,努力建设数据人才聚集平台、打造中国大数据最强集团军。




新浪微博:@数据派THU

微信视频号:数据派THU

今日头条:数据派THU

VQ-VAE通过停梯度操作(detach),确保梯度能够直接从解码器流向编码器,这种机制能够有效地避免信息丢失。

它的特别之处在于使用独热编码来选择最近的嵌入,而这个独热编码又参与到损失计算中,使得模型可以有效地学习。

抓住这个重点,VQ-VAE的设计使得梯度可以流动得更顺畅,感觉就像给了模型一条隐形的快速通道,太巧妙了!

主要的区别在于VQ-VAE使用码本来量化潜在表征,而VAE则只是输出一个连续的潜在空间模型。这使得VQ-VAE在处理离散数据时表现更加突出。

简单来说,VQ-VAE是离散的,而VAE是连续的。想想看,VQ-VAE就像用乐高拼图,而传统VAE是画一幅画,前者可以更灵活组合!

VQ-VAE可以应用于图像生成、语音处理,甚至在自然语言处理任务中也能找到它的身影,比如文本生成和翻译等。

由于VQ-VAE对离散数据很敏感,它非常适合用于生成模型,比如自动文本生成、图像风格转换等应用场景。

我觉得VQ-VAE可以用在很多地方,比如生成新的艺术作品,或是提高聊天机器人对话的自然度,真的是一个颇具潜力的模型!