PyTorch Style Transfer Sample Code

July 26th, 2023

PyTorch Style Transfer Sample Code

PyTorch Style Transfer Sample Code

code

Pytorch Example of Style Transfer

Here's an example of a Python implementation for style transfer in generative AI art using the PyTorch framework:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt

# Load content and style images
content_image = Image.open('content.jpg')
style_image = Image.open('style.jpg')

# Define image transformations
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Preprocess images
content_tensor = transform(content_image).unsqueeze(0)
style_tensor = transform(style_image).unsqueeze(0)

# The VGG19 model with adjusted layers
class VGG19(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = models.vgg19(pretrained=True).features
        self.slice1 = vgg[:2]
        self.slice2 = vgg[2:9]
        self.slice3 = vgg[9:16]
        self.slice4 = vgg[16:23]
        self.slice5 = vgg[23:30]

    def forward(self, x):
        x1 = self.slice1(x)
        x2 = self.slice2(x1)
        x3 = self.slice3(x2)
        x4 = self.slice4(x3)
        x5 = self.slice5(x4)
        return [x1, x2, x3, x4, x5]

# Loss Functions
def gram_matrix(x):
    _, c, h, w = x.size()
    features = x.view(c, h * w)
    G = torch.mm(features, features.t())
    return G
    
def content_loss(gen_features, content_features):
    return torch.mean((gen_features - content_features)**2)

def style_loss(gen_gram, style_gram):
    return torch.mean((gen_gram - style_gram)**2)

# Training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
content_tensor = content_tensor.to(device)
style_tensor = style_tensor.to(device)

vgg = VGG19().to(device)
for param in vgg.parameters():
    param.requires_grad_(False)
    
generated_img = content_tensor.clone().requires_grad_(True)

optimizer = optim.Adam([generated_img], lr=0.01)

n_epochs = 5000

for epoch in range(n_epochs):
    gen_features = vgg(generated_img)
    content_features = vgg(content_tensor)
    style_features = vgg(style_tensor)
    
    content_l = content_loss(gen_features[3], content_features[3])
    style_l = 0
    
    for gf, sf in zip(gen_features, style_features):
        gm_gen = gram_matrix(gf)
        gm_style = gram_matrix(sf)
        
        style_l += style_loss(gm_gen, gm_style)
        
    total_loss = content_l + 1e5 * style_l
    
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        print('Epoch: [%d/%d], Loss: %.4f' % (epoch, n_epochs, total_loss))

# Postprocess the generated image
generated_np = generated_img.squeeze(0).cpu().detach().numpy().transpose(1, 2, 0)
generated_np = (generated_np * (0.229, 0.224, 0.225) + (0.485, 0.456, 0.406)) * 255

# Show the generated image
plt.imshow(generated_np.clip(0, 255).astype('uint8'))
plt.show()

creata ai pytorch sample code for style transfer

Conclusion

In this example, we use the VGG19 model for the style transfer. The input images are styled by minimizing the content loss between the generated image and the content image, and the style loss between the generated image and the style image. The generated image is then displayed.

the complete code is available on github:

style transfer with pytorch

Other articles

September 23rd, 2023

Unveiling Voice Cloning: The Process and AI Involvement

ificial Intelligence, and its diverse applications and implications. read more...

September 5th, 2023

This Lora makes REAL ART – XL_more_art-full for SDXL

rained on all kinds of … read more...

September 26th, 2023

Guide to Using GPT-4 API for Task Management: Creating Weekly Schedules

through creating a weekly schedule using AI. read more...