!pip install torch torchvision
Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (1.1.0) Requirement already satisfied: torchvision in /usr/local/lib/python3.6/dist-packages (0.3.0) Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch) (1.16.4) Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchvision) (1.12.0) Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision) (4.3.0) Requirement already satisfied: olefile in /usr/local/lib/python3.6/dist-packages (from pillow>=4.1.1->torchvision) (0.46)
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
batch_size = 256
learning_rate = 0.0002
num_epoch = 10
mnist_train = dset.MNIST("./", train=True, transform=transforms.ToTensor(), target_transform=None, download=True)
mnist_test = dset.MNIST("./", train=False, transform=transforms.ToTensor(), target_transform=None, download=True)
print(mnist_train.__getitem__(0)[0].size(), mnist_train.__len__())
mnist_test.__getitem__(0)[0].size(), mnist_test.__len__()
torch.Size([1, 28, 28]) 60000
(torch.Size([1, 28, 28]), 10000)
train_loader = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size, shuffle=True,num_workers=2,drop_last=True)
test_loader = torch.utils.data.DataLoader(mnist_test,batch_size=batch_size, shuffle=False,num_workers=2,drop_last=True)
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__()
self.layer = nn.Sequential(
nn.Conv2d(1,16,3,padding=1), # 28 x 28
nn.ReLU(),
nn.Conv2d(16,32,3,padding=1), # 28 x 28
nn.ReLU(),
nn.MaxPool2d(2,2), # 14 x 14
nn.Conv2d(32,64,3,padding=1), # 14 x 14
nn.ReLU(),
nn.MaxPool2d(2,2) # 7 x 7
)
self.fc_layer = nn.Sequential(
nn.Linear(64*7*7,100),
nn.ReLU(),
nn.Linear(100,10)
)
# ÃʱâÈ ÇÏ´Â ¹æ¹ý
# ¸ðµ¨ÀÇ ¸ðµâÀ» Â÷·Ê´ë·Î ºÒ·¯¿É´Ï´Ù.
for m in self.modules():
# ¸¸¾à ±× ¸ðµâÀÌ nn.Conv2dÀÎ °æ¿ì
if isinstance(m, nn.Conv2d):
'''
# ÀÛÀº ¼ýÀÚ·Î ÃʱâÈÇÏ´Â ¹æ¹ý
# °¡ÁßÄ¡¸¦ Æò±Õ 0, ÆíÂ÷ 0.02·Î ÃʱâÈÇÕ´Ï´Ù.
# ÆíÂ÷¸¦ 0À¸·Î ÃʱâÈÇÕ´Ï´Ù.
m.weight.data.normal_(0.0, 0.02)
m.bias.data.fill_(0)
# Xavier Initialization
# ¸ðµâÀÇ °¡ÁßÄ¡¸¦ xavier normal·Î ÃʱâÈÇÕ´Ï´Ù.
# ÆíÂ÷¸¦ 0À¸·Î ÃʱâÈÇÕ´Ï´Ù.
init.xavier_normal(m.weight.data)
m.bias.data.fill_(0)
'''
# Kaming Initialization
# ¸ðµâÀÇ °¡ÁßÄ¡¸¦ kaming he normal·Î ÃʱâÈÇÕ´Ï´Ù.
# ÆíÂ÷¸¦ 0À¸·Î ÃʱâÈÇÕ´Ï´Ù.
init.kaiming_normal_(m.weight.data)
m.bias.data.fill_(0)
# ¸¸¾à ±× ¸ðµâÀÌ nn.LinearÀÎ °æ¿ì
elif isinstance(m, nn.Linear):
'''
# ÀÛÀº ¼ýÀÚ·Î ÃʱâÈÇÏ´Â ¹æ¹ý
# °¡ÁßÄ¡¸¦ Æò±Õ 0, ÆíÂ÷ 0.02·Î ÃʱâÈÇÕ´Ï´Ù.
# ÆíÂ÷¸¦ 0À¸·Î ÃʱâÈÇÕ´Ï´Ù.
m.weight.data.normal_(0.0, 0.02)
m.bias.data.fill_(0)
# Xavier Initialization
# ¸ðµâÀÇ °¡ÁßÄ¡¸¦ xavier normal·Î ÃʱâÈÇÕ´Ï´Ù.
# ÆíÂ÷¸¦ 0À¸·Î ÃʱâÈÇÕ´Ï´Ù.
init.xavier_normal(m.weight.data)
m.bias.data.fill_(0)
'''
# Kaming Initialization
# ¸ðµâÀÇ °¡ÁßÄ¡¸¦ kaming he normal·Î ÃʱâÈÇÕ´Ï´Ù.
# ÆíÂ÷¸¦ 0À¸·Î ÃʱâÈÇÕ´Ï´Ù.
init.kaiming_normal_(m.weight.data)
m.bias.data.fill_(0)
def forward(self,x):
out = self.layer(x)
out = out.view(batch_size,-1)
out = self.fc_layer(out)
return out
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model = CNN().to(device)
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
cuda:0
for i in range(num_epoch):
for j,[image,label] in enumerate(train_loader):
x = image.to(device)
y_= label.to(device)
optimizer.zero_grad()
output = model.forward(x)
loss = loss_func(output,y_)
loss.backward()
optimizer.step()
if i % 10 == 0:
print(loss)
tensor(1.8737, device='cuda:0', grad_fn=<NllLossBackward>)
#param_list = list(model.parameters())
#print(param_list)
correct = 0
total = 0
with torch.no_grad():
for image,label in test_loader:
x = image.to(device)
y_= label.to(device)
output = model.forward(x)
_,output_index = torch.max(output,1)
total += label.size(0)
correct += (output_index == y_).sum().float()
print("Accuracy of Test Data: {}".format(100*correct/total))
Accuracy of Test Data: 86.56851196289062