In [1]:
!pip install torch torchvision
Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (1.1.0)
Requirement already satisfied: torchvision in /usr/local/lib/python3.6/dist-packages (0.3.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch) (1.16.4)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchvision) (1.12.0)
Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision) (4.3.0)
Requirement already satisfied: olefile in /usr/local/lib/python3.6/dist-packages (from pillow>=4.1.1->torchvision) (0.46)
In [2]:
# 필요한 이미지들을 다운받습니다.

!rm -r images
import os 

try:
  os.mkdir("images")
  os.mkdir("images/content")
  os.mkdir("images/style")
except:
  pass

!wget https://upload.wikimedia.org/wikipedia/commons/0/00/Tuebingen_Neckarfront.jpg -P images/content
!wget https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg -P images/style
--2019-06-03 12:33:20--  https://upload.wikimedia.org/wikipedia/commons/0/00/Tuebingen_Neckarfront.jpg
Resolving upload.wikimedia.org (upload.wikimedia.org)... 103.102.166.240, 2001:df2:e500:ed1a::2:b
Connecting to upload.wikimedia.org (upload.wikimedia.org)|103.102.166.240|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 406531 (397K) [image/jpeg]
Saving to: ‘images/content/Tuebingen_Neckarfront.jpg’

Tuebingen_Neckarfro 100%[===================>] 397.00K  --.-KB/s    in 0.01s   

2019-06-03 12:33:20 (29.3 MB/s) - ‘images/content/Tuebingen_Neckarfront.jpg’ saved [406531/406531]

--2019-06-03 12:33:24--  https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg
Resolving upload.wikimedia.org (upload.wikimedia.org)... 103.102.166.240, 2001:df2:e500:ed1a::2:b
Connecting to upload.wikimedia.org (upload.wikimedia.org)|103.102.166.240|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 613563 (599K) [image/jpeg]
Saving to: ‘images/style/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg’

1280px-Van_Gogh_-_S 100%[===================>] 599.18K  --.-KB/s    in 0.02s   

2019-06-03 12:33:25 (29.3 MB/s) - ‘images/style/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg’ saved [613563/613563]

1. Settings

1) Import required libraries

In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils as utils
import torch.utils.data as data
import torchvision.models as models
import torchvision.utils as v_utils
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
%matplotlib inline

2) Hyperparameter

In [0]:
# 컨텐츠 손실을 어느 지점에서 맞출것인지 지정해놓습니다.
content_layer_num = 1
image_size = 512
epoch = 5000

2. Data

1) Directory

In [0]:
content_dir = "./images/content/Tuebingen_Neckarfront.jpg"
style_dir = "./images/style/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg"

2) Prepocessing Function

  • 전처리 함수
In [0]:
# 이미 학습된 ResNet 모델이 이미지넷으로 학습된 모델이기 때문에 이에 따라 정규화해줍니다.

def image_preprocess(img_dir):
    img = Image.open(img_dir)
    transform = transforms.Compose([
                    transforms.Resize(image_size),
                    transforms.CenterCrop(image_size),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.40760392, 0.45795686, 0.48501961], 
                                         std=[1,1,1]),
                ])
    img = transform(img).view((-1,3,image_size,image_size))
    return img

3) Post processing Function

  • 후처리 함수
In [0]:
# 정규화 된 상태로 연산을 진행하고 다시 이미지화 해서 보기위해 뺐던 값들을 다시 더해줍니다.
# 또한 이미지가 0에서 1사이의 값을 가지게 해줍니다.

def image_postprocess(tensor):
    transform = transforms.Normalize(mean=[-0.40760392, -0.45795686, -0.48501961], 
                                     std=[1,1,1])
    img = transform(tensor.clone())
    img = img.clamp(0,1)
    img = torch.transpose(img,0,1)
    img = torch.transpose(img,1,2)
    return img

3. Model & Loss Function

1) Resnet

In [8]:
# 미리 학습된 resnet50를 사용합니다.
resnet = models.resnet50(pretrained=True)
for name,module in resnet.named_children():
    print(name)
conv1
bn1
relu
maxpool
layer1
layer2
layer3
layer4
avgpool
fc

2) Delete Fully Connected Layer

In [0]:
# 레이어마다 결과값을 가져올 수 있게 forward를 정의합니다.

class Resnet(nn.Module):
    def __init__(self):
        super(Resnet,self).__init__()
        self.layer0 = nn.Sequential(*list(resnet.children())[0:1])
        self.layer1 = nn.Sequential(*list(resnet.children())[1:4])
        self.layer2 = nn.Sequential(*list(resnet.children())[4:5])
        self.layer3 = nn.Sequential(*list(resnet.children())[5:6])
        self.layer4 = nn.Sequential(*list(resnet.children())[6:7])
        self.layer5 = nn.Sequential(*list(resnet.children())[7:8])

    def forward(self,x):
        out_0 = self.layer0(x)
        out_1 = self.layer1(out_0)
        out_2 = self.layer2(out_1)
        out_3 = self.layer3(out_2)
        out_4 = self.layer4(out_3)
        out_5 = self.layer5(out_4)
        return out_0, out_1, out_2, out_3, out_4, out_5

3) Gram Matrix Function

In [0]:
# 그람 행렬을 생성하는 클래스 및 함수를 정의합니다. 
# [batch,channel,height,width] -> [b,c,h*w]
# [b,c,h*w] x [b,h*w,c] = [b,c,c]

class GramMatrix(nn.Module):
    def forward(self, input):
        b,c,h,w = input.size()
        F = input.view(b, c, h*w)
        G = torch.bmm(F, F.transpose(1,2)) 
        return G

4) Model on GPU

In [11]:
# 모델을 학습의 대상이 아니기 때문에 requires_grad를 False로 설정합니다.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

resnet = Resnet().to(device)
for param in resnet.parameters():
    param.requires_grad = False
cuda:0

5) Gram Matrix Loss

In [0]:
# 그람행렬간의 손실을 계산하는 클래스 및 함수를 정의합니다.

class GramMSELoss(nn.Module):
    def forward(self, input, target):
        out = nn.MSELoss()(GramMatrix()(input), target)
        return out

4. Train

1) Prepare Images

In [13]:
# 컨텐츠 이미지, 스타일 이미지, 학습의 대상이 되는 이미지를 정의합니다.

content = image_preprocess(content_dir).to(device)
style = image_preprocess(style_dir).to(device)
generated = content.clone().requires_grad_().to(device)

print(content.requires_grad,style.requires_grad,generated.requires_grad)

# 각각을 시각화 합니다.

plt.imshow(image_postprocess(content[0].cpu()))
plt.show()

plt.imshow(image_postprocess(style[0].cpu()))
plt.show()

gen_img = image_postprocess(generated[0].cpu()).data.numpy()
plt.imshow(gen_img)
plt.show()
False False True

2) Set Targets & Style Weights

In [0]:
# 목표값을 설정하고 행렬의 크기에 따른 가중치도 함께 정의해놓습니다

style_target = list(GramMatrix().to(device)(i) for i in resnet(style))
content_target = resnet(content)[content_layer_num]
style_weight = [1/n**2 for n in [64,64,256,512,1024,2048]]

3) Train

In [15]:
# LBFGS 최적화 함수를 사용합니다.
# 이때 학습의 대상은 모델의 가중치가 아닌 이미지 자체입니다.
# for more info about LBFGS -> http://pytorch.org/docs/optim.html?highlight=lbfgs#torch.optim.LBFGS

optimizer = optim.LBFGS([generated])

iteration = [0]
while iteration[0] < epoch:
    def closure():
        optimizer.zero_grad()
        out = resnet(generated)
        
        # 스타일 손실을 각각의 목표값에 따라 계산하고 이를 리스트로 저장합니다.
        style_loss = [GramMSELoss().to(device)(out[i],style_target[i])*style_weight[i] for i in range(len(style_target))]
        
        # 컨텐츠 손실은 지정한 위치에서만 계산되므로 하나의 수치로 저장됩니다.
        content_loss = nn.MSELoss().to(device)(out[content_layer_num],content_target)
        
        # 스타일:컨텐츠 = 1000:1의 비중으로 총 손실을 계산합니다.
        total_loss = 1000 * sum(style_loss) + torch.sum(content_loss)
        total_loss.backward()

        if iteration[0] % 100 == 0:
            print(total_loss)
        iteration[0] += 1
        return total_loss

    optimizer.step(closure)
tensor(1507036.8750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(43.4819, device='cuda:0', grad_fn=<AddBackward0>)
tensor(7.3298, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.6613, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.4060, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.9018, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.6482, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.4989, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.4054, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.3458, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.3017, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.2678, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.2421, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.2226, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.2070, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1936, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1823, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1735, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1657, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1588, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1530, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1480, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1437, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1396, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1361, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1329, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1300, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1273, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1227, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1206, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1187, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1169, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1153, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1138, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1123, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1110, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1098, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1086, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1075, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1064, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1054, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1045, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1035, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1026, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1018, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1010, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1003, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.0995, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.0988, device='cuda:0', grad_fn=<AddBackward0>)

5. Check Results

In [16]:
# 학습된 결과 이미지를 확인합니다.

gen_img = image_postprocess(generated[0].cpu()).data.numpy()

plt.figure(figsize=(10,10))
plt.imshow(gen_img)
plt.show()