카테고리 없음
[PyTorch] index_select, gather
SweetDev
2022. 1. 24. 15:06
index_select
>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
[-0.4664, 0.2647, -0.1228, -1.1068],
[-1.1734, -0.6571, 0.7230, -0.6004]])
>>> indices = torch.tensor([0, 2])
>>> torch.index_select(x, 0, indices)
tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
[-1.1734, -0.6571, 0.7230, -0.6004]])
>>> torch.index_select(x, 1, indices)
tensor([[ 0.1427, -0.5414],
[-0.4664, -0.1228],
[-1.1734, 0.7230]])
tensor을 넣어야하지만 꼭 범위여야 하는건 아니다.(스칼라 값도 가능)
축에 관해서는
0 = 가로
1 = 세로
그 이상은 직접!
gather
https://runebook.dev/ko/docs/pytorch/generated/torch.index_select
https://pytorch.org/docs/stable/generated/torch.index_select.html?highlight=index#torch.index_select