Super Kawaii Cute Cat Kaoani [PyTorch] tensor 확장하기: torch.expand vs torch.repeat vs torch.repeat_interleave

연구/PyTorch

[PyTorch] tensor 확장하기: torch.expand vs torch.repeat vs torch.repeat_interleave

치킨고양이짱아 2024. 2. 21. 16:28
728x90
728x90

pytorch tensor를 사용하다보면 dimension을 늘려줘야하는 상황이 많이 발생한다. 이때 많이 사용하는 함수가 torch.expand, torch.repeat, torch.repeat_interleave이다. 상황마다 쓰기 편리한 함수들이 있는데 매번 까먹어서;; 내가 보려고 작성하는 비교 글이다.

1) torch.expand(*size)

torch.expand 함수는 개수가 1인 차원에 대해서만 확장이 가능하며, desired size를 input으로 받는다. 

>>> x = torch.tensor([[1], [2], [3]])
>>> x.size()
torch.Size([3, 1])
>>> x.expand(3, 4)
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3]])
>>> x.expand(-1, 4) # -1 means not changing the size of that dimension
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3]])

위와 같이 개수가 1인 차원에 대해서는 잘 확장이 된다.

>>> x.expand(12, 1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: The expanded size of the tensor (12) must match the existing size (3) at non-singleton dimension 0.  Target sizes: [12, 1].  Tensor sizes: [3, 1]

만약 개수가 1이 아닌 차원에 대해서 확장을 하려고 하면 위와 같은 에러메세지가 뜨게 된다.

 

2) torch.repeat(*size)

특정 텐서를 sizes 차원만큼 반복한다.

>>> x = torch.tensor([[1], [2], [3]])
>>> x.size()
torch.Size([3, 1])
>>> x.repeat(4, 2)
tensor([[1, 1],
        [2, 2],
        [3, 3],
        [1, 1],
        [2, 2],
        [3, 3],
        [1, 1],
        [2, 2],
        [3, 3],
        [1, 1],
        [2, 2],
        [3, 3]])
        
>>> x.repeat(4, 2).size()
torch.Size([12, 2])

>>> y = torch.tensor([[1, 2], [2, 3], [3, 4]])
>>> y.size()
torch.Size([3, 2])
>>> y.repeat(1, 3)
tensor([[1, 2, 1, 2, 1, 2],
        [2, 3, 2, 3, 2, 3],
        [3, 4, 3, 4, 3, 4]])

x의 shape는 (3, 1)로, dim=0으로 4, dim=1으로 2만큼 반복하니 (12,2)의 shape를 가지는 tensor가 만들어진다. y의 shape는 (3, 2)로 dim = 0으로 1, dim =1으로 3만큼 반복하니 (3, 6)의 shape를 가지는 tensor가 만들어진다.

이때 1-d의 텐서에 대해 함수를 적용시키면, tensor의 shape를 [n]이 아닌 [1, n]으로 간주한다.

>>> x = torch.tensor([1, 2, 3])
>>> x.size()
torch.Size([3])
>>> x.repeat(3, 4)
tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3]])
>>> x.repeat(3, 4).size()
torch.Size([3, 12])

위의 경우에서 x의 shape는 (3,)이지만 repeat 함수를 사용할 때는 (1, 3)으로 간주한다. 여기서 dim=0으로 3, dim=1으로 4만큼 반복한. (3, 12)의 shape를 가지는 tensor가 만들어진다.

 

3) torch.repeat_interleave(input, repeats, dim=None)

원하는 행별로 반복할 수 있는 함수. input텐서(input), 반복횟수(repeats), 확장을 원하는 행(dim)을 parameter로 받는다.

>>> x = torch.tensor([[1, 2], [2, 3], [3, 4]])
>>> x.size()
torch.Size([3, 2])
>>> y = torch.repeat_interleave(x, 2, dim = 0)
>>> y
tensor([[1, 2],
        [1, 2],
        [2, 3],
        [2, 3],
        [3, 4],
        [3, 4]])
>>> y.size()
torch.Size([6, 2])
>>> z = torch.repeat_interleave(x, 3, dim = 1)
>>> z
tensor([[1, 1, 1, 2, 2, 2],
        [2, 2, 2, 3, 3, 3],
        [3, 3, 3, 4, 4, 4]])

이렇게 원하는 축에 대해서 반복시키는 것이 가능하다.

**** 주의

repeat 함수와 repeat_interleave가 만들어내는 결과물이 좀 다르다. 주의해서 사용하도록 하자.

>>> y = torch.tensor([[1, 2], [2, 3], [3, 4]])
>>> y.repeat(1, 3)
tensor([[1, 2, 1, 2, 1, 2],
        [2, 3, 2, 3, 2, 3],
        [3, 4, 3, 4, 3, 4]])
        
>>> torch.repeat_interleave(y, 3, dim = 1)
tensor([[1, 1, 1, 2, 2, 2],
        [2, 2, 2, 3, 3, 3],
        [3, 3, 3, 4, 4, 4]])

이렇게 만들어지는 tensor의 모양이 다르다. repeat의 경우 1 2 3을 반복할 때 1 2 3 1 2 3 <- 이런식으로 반복한다면, repeat_interleave의 경우 1 2 3을 반복할 때 1 1 2 2 3 3 <- 이런식으로 반복한다.

위의 차이점을 숙지하여 의도한대로 tensor를 잘 확장하여 사용해보자~!!

728x90
728x90