探索人体姿态估计新范式:利用VQ-VAE将姿态转化为离散标记组合,有效捕捉关键点间关系,提升模型泛化能力。
原文标题:用离散标记重塑人体姿态:VQ-VAE实现关键点组合关系编码
原文作者:数据派THU
冷月清谈:
怜星夜思:
2、文章中使用了合成的火柴人数据集,这虽然简化了实验,但与真实人体姿态相比存在差距。你认为在真实人体姿态数据集上应用这种方法,会遇到哪些新的挑战?又该如何应对?
3、文章中提到基于离散标记的姿态表示方法,为后续的姿态分析和理解任务提供了新的可能性。除了文章中提到的动作识别和姿态分类,你还能想到哪些其他的下游任务可以利用这种离散的姿态表示?这种表示方法在这些任务中可能有哪些优势?
原文内容
来源:Deephub Imba本文共5500字,建议阅读6分钟
本文介绍了使用离散标记重塑人体姿态的过程。
在人体姿态估计领域,传统方法通常将关键点作为基本处理单元,这些关键点在人体骨架结构上代表关节位置(如肘部、膝盖和头部)的空间坐标。现有模型对这些关键点的预测主要采用两种范式:直接通过坐标回归或间接通过热图(heat map,即图像空间中的密集概率分布)进行估计。尽管这些方法在实际应用中取得了显著效果,但它们往往将每个关键点作为独立单元处理,未能充分利用人体骨架结构中固有的关键点间组合关系。
合成火柴人数据集
基于组合编码的VQ-VAE
class MLPBlock(nn.Module): def __init__(self, dim, inter_dim, dropout_ratio): super().__init__()
self.ff = nn.Sequential(
nn.Linear(dim, inter_dim),
nn.GELU(),
nn.Dropout(dropout_ratio),
nn.Linear(inter_dim, dim),
nn.Dropout(dropout_ratio)
)def forward(self, x):
return self.ff(x)class MixerLayer(nn.Module):
def init(self,
hidden_dim,
hidden_inter_dim,
token_dim,
token_inter_dim,
dropout_ratio):
super().init()self.layernorm1 = nn.LayerNorm(hidden_dim)
self.MLP_token = MLPBlock(token_dim, token_inter_dim, dropout_ratio)
self.layernorm2 = nn.LayerNorm(hidden_dim)
self.MLP_channel = MLPBlock(hidden_dim, hidden_inter_dim, dropout_ratio)
def forward(self, x):
y = self.layernorm1(x)
y = y.transpose(2, 1)
y = self.MLP_token(y)
y = y.transpose(2, 1)
z = self.layernorm2(x + y)
z = self.MLP_channel(z)
out = x + y + z
return out
class CompositionalEncoder(nn.Module): def __init__(self, numberOfKeypoints=11, dimensionOfKeypoints=2, linearProjectionSize=128, numberOfMixerBlocks=4, codebookTokenDimension=64, internalMixerSize=64, internalMixerTokenSize=32, mixerDropout=0.1): super(CompositionalEncoder, self).__init__()
self.numberOfKeypoints = numberOfKeypoints # K
self.dimensionOfKeypoints = dimensionOfKeypoints # D
self.linearProjectionSize = linearProjectionSize # H
self.numberOfMixerBlocks = numberOfMixerBlocks # N
self.codebookTokenDimension = codebookTokenDimension # M
self.internalMixerSize = internalMixerSize
self.internalMixerTokenSize = internalMixerTokenSize
self.mixerDropout = mixerDropoutself.initial_linear = nn.Linear(self.dimensionOfKeypoints,
self.linearProjectionSize) # 从BxKxD投影到BxKxHself.mixer_layers = nn.ModuleList([MixerLayer(self.linearProjectionSize,
self.internalMixerSize,
self.numberOfKeypoints,
self.internalMixerTokenSize,
self.mixerDropout) for _ in range(self.numberOfMixerBlocks)]) # BxKxHself.mixer_layer_norm = nn.LayerNorm(self.linearProjectionSize) # BxKxH
self.token_linear = nn.Linear(self.numberOfKeypoints,
self.codebookTokenDimension) # BxHxK -> BxHxMself.feature_embed = nn.Linear(self.linearProjectionSize,
self.codebookTokenDimension)def forward(self, x):
之前: BxDxK
x = x.transpose(2,1)
之后: BxKxD
之前: BxKxD
x = self.initial_linear(x)
之后: BxKxH
之前: BxKxH
for mixer in self.mixer_layers:
x = mixer(x)之后: BxKxH
之前: BxKxH
x = self.mixer_layer_norm(x)
之后: BxKxH
之前: BxKxH
x = x.transpose(2,1)
之后: BxHxK
之前: BxHxK
x = self.token_linear(x)
之后: BxHxM
之前: BxHxM
x = x.transpose(2,1)
之后: BxMxH
之前: BxMxH
x = self.feature_embed(x)
之后: BXMxM
return x
class CodebookVQ(nn.Module): def __init__(self, codebookDimension, numberOfCodebookTokens, decay=0.99, epsilon=1e-5): super(CodebookVQ, self).__init__()
self.codebookDimension = codebookDimension
self.numberOfCodebookTokens = numberOfCodebookTokens
self.decay = decay
self.epsilon = epsilonself.register_buffer(‘codebook’, torch.empty(numberOfCodebookTokens, codebookDimension))
self.codebook.data.normal_()self.register_buffer(‘ema_cluster_size’, torch.zeros(numberOfCodebookTokens))
self.register_buffer(‘ema_w’, torch.empty(numberOfCodebookTokens, codebookDimension))
self.ema_w.data.normal_()def forward(self, encode_feat):
M = encode_feat.shape[1]
B = encode_feat.shape[0]
encode_feat = encode_feat.view(-1, self.codebookDimension) # [B*M, M]计算与码本条目的距离
distances = (
encode_feat.pow(2).sum(1, keepdim=True)
- 2 * encode_feat @ self.codebook.t()
- self.codebook.pow(2).sum(1)
) # [B*M, num_tokens]找到最近的码本索引
encoding_indices = torch.argmin(distances, dim=1) # [BM]
encodings = F.one_hot(encoding_indices, self.numberOfCodebookTokens).type(encode_feat.dtype) # [BM, num_tokens]量化输出
quantized = encodings @ self.codebook # [B*M, M]
quantized = quantized.view_as(encode_feat) # 重塑回原始输入形状if self.training:
EMA更新
ema_counts = encodings.sum(0) # [num_tokens]
dw = encodings.t() @ encode_feat # [num_tokens, M]self.ema_cluster_size.mul_(self.decay).add_(ema_counts, alpha=1 - self.decay)
self.ema_w.mul_(self.decay).add_(dw, alpha=1 - self.decay)n = self.ema_cluster_size.sum()
cluster_size = (
(self.ema_cluster_size + self.epsilon)
/ (n + self.numberOfCodebookTokens * self.epsilon) * n
)self.codebook.data = self.ema_w / cluster_size.unsqueeze(1)
quantized = quantized.view(B, M, M)
encoding_indices = encoding_indices.view(B, M)
return quantized, encoding_indices
-
码本包含num_codes个代码向量条目
-
每个输入标记根据L2距离独立选择最相近的码本向量
-
码本在训练过程中通过EMA机制进行自我更新,确保码本适应训练数据分布
class PoseDecoder(nn.Module): def __init__(self, codebookTokenDimension=64, numberOfKeypoints=11, keypointDimension=2, hiddenDimensionSize=128, numberOfMixerBlocks=4, mixerInternalDimensionSize=64, mixerTokenInternalDimensionSize=128, mixerDropout=0.1): super(PoseDecoder, self).__init__()
self.codebookTokenDimension = codebookTokenDimension
self.numberOfKeypoints = numberOfKeypoints
self.keypointDimension = keypointDimension
self.hiddenDimensionSize = hiddenDimensionSize
self.mixerInternalDimensionSize = mixerInternalDimensionSize
self.mixerTokenInternalDimensionSize = mixerTokenInternalDimensionSize
self.mixerDropout = mixerDropout
self.numberOfMixerBlocks = numberOfMixerBlocksself.linear_token = nn.Linear(self.codebookTokenDimension, self.numberOfKeypoints)
self.initial_linear = nn.Linear(self.codebookTokenDimension, self.hiddenDimensionSize)self.mixer_layers = nn.ModuleList([MixerLayer(self.hiddenDimensionSize, self.mixerInternalDimensionSize, self.numberOfKeypoints, self.mixerTokenInternalDimensionSize, self.mixerDropout) for _ in range(self.numberOfMixerBlocks)])
self.decoder_layer_norm = nn.LayerNorm(self.hiddenDimensionSize)
self.recover_embed = nn.Linear(self.hiddenDimensionSize, self.keypointDimension)
def forward(self, x):
之前: BxMxM
x = self.linear_token(x)
之后: BxMxK
之前: BxMxK
x = x.transpose(2,1)
之后: BxKxM
之前: BxKxM
x = self.initial_linear(x)
之后: BxKxH
之前: BxKxH
for mixer in self.mixer_layers:
x = mixer(x)之后: BxKxH
之前: BxKxH
x = self.decoder_layer_norm(x)
之后: BxKxH
之前: BxKxH
x = self.recover_embed(x)
之后: BxKxD
之后: BxKxD
x = x.transpose(2,1)
之后: BxDxK
return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False— 数据集 —
dataset = StickFigureDataset(
num_samples=10000,
image_size=64,
core_radius=1,
limb_radius=5
)
loader = DataLoader(dataset, batch_size=64, shuffle=True)numberOfKeypoints = 13
dimensionOfKeypoints = 2
linearProjectionSize = 256
numberOfMixerBlocks = 16
codebookTokenDimension = 64
internalMixerSize = 64
internalMixerTokenSize = 32
mixerDropout = 0.1encoder = CompositionalEncoder(numberOfKeypoints=numberOfKeypoints,
dimensionOfKeypoints=dimensionOfKeypoints,
linearProjectionSize=linearProjectionSize,
numberOfMixerBlocks=numberOfMixerBlocks,
codebookTokenDimension=codebookTokenDimension,
internalMixerSize=internalMixerSize,
internalMixerTokenSize=internalMixerTokenSize,
mixerDropout=mixerDropout).to(device)
codebook = CodebookVQ(codebookDimension=codebookTokenDimension,
numberOfCodebookTokens=codebookTokenDimension,
decay=0.99,
epsilon=1e-5).to(device)
decoder = PoseDecoder(codebookTokenDimension=codebookTokenDimension,
numberOfKeypoints=numberOfKeypoints,
keypointDimension=dimensionOfKeypoints,
hiddenDimensionSize=linearProjectionSize,
numberOfMixerBlocks=numberOfMixerBlocks,
mixerInternalDimensionSize=internalMixerSize,
mixerTokenInternalDimensionSize=internalMixerTokenSize,
mixerDropout=mixerDropout).to(device)optimizer = torch.optim.Adam(
list(encoder.parameters()) +
list(decoder.parameters()),
lr=1e-4
)encoder.train()
codebook.train()
decoder.train()num_epochs = 20
beta = 0.25for epoch in range(num_epochs):
epoch_loss = 0.0
num_batches = 0for imgs, gt_keypoints in loader:
keypoints = gt_keypoints.permute(0, 2, 1).to(device)optimizer.zero_grad()
token_feats = encoder(keypoints)
quantized, _ = codebook(token_feats)
reconstructed = decoder(quantized)recon_loss = F.smooth_l1_loss(reconstructed, keypoints)
commitment_loss = F.mse_loss(quantized.detach(), token_feats)
loss = recon_loss + beta * commitment_lossloss.backward()
optimizer.step()epoch_loss += loss.item()
num_batches += 1
avg_loss = epoch_loss / num_batches
print(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {avg_loss:.4f}")
Epoch 1/20 - Average Loss: 18.5585 Epoch 2/20 - Average Loss: 14.4645 Epoch 3/20 - Average Loss: 11.6697 Epoch 4/20 - Average Loss: 9.7948 Epoch 5/20 - Average Loss: 8.3735 Epoch 6/20 - Average Loss: 7.2084 Epoch 7/20 - Average Loss: 6.5090 Epoch 8/20 - Average Loss: 6.0753 Epoch 9/20 - Average Loss: 5.6844 Epoch 10/20 - Average Loss: 5.4609 Epoch 11/20 - Average Loss: 5.3141 Epoch 12/20 - Average Loss: 5.2014 Epoch 13/20 - Average Loss: 5.1606 Epoch 14/20 - Average Loss: 5.1018 Epoch 15/20 - Average Loss: 5.1005 Epoch 16/20 - Average Loss: 5.0874 Epoch 17/20 - Average Loss: 5.0735 Epoch 18/20 - Average Loss: 5.0267 Epoch 19/20 - Average Loss: 5.0190 Epoch 20/20 - Average Loss: 5.0247
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False— 数据集 —
dataset = StickFigureDataset(
num_samples=10000,
image_size=64,
core_radius=1,
limb_radius=5
)
loader = DataLoader(dataset, batch_size=64, shuffle=True)numberOfKeypoints = 13
dimensionOfKeypoints = 2
linearProjectionSize = 256
numberOfMixerBlocks = 16
codebookTokenDimension = 64
internalMixerSize = 64
internalMixerTokenSize = 32
mixerDropout = 0.1encoder = CompositionalEncoder(numberOfKeypoints=numberOfKeypoints, dimensionOfKeypoints=dimensionOfKeypoints, linearProjectionSize=linearProjectionSize, numberOfMixerBlocks=numberOfMixerBlocks, codebookTokenDimension=codebookTokenDimension, internalMixerSize=internalMixerSize, internalMixerTokenSize=internalMixerTokenSize, mixerDropout=mixerDropout).to(device)
codebook = CodebookVQ(codebookDimension=codebookTokenDimension, numberOfCodebookTokens=codebookTokenDimension, decay=0.99, epsilon=1e-5).to(device)
decoder = PoseDecoder(codebookTokenDimension=codebookTokenDimension, numberOfKeypoints=numberOfKeypoints, keypointDimension=dimensionOfKeypoints, hiddenDimensionSize=linearProjectionSize, numberOfMixerBlocks=numberOfMixerBlocks, mixerInternalDimensionSize=internalMixerSize, mixerTokenInternalDimensionSize=internalMixerTokenSize, mixerDropout=mixerDropout).to(device)optimizer = torch.optim.Adam(
list(encoder.parameters()) +
list(decoder.parameters()),
lr=1e-4
)encoder.train()
codebook.train()
decoder.train()num_epochs = 20
beta = 0.25skipQuantization = True
print(“Encoder pretraining”)
for epoch in range(num_epochs):
epoch_loss = 0.0
num_batches = 0for imgs, gt_keypoints in loader:
keypoints = gt_keypoints.permute(0, 2, 1).to(device) # [B, 2, 11]optimizer.zero_grad()
token_feats = encoder(keypoints) # (B, M, M)
reconstructed = decoder(token_feats) # (B, K, D)loss = F.smooth_l1_loss(reconstructed, keypoints)
loss.backward()
optimizer.step()epoch_loss += loss.item()
num_batches += 1avg_loss = epoch_loss / num_batches
print(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {avg_loss:.4f}")初始训练后冻结编码器
for param in encoder.parameters():
param.requires_grad = False重置解码器(这一行已经做到了)
decoder = PoseDecoder(codebookTokenDimension=codebookTokenDimension, numberOfKeypoints=numberOfKeypoints, keypointDimension=dimensionOfKeypoints, hiddenDimensionSize=linearProjectionSize, numberOfMixerBlocks=numberOfMixerBlocks, mixerInternalDimensionSize=internalMixerSize, mixerTokenInternalDimensionSize=internalMixerTokenSize, mixerDropout=mixerDropout).to(device)
更新优化器,只包括解码器(如有需要,可选择包括码本)
optimizer = torch.optim.Adam(
list(decoder.parameters()),
lr=1e-4
)print(“Codebook and Decoder training”)
for epoch in range(num_epochs):
epoch_loss = 0.0
num_batches = 0for imgs, gt_keypoints in loader:
keypoints = gt_keypoints.permute(0, 2, 1).to(device)optimizer.zero_grad()
token_feats = encoder(keypoints) # (B, M, M)
quantized, _ = codebook(token_feats) # (B, M, M)
reconstructed = decoder(quantized) # (B, K, D)recon_loss = F.smooth_l1_loss(reconstructed, keypoints)
commitment_loss = F.mse_loss(quantized.detach(), token_feats)
loss = recon_loss + beta * commitment_lossloss.backward()
optimizer.step()epoch_loss += loss.item()
num_batches += 1
avg_loss = epoch_loss / num_batches
print(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {avg_loss:.4f}")
Encoder pretraining Epoch 1/20 - Average Loss: 18.5178 Epoch 2/20 - Average Loss: 14.1350 Epoch 3/20 - Average Loss: 10.7014 Epoch 4/20 - Average Loss: 8.3755 Epoch 5/20 - Average Loss: 6.5254 Epoch 6/20 - Average Loss: 4.8045 Epoch 7/20 - Average Loss: 3.7144 Epoch 8/20 - Average Loss: 2.9114 Epoch 9/20 - Average Loss: 2.2571 Epoch 10/20 - Average Loss: 1.7662 Epoch 11/20 - Average Loss: 1.4548 Epoch 12/20 - Average Loss: 1.2346 Epoch 13/20 - Average Loss: 1.0853 Epoch 14/20 - Average Loss: 0.9722 Epoch 15/20 - Average Loss: 0.9048 Epoch 16/20 - Average Loss: 0.8413 Epoch 17/20 - Average Loss: 0.7932 Epoch 18/20 - Average Loss: 0.7520 Epoch 19/20 - Average Loss: 0.7124 Epoch 20/20 - Average Loss: 0.6845
Codebook and Decoder training
Epoch 1/20 - Average Loss: 18.7635
Epoch 2/20 - Average Loss: 14.3140
Epoch 3/20 - Average Loss: 10.7931
Epoch 4/20 - Average Loss: 8.4169
Epoch 5/20 - Average Loss: 6.4152
Epoch 6/20 - Average Loss: 4.8894
Epoch 7/20 - Average Loss: 3.9022
Epoch 8/20 - Average Loss: 3.1704
Epoch 9/20 - Average Loss: 2.6313
Epoch 10/20 - Average Loss: 2.1175
Epoch 11/20 - Average Loss: 1.8104
Epoch 12/20 - Average Loss: 1.6105
Epoch 13/20 - Average Loss: 1.4768
Epoch 14/20 - Average Loss: 1.3906
Epoch 15/20 - Average Loss: 1.3409
Epoch 16/20 - Average Loss: 1.2982
Epoch 17/20 - Average Loss: 1.2638
Epoch 18/20 - Average Loss: 1.2331
Epoch 19/20 - Average Loss: 1.2075
Epoch 20/20 - Average Loss: 1.1834