Super Kawaii Cute Cat Kaoani Scheduled sampling 기법에 대한 분석 (Character Controller Using Motion VAEs)

연구/논문 리뷰

Scheduled sampling 기법에 대한 분석 (Character Controller Using Motion VAEs)

치킨고양이짱아 2023. 10. 7. 18:52
728x90
728x90

scheduled sampling에 대한 개념은 어느정도 이해를 하였으나, 구현을 하려고 보니 막상 디테일적인 부분에서 막히는게 많았다... 그래서 코드까지 보면서 좀 더 자세히 분석을 해보려고 한다.

(이 글은 Character Controller Using Motion VAEs 논문을 참고하여 본 글을 작성하였다.)

 

1) 각 Roll out마다 prediction length 정하기

scheduled sampling을 구현하기 위해서는 prediction length L를 각 roll-out마다 정하여야한다.  각 Roll-out에서 motion capture database로부터 start frame을 random하게 sample하고 L step동안  pose prediction을 진행한다. L은 실제 런타임 분포를 시뮬레이션할 수 있도록 충분히 커야한다. (Character Controller using Motion VAE 논문에서는 L = 8 (1/4초 분량) 을 사용하였다.)

 

2) end-of-clip problem

각 Roll out에서 motion capture database로부터 start frame을 random하게 sample한다고 언급했었다. 이 frame으로부터 L만큼의 subsequent frame이 연속적으로 필요하게 된다. 따라서 L만큼의 subsequent frame이 존재하는 frame만 sampling해야한다. (즉, clip의 너무 끝 지점을 start frame으로 sampling하면 안된다!!)

위의 1), 2) 디테일은 논문에 정리되어 있는 내용이고, 이것만으로 구현을 하기엔 부족한 면이 있는 것 같아 github의 코드도 분석해보았다.

 

3) 코드 분석하면서 알게된 내용

코드 분석한 내용을 쓰기에 앞서, 내가 코드를 분석하면서 알게된 내용은 다음과 같다. 아래의 내용이 이해된다면 코드를 하나하나 분석한 내용은 굳이 읽을 필요가 없을듯하다.

* 적당한 L step size를 정해서, random하게 frame을 sample하고 나면 L step만큼은 연속적인 pose update를 해나간다. (p값에 상관없이 진행하는 것같다.) L step이 끝나면 또 sampler가 새롭게 frame들을 sampling한다.

* 모델의 input으로 과거 History 정보도 들어가게 될텐데, 이 History 정보는 dataset의 ground truth 값으로 초기화한다. 만약 use_student가 true인 경우(Model의 결과를 사용하는 경우) L step동안 model의 연산결과로 나오는 값들을 사용해 구성한 pose로 history를 점차 채워간다.

 

4) 코드 분석 내용

* 각 Epoch에서 사용할 확률 p 세팅

sample_schedule = torch.cat(
    (
    # First part is pure teacher forcing
    torch.zeros(teacher_epochs),
    # Second part with schedule sampling
    torch.linspace(0.0, 1.0, ramping_epochs),
    # last part is pure student
    torch.ones(student_epochs),
    )
)

첫번째 단계에서는 p를 모두 0으로 세팅하였고, 두번째 단계에서는 p를 0에서 1로 선형적으로 증가하도록 세팅하였고, 세번째 단계에서는 p를 모두 1로 세팅하였다.

 

* 학습 코드

shape = (args.mini_batch_size, args.num_condition_frames, frame_size)
history = torch.empty(shape).to(args.device)
for ep in range(1, args.num_epochs + 1):
        sampler = BatchSampler(
            SubsetRandomSampler(selectable_indices),
            args.mini_batch_size,
            drop_last=True,
        )

Sampler를 다음과 같이 세팅해주었다. selectable_indices는 end-of-clip 부분을 제외한 index들을 담고 있는 list인듯 하다.

 

        ep_recon_loss = 0
        ep_kl_loss = 0
        ep_perplexity = 0

        update_linear_schedule(
            vae_optimizer, ep - 1, args.num_epochs, args.initial_lr, args.final_lr
        )

update_linear_schedule 함수는 optimizer의 leraning rate를 감소시킨다.

 

        num_mini_batch = 1
        for num_mini_batch, indices in enumerate(sampler):
            t_indices = torch.LongTensor(indices)

            # condition is from newest...oldest, i.e. (t-1, t-2, ... t-n)
            condition_range = (
                t_indices.repeat((args.num_condition_frames, 1)).t()
                + torch.arange(args.num_condition_frames - 1, -1, -1).long()
            )

t_indices는 sampler에 의해 sampling된 index들을 의미한다.

그리고 num_condition_frames는 과거 frame을 몇개나 참조하는지 의미하는듯하다.  (Character Controller using Motion VAE 논문에서는 과거 frame은 하나만 참조하므로 num_condition_frames는 1로 세팅하였다.)

만약 indices가 [5, 10, 15]로 세팅되었고, args.num_condition_frames가 3이라 가정하면,

t_indices

[ 5, 10, 15]

t_indices.repeat((args.num_condition_frames, 1)).t()

[[ 5,  5,  5],
 [10, 10, 10],
 [15, 15, 15]]

torch.arange(args.num_doncition_frames, -1, -1, -1).long()

[2, 1, 0]

condition_range의 결과는

[[ 5+2,  5+1,  5+0],
 [10+2, 10+1, 10+0],
 [15+2, 15+1, 15+0]]
[[ 7,  6,  5],
 [12, 11, 10],
 [17, 16, 15]]

 

            t_indices += args.num_condition_frames
            history[:, : args.num_condition_frames].copy_(mocap_data[condition_range])

위의 예시에 이어 결과값들을 예상해보면 t_indices에 args.num_condition_frame(3) 를 더해주었으므로 t_indices는 [8, 13, 18]이 된다.

현재의 timestep은 각각 8, 13, 18이고 과거 3 step의 data가 history에 저장된다.

  • 즉 8번째 data의 args.num_condition_frame(3) 만큼의 과거 data인 7번째 data, 6번째 data, 5번째 data가 history에 저장되고
  • 13번째 data의 args.num_condition_frame(3) 만큼의 과거 data인 12,11,10번째 data가 history에 저장되고
  • 18번째 data의 과거 data인 17, 16, 15번째 data가 저장된다.

 

            for offset in range(args.num_steps_per_rollout):
                # dims: (num_parallel, num_window, feature_size)
                use_student = torch.rand(1) < sample_schedule[ep - 1]

                prediction_range = (
                    t_indices.repeat((args.num_future_predictions, 1)).t()
                    + torch.arange(offset, offset + args.num_future_predictions).long()
                )
                ground_truth = mocap_data[prediction_range]
                condition = history[:, : args.num_condition_frames]

여기서 말하는 num_step_per_rollout이 논문에서 설명한 L이다.

use_student는 model의 결과를 사용할지 여부를 의미한다.

prediction_range는 현재 구성하고자 하는 frame의 index를 의미하며 첫번째 loop에서는 [[8], [13], [18]], 두번째 loop에서는 [[9], [14], [19]] ... 로 점차 증가해나간다.

 

                if isinstance(pose_vae, PoseVQVAE):
                    (vae_output, perplexity), (recon_loss, kl_loss) = feed_vae(
                        pose_vae, ground_truth, condition, future_weights
                    )
                    ep_perplexity += float(perplexity) / args.num_steps_per_rollout
                else:
                    # PoseVAE, PoseMixtureVAE, PoseMixtureSpecialistVAE
                    (vae_output, _, _), (recon_loss, kl_loss) = feed_vae(
                        pose_vae, ground_truth, condition, future_weights
                    )

이렇게 구성한 Input들을 사용해 model의 연산을 진행한다.

 

                history = history.roll(1, dims=1)
                next_frame = vae_output[:, 0] if use_student else ground_truth[:, 0]
                history[:, 0].copy_(next_frame.detach())

model의 연산을 사용하는 경우(use_student가 True인 경우) history에 model의 결과로 나온 pose를 넣어서 update하고, model의 연산을 사용하지 않는 경우(use student가 False인 경우) history에 ground_truth의 pose를 넣어서 update한다.

 

                vae_optimizer.zero_grad()
                (recon_loss + args.kl_beta * kl_loss).backward()
                vae_optimizer.step()

                ep_recon_loss += float(recon_loss) / args.num_steps_per_rollout
                ep_kl_loss += float(kl_loss) / args.num_steps_per_rollout

        avg_ep_recon_loss = ep_recon_loss / num_mini_batch
        avg_ep_kl_loss = ep_kl_loss / num_mini_batch
        avg_ep_perplexity = ep_perplexity / num_mini_batch

        logger.log_stats(
            {
                "epoch": ep,
                "ep_recon_loss": avg_ep_recon_loss,
                "ep_kl_loss": avg_ep_kl_loss,
                "ep_perplexity": avg_ep_perplexity,
            }
        )

나머지 코드 부분은 계산한 Loss를 사용해서 model을 update하는 부분이다.


코드를 쪼개서 하나씩 살펴봤는데 전체 코드는 다음과 같다.

for ep in range(1, args.num_epochs + 1):
        sampler = BatchSampler(
            SubsetRandomSampler(selectable_indices),
            args.mini_batch_size,
            drop_last=True,
        )
        ep_recon_loss = 0
        ep_kl_loss = 0
        ep_perplexity = 0

        update_linear_schedule(
            vae_optimizer, ep - 1, args.num_epochs, args.initial_lr, args.final_lr
        )

        num_mini_batch = 1
        for num_mini_batch, indices in enumerate(sampler):
            t_indices = torch.LongTensor(indices)

            # condition is from newest...oldest, i.e. (t-1, t-2, ... t-n)
            condition_range = (
                t_indices.repeat((args.num_condition_frames, 1)).t()
                + torch.arange(args.num_condition_frames - 1, -1, -1).long()
            )

            t_indices += args.num_condition_frames
            history[:, : args.num_condition_frames].copy_(mocap_data[condition_range])

            for offset in range(args.num_steps_per_rollout):
                # dims: (num_parallel, num_window, feature_size)
                use_student = torch.rand(1) < sample_schedule[ep - 1]

                prediction_range = (
                    t_indices.repeat((args.num_future_predictions, 1)).t()
                    + torch.arange(offset, offset + args.num_future_predictions).long()
                )
                ground_truth = mocap_data[prediction_range]
                condition = history[:, : args.num_condition_frames]

                if isinstance(pose_vae, PoseVQVAE):
                    (vae_output, perplexity), (recon_loss, kl_loss) = feed_vae(
                        pose_vae, ground_truth, condition, future_weights
                    )
                    ep_perplexity += float(perplexity) / args.num_steps_per_rollout
                else:
                    # PoseVAE, PoseMixtureVAE, PoseMixtureSpecialistVAE
                    (vae_output, _, _), (recon_loss, kl_loss) = feed_vae(
                        pose_vae, ground_truth, condition, future_weights
                    )

                history = history.roll(1, dims=1)
                next_frame = vae_output[:, 0] if use_student else ground_truth[:, 0]
                history[:, 0].copy_(next_frame.detach())

                vae_optimizer.zero_grad()
                (recon_loss + args.kl_beta * kl_loss).backward()
                vae_optimizer.step()

                ep_recon_loss += float(recon_loss) / args.num_steps_per_rollout
                ep_kl_loss += float(kl_loss) / args.num_steps_per_rollout

        avg_ep_recon_loss = ep_recon_loss / num_mini_batch
        avg_ep_kl_loss = ep_kl_loss / num_mini_batch
        avg_ep_perplexity = ep_perplexity / num_mini_batch

        logger.log_stats(
            {
                "epoch": ep,
                "ep_recon_loss": avg_ep_recon_loss,
                "ep_kl_loss": avg_ep_kl_loss,
                "ep_perplexity": avg_ep_perplexity,
            }
        )
728x90
728x90