MachineLearning
[PyTorch] nn.Linear
SweetDev
2022. 2. 7. 13:43
nn.Linear은 행렬과 같다.
nn.Linear을 통해서 행렬을 곱해서 tensor의 size()를 바꿀 수 있다.
import torch
from torch import nn
X = torch.Tensor([[1, 2],
[3, 4]])
# TODO : tensor X의 크기는 (2, 2)입니다
# nn.Linear를 사용하여서 (2, 5)로 크기를 바꾸고 이 크기를 출력하세요!
m = torch.nn.Linear(2, 5)
output = m(X)
print(output.size())