In [1]:
!pip install torch torchvision
Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (1.1.0)
Requirement already satisfied: torchvision in /usr/local/lib/python3.6/dist-packages (0.3.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch) (1.16.4)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchvision) (1.12.0)
Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision) (4.3.0)
Requirement already satisfied: olefile in /usr/local/lib/python3.6/dist-packages (from pillow>=4.1.1->torchvision) (0.46)
In [2]:
# ÇÊ¿äÇÑ À̹ÌÁöµéÀ» ´Ù¿î¹Þ½À´Ï´Ù.

!rm -r images
import os 

try:
  os.mkdir("images")
  os.mkdir("images/content")
  os.mkdir("images/style")
except:
  pass

!wget https://upload.wikimedia.org/wikipedia/commons/0/00/Tuebingen_Neckarfront.jpg -P images/content
!wget https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg -P images/style
--2019-06-03 12:33:20--  https://upload.wikimedia.org/wikipedia/commons/0/00/Tuebingen_Neckarfront.jpg
Resolving upload.wikimedia.org (upload.wikimedia.org)... 103.102.166.240, 2001:df2:e500:ed1a::2:b
Connecting to upload.wikimedia.org (upload.wikimedia.org)|103.102.166.240|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 406531 (397K) [image/jpeg]
Saving to: ¡®images/content/Tuebingen_Neckarfront.jpg¡¯

Tuebingen_Neckarfro 100%[===================>] 397.00K  --.-KB/s    in 0.01s   

2019-06-03 12:33:20 (29.3 MB/s) - ¡®images/content/Tuebingen_Neckarfront.jpg¡¯ saved [406531/406531]

--2019-06-03 12:33:24--  https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg
Resolving upload.wikimedia.org (upload.wikimedia.org)... 103.102.166.240, 2001:df2:e500:ed1a::2:b
Connecting to upload.wikimedia.org (upload.wikimedia.org)|103.102.166.240|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 613563 (599K) [image/jpeg]
Saving to: ¡®images/style/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg¡¯

1280px-Van_Gogh_-_S 100%[===================>] 599.18K  --.-KB/s    in 0.02s   

2019-06-03 12:33:25 (29.3 MB/s) - ¡®images/style/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg¡¯ saved [613563/613563]

1. Settings

1) Import required libraries

In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils as utils
import torch.utils.data as data
import torchvision.models as models
import torchvision.utils as v_utils
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
%matplotlib inline

2) Hyperparameter

In [0]:
# ÄÁÅÙÃ÷ ¼Õ½ÇÀ» ¾î´À ÁöÁ¡¿¡¼­ ¸ÂÃâ°ÍÀÎÁö ÁöÁ¤Çسõ½À´Ï´Ù.
content_layer_num = 1
image_size = 512
epoch = 5000

2. Data

1) Directory

In [0]:
content_dir = "./images/content/Tuebingen_Neckarfront.jpg"
style_dir = "./images/style/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg"

2) Prepocessing Function

  • Àüó¸® ÇÔ¼ö
In [0]:
# ÀÌ¹Ì ÇнÀµÈ ResNet ¸ðµ¨ÀÌ À̹ÌÁö³ÝÀ¸·Î ÇнÀµÈ ¸ðµ¨À̱⠶§¹®¿¡ ÀÌ¿¡ µû¶ó Á¤±ÔÈ­ÇØÁÝ´Ï´Ù.

def image_preprocess(img_dir):
    img = Image.open(img_dir)
    transform = transforms.Compose([
                    transforms.Resize(image_size),
                    transforms.CenterCrop(image_size),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.40760392, 0.45795686, 0.48501961], 
                                         std=[1,1,1]),
                ])
    img = transform(img).view((-1,3,image_size,image_size))
    return img

3) Post processing Function

  • ÈÄó¸® ÇÔ¼ö
In [0]:
# Á¤±ÔÈ­ µÈ »óÅ·Π¿¬»êÀ» ÁøÇàÇÏ°í ´Ù½Ã À̹ÌÁöÈ­ Çؼ­ º¸±âÀ§ÇØ »°´ø °ªµéÀ» ´Ù½Ã ´õÇØÁÝ´Ï´Ù.
# ¶ÇÇÑ À̹ÌÁö°¡ 0¿¡¼­ 1»çÀÌÀÇ °ªÀ» °¡Áö°Ô ÇØÁÝ´Ï´Ù.

def image_postprocess(tensor):
    transform = transforms.Normalize(mean=[-0.40760392, -0.45795686, -0.48501961], 
                                     std=[1,1,1])
    img = transform(tensor.clone())
    img = img.clamp(0,1)
    img = torch.transpose(img,0,1)
    img = torch.transpose(img,1,2)
    return img

3. Model & Loss Function

1) Resnet

In [8]:
# ¹Ì¸® ÇнÀµÈ resnet50¸¦ »ç¿ëÇÕ´Ï´Ù.
resnet = models.resnet50(pretrained=True)
for name,module in resnet.named_children():
    print(name)
conv1
bn1
relu
maxpool
layer1
layer2
layer3
layer4
avgpool
fc

2) Delete Fully Connected Layer

In [0]:
# ·¹À̾´Ù °á°ú°ªÀ» °¡Á®¿Ã ¼ö ÀÖ°Ô forward¸¦ Á¤ÀÇÇÕ´Ï´Ù.

class Resnet(nn.Module):
    def __init__(self):
        super(Resnet,self).__init__()
        self.layer0 = nn.Sequential(*list(resnet.children())[0:1])
        self.layer1 = nn.Sequential(*list(resnet.children())[1:4])
        self.layer2 = nn.Sequential(*list(resnet.children())[4:5])
        self.layer3 = nn.Sequential(*list(resnet.children())[5:6])
        self.layer4 = nn.Sequential(*list(resnet.children())[6:7])
        self.layer5 = nn.Sequential(*list(resnet.children())[7:8])

    def forward(self,x):
        out_0 = self.layer0(x)
        out_1 = self.layer1(out_0)
        out_2 = self.layer2(out_1)
        out_3 = self.layer3(out_2)
        out_4 = self.layer4(out_3)
        out_5 = self.layer5(out_4)
        return out_0, out_1, out_2, out_3, out_4, out_5

3) Gram Matrix Function

In [0]:
# ±×¶÷ Çà·ÄÀ» »ý¼ºÇϴ Ŭ·¡½º ¹× ÇÔ¼ö¸¦ Á¤ÀÇÇÕ´Ï´Ù. 
# [batch,channel,height,width] -> [b,c,h*w]
# [b,c,h*w] x [b,h*w,c] = [b,c,c]

class GramMatrix(nn.Module):
    def forward(self, input):
        b,c,h,w = input.size()
        F = input.view(b, c, h*w)
        G = torch.bmm(F, F.transpose(1,2)) 
        return G

4) Model on GPU

In [11]:
# ¸ðµ¨À» ÇнÀÀÇ ´ë»óÀÌ ¾Æ´Ï±â ¶§¹®¿¡ requires_grad¸¦ False·Î ¼³Á¤ÇÕ´Ï´Ù.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

resnet = Resnet().to(device)
for param in resnet.parameters():
    param.requires_grad = False
cuda:0

5) Gram Matrix Loss

In [0]:
# ±×¶÷Çà·Ä°£ÀÇ ¼Õ½ÇÀ» °è»êÇϴ Ŭ·¡½º ¹× ÇÔ¼ö¸¦ Á¤ÀÇÇÕ´Ï´Ù.

class GramMSELoss(nn.Module):
    def forward(self, input, target):
        out = nn.MSELoss()(GramMatrix()(input), target)
        return out

4. Train

1) Prepare Images

In [13]:
# ÄÁÅÙÃ÷ À̹ÌÁö, ½ºÅ¸ÀÏ À̹ÌÁö, ÇнÀÀÇ ´ë»óÀÌ µÇ´Â À̹ÌÁö¸¦ Á¤ÀÇÇÕ´Ï´Ù.

content = image_preprocess(content_dir).to(device)
style = image_preprocess(style_dir).to(device)
generated = content.clone().requires_grad_().to(device)

print(content.requires_grad,style.requires_grad,generated.requires_grad)

# °¢°¢À» ½Ã°¢È­ ÇÕ´Ï´Ù.

plt.imshow(image_postprocess(content[0].cpu()))
plt.show()

plt.imshow(image_postprocess(style[0].cpu()))
plt.show()

gen_img = image_postprocess(generated[0].cpu()).data.numpy()
plt.imshow(gen_img)
plt.show()
False False True

2) Set Targets & Style Weights

In [0]:
# ¸ñÇ¥°ªÀ» ¼³Á¤ÇÏ°í Çà·ÄÀÇ Å©±â¿¡ µû¸¥ °¡ÁßÄ¡µµ ÇÔ²² Á¤ÀÇÇسõ½À´Ï´Ù

style_target = list(GramMatrix().to(device)(i) for i in resnet(style))
content_target = resnet(content)[content_layer_num]
style_weight = [1/n**2 for n in [64,64,256,512,1024,2048]]

3) Train

In [15]:
# LBFGS ÃÖÀûÈ­ ÇÔ¼ö¸¦ »ç¿ëÇÕ´Ï´Ù.
# À̶§ ÇнÀÀÇ ´ë»óÀº ¸ðµ¨ÀÇ °¡ÁßÄ¡°¡ ¾Æ´Ñ À̹ÌÁö ÀÚüÀÔ´Ï´Ù.
# for more info about LBFGS -> http://pytorch.org/docs/optim.html?highlight=lbfgs#torch.optim.LBFGS

optimizer = optim.LBFGS([generated])

iteration = [0]
while iteration[0] < epoch:
    def closure():
        optimizer.zero_grad()
        out = resnet(generated)
        
        # ½ºÅ¸ÀÏ ¼Õ½ÇÀ» °¢°¢ÀÇ ¸ñÇ¥°ª¿¡ µû¶ó °è»êÇÏ°í À̸¦ ¸®½ºÆ®·Î ÀúÀåÇÕ´Ï´Ù.
        style_loss = [GramMSELoss().to(device)(out[i],style_target[i])*style_weight[i] for i in range(len(style_target))]
        
        # ÄÁÅÙÃ÷ ¼Õ½ÇÀº ÁöÁ¤ÇÑ À§Ä¡¿¡¼­¸¸ °è»êµÇ¹Ç·Î ÇϳªÀÇ ¼öÄ¡·Î ÀúÀåµË´Ï´Ù.
        content_loss = nn.MSELoss().to(device)(out[content_layer_num],content_target)
        
        # ½ºÅ¸ÀÏ:ÄÁÅÙÃ÷ = 1000:1ÀÇ ºñÁßÀ¸·Î ÃÑ ¼Õ½ÇÀ» °è»êÇÕ´Ï´Ù.
        total_loss = 1000 * sum(style_loss) + torch.sum(content_loss)
        total_loss.backward()

        if iteration[0] % 100 == 0:
            print(total_loss)
        iteration[0] += 1
        return total_loss

    optimizer.step(closure)
tensor(1507036.8750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(43.4819, device='cuda:0', grad_fn=<AddBackward0>)
tensor(7.3298, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.6613, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.4060, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.9018, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.6482, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.4989, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.4054, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.3458, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.3017, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.2678, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.2421, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.2226, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.2070, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1936, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1823, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1735, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1657, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1588, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1530, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1480, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1437, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1396, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1361, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1329, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1300, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1273, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1227, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1206, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1187, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1169, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1153, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1138, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1123, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1110, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1098, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1086, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1075, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1064, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1054, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1045, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1035, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1026, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1018, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1010, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1003, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.0995, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.0988, device='cuda:0', grad_fn=<AddBackward0>)

5. Check Results

In [16]:
# ÇнÀµÈ °á°ú À̹ÌÁö¸¦ È®ÀÎÇÕ´Ï´Ù.

gen_img = image_postprocess(generated[0].cpu()).data.numpy()

plt.figure(figsize=(10,10))
plt.imshow(gen_img)
plt.show()
In [0]: