Super Kawaii Cute Cat Kaoani [코드분석] Taming Diffusion Probabilistic Models for Character Control

연구/논문 리뷰

[코드분석] Taming Diffusion Probabilistic Models for Character Control

치킨고양이짱아 2024. 9. 3. 16:55
728x90
728x90

1. Motion Diffusion class

모델 코드는 생각보다 간단해서 이해가 쉽다.

class MotionDiffusion(nn.Module):
    def __init__(self, input_feats, nstyles, njoints, nfeats, rot_req, clip_len,
                 latent_dim=256, ff_size=1024, num_layers=8, num_heads=4, dropout=0.2,
                 ablation=None, activation="gelu", legacy=False, 
                 arch='trans_enc', cond_mask_prob=0, device=None):
        super().__init__()

        self.legacy = legacy
        self.training = True
        
        self.rot_req = rot_req
        self.nfeats = nfeats
        self.njoints = njoints
        self.clip_len = clip_len
        self.input_feats = input_feats

        self.latent_dim = latent_dim
        self.ff_size = ff_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dropout = dropout
        self.ablation = ablation
        self.activation = activation
        self.cond_mask_prob = cond_mask_prob
        self.arch = arch
        
        self.future_motion_process = MotionProcess(self.input_feats, self.latent_dim)
        self.past_motion_process = MotionProcess(self.input_feats, self.latent_dim)
        self.traj_trans_process = TrajProcess(2, self.latent_dim)
        self.traj_pose_process = TrajProcess(6, self.latent_dim)
        self.style_feature_process = nn.Linear(512, self.latent_dim)
        self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)
        self.embed_style = EmbedStyle(nstyles, self.latent_dim)
        self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder)

        if self.arch == 'trans_enc':
            print("TRANS_ENC init")
            
            seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
                                                              nhead=self.num_heads,
                                                              dim_feedforward=self.ff_size,
                                                              dropout=self.dropout,
                                                              activation=self.activation)

            self.seqEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
                                                         num_layers=self.num_layers)
        elif self.arch == 'trans_dec':
            print("TRANS_DEC init")
            seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=self.latent_dim,
                                                              nhead=self.num_heads,
                                                              dim_feedforward=self.ff_size,
                                                              dropout=self.dropout,
                                                              activation=activation)
            self.seqEncoder = nn.TransformerDecoder(seqTransDecoderLayer,
                                                         num_layers=self.num_layers)

        elif self.arch == 'gru':
            print("GRU init")
            self.seqEncoder = nn.GRU(self.latent_dim, self.latent_dim, num_layers=self.num_layers, batch_first=True)
        else:
            raise ValueError('Please choose correct architecture [trans_enc, trans_dec, gru]')
      
        self.output_process = OutputProcess(self.input_feats, self.latent_dim, self.njoints, self.nfeats)

핵심 모듈:

  • MotionProcess: 모션 데이터를 잠재 공간에 임베딩.
  • TrajProcess: 경로 데이터를 임베딩 처리.
  • PositionalEncoding: 시퀀스 데이터에 위치 정보를 추가, 이는 모델이 시퀀스 내 프레임의 순서를 이해하는 데 도움을 줌.
  • EmbedStyle: 모션 스타일을 잠재 공간에 임베딩.
  • TimestepEmbedder: 타임스텝을 잠재 공간에 임베딩.
  • seqEncoder: 아키텍처에 따라 Transformer 인코더, 디코더 또는 GRU를 선택하여 시퀀스를 처리.
  • OutputProcess: 처리된 잠재 표현을 다시 모션 데이터로 변환.
    def forward(self, x, timesteps, past_motion, traj_pose, traj_trans, style_idx):
        bs, njoints, nfeats, nframes = x.shape
        
        time_emb = self.embed_timestep(timesteps)  # [1, bs, L]
        style_emb = self.embed_style(style_idx).unsqueeze(0)  # [1, bs, L]
        traj_trans_emb = self.traj_trans_process(traj_trans) # [N/2, bs, L] 
        traj_pose_emb = self.traj_pose_process(traj_pose) # [N/2, bs, L] 
        past_motion_emb = self.past_motion_process(past_motion)  # [past_frames, bs, L] 
        
        future_motion_emb = self.future_motion_process(x) 
        
        xseq = torch.cat((time_emb, style_emb, 
                          traj_trans_emb, traj_pose_emb,
                          past_motion_emb, future_motion_emb), axis=0)
        
        xseq = self.sequence_pos_encoder(xseq)
        output = self.seqEncoder(xseq)[-nframes:] 
        output = self.output_process(output)  
        return output

 

  • 모션 데이터(x), 타임스텝, 과거 모션, 경로 포즈, 경로 변환, 스타일 인덱스가 입력으로 제공된다. 여기서 타임스텝은 디퓨전의 몇번째 step인지를 의미하는 값이다.
  • 각 입력은 처리되어 잠재 공간에 임베딩되고
  • 임베딩된 시퀀스는 지정된 시퀀스 모델(Transformer 또는 GRU)에 의해 처리(실제로 논문에서 사용하는건 Transformer Encoder인듯)

 

2. TrainingPortal

Model 코드보다 Training 시키는 코드가 좀 더 복잡한데... 자세히 살펴보자.

2-1. BaseTrainingPortal

class BaseTrainingPortal:
    def __init__(self, config, model, diffusion, dataloader, logger, tb_writer, prior_loader=None):
        
        self.model = model
        self.diffusion = diffusion
        self.dataloader = dataloader
        self.logger = logger
        self.tb_writer = tb_writer
        self.config = config
        self.batch_size = config.trainer.batch_size
        self.lr = config.trainer.lr
        self.lr_anneal_steps = config.trainer.lr_anneal_steps

        self.epoch = 0
        self.num_epochs = config.trainer.epoch
        self.save_freq = config.trainer.save_freq
        self.best_loss = 1e10
        
        print('Train with %d epoches, %d batches by %d batch_size' % (self.num_epochs, len(self.dataloader), self.batch_size))

        self.save_dir = config.save

        self.opt = AdamW(self.model.parameters(), lr=self.lr, weight_decay=config.trainer.weight_decay)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=self.num_epochs, eta_min=self.lr * 0.1)
        
        if config.trainer.ema:
            self.ema = ExponentialMovingAverage(self.model.parameters(), decay=0.995)
        
        self.device = config.device

        self.schedule_sampler_type = 'uniform'
        self.schedule_sampler = create_named_schedule_sampler(self.schedule_sampler_type, diffusion)
        self.use_ddp = False
        
        self.prior_loader = prior_loader
        
        
    def diffuse(self, x_start, t, cond, noise=None, return_loss=False):
        raise NotImplementedError('diffuse function must be implemented')

    def evaluate_sampling(self, dataloader, save_folder_name):
        raise NotImplementedError('evaluate_sampling function must be implemented')
    
        
    def run_loop(self):
        sampling_num = 16
        sampling_idx = np.random.randint(0, len(self.dataloader.dataset), sampling_num)
        sampling_subset = DataLoader(Subset(self.dataloader.dataset, sampling_idx), batch_size=sampling_num)
        self.evaluate_sampling(sampling_subset, save_folder_name='init_samples')
        
        epoch_process_bar = tqdm(range(self.epoch, self.num_epochs), desc=f'Epoch {self.epoch}')
        for epoch_idx in epoch_process_bar:
            self.model.train()
            self.model.training = True
            self.epoch = epoch_idx
            epoch_losses = {}
            
            data_len = len(self.dataloader)
            
            for datas in self.dataloader:
                datas = {key: val.to(self.device) if torch.is_tensor(val) else val for key, val in datas.items()}
                cond = {key: val.to(self.device) if torch.is_tensor(val) else val for key, val in datas['conditions'].items()}
                x_start = datas['data']

                self.opt.zero_grad()
                t, weights = self.schedule_sampler.sample(x_start.shape[0], self.device)
                
                _, losses = self.diffuse(x_start, t, cond, noise=None, return_loss=True)
                total_loss = (losses["loss"] * weights).mean()
                total_loss.backward()
                self.opt.step()
            
                if self.config.trainer.ema:
                    self.ema.update()
                
                for key_name in losses.keys():
                    if 'loss' in key_name:
                        if key_name not in epoch_losses.keys():
                            epoch_losses[key_name] = []
                        epoch_losses[key_name].append(losses[key_name].mean().item())
            
            if self.prior_loader is not None:
                for prior_datas in itertools.islice(self.prior_loader, data_len):
                    prior_datas = {key: val.to(self.device) if torch.is_tensor(val) else val for key, val in prior_datas.items()}
                    prior_cond = {key: val.to(self.device) if torch.is_tensor(val) else val for key, val in prior_datas['conditions'].items()}
                    prior_x_start = prior_datas['data']
                    
                    self.opt.zero_grad()
                    t, weights = self.schedule_sampler.sample(prior_x_start.shape[0], self.device)
                    
                    _, prior_losses = self.diffuse(prior_x_start, t, prior_cond, noise=None, return_loss=True)
                    total_loss = (prior_losses["loss"] * weights).mean()
                    total_loss.backward()
                    self.opt.step()
                    
                    for key_name in prior_losses.keys():
                        if 'loss' in key_name:
                            if key_name not in epoch_losses.keys():
                                epoch_losses[key_name] = []
                            epoch_losses[key_name].append(prior_losses[key_name].mean().item())
            
            loss_str = ''
            for key in epoch_losses.keys():
                loss_str += f'{key}: {np.mean(epoch_losses[key]):.6f}, '
            
            epoch_avg_loss = np.mean(epoch_losses['loss'])
            
            if self.epoch > 10 and epoch_avg_loss < self.best_loss:                
                self.save_checkpoint(filename='best')
            
            if epoch_avg_loss < self.best_loss:
                self.best_loss = epoch_avg_loss
            
            epoch_process_bar.set_description(f'Epoch {epoch_idx}/{self.config.trainer.epoch} | loss: {epoch_avg_loss:.6f} | best_loss: {self.best_loss:.6f}')
            self.logger.info(f'Epoch {epoch_idx}/{self.config.trainer.epoch} | {loss_str} | best_loss: {self.best_loss:.6f}')
                        
            if epoch_idx > 0 and epoch_idx % self.config.trainer.save_freq == 0:
                self.save_checkpoint(filename=f'weights_{epoch_idx}')
                self.evaluate_sampling(sampling_subset, save_folder_name='train_samples')
            
            for key_name in epoch_losses.keys():
                if 'loss' in key_name:
                    self.tb_writer.add_scalar(f'train/{key_name}', np.mean(epoch_losses[key_name]), epoch_idx)

            self.scheduler.step()
        
        best_path = '%s/best.pt' % (self.config.save)
        self.load_checkpoint(best_path)
        self.evaluate_sampling(sampling_subset, save_folder_name='best')


    def state_dict(self):
        model_state = self.model.state_dict()
        opt_state = self.opt.state_dict()
            
        return {
            'epoch': self.epoch,
            'state_dict': model_state,
            'opt_state_dict': opt_state,
            'config': self.config,
            'loss': self.best_loss,
        }

    def save_checkpoint(self, filename='weights'):
        save_path = '%s/%s.pt' % (self.config.save, filename)
        with bf.BlobFile(bf.join(save_path), "wb") as f:
            torch.save(self.state_dict(), f)
        self.logger.info(f'Saved checkpoint: {save_path}')


    def load_checkpoint(self, resume_checkpoint, load_hyper=True):
        if bf.exists(resume_checkpoint):
            checkpoint = torch.load(resume_checkpoint)
            self.model.load_state_dict(checkpoint['state_dict'])
            if load_hyper:
                self.epoch = checkpoint['epoch'] + 1
                self.best_loss = checkpoint['loss']
                self.opt.load_state_dict(checkpoint['opt_state_dict'])
            self.logger.info('\nLoad checkpoint from %s, start at epoch %d, loss: %.4f' % (resume_checkpoint, self.epoch, checkpoint['loss']))
        else:
            raise FileNotFoundError(f'No checkpoint found at {resume_checkpoint}')

 

  • 이 클래스는 학습과 관련된 기본 기능을 제공
  • 초기화(__init__): 모델, 디퓨전 과정, 데이터로더, 로깅 도구 등을 초기화합니다. 옵티마이저, 학습률 스케줄러, EMA(지수 이동 평균) 등 학습에 필요한 다양한 도구도 초기화
  • run_loop: 학습 루프를 실행하는 메서드로, 에포크(epoch)별로 데이터를 학습시킨다. 데이터로더에서 배치 단위로 데이터를 가져와 모델을 학습시키며, 학습 손실(loss)을 계산하고 모델 파라미터를 업데이트한다. 좀 중요한 부분이라 아래에서 코드를 조금 더 자세히 살펴보자.
  • diffuseevaluate_sampling: 추상 메서드로, 자식 클래스에서 구체적으로 구현해야 한다. diffuse는 디퓨전 과정을 구현하고, evaluate_sampling은 모델의 샘플링 성능을 평가하는 메서드이다.
  • save_checkpointload_checkpoint: 학습된 모델의 체크포인트를 저장하고 불러오는 기능을 한다.

아래의 함수는 run_loop 의 코드이다.

    def run_loop(self):
        sampling_num = 16
        sampling_idx = np.random.randint(0, len(self.dataloader.dataset), sampling_num)
        sampling_subset = DataLoader(Subset(self.dataloader.dataset, sampling_idx), batch_size=sampling_num)
        self.evaluate_sampling(sampling_subset, save_folder_name='init_samples')

위의 코드에서 sampling num과 smapling_index는 데이터셋의 서브셋 sampling_subset을 생성하여 학습 평가시에 사용하기 위해 사용되는 변수이다. evaluate_sampling 함수는 초기화된 모델에 대해 sampling_subset으로 샘플링을 수행하고, 결과를 저장하는 역할을 한다.

	epoch_process_bar = tqdm(range(self.epoch, self.num_epochs), desc=f'Epoch {self.epoch}')
        for epoch_idx in epoch_process_bar:
            self.model.train()
            self.model.training = True
            self.epoch = epoch_idx
            epoch_losses = {}
            
            data_len = len(self.dataloader)
            
            for datas in self.dataloader:
                datas = {key: val.to(self.device) if torch.is_tensor(val) else val for key, val in datas.items()}
                cond = {key: val.to(self.device) if torch.is_tensor(val) else val for key, val in datas['conditions'].items()}
                x_start = datas['data']

                self.opt.zero_grad()
                t, weights = self.schedule_sampler.sample(x_start.shape[0], self.device)
                
                _, losses = self.diffuse(x_start, t, cond, noise=None, return_loss=True)
                total_loss = (losses["loss"] * weights).mean()
                total_loss.backward()
                self.opt.step()

epoch를 반복하며 학습을 진행하게 되는데 이때 self.dataloader는 학습에 사용할 데이터셋을 배치 단위로 제공한다. datas는 데이터셋에서 추출된 현재 배치 데이터로, 입력 데이터(x_start), 조건부 정보(cond) 그리고 기타 추가 정보들이 포함된다.

그리고 self.schedule_sampler.sample 함수를 통해 디퓨전 타임스텝을 샘플링한다. x_start.shape[0]은 배치 사이즈로, 이 크기만큼의 타임스텝을 샘플링한다. 해당 함수가 timestep t와 함께 return하는 weight는 샘플링된 타임스텝에 대해 손실을 계산할 때 적용되는 가중치입니다. 이는 각 타임스텝이 학습에서 차지하는 중요도를 반영한다.

그런다음 self.diffuse 메서드를 호출하여 디퓨전 모델에서 타임스텝 t에 대응되는 노이즈가 추가된 데이터 x_t를 생성하고 모델의 출력을 얻은 뒤 이를 사용해 loss를 계산하는 역할을 한다. 여기서  return_loss 옵션은 diffuse메서드가 손실 값을 반환할지를 결정하는 메서드인데 여기서는 True로 세팅하여 학습에 loss를 활용하였다.

또한 샘플링된 타입스텝의 가중치를 loss에 곱한 후, 배치 전체에 대한 평균 손실을 계산한다. 이렇게 계싼된 손실은 모델 업데이트에 사용된다.

 

2-2. MotionTrainingPortal

다음음 BaseTrainingPortal 클래스를 상속받는 MotionTrainingPortal 클래스의 코드이다.

class MotionTrainingPortal(BaseTrainingPortal):
    def __init__(self, config, model, diffusion, dataloader, logger, tb_writer, finetune_loader=None):
        super().__init__(config, model, diffusion, dataloader, logger, tb_writer, finetune_loader)
        self.skel_offset = torch.from_numpy(self.dataloader.dataset.T_pose.offsets).to(self.device)
        self.skel_parents = self.dataloader.dataset.T_pose.parents
        

    def diffuse(self, x_start, t, cond, noise=None, return_loss=False):
        batch_size, frame_num, joint_num, joint_feat = x_start.shape
        x_start = x_start.permute(0, 2, 3, 1)
        
        if noise is None:
            noise = th.randn_like(x_start)
        
        x_t = self.diffusion.q_sample(x_start, t, noise=noise)
        
        # [bs, joint_num, joint_feat, future_frames]
        cond['past_motion'] = cond['past_motion'].permute(0, 2, 3, 1) # [bs, joint_num, joint_feat, past_frames]
        cond['traj_pose'] = cond['traj_pose'].permute(0, 2, 1) # [bs, 6, frame_num//2]
        cond['traj_trans'] = cond['traj_trans'].permute(0, 2, 1) # [bs, 2, frame_num//2]
        
        model_output = self.model.interface(x_t, self.diffusion._scale_timesteps(t), cond)
        
        if return_loss:
            loss_terms = {}
            
            if self.diffusion.model_var_type in [ModelVarType.LEARNED,  ModelVarType.LEARNED_RANGE]:
                B, C = x_t.shape[:2]
                assert model_output.shape == (B, C * 2, *x_t.shape[2:])
                model_output, model_var_values = torch.split(model_output, C, dim=1)
                frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)
                loss_terms["vb"] = self.diffusion._vb_terms_bpd(model=lambda *args, r=frozen_out: r, x_start=x_start, x_t=x_t, t=t, clip_denoised=False)["output"]
                if self.loss_type == LossType.RESCALED_MSE:
                    loss_terms["vb"] *= self.diffusion.num_timesteps / 1000.0
            target = {
                ModelMeanType.PREVIOUS_X: self.diffusion.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0],
                ModelMeanType.START_X: x_start,
                ModelMeanType.EPSILON: noise,
            }[self.diffusion.model_mean_type]
            assert model_output.shape == target.shape == x_start.shape
            mask = cond['mask'].view(batch_size, 1, 1, -1)
            
            if self.config.trainer.use_loss_mse:
                loss_terms['loss_data'] = self.diffusion.masked_l2(target, model_output, mask) # mean_flat(rot_mse)
                
            if self.config.trainer.use_loss_vel:
                model_output_vel = model_output[..., 1:] - model_output[..., :-1]
                target_vel = target[..., 1:] - target[..., :-1]
                loss_terms['loss_data_vel'] = self.diffusion.masked_l2(target_vel[:, :-1], model_output_vel[:, :-1], mask[..., 1:])
                  
            if self.config.trainer.use_loss_3d or self.config.use_loss_contact:
                target_rot, pred_rot, past_rot = target.permute(0, 3, 1, 2), model_output.permute(0, 3, 1, 2), cond['past_motion'].permute(0, 3, 1, 2)
                target_root_pos, pred_root_pos, past_root_pos = target_rot[:, :, -1, :3], pred_rot[:, :, -1, :3], past_rot[:, :, -1, :3]
                skeletons = self.skel_offset.unsqueeze(0).expand(batch_size, -1, -1)
                parents = self.skel_parents[None]
                
                target_xyz = neural_FK(target_rot[:, :, :-1], skeletons, target_root_pos, parents, rotation_type=self.config.arch.rot_req)
                pred_xyz = neural_FK(pred_rot[:, :, :-1], skeletons, pred_root_pos, parents, rotation_type=self.config.arch.rot_req)
                
                if self.config.trainer.use_loss_3d:
                    loss_terms["loss_geo_xyz"] = self.diffusion.masked_l2(target_xyz.permute(0, 2, 3, 1), pred_xyz.permute(0, 2, 3, 1), mask)
                
                if self.config.trainer.use_loss_vel and self.config.trainer.use_loss_3d:
                    target_xyz_vel = target_xyz[:, 1:] - target_xyz[:, :-1]
                    pred_xyz_vel = pred_xyz[:, 1:] - pred_xyz[:, :-1]
                    loss_terms["loss_geo_xyz_vel"] = self.diffusion.masked_l2(target_xyz_vel.permute(0, 2, 3, 1), pred_xyz_vel.permute(0, 2, 3, 1), mask[..., 1:])
                
                if self.config.trainer.use_loss_contact:
                    l_foot_idx, r_foot_idx = 24, 19
                    relevant_joints = [l_foot_idx, r_foot_idx]
                    target_xyz_reshape = target_xyz.permute(0, 2, 3, 1)  
                    pred_xyz_reshape = pred_xyz.permute(0, 2, 3, 1)
                    gt_joint_xyz = target_xyz_reshape[:, relevant_joints, :, :]  # [BatchSize, 2, 3, Frames]
                    gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2)  # [BatchSize, 4, Frames]
                    fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1)
                    pred_joint_xyz = pred_xyz_reshape[:, relevant_joints, :, :]  # [BatchSize, 2, 3, Frames]
                    pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1]
                    pred_vel[~fc_mask] = 0
                    loss_terms["loss_foot_contact"] = self.diffusion.masked_l2(pred_vel,
                                                torch.zeros(pred_vel.shape, device=pred_vel.device),
                                                mask[:, :, :, 1:])
            
            loss_terms["loss"] = loss_terms.get('vb', 0.) + \
                            loss_terms.get('loss_data', 0.) + \
                            loss_terms.get('loss_data_vel', 0.) + \
                            loss_terms.get('loss_geo_xyz', 0) + \
                            loss_terms.get('loss_geo_xyz_vel', 0) + \
                            loss_terms.get('loss_foot_contact', 0)
            
            return model_output.permute(0, 3, 1, 2), loss_terms
        
        return model_output.permute(0, 3, 1, 2)
        
    
    def evaluate_sampling(self, dataloader, save_folder_name):
        self.model.eval()
        self.model.training = False
        common.mkdir('%s/%s' % (self.save_dir, save_folder_name))
        
        datas = next(iter(dataloader)) 
        datas = {key: val.to(self.device) if torch.is_tensor(val) else val for key, val in datas.items()}
        cond = {key: val.to(self.device) if torch.is_tensor(val) else val for key, val in datas['conditions'].items()}
        x_start = datas['data']
        t, _ = self.schedule_sampler.sample(dataloader.batch_size, self.device)
        with torch.no_grad():
            model_output = self.diffuse(x_start, t, cond, noise=None, return_loss=False)
        
        common_past_motion = cond['past_motion'].permute(0, 3, 1, 2)
        self.export_samples(x_start, common_past_motion, '%s/%s/' % (self.save_dir, save_folder_name), 'gt')
        self.export_samples(model_output, common_past_motion, '%s/%s/' % (self.save_dir, save_folder_name), 'pred')
        
        self.logger.info(f'Evaluate the sampling {save_folder_name} at epoch {self.epoch}')
        

    def export_samples(self, future_motion_feature, past_motion_feature, save_path, prefix):
        motion_feature = torch.cat((past_motion_feature, future_motion_feature), dim=1)
        rotations = nn_transforms.repr6d2quat(motion_feature[:, :, :-1]).cpu().numpy()
        root_pos = motion_feature[:, :, -1, :3].cpu().numpy()
        
        for samplie_idx in range(future_motion_feature.shape[0]):
            T_pose_template = self.dataloader.dataset.T_pose.copy()
            T_pose_template.rotations = rotations[samplie_idx]
            T_pose_template.positions = np.zeros((rotations[samplie_idx].shape[0], T_pose_template.positions.shape[1], T_pose_template.positions.shape[2]))
            T_pose_template.positions[:, 0] = root_pos[samplie_idx]
            T_pose_template.export(f'{save_path}/motion_{samplie_idx}.{prefix}.bvh', save_ori_scal=True)

 

 

가장 중요한 diffuse 매서드에 대해 조금 더 자세히 살펴보자.

diffuse 매서드의 역할은 모션 데이터를 받아 디퓨전 과정을 통해 데이터를 변형시키고, 모델의 출력을 생성하여 손실을 계산하는 역할을 한다. 해당 함수가 입력으로 받는 x_start는 입력 모션 데이터이며, t는 디퓨전 스텝을 나타내는 타입스텝, cond는 조건부 정보 등이다. (과거 모션, 경로 등..)

def diffuse(self, x_start, t, cond, noise=None, return_loss=False):
        batch_size, frame_num, joint_num, joint_feat = x_start.shape
        x_start = x_start.permute(0, 2, 3, 1)
        
        if noise is None:
            noise = th.randn_like(x_start)
        
        x_t = self.diffusion.q_sample(x_start, t, noise=noise)
        
        # [bs, joint_num, joint_feat, future_frames]
        cond['past_motion'] = cond['past_motion'].permute(0, 2, 3, 1) # [bs, joint_num, joint_feat, past_frames]
        cond['traj_pose'] = cond['traj_pose'].permute(0, 2, 1) # [bs, 6, frame_num//2]
        cond['traj_trans'] = cond['traj_trans'].permute(0, 2, 1) # [bs, 2, frame_num//2]
        
        model_output = self.model.interface(x_t, self.diffusion._scale_timesteps(t), cond)

 

디퓨전 과정을 살펴보면, 모션 데이터에 노이즈를 더해 x_t를 생성한다. 위의 코드에서 self.diffusion은 디퓨전 모델에서 사용하는 디퓨전 과정을 의미하며 q_sample함수는 data에 노이즈를 더 해주는 함수이다.

노이즈가 더해진 모션 데이터 x_t와 조건부 정보를 모델에 입력하여 출력을 얻을 수 있다. 이 출력을 사용해 손실을 계산하게 되는데 손실은 다양한 방법으로 계산되며, 이 과정에서 3D 위치, 속도, 발 접촉 정보 등을 고려한다.

 

 

728x90
728x90