!pip install torch torchvision
Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (1.1.0) Requirement already satisfied: torchvision in /usr/local/lib/python3.6/dist-packages (0.3.0) Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch) (1.16.4) Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchvision) (1.12.0) Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision) (4.3.0) Requirement already satisfied: olefile in /usr/local/lib/python3.6/dist-packages (from pillow>=4.1.1->torchvision) (0.46)
# ÇÊ¿äÇÑ À̹ÌÁöµéÀ» ´Ù¿î¹Þ½À´Ï´Ù.
!rm -r images
import os
try:
os.mkdir("images")
os.mkdir("images/content")
os.mkdir("images/style")
except:
pass
!wget https://upload.wikimedia.org/wikipedia/commons/0/00/Tuebingen_Neckarfront.jpg -P images/content
!wget https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg -P images/style
--2019-06-03 12:33:20-- https://upload.wikimedia.org/wikipedia/commons/0/00/Tuebingen_Neckarfront.jpg Resolving upload.wikimedia.org (upload.wikimedia.org)... 103.102.166.240, 2001:df2:e500:ed1a::2:b Connecting to upload.wikimedia.org (upload.wikimedia.org)|103.102.166.240|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 406531 (397K) [image/jpeg] Saving to: ¡®images/content/Tuebingen_Neckarfront.jpg¡¯ Tuebingen_Neckarfro 100%[===================>] 397.00K --.-KB/s in 0.01s 2019-06-03 12:33:20 (29.3 MB/s) - ¡®images/content/Tuebingen_Neckarfront.jpg¡¯ saved [406531/406531] --2019-06-03 12:33:24-- https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg Resolving upload.wikimedia.org (upload.wikimedia.org)... 103.102.166.240, 2001:df2:e500:ed1a::2:b Connecting to upload.wikimedia.org (upload.wikimedia.org)|103.102.166.240|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 613563 (599K) [image/jpeg] Saving to: ¡®images/style/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg¡¯ 1280px-Van_Gogh_-_S 100%[===================>] 599.18K --.-KB/s in 0.02s 2019-06-03 12:33:25 (29.3 MB/s) - ¡®images/style/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg¡¯ saved [613563/613563]
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils as utils
import torch.utils.data as data
import torchvision.models as models
import torchvision.utils as v_utils
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
%matplotlib inline
# ÄÁÅÙÃ÷ ¼Õ½ÇÀ» ¾î´À ÁöÁ¡¿¡¼ ¸ÂÃâ°ÍÀÎÁö ÁöÁ¤Çسõ½À´Ï´Ù.
content_layer_num = 1
image_size = 512
epoch = 5000
content_dir = "./images/content/Tuebingen_Neckarfront.jpg"
style_dir = "./images/style/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg"
# ÀÌ¹Ì ÇнÀµÈ ResNet ¸ðµ¨ÀÌ À̹ÌÁö³ÝÀ¸·Î ÇнÀµÈ ¸ðµ¨À̱⠶§¹®¿¡ ÀÌ¿¡ µû¶ó Á¤±ÔÈÇØÁÝ´Ï´Ù.
def image_preprocess(img_dir):
img = Image.open(img_dir)
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.40760392, 0.45795686, 0.48501961],
std=[1,1,1]),
])
img = transform(img).view((-1,3,image_size,image_size))
return img
# Á¤±ÔÈ µÈ »óÅ·Π¿¬»êÀ» ÁøÇàÇÏ°í ´Ù½Ã À̹ÌÁöÈ ÇØ¼ º¸±âÀ§ÇØ »°´ø °ªµéÀ» ´Ù½Ã ´õÇØÁÝ´Ï´Ù.
# ¶ÇÇÑ À̹ÌÁö°¡ 0¿¡¼ 1»çÀÌÀÇ °ªÀ» °¡Áö°Ô ÇØÁÝ´Ï´Ù.
def image_postprocess(tensor):
transform = transforms.Normalize(mean=[-0.40760392, -0.45795686, -0.48501961],
std=[1,1,1])
img = transform(tensor.clone())
img = img.clamp(0,1)
img = torch.transpose(img,0,1)
img = torch.transpose(img,1,2)
return img
# ¹Ì¸® ÇнÀµÈ resnet50¸¦ »ç¿ëÇÕ´Ï´Ù.
resnet = models.resnet50(pretrained=True)
for name,module in resnet.named_children():
print(name)
conv1 bn1 relu maxpool layer1 layer2 layer3 layer4 avgpool fc
# ·¹À̾´Ù °á°ú°ªÀ» °¡Á®¿Ã ¼ö ÀÖ°Ô forward¸¦ Á¤ÀÇÇÕ´Ï´Ù.
class Resnet(nn.Module):
def __init__(self):
super(Resnet,self).__init__()
self.layer0 = nn.Sequential(*list(resnet.children())[0:1])
self.layer1 = nn.Sequential(*list(resnet.children())[1:4])
self.layer2 = nn.Sequential(*list(resnet.children())[4:5])
self.layer3 = nn.Sequential(*list(resnet.children())[5:6])
self.layer4 = nn.Sequential(*list(resnet.children())[6:7])
self.layer5 = nn.Sequential(*list(resnet.children())[7:8])
def forward(self,x):
out_0 = self.layer0(x)
out_1 = self.layer1(out_0)
out_2 = self.layer2(out_1)
out_3 = self.layer3(out_2)
out_4 = self.layer4(out_3)
out_5 = self.layer5(out_4)
return out_0, out_1, out_2, out_3, out_4, out_5
# ±×¶÷ Çà·ÄÀ» »ý¼ºÇϴ Ŭ·¡½º ¹× ÇÔ¼ö¸¦ Á¤ÀÇÇÕ´Ï´Ù.
# [batch,channel,height,width] -> [b,c,h*w]
# [b,c,h*w] x [b,h*w,c] = [b,c,c]
class GramMatrix(nn.Module):
def forward(self, input):
b,c,h,w = input.size()
F = input.view(b, c, h*w)
G = torch.bmm(F, F.transpose(1,2))
return G
# ¸ðµ¨À» ÇнÀÀÇ ´ë»óÀÌ ¾Æ´Ï±â ¶§¹®¿¡ requires_grad¸¦ False·Î ¼³Á¤ÇÕ´Ï´Ù.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
resnet = Resnet().to(device)
for param in resnet.parameters():
param.requires_grad = False
cuda:0
# ±×¶÷Çà·Ä°£ÀÇ ¼Õ½ÇÀ» °è»êÇϴ Ŭ·¡½º ¹× ÇÔ¼ö¸¦ Á¤ÀÇÇÕ´Ï´Ù.
class GramMSELoss(nn.Module):
def forward(self, input, target):
out = nn.MSELoss()(GramMatrix()(input), target)
return out
# ÄÁÅÙÃ÷ À̹ÌÁö, ½ºÅ¸ÀÏ À̹ÌÁö, ÇнÀÀÇ ´ë»óÀÌ µÇ´Â À̹ÌÁö¸¦ Á¤ÀÇÇÕ´Ï´Ù.
content = image_preprocess(content_dir).to(device)
style = image_preprocess(style_dir).to(device)
generated = content.clone().requires_grad_().to(device)
print(content.requires_grad,style.requires_grad,generated.requires_grad)
# °¢°¢À» ½Ã°¢È ÇÕ´Ï´Ù.
plt.imshow(image_postprocess(content[0].cpu()))
plt.show()
plt.imshow(image_postprocess(style[0].cpu()))
plt.show()
gen_img = image_postprocess(generated[0].cpu()).data.numpy()
plt.imshow(gen_img)
plt.show()
False False True
# ¸ñÇ¥°ªÀ» ¼³Á¤Çϰí Çà·ÄÀÇ Å©±â¿¡ µû¸¥ °¡ÁßÄ¡µµ ÇÔ²² Á¤ÀÇÇØ³õ½À´Ï´Ù
style_target = list(GramMatrix().to(device)(i) for i in resnet(style))
content_target = resnet(content)[content_layer_num]
style_weight = [1/n**2 for n in [64,64,256,512,1024,2048]]
# LBFGS ÃÖÀûÈ ÇÔ¼ö¸¦ »ç¿ëÇÕ´Ï´Ù.
# À̶§ ÇнÀÀÇ ´ë»óÀº ¸ðµ¨ÀÇ °¡ÁßÄ¡°¡ ¾Æ´Ñ À̹ÌÁö ÀÚüÀÔ´Ï´Ù.
# for more info about LBFGS -> http://pytorch.org/docs/optim.html?highlight=lbfgs#torch.optim.LBFGS
optimizer = optim.LBFGS([generated])
iteration = [0]
while iteration[0] < epoch:
def closure():
optimizer.zero_grad()
out = resnet(generated)
# ½ºÅ¸ÀÏ ¼Õ½ÇÀ» °¢°¢ÀÇ ¸ñÇ¥°ª¿¡ µû¶ó °è»êÇϰí À̸¦ ¸®½ºÆ®·Î ÀúÀåÇÕ´Ï´Ù.
style_loss = [GramMSELoss().to(device)(out[i],style_target[i])*style_weight[i] for i in range(len(style_target))]
# ÄÁÅÙÃ÷ ¼Õ½ÇÀº ÁöÁ¤ÇÑ À§Ä¡¿¡¼¸¸ °è»êµÇ¹Ç·Î ÇϳªÀÇ ¼öÄ¡·Î ÀúÀåµË´Ï´Ù.
content_loss = nn.MSELoss().to(device)(out[content_layer_num],content_target)
# ½ºÅ¸ÀÏ:ÄÁÅÙÃ÷ = 1000:1ÀÇ ºñÁßÀ¸·Î ÃÑ ¼Õ½ÇÀ» °è»êÇÕ´Ï´Ù.
total_loss = 1000 * sum(style_loss) + torch.sum(content_loss)
total_loss.backward()
if iteration[0] % 100 == 0:
print(total_loss)
iteration[0] += 1
return total_loss
optimizer.step(closure)
tensor(1507036.8750, device='cuda:0', grad_fn=<AddBackward0>) tensor(43.4819, device='cuda:0', grad_fn=<AddBackward0>) tensor(7.3298, device='cuda:0', grad_fn=<AddBackward0>) tensor(2.6613, device='cuda:0', grad_fn=<AddBackward0>) tensor(1.4060, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.9018, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.6482, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.4989, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.4054, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.3458, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.3017, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.2678, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.2421, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.2226, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.2070, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1936, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1823, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1735, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1657, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1588, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1530, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1480, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1437, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1396, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1361, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1329, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1300, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1273, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1250, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1227, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1206, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1187, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1169, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1153, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1138, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1123, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1110, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1098, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1086, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1075, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1064, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1054, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1045, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1035, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1026, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1018, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1010, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.1003, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.0995, device='cuda:0', grad_fn=<AddBackward0>) tensor(0.0988, device='cuda:0', grad_fn=<AddBackward0>)
# ÇнÀµÈ °á°ú À̹ÌÁö¸¦ È®ÀÎÇÕ´Ï´Ù.
gen_img = image_postprocess(generated[0].cpu()).data.numpy()
plt.figure(figsize=(10,10))
plt.imshow(gen_img)
plt.show()