Super Kawaii Cute Cat Kaoani Supervised Contrastive Learning 코드 분석

수업정리/딥러닝 이론

Supervised Contrastive Learning 코드 분석

치킨고양이짱아 2024. 5. 28. 19:41
728x90
728x90

supervised contrastive learning을 하기 위해 공개된 코드를 사용해야할 일이 생겼다. 아무리 가져온 코드라 하더라도 이해하지 않고 사용하는건 말이 안되는 것 같아, 해당 코드를 분석하는 포스트를 작성하고자 한다.

🍀 코드 출처: https://ffighting.net/deep-learning-paper-review/self-supervised-learning/supervised-contrastive-learning/

 

Supervised Contrastive Learning - 딥러닝 논문 리뷰

Supervised Contrastive Learning 논문의 핵심 내용을 리뷰합니다. Supervised Contrastive Learning의 제안 방법을 살펴봅니다. 마지막으로 성능 비교 실험을 통해 효과를 확인합니다.

ffighting.net

 

기존에 사용하던 Self-supervised contrastive learning에서 사용하던 loss식은 다음과 같다.

Self-supervised Contrastive Loss는 주어진 앵커 샘플 z_i와 양성 샘플 z_j 사이의 유사도를 최대화하면서, 앵커 샘플 z_i와 다른 모든 샘플 (주로 음성 샘플) 사이의 유사도를 최소한다. 이를 통해 모델은 양성 샘플을 가까이, 음성 샘플을 멀리 두도록 학습한다.

log 안쪽의 식을 보면, 양성 샘플의 유사도 지수 함수 값을 모든 샘플의 유사도 지수 함수 합으로 나누어, 양성 샘플이 전체 샘플 중에서 상대적으로 얼마나 유사한지를 측정한다. 온도 파라미터로 유사도를 조정하는 역할을 한다. 값이 작을수록 유사도 차이가 더 뚜렷하게 나타난다.

이때, 나의 Augmentation 된 data만 Positive Pair로 취급하기 때문에, 나와 다른 data는 같은 클래스에 속하더라도 Negative Pair로 인식하게 되는 한계가 있었다.

 

Supervised contrasive learning에서 사용하는 loss식은 다음과 같다.

식이 거의 self-supervised contrasitve loss와 유사하면서도 조금씩 다르다. 차이를 말하자면, self-supervised contrastive loss가 자신의 augmentation data만 postive pair로 보고 해당 data와의 유사도만 높이고 나머지와의 유사도는 낮추는 식이였다면, 위의 식은 같은 클래스에 속하는 data는 모두 postive pair로 보고, 해당 data와의 유사도는 높이고, 다른 클래스에 속하는 data와의 유사도는 낮추는 식이다.

 

loss식을 어떻게 사용하면 되는지 알아보았으니, 이제 코드를 살펴보자.

전체 코드는 다음과 같다.

class SupervisedContrastiveLoss(torch.nn.Module):
    def __init__(self, temperature=0.1):
        super(SupervisedContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, embeddings, labels):
        # Normalize the embeddings
        embeddings = F.normalize(embeddings, p=2, dim=-1)

        # Compute the similarity matrix
        sim_matrix = torch.matmul(embeddings, embeddings.T)

        # Create the positive mask
        labels = labels.unsqueeze(1)
        positive_mask = torch.eq(labels, labels.T).float()

        # Create the negative mask
        negative_mask = torch.ne(labels, labels.T).float()

        # Compute the positive and negative similarity
        sim_ij = torch.exp(sim_matrix / self.temperature)
        exp_sim_matrix = torch.exp(sim_matrix / self.temperature)

        # Remove diagonal elements
        ind = torch.eye(labels.size(0)).bool().to(device)
        sim_ij.masked_fill_(ind, 0)
        exp_sim_matrix.masked_fill_(ind, 0)

        # Compute the loss
        sum_exp_sim_matrix = torch.sum(exp_sim_matrix, dim=1)
        pos_exp_sim_matrix = torch.sum(sim_ij * positive_mask, dim=1)
        loss = -torch.log(pos_exp_sim_matrix / sum_exp_sim_matrix)

        return torch.Mean(loss)

 

forward 함수의 코드를 하나씩 살펴보자.

def forward(self, embeddings, labels):
    # Normalize the embeddings
    embeddings = F.normalize(embeddings, p=2, dim=-1)
    
    # Compute the similarity matrix
    sim_matrix = torch.matmul(embeddings, embeddings.T)

forward 함수는 각 data에 대한 embdding 값들과 해당 data들에 대한 label 값들을 input으로 받는다.

sim_matrix는 embedding 값들끼리의 similarity를 계산해놓은 similarity matrix를 의미한다.

    # Create the positive mask
    labels = labels.unsqueeze(1)
    positive_mask = torch.eq(labels, labels.T).float()
    
    # Create the negative mask
    negative_mask = torch.ne(labels, labels.T).float()

여기서는 서로의 label이 같은지를 표현해주는 postive_mask와 negative_mask를 계산하고 있다.

    # Compute the positive and negative similarity
    sim_ij = torch.exp(sim_matrix / self.temperature)
    exp_sim_matrix = torch.exp(sim_matrix / self.temperature)

여기서는 label에 상관없이 모든 쌍들끼리의 similarity를 측정하고 있다.

    # Remove diagonal elements
    ind = torch.eye(labels.size(0)).bool().to(device)
    sim_ij.masked_fill_(ind, 0)
    exp_sim_matrix.masked_fill_(ind, 0)

sim_ij와 exp_sim_matrix의 대각 부분은 자기 자신과의 유사도를 의미하기 때문에 항상 높게 나타나므로 해당 부분을 제거해야한다. 여기서는 대각 부분만 False로 세팅된 matrix인 ind를 사용해서 sim_ij의 exp_sim_matrix의 대각 부분을 0으로 만들어주고 있다.

    # Compute the loss
    sum_exp_sim_matrix = torch.sum(exp_sim_matrix, dim=1)
    pos_exp_sim_matrix = torch.sum(sim_ij * positive_mask, dim=1)
    loss = -torch.log(pos_exp_sim_matrix / sum_exp_sim_matrix)
    
    return torch.Mean(loss)

sum_exp_sim_matrix는 모든 similarity의 합을 계산하고 있다.

pos_exp_sim_matrix는 서로 같은 class에 속하는 샘플끼리의 similarity끼리만 합을 계산하고 있다.

이렇게 계산한 값들을 사용해서, loss식에서 본 것처럼 최종 loss값을 계산하게 된다.

728x90
728x90