Character Controllers Using Motion VAEs의 conditional VAE 부분을 기존에도 구현해서 사용하고 있었는데...다시 보니까 빠뜨린게 왜이렇게 많냐ㅠ 다시 확실하게 정리할 필요가 있을 것 같아서 코드 분석 & 논문에서 필요한 내용을 가져와 정리할 예정이다.
전체 코드는 아래의 링크에서 볼 수 있다.
https://github.com/electronicarts/character-motion-vaes/tree/main
아래의 코드는 논문에서 사용하는 PoseVAE 코드이다.
class PoseMixtureVAE(nn.Module):
def __init__(
self,
frame_size,
latent_size,
num_condition_frames,
num_future_predictions,
normalization,
num_experts,
):
super().__init__()
self.frame_size = frame_size
self.latent_size = latent_size
self.num_condition_frames = num_condition_frames
self.num_future_predictions = num_future_predictions
self.mode = normalization.get("mode")
self.data_max = normalization.get("max")
self.data_min = normalization.get("min")
self.data_avg = normalization.get("avg")
self.data_std = normalization.get("std")
hidden_size = 256
args = (
frame_size,
latent_size,
hidden_size,
num_condition_frames,
num_future_predictions,
)
self.encoder = Encoder(*args)
self.decoder = MixedDecoder(*args, num_experts)
def normalize(self, t):
if self.mode == "minmax":
return 2 * (t - self.data_min) / (self.data_max - self.data_min) - 1
elif self.mode == "zscore":
return (t - self.data_avg) / self.data_std
elif self.mode == "none":
return t
else:
raise ValueError("Unknown normalization mode")
def denormalize(self, t):
if self.mode == "minmax":
return (t + 1) * (self.data_max - self.data_min) / 2 + self.data_min
elif self.mode == "zscore":
return t * self.data_std + self.data_avg
elif self.mode == "none":
return t
else:
raise ValueError("Unknown normalization mode")
def encode(self, x, c):
_, mu, logvar = self.encoder(x, c)
return mu, logvar
def forward(self, x, c):
z, mu, logvar = self.encoder(x, c)
return self.decoder(z, c), mu, logvar
def sample(self, z, c, deterministic=False):
return self.decoder(z, c)
encode 함수의 역할
def encode(self, x, c):
_, mu, logvar = self.encoder(x, c)
return mu, logvar
encode 함수는 Input data x와 c를 latent space에 투영한다. 이때 x는 현재 프레임의 포즈, C에는 과거 프레임의 포즈가 들어간다.
latent space애서 데이터 분포를 정의하기 위해 encode는 평균과 분산의 로그 값을 출력하게 되는데, 추후에 이 두 값은 latent space에서 데이터 포인트를 샘플링하기 위한 파라미터로 사용된다.
Encoder 클래스의 forward 함수 및 encode 함수를 살펴보면, 아래와 같다.
def forward(self, x, c):
mu, logvar = self.encode(x, c)
z = self.reparameterize(mu, logvar)
return z, mu, logvar
def encode(self, x, c):
h1 = F.elu(self.fc1(torch.cat((x, c), dim=1)))
h2 = F.elu(self.fc2(torch.cat((x, h1), dim=1)))
s = torch.cat((x, h2), dim=1)
return self.mu(s), self.logvar(s)
encode 함수에서 매 layer마다 현재 프레임을 넘겨주는 부분이 특징적이다.
그리고 mu 값과 logvar 값을 출력할 때도 그냥 출력하는 것이 아니라 각각 layer를 통과시켜서 출력시킨다. 위의 코드에서 self.mu와 self.logvar 모두 아래와 같이 Linear layer를 의미하고 있다.
self.mu = nn.Linear(frame_size + hidden_size, latent_size)
self.logvar = nn.Linear(frame_size + hidden_size, latent_size)
이렇게 출력된 mu값과 logvar를 사용해 forward 함수 내부에서 reparameterize를 시키고 있다. reparameterize는 mu값과 logvar값을 사용하여 잠재공간상에서 샘플링을 수행하고, 샘플링된 값 z를 return한다.
forward 함수 분석
위에서 encode 함수의 역할에 대해 살펴보았다. encode 함수는 latent space 상에 mapping된 data의 평균과 분산 값을 구하고 이 값을 바탕으로 sampling을 진행한다. z가 latent space 상에서 샘플링된 값, mu가 latent space 상에서 데이터의 평균, logvar가 latent space 상에서 분산에 log를 취한 값이다.
def forward(self, x, c):
z, mu, logvar = self.encoder(x, c)
return self.decoder(z, c), mu, logvar
forward 함수는 mu와 logvar는 그대로 출력하고 z와 c를 decoder를 통과시킨 값을 출력하고 있다. decoder는 latent space 상에 mapping 된 z값을 다시 원본 데이터 형태로 복구하는 역할을 하며, Mixture of Experts (MoE) 구조로 되어있다. decoder 부분 코드는 복잡하여 첨부하지 않았지만, posterior collapse를 방지하기 위해 매 layer마다 latent vector를 넘겨주었다고 하니 이 부분에 주의해야할 것 같다.
Training loss 분석
아래의 코드는 VAE를 training 시키기 위해 loss를 계산하는 부분의 코드이다.
VAE를 training 시키기 위해 총 2개의 loss를 사용하게 된다. -> 1) kl divergence loss와 2) reconstruction loss
vae_output, mu, logvar = pose_vae(flattened_truth, condition)
vae_output = vae_output.view(output_shape)
kl_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum().clamp(max=0)
kl_loss /= logvar.numel()
recon_loss = (vae_output - ground_truth).pow(2).mean(dim=(0, -1))
recon_loss = recon_loss.mul(future_weights).sum()
return (vae_output, mu, logvar), (recon_loss, kl_loss)
1) kl divergence loss
kl_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum().clamp(max=0)
위의 식은 latent space 상에서의 분포와 표준 정규 분포 사이의 차이를 의미하며 해당 loss를 통해 latent space의 분포가 정규 분포에 가깝도록 만든다. 이때 clamp 함수를 통해 kl_loss 값이 음수가 되지 않도록 한다.
2) reconstruction loss
reconstruction loss는 decoder를 통과한 데이터가 얼마나 원본 데이터를 잘 reconstruct 했는지 측정하는 loss이다.
이제 좀 정리가 된 느낌~