MachineLearning

[PyTorch] nn.Linear

SweetDev 2022. 2. 7. 13:43

nn.Linear은 행렬과 같다. 

nn.Linear을 통해서 행렬을 곱해서 tensor의 size()를 바꿀 수 있다. 

 

import torch
from torch import nn

X = torch.Tensor([[1, 2],
                  [3, 4]])

# TODO : tensor X의 크기는 (2, 2)입니다
#        nn.Linear를 사용하여서 (2, 5)로 크기를 바꾸고 이 크기를 출력하세요!

m = torch.nn.Linear(2, 5)
output = m(X)
print(output.size())