파이토치 튜토리얼의 quickstart를 번역했습니다.
2024. 10. 21 최초작성
2024. 10. 26
다음 문서를 기반으로 작성되었습니다.
https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html
https://pytorch.org/tutorials/beginner/basics/transforms_tutorial.html
https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html
데이터와 함께 작업하기
파이토치(PyTorch)에는 데이터 작업을 위한 두 가지 핵심 요소가 있습니다. torch.utils.data.DataLoader와 torch.utils.data.Dataset입니다.
Dataset은 전체 데이터셋의 구조를 정의하며 개별 데이터 샘플과 해당 레이블에 접근하는 방법을 제공합니다. 실제 데이터를 보관하거나 데이터 소스를 참조합니다. 필요한 경우 개별 샘플에 대한 기본적인 전처리를 수행하기도 합니다.
DataLoader는 Dataset을 감싸서 이터러블(iterable)로 만듭니다. 이터러블은 for문을 사용하여 순차적으로 요소에 접근이 가능한 객체입니다. Dataset으로부터 데이터를 효율적으로 불러올 수 있게 됩니다. 데이터로더는 여러 샘플을 묶어 일정 개수의 샘플로 구성된 미니배치 형태로 데이터를 제공하며 필요에 따라 데이터를 뒤섞거나 정렬할 수 있습니다. 다중 워커를 통해 데이터 로딩을 병렬처리 할 수 있습니다. 필요한 데이터만 메모리에 로드할 수 있어 메모리 사용량을 줄이면서 대규모 데이터셋 처리를 가능하게 합니다.
import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import ToTensor |
파이토치는 TorchText, TorchVision, TorchAudio와 같은 분야별 라이브러리를 통해 데이터셋을 제공합니다. 본 튜토리얼에서는 TorchVision의 데이터셋을 사용합니다.
torchvision.datasets 모듈에는 CIFAR, COCO와 같은 다양한 실제 비전 데이터의 Dataset 객체가 포함되어 있습니다. 본 튜토리얼에서는 FashionMNIST 데이터셋을 사용합니다. Fashion-MNIST는 60,000개의 훈련 예제와 10,000개의 테스트 예제로 구성된 Zalando의 상품 이미지 데이터 세트입니다. 각 예제는 28×28 그레이스케일 이미지와 10개 클래스 중 하나의 관련 레이블로 구성됩니다.
모든 TorchVision 데이터셋에는 두 개의 중요한 매개변수가 있습니다: transform과 target_transform입니다. 이들은 각각 샘플과 레이블을 수정하는 데 사용됩니다.
transform: 이 매개변수는 입력 데이터 샘플(예: 이미지)을 변형하는 데 사용됩니다. 주로 다음과 같은 작업을 수행합니다:
이미지 크기 조정 텐서로 변환 정규화 데이터 증강(예: 회전, 반전, 색상 변경 등)
target_transform: 이 매개변수는 레이블(타겟) 데이터를 변형하는 데 사용됩니다. 주로 다음과 같은 작업을 수행합니다:
레이블 인코딩 변경(예: 원-핫 인코딩) 레이블 형식 변환(예: 문자열에서 정수로 변환)
이 두 매개변수를 사용하면 데이터셋의 샘플과 레이블을 모델 학습에 적합한 형태로 전처리할 수 있습니다. 변환은 데이터를 로드할 때 자동으로 적용되므로, 일관된 전처리를 보장하고 코드의 재사용성을 높일 수 있습니다.
데이터가 항상 머신러닝 알고리즘 학습에 필요한 최종 가공된 형태로 제공되는 것은 아닙니다. 트랜스폼을 사용하여 데이터를 일부 조작하여 학습에 적합하게 만드는 것이 필요합니다.
# 훈련 데이터를 다운로드합니다. training_data = datasets.FashionMNIST( root="data", # 훈련 데이터가 저장된 경로입니다, train=True, # 훈련 데이터임을 명시합니다. download=True, # root에서 데이터를 사용할 수 없는 경우 인터넷에서 데이터를 다운로드합니다. transform=ToTensor(), # PIL 이미지 또는 NumPy ndarray를 FloatTensor로 변환하고 이미지의 픽셀 밝기 값을 [0., 1.] 범위에서 스케일링합니다. ) # 테스트 데이터를 다운로드합니다. test_data = datasets.FashionMNIST( root="data", # 테스트 데이터가 저장된 경로입니다, train=False, # 테스트 데이터임을 명시합니다. download=True, # root에서 데이터를 사용할 수 없는 경우 인터넷에서 데이터를 다운로드합니다. transform=ToTensor(), # PIL 이미지 또는 NumPy ndarray를 FloatTensor로 변환하고 이미지의 픽셀 밝기 값을 [0., 1.] 범위에서 스케일링합니다. ) |
데어터셋을 데이터 로더의 아규먼트로 전달합니다. 데이터로더는 파이토치에서 데이터를 효율적으로 로드하고 전처리합니다.
- 데이터셋을 이터러블로 래핑하여 데이터를 쉽게 순회할 수 있게 합니다.
- 개별 샘플들을 지정된 크기의 배치로 그룹화합니다.
- 데이터셋에서 샘플을 추출하는 방식을 조정할 수 있습니다.
- 각 에포크마다 데이터 순서를 무작위로 섞어(셔플링) 모델의 일반화 성능을 향상시킵니다.
- 데이터 로딩시 여러 CPU 코어를 사용하여(멀티프로세스) 데이터 로딩 속도를 높입니다.
본 예제에서는 배치 크기를 64로 정의합니다. 이는 데이터로더의 이터러블이 64개의 샘플과 64개의 레이블로 구성된 배치를 반환한다는 의미입니다.
- 각 배치의 특징(feature) 텐서 형상: [64, C, H, W] (C, H, W는 각각 채널, 높이, 너비)
- 각 배치의 레이블 텐서 형상: [64]
batch_size = 64 # 데이터 로더를 생성시 배치 크기를 지정합니다. train_dataloader = DataLoader(training_data, batch_size=batch_size) test_dataloader = DataLoader(test_data, batch_size=batch_size) for X, y in test_dataloader: print(f"Shape of X [N, C, H, W]: {X.shape}") print(f"Shape of y: {y.shape} {y.dtype}") break |
Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64
모델 만들기
파이토치에서 신경망을 정의하려면 nn.Module에서 상속하는 클래스를 생성해야 합니다.
# 훈련을 하기 위해 gpu(cuda), mps, cpu 장치 중 하나를 사용하도록 device를 설정합니다. device = ( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) print(f"Using {device} device") # 모델을 정의합니다. class NeuralNetwork(nn.Module): # 네트워크의 레이어를 정의합니다. def __init__(self): super().__init__() # 2차원 28x28 이미지를 784픽셀 값의 1차원 배열로 변환하기 위해 nn.Flatten 레이어를 초기화합니다. 1차원 배열로 바뀐 후에도 배치 형태는 유지됩니다. self.flatten = nn.Flatten() # 모듈의 순차 컨테이너입니다. 데이터는 정의된 것과 동일한 순서로 모든 모듈을 통해 전달됩니다. self.linear_relu_stack = nn.Sequential( # 선형 레이어는 저장된 가중치와 바이어스를 사용하여 입력에 선형 변환을 적용하는 모듈입니다. nn.Linear(28*28, 512), # 비선형 활성화는 모델의 입력과 출력 사이에 복잡한 매핑을 생성하는 것입니다. # 비선형 활성화는 선형 변환 후에 적용되어 비선형성을 도입함으로써 신경망이 다양한 현상을 학습할 수 있도록 도와줍니다. nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10) ) # 입력 데이터가 네트워크를 거치는 방법을 정의합니다.(연산이 이루어집니다.) def forward(self, x): x = self.flatten(x) logits = self.linear_relu_stack(x) return logits # 신경망의 연산을 가속화하기 위해 가능한 경우 신경망을 앞에서 설정한 장치로 이동합니다. model = NeuralNetwork().to(device) print(model) |
Using cuda device
NeuralNetwork(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
)
모델 파라미터 최적화
단일 훈련 루프에서 모델은 훈련 데이터 세트(배치로 제공)를 예측하고 예측 오류를 역전파하여 모델의 매개 변수를 조정하는 train 함수를 정의합니다.
def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) # 훈련을 진행합니다. model.train() for batch, (X, y) in enumerate(dataloader): # 데이터를 장치로 옮깁니다. X, y = X.to(device), y.to(device) # 입력 X로부터 출력을 예측합니다. pred = model(X) # 예측 오류(prediction error)를 계산합니다. loss = loss_fn(pred, y) # 손실에 대한 역전파(Backpropagation)를 진행하여 각 파라미터에 대한 손실의 기울기를 저장합니다. loss.backward() # 기울기를 사용하여 파라미터를 조정합니다. optimizer.step() optimizer.zero_grad() # 손실(loss)를 출력합니다. if batch % 100 == 0: loss, current = loss.item(), (batch + 1) * len(X) print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") |
테스트 데이터 세트에서 추론을 진행하여 모델의 성능을 확인하여 학습이 제대로 이루어지고 있는지 확인하는 test 함수를 정의합니다.
def test(dataloader, model, loss_fn): size = len(dataloader.dataset) num_batches = len(dataloader) # 신경망 모델을 평가 모드로 전환합니다. model.eval() test_loss, correct = 0, 0 # 모델을 평가하기 전에 그래디언트 계산을 비활성화합니다. # 추론시에는 순방향 패스만 수행하기 때문에 그래디언트 계산을 비활성화하면 계산 속도를 높이고 메모리를 절약할 수 있습니다. with torch.no_grad(): for X, y in dataloader: # 데이터를 장치로 옮깁니다. X, y = X.to(device), y.to(device) # 입력 X로부터 출력을 예측합니다. pred = model(X) # 예측 오류(prediction error)를 계산합니다. test_loss += loss_fn(pred, y).item() # 정답을 맞춘 개수를 누적합니다. correct += (pred.argmax(1) == y).type(torch.float).sum().item() test_loss /= num_batches correct /= size # 정확도(accuracy, 전체 중 정답 맞춘개수)와 손실(loss)를 출력합니다. print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n") |
모델을 훈련하기 위해 필요한 Loss와 옵티마이저를 초기화하고 train 함수와 test 함수에 전달합니다.
optimization 루프를 통해 모델을 훈련하고 최적화할 수 있습니다. optimization 루프의 각 반복을 에포크라고 합니다. 각 에포크는 크게 두 부분으로 구성됩니다: 훈련 루프 - 훈련 데이터 세트를 반복하여 최적의 파라미터로 수렴을 시도합니다. - 검증/테스트 루프 - 테스트 데이터 세트를 반복하여 모델 성능이 개선되고 있는지 확인합니다
학습 데이터가 주어지면 학습되지 않은 네트워크는 정답을 제시하지 못할 가능성이 높습니다. 손실 함수는 얻은 결과와 목표 값의 불일치 정도를 측정하는 것으로, 학습 과정에서 최소화하고자 하는 것이 바로 손실 함수입니다. 손실 함수를 계산하기 위해 주어진 데이터 샘플의 입력을 사용하여 예측을 하고 이를 실제 데이터 레이블 값과 비교합니다. 모델의 출력 로짓을 nn.CrossEntropyLoss로 전달하면 로짓을 정규화하고 예측 오차를 계산합니다.
Optimization는 각 학습 단계에서 모델 오류를 줄이기 위해 모델 파라미터를 조정하는 프로세스입니다. Optimization 알고리즘은 이 프로세스가 수행되는 방식을 정의합니다(이 예에서는 확률적 경사 하강을 사용합니다). 모든 Optimization 로직은 optimizer 객체에 캡슐화되어 있습니다. 여기서는 SGD 옵티마이저를 사용하지만, 다양한 종류의 모델과 데이터에 더 잘 작동하는 ADAM 및 RMSProp과 같은 다양한 옵티마이저를 PyTorch에서 사용할 수 있습니다. 학습해야 하는 모델의 파라미터를 등록하고 학습 속도 하이퍼파라미터를 전달하여 optimizer를 초기화합니다.
loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) |
훈련 과정은 여러 번의 반복(에포크)에 걸쳐 진행됩니다. 각 에포크 동안 모델은 더 나은 예측을 위해 파라미터를 학습합니다. 아래 코드 실행결과 각 에포크에서 모델의 정확도와 손실을 출력하고 있습니다. 학습이 잘 이루어졌다면 학습이 진행됨에 따라 정확도는 증가하고 손실은 감소해야 합니다.
epochs = 20 for t in range(epochs): print(f"Epoch {t+1}\n-------------------------------") train(train_dataloader, model, loss_fn, optimizer) test(test_dataloader, model, loss_fn) print("Done!") |
Epoch 1
-------------------------------
loss: 2.297730 [ 64/60000]
loss: 2.297795 [ 6464/60000]
loss: 2.267863 [12864/60000]
loss: 2.264452 [19264/60000]
loss: 2.256921 [25664/60000]
loss: 2.216290 [32064/60000]
loss: 2.230775 [38464/60000]
loss: 2.189008 [44864/60000]
loss: 2.187047 [51264/60000]
loss: 2.153267 [57664/60000]
Test Error:
Accuracy: 22.9%, Avg loss: 2.152832
Epoch 2
-------------------------------
loss: 2.160404 [ 64/60000]
loss: 2.161666 [ 6464/60000]
loss: 2.101424 [12864/60000]
loss: 2.121828 [19264/60000]
loss: 2.077269 [25664/60000]
loss: 2.012324 [32064/60000]
loss: 2.043198 [38464/60000]
loss: 1.961496 [44864/60000]
loss: 1.971484 [51264/60000]
loss: 1.898289 [57664/60000]
Test Error:
Accuracy: 56.3%, Avg loss: 1.903687
Epoch 3
-------------------------------
loss: 1.932652 [ 64/60000]
loss: 1.911246 [ 6464/60000]
loss: 1.803332 [12864/60000]
loss: 1.840628 [19264/60000]
loss: 1.729639 [25664/60000]
loss: 1.684043 [32064/60000]
loss: 1.698764 [38464/60000]
loss: 1.595764 [44864/60000]
loss: 1.625371 [51264/60000]
loss: 1.514052 [57664/60000]
Test Error:
Accuracy: 60.3%, Avg loss: 1.537346
Epoch 4
-------------------------------
loss: 1.602479 [ 64/60000]
loss: 1.567376 [ 6464/60000]
loss: 1.427360 [12864/60000]
loss: 1.488943 [19264/60000]
loss: 1.366105 [25664/60000]
loss: 1.369099 [32064/60000]
loss: 1.369008 [38464/60000]
loss: 1.288888 [44864/60000]
loss: 1.337163 [51264/60000]
loss: 1.231352 [57664/60000]
Test Error:
Accuracy: 62.5%, Avg loss: 1.260026
Epoch 5
-------------------------------
loss: 1.338295 [ 64/60000]
loss: 1.315792 [ 6464/60000]
loss: 1.160664 [12864/60000]
loss: 1.255250 [19264/60000]
loss: 1.134436 [25664/60000]
loss: 1.165547 [32064/60000]
loss: 1.169178 [38464/60000]
loss: 1.102997 [44864/60000]
loss: 1.159231 [51264/60000]
loss: 1.069592 [57664/60000]
Test Error:
Accuracy: 63.9%, Avg loss: 1.091402
Epoch 6
-------------------------------
loss: 1.163400 [ 64/60000]
loss: 1.159861 [ 6464/60000]
loss: 0.988388 [12864/60000]
loss: 1.112314 [19264/60000]
loss: 0.996365 [25664/60000]
loss: 1.032441 [32064/60000]
loss: 1.048413 [38464/60000]
loss: 0.987903 [44864/60000]
loss: 1.045163 [51264/60000]
loss: 0.969288 [57664/60000]
Test Error:
Accuracy: 65.1%, Avg loss: 0.984817
Epoch 7
-------------------------------
loss: 1.043626 [ 64/60000]
loss: 1.060755 [ 6464/60000]
loss: 0.872681 [12864/60000]
loss: 1.019313 [19264/60000]
loss: 0.912454 [25664/60000]
loss: 0.940746 [32064/60000]
loss: 0.970484 [38464/60000]
loss: 0.914418 [44864/60000]
loss: 0.966572 [51264/60000]
loss: 0.902520 [57664/60000]
Test Error:
Accuracy: 66.4%, Avg loss: 0.913051
Epoch 8
-------------------------------
loss: 0.956732 [ 64/60000]
loss: 0.993266 [ 6464/60000]
loss: 0.790762 [12864/60000]
loss: 0.954502 [19264/60000]
loss: 0.857460 [25664/60000]
loss: 0.874211 [32064/60000]
loss: 0.916228 [38464/60000]
loss: 0.865656 [44864/60000]
loss: 0.909292 [51264/60000]
loss: 0.854611 [57664/60000]
Test Error:
Accuracy: 67.5%, Avg loss: 0.861493
Epoch 9
-------------------------------
loss: 0.890377 [ 64/60000]
loss: 0.943816 [ 6464/60000]
loss: 0.729730 [12864/60000]
loss: 0.906799 [19264/60000]
loss: 0.818221 [25664/60000]
loss: 0.823825 [32064/60000]
loss: 0.875575 [38464/60000]
loss: 0.832020 [44864/60000]
loss: 0.865904 [51264/60000]
loss: 0.818182 [57664/60000]
Test Error:
Accuracy: 68.6%, Avg loss: 0.822378
Epoch 10
-------------------------------
loss: 0.837673 [ 64/60000]
loss: 0.904727 [ 6464/60000]
loss: 0.682005 [12864/60000]
loss: 0.870227 [19264/60000]
loss: 0.788101 [25664/60000]
loss: 0.784589 [32064/60000]
loss: 0.843065 [38464/60000]
loss: 0.807395 [44864/60000]
loss: 0.831879 [51264/60000]
loss: 0.789202 [57664/60000]
Test Error:
Accuracy: 70.0%, Avg loss: 0.791315
Epoch 11
-------------------------------
loss: 0.794326 [ 64/60000]
loss: 0.871852 [ 6464/60000]
loss: 0.643375 [12864/60000]
loss: 0.841336 [19264/60000]
loss: 0.763481 [25664/60000]
loss: 0.753298 [32064/60000]
loss: 0.815625 [38464/60000]
loss: 0.788349 [44864/60000]
loss: 0.804235 [51264/60000]
loss: 0.764852 [57664/60000]
Test Error:
Accuracy: 71.3%, Avg loss: 0.765528
Epoch 12
-------------------------------
loss: 0.757347 [ 64/60000]
loss: 0.843171 [ 6464/60000]
loss: 0.610964 [12864/60000]
loss: 0.817892 [19264/60000]
loss: 0.742816 [25664/60000]
loss: 0.727812 [32064/60000]
loss: 0.791630 [38464/60000]
loss: 0.772294 [44864/60000]
loss: 0.781459 [51264/60000]
loss: 0.744154 [57664/60000]
Test Error:
Accuracy: 72.6%, Avg loss: 0.743515
Epoch 13
-------------------------------
loss: 0.725574 [ 64/60000]
loss: 0.817593 [ 6464/60000]
loss: 0.583425 [12864/60000]
loss: 0.798092 [19264/60000]
loss: 0.724644 [25664/60000]
loss: 0.706677 [32064/60000]
loss: 0.769809 [38464/60000]
loss: 0.758199 [44864/60000]
loss: 0.762001 [51264/60000]
loss: 0.725761 [57664/60000]
Test Error:
Accuracy: 73.5%, Avg loss: 0.724076
Epoch 14
-------------------------------
loss: 0.697702 [ 64/60000]
loss: 0.794294 [ 6464/60000]
loss: 0.559490 [12864/60000]
loss: 0.780967 [19264/60000]
loss: 0.708408 [25664/60000]
loss: 0.688571 [32064/60000]
loss: 0.749708 [38464/60000]
loss: 0.745548 [44864/60000]
loss: 0.745098 [51264/60000]
loss: 0.708950 [57664/60000]
Test Error:
Accuracy: 74.5%, Avg loss: 0.706515
Epoch 15
-------------------------------
loss: 0.672823 [ 64/60000]
loss: 0.772744 [ 6464/60000]
loss: 0.538398 [12864/60000]
loss: 0.765633 [19264/60000]
loss: 0.693701 [25664/60000]
loss: 0.672851 [32064/60000]
loss: 0.730966 [38464/60000]
loss: 0.734069 [44864/60000]
loss: 0.730207 [51264/60000]
loss: 0.693524 [57664/60000]
Test Error:
Accuracy: 75.3%, Avg loss: 0.690455
Epoch 16
-------------------------------
loss: 0.650471 [ 64/60000]
loss: 0.752743 [ 6464/60000]
loss: 0.519615 [12864/60000]
loss: 0.751554 [19264/60000]
loss: 0.680561 [25664/60000]
loss: 0.659143 [32064/60000]
loss: 0.713213 [38464/60000]
loss: 0.723637 [44864/60000]
loss: 0.717018 [51264/60000]
loss: 0.679103 [57664/60000]
Test Error:
Accuracy: 76.3%, Avg loss: 0.675680
Epoch 17
-------------------------------
loss: 0.630207 [ 64/60000]
loss: 0.734156 [ 6464/60000]
loss: 0.502918 [12864/60000]
loss: 0.738628 [19264/60000]
loss: 0.668589 [25664/60000]
loss: 0.647090 [32064/60000]
loss: 0.696467 [38464/60000]
loss: 0.714000 [44864/60000]
loss: 0.705222 [51264/60000]
loss: 0.665747 [57664/60000]
Test Error:
Accuracy: 76.8%, Avg loss: 0.662032
Epoch 18
-------------------------------
loss: 0.611751 [ 64/60000]
loss: 0.716917 [ 6464/60000]
loss: 0.487931 [12864/60000]
loss: 0.726682 [19264/60000]
loss: 0.657745 [25664/60000]
loss: 0.636389 [32064/60000]
loss: 0.680725 [38464/60000]
loss: 0.705224 [44864/60000]
loss: 0.694848 [51264/60000]
loss: 0.653201 [57664/60000]
Test Error:
Accuracy: 77.3%, Avg loss: 0.649433
Epoch 19
-------------------------------
loss: 0.594887 [ 64/60000]
loss: 0.700946 [ 6464/60000]
loss: 0.474346 [12864/60000]
loss: 0.715512 [19264/60000]
loss: 0.647860 [25664/60000]
loss: 0.626816 [32064/60000]
loss: 0.665957 [38464/60000]
loss: 0.697415 [44864/60000]
loss: 0.685761 [51264/60000]
loss: 0.641332 [57664/60000]
Test Error:
Accuracy: 77.7%, Avg loss: 0.637767
Epoch 20
-------------------------------
loss: 0.579392 [ 64/60000]
loss: 0.686109 [ 6464/60000]
loss: 0.461985 [12864/60000]
loss: 0.705048 [19264/60000]
loss: 0.638846 [25664/60000]
loss: 0.618231 [32064/60000]
loss: 0.652100 [38464/60000]
loss: 0.690476 [44864/60000]
loss: 0.677830 [51264/60000]
loss: 0.630103 [57664/60000]
Test Error:
Accuracy: 78.1%, Avg loss: 0.626963
Done!
모델 저장하기
모델을 저장하는 방법은 내부 상태 딕셔너리(모델 매개변수 포함)를 직렬화하는 것입니다.
torch.save(model.state_dict(), "model.pth") print("Saved PyTorch Model State to model.pth") |
Saved PyTorch Model State to model.pth
모델 로드하기
모델을 로드하려면 모델 구조를 다시 만들고 상태 딕셔너리를 로드하면 됩니다.
model = NeuralNetwork().to(device) model.load_state_dict(torch.load("model.pth", weights_only=True)) |
<All keys matched successfully>
이제 이 모델을 사용하여 예측을 해봅니다.
# 모델은 입력이 주어지면 다음 클래스 중 하나라고 예측합니다. classes = [ "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot", ] # 신경망 모델을 평가 모드로 전환합니다. model.eval() x, y = test_data[0][0], test_data[0][1] # 모델을 평가하기 전에 그래디언트 계산을 비활성화합니다. with torch.no_grad(): # 입력 x를 장치로 옮깁니다. x = x.to(device) # 입력 x로부터 출력을 예측합니다. pred = model(x) # 예측 결과에서 가장 높은 확률의 클래스가 무엇인지 확인합니다. predicted, actual = classes[pred[0].argmax(0)], classes[y] print(f'Predicted: "{predicted}", Actual: "{actual}"') |
Predicted: "Ankle boot", Actual: "Ankle boot"
Hyperparameters
하이퍼파라미터는 모델 최적화 프로세스를 제어할 수 있는 조정 가능한 파라미터입니다. 하이퍼파라미터 값이 다르면 모델 학습 및 수렴 속도에 영향을 미칠 수 있습니다. 학습을 위해 다음과 같은 하이퍼파라미터를 정의합니다:
에포크 수 - 데이터 세트에 대해 반복할 횟수입니다.
배치 크기- 매개변수가 업데이트되기 전에 네트워크를 통해 전파되는 데이터 샘플의 수입니다.
학습 속도 - 각 배치/에포크에서 모델 파라미터를 업데이트할 양입니다. 값이 작을수록 학습 속도가 느려지고, 값이 클수록 학습 중에 예측할 수 없는 동작이 발생할 수 있습니다.
'Deep Learning & Machine Learning > PyTorch' 카테고리의 다른 글
간단한 Pytorch 예제 설명 (0) | 2023.10.18 |
---|
시간날때마다 틈틈이 이것저것 해보며 블로그에 글을 남깁니다.
블로그의 문서는 종종 최신 버전으로 업데이트됩니다.
여유 시간이 날때 진행하는 거라 언제 진행될지는 알 수 없습니다.
영화,책, 생각등을 올리는 블로그도 운영하고 있습니다.
https://freewriting2024.tistory.com
제가 쓴 책도 한번 검토해보세요 ^^
그렇게 천천히 걸으면서도 그렇게 빨리 앞으로 나갈 수 있다는 건.
포스팅이 좋았다면 "좋아요❤️" 또는 "구독👍🏻" 해주세요!