pytorch tensor를 사용하다보면 dimension을 늘려줘야하는 상황이 많이 발생한다. 이때 많이 사용하는 함수가 torch.expand, torch.repeat, torch.repeat_interleave이다. 상황마다 쓰기 편리한 함수들이 있는데 매번 까먹어서;; 내가 보려고 작성하는 비교 글이다.
1) torch.expand(*size)
torch.expand 함수는 개수가 1인 차원에 대해서만 확장이 가능하며, desired size를 input으로 받는다.
>>> x = torch.tensor([[1], [2], [3]])
>>> x.size()
torch.Size([3, 1])
>>> x.expand(3, 4)
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]])
>>> x.expand(-1, 4) # -1 means not changing the size of that dimension
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]])
위와 같이 개수가 1인 차원에 대해서는 잘 확장이 된다.
>>> x.expand(12, 1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: The expanded size of the tensor (12) must match the existing size (3) at non-singleton dimension 0. Target sizes: [12, 1]. Tensor sizes: [3, 1]
만약 개수가 1이 아닌 차원에 대해서 확장을 하려고 하면 위와 같은 에러메세지가 뜨게 된다.
2) torch.repeat(*size)
특정 텐서를 sizes 차원만큼 반복한다.
>>> x = torch.tensor([[1], [2], [3]])
>>> x.size()
torch.Size([3, 1])
>>> x.repeat(4, 2)
tensor([[1, 1],
[2, 2],
[3, 3],
[1, 1],
[2, 2],
[3, 3],
[1, 1],
[2, 2],
[3, 3],
[1, 1],
[2, 2],
[3, 3]])
>>> x.repeat(4, 2).size()
torch.Size([12, 2])
>>> y = torch.tensor([[1, 2], [2, 3], [3, 4]])
>>> y.size()
torch.Size([3, 2])
>>> y.repeat(1, 3)
tensor([[1, 2, 1, 2, 1, 2],
[2, 3, 2, 3, 2, 3],
[3, 4, 3, 4, 3, 4]])
x의 shape는 (3, 1)로, dim=0으로 4, dim=1으로 2만큼 반복하니 (12,2)의 shape를 가지는 tensor가 만들어진다. y의 shape는 (3, 2)로 dim = 0으로 1, dim =1으로 3만큼 반복하니 (3, 6)의 shape를 가지는 tensor가 만들어진다.
이때 1-d의 텐서에 대해 함수를 적용시키면, tensor의 shape를 [n]이 아닌 [1, n]으로 간주한다.
>>> x = torch.tensor([1, 2, 3])
>>> x.size()
torch.Size([3])
>>> x.repeat(3, 4)
tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3]])
>>> x.repeat(3, 4).size()
torch.Size([3, 12])
위의 경우에서 x의 shape는 (3,)이지만 repeat 함수를 사용할 때는 (1, 3)으로 간주한다. 여기서 dim=0으로 3, dim=1으로 4만큼 반복한. (3, 12)의 shape를 가지는 tensor가 만들어진다.
3) torch.repeat_interleave(input, repeats, dim=None)
원하는 행별로 반복할 수 있는 함수. input텐서(input), 반복횟수(repeats), 확장을 원하는 행(dim)을 parameter로 받는다.
>>> x = torch.tensor([[1, 2], [2, 3], [3, 4]])
>>> x.size()
torch.Size([3, 2])
>>> y = torch.repeat_interleave(x, 2, dim = 0)
>>> y
tensor([[1, 2],
[1, 2],
[2, 3],
[2, 3],
[3, 4],
[3, 4]])
>>> y.size()
torch.Size([6, 2])
>>> z = torch.repeat_interleave(x, 3, dim = 1)
>>> z
tensor([[1, 1, 1, 2, 2, 2],
[2, 2, 2, 3, 3, 3],
[3, 3, 3, 4, 4, 4]])
이렇게 원하는 축에 대해서 반복시키는 것이 가능하다.
**** 주의
repeat 함수와 repeat_interleave가 만들어내는 결과물이 좀 다르다. 주의해서 사용하도록 하자.
>>> y = torch.tensor([[1, 2], [2, 3], [3, 4]])
>>> y.repeat(1, 3)
tensor([[1, 2, 1, 2, 1, 2],
[2, 3, 2, 3, 2, 3],
[3, 4, 3, 4, 3, 4]])
>>> torch.repeat_interleave(y, 3, dim = 1)
tensor([[1, 1, 1, 2, 2, 2],
[2, 2, 2, 3, 3, 3],
[3, 3, 3, 4, 4, 4]])
이렇게 만들어지는 tensor의 모양이 다르다. repeat의 경우 1 2 3을 반복할 때 1 2 3 1 2 3 <- 이런식으로 반복한다면, repeat_interleave의 경우 1 2 3을 반복할 때 1 1 2 2 3 3 <- 이런식으로 반복한다.
위의 차이점을 숙지하여 의도한대로 tensor를 잘 확장하여 사용해보자~!!
'연구 > PyTorch' 카테고리의 다른 글
[Pytorch] Dataloader와 Sampler (0) | 2023.10.07 |
---|---|
[PyTorch] Tensor list to Tensor (0) | 2023.09.21 |
[PyTorch] Tensor 조작법 기본) indexing, view, squeeze, unsqueeze (0) | 2023.09.17 |
PyTorch로 AutoEncoder 구현하기 (0) | 2023.08.16 |
PyTorch DataLoader 사용하기 & epoch, batch, iteration 개념 (0) | 2023.08.03 |