728x90
◾ squeeze 함수
squeeze 함수는 Tensor의 차원을 줄이는 함수로, 설정한 차원을 제거 해 준다.
따로 차원을 설정하지 않으면 1인 차원을 모두 제거한다. 1인 차원이 여러개 있어도 여러개 전부 다! 제거한다.
import torch
x = torch.rand(1,7,46,46)
print(x.shape) # torch.Size([1, 7, 46, 46])
x = x.squeeze(dim=1)
print(x.shape) # torch.Size([7, 46, 46])
한가지 조심해야 할 것은 batch size가 1일 때 squeeze 함수를 사용하게 되면 batch 차원을 없애버려 validation 시 오류가 발생하게 된다. 이걸 간과하고 있어서 학습할 때 validation에서 계속 오류가 났는데.. 이게 원인이었다..
squeeze(dim='제거하고자 하는 차원') 함수 사용 시 dim 뒤에 제거하고자 하는 차원을 지정 해 주면 된다. dim=을 적지 않은 경우 1인 차원을 제거하게 된다.
◾ unsqueeze 함수
반대로 unsqueeze 함수는 squeeze 함수와 반대로 차원을 늘려주는 함수인데, 1인 차원을 생성한다. unsqueeze 함수를 사용할 땐 어느 차원에 1인 차원을 생성을 할 것인지를 지정 해 주어야만 한다.
import torch
x = torch.rand(7,46,46)
print(x.shape) # torch.Size([7, 46, 46])
x = x.unsqueeze(dim=2)
print(x.shape) # torch.Size([7, 46, 1, 46])
728x90