diff --git a/C2W3_FreezeD_(Optional).ipynb b/C2W3_FreezeD_(Optional).ipynb new file mode 100644 index 0000000..9d45575 --- /dev/null +++ b/C2W3_FreezeD_(Optional).ipynb @@ -0,0 +1,1232 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "C2W3: FreezeD (Optional).ipynb", + "provenance": [], + "collapsed_sections": [], + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kz7GMf9fruXG" + }, + "source": [ + "# Freeze the Discriminator (FreezeD)\n", + "*Please note that this is an optional notebook meant to introduce more advanced concepts. If you’re up for a challenge, take a look and don’t worry if you can’t follow everything. There is no code to implement—only some cool code for you to learn and run!*\n", + "\n", + "### Goals\n", + "In this notebook, you'll learn about and implement the fine-tuning approach proposed in [Freeze the Discriminator: a Simple Baseline for Fine-Tuning GANs](https://arxiv.org/abs/2002.10964) (Mo et al. 2020), which introduces the concept of freezing the upper layers of the discriminator in fine-tuning. Specifically, you'll fine-tune a pretrained StyleGAN to generate anime faces from human faces.\n", + "\n", + "### Background\n", + "\n", + "What's attractive about this new proposed baseline is that it is much simpler than existing methods for fine-tuning GANs that circumvent the issues of overfitting and low-fidelity samples. There have been numerous approaches to fine-tuning GANs on test data of different distributions than training data. Some of these include:\n", + "\n", + "1. Fine-tuning the generator and discriminator without freezing any layers. This shows decent performance but also prone to a significant amount of overfitting, especially since most fine-funing datasets are smaller.\n", + "\n", + "2. Fine-tuning only scale and shift parameters (i.e. batch normalization layers) while freezing all other weights. The idea is to accomodate for different statistics of fine-tuning data, but this restriction reduces the ability of the models to fine-tune, especially when the data distribution is significantly different.\n", + "\n", + "3. Generative latent optimization (GLO), or optimizing L1 image and perceptual losses for the generator. Because the discriminator can be unreliable for limited data, this method removes the discriminator for fine-tuning. However, L1 loss tends to produce blurry images, which is a major drawback for GLO." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wgRdZW0VtZvx" + }, + "source": [ + "## Getting Started\n", + "\n", + "You will begin by importing some packages from PyTorch and defining a visualization function which will be useful later." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "WjkUpQFKtgQz" + }, + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from torchvision.utils import make_grid\n", + "\n", + "def show_tensor_images(image_tensor, num_images=16, size=(3, 64, 64)):\n", + " '''\n", + " Function for visualizing images: Given a tensor of images, number of images,\n", + " size per image, and images per row, plots and prints the images in an uniform grid.\n", + " '''\n", + " image_tensor = (image_tensor + 1) / 2\n", + " image_unflat = image_tensor.detach().cpu().clamp_(0, 1)\n", + " image_grid = make_grid(image_unflat[:num_images], nrow=4, padding=0)\n", + " plt.imshow(image_grid.permute(1, 2, 0).squeeze())\n", + " plt.axis('off')\n", + " plt.show()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7dDgcuDayiwL" + }, + "source": [ + "Now let's take a look at the code to preprocess and prepare the data." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "g8XbTIsYy_vb" + }, + "source": [ + "import os\n", + "\n", + "import torchvision.transforms as transforms\n", + "from PIL import Image\n", + "import numpy as np\n", + "\n", + "class Dataset(torch.utils.data.Dataset):\n", + "\n", + " def __init__(self, root, n_classes=10, resolution=256):\n", + " super().__init__()\n", + "\n", + " self.n_classes = n_classes\n", + "\n", + " # List of paths to training examples\n", + " self.examples = []\n", + " self.load_examples_from_dir(root)\n", + "\n", + " # Initialize transforms\n", + " self.transforms = transforms.Compose([\n", + " transforms.Resize((resolution, resolution), Image.LANCZOS),\n", + " transforms.RandomHorizontalFlip(),\n", + " transforms.Lambda(lambda x: np.array(x)),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", + " ])\n", + "\n", + " def load_examples_from_dir(self, abs_path):\n", + " '''\n", + " Given a folder of examples, this function returns a list of paired examples.\n", + " '''\n", + " assert os.path.isdir(abs_path)\n", + "\n", + " img_suffix = '.png'\n", + "\n", + " n_classes = 0\n", + " for root, _, files in os.walk(abs_path):\n", + " if n_classes == self.n_classes:\n", + " break\n", + " for f in files:\n", + " if f.endswith(img_suffix):\n", + " self.examples.append(root + '/' + f)\n", + " n_classes += 1\n", + "\n", + " def __getitem__(self, idx):\n", + " example = self.examples[idx]\n", + " img = Image.open(example).convert('RGB')\n", + " return self.transforms(img)\n", + "\n", + " def __len__(self):\n", + " return len(self.examples)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "wF5qWtsFzbqo" + }, + "source": [ + "# Download Anime Face dataset to `data` folder\n", + "if not os.path.isdir('data/animeface-character-dataset'):\n", + " !wget http://www.nurs.or.jp/~nagadomi/animeface-character-dataset/data/animeface-character-dataset.zip\n", + " !unzip animeface-character-dataset.zip -d data" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WcBmpGpmqtF-" + }, + "source": [ + "## StyleGAN\n", + "\n", + "You should already be familiar with StyleGAN and its implementation. As in the FreezeD paper, the source code below is taken from [this](https://github.com/rosinality/style-based-gan-pytorch) repository, which also provides a pretrained checkpoint." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "9ZnNqci5cH_2" + }, + "source": [ + "#@title Code (double click to view) { form-width: \"200px\", display-mode: \"code\" }\n", + "\n", + "''' Taken from https://github.com/rosinality/style-based-gan-pytorch/blob/master/model.py '''\n", + "# MIT License\n", + "#\n", + "# Copyright (c) 2019 Kim Seonghyeon\n", + "#\n", + "# Permission is hereby granted, free of charge, to any person obtaining a copy\n", + "# of this software and associated documentation files (the \"Software\"), to deal\n", + "# in the Software without restriction, including without limitation the rights\n", + "# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", + "# copies of the Software, and to permit persons to whom the Software is\n", + "# furnished to do so, subject to the following conditions:\n", + "#\n", + "# The above copyright notice and this permission notice shall be included in all\n", + "# copies or substantial portions of the Software.\n", + "#\n", + "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", + "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", + "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", + "# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", + "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", + "# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n", + "# SOFTWARE.\n", + "\n", + "from torch.nn import init\n", + "from torch.autograd import Function\n", + "\n", + "from math import sqrt\n", + "\n", + "import random\n", + "\n", + "\n", + "def init_linear(linear):\n", + " init.xavier_normal(linear.weight)\n", + " linear.bias.data.zero_()\n", + "\n", + "\n", + "def init_conv(conv, glu=True):\n", + " init.kaiming_normal(conv.weight)\n", + " if conv.bias is not None:\n", + " conv.bias.data.zero_()\n", + "\n", + "\n", + "class EqualLR:\n", + " def __init__(self, name):\n", + " self.name = name\n", + "\n", + " def compute_weight(self, module):\n", + " weight = getattr(module, self.name + '_orig')\n", + " fan_in = weight.data.size(1) * weight.data[0][0].numel()\n", + "\n", + " return weight * sqrt(2 / fan_in)\n", + "\n", + " @staticmethod\n", + " def apply(module, name):\n", + " fn = EqualLR(name)\n", + "\n", + " weight = getattr(module, name)\n", + " del module._parameters[name]\n", + " module.register_parameter(name + '_orig', nn.Parameter(weight.data))\n", + " module.register_forward_pre_hook(fn)\n", + "\n", + " return fn\n", + "\n", + " def __call__(self, module, input):\n", + " weight = self.compute_weight(module)\n", + " setattr(module, self.name, weight)\n", + "\n", + "\n", + "def equal_lr(module, name='weight'):\n", + " EqualLR.apply(module, name)\n", + "\n", + " return module\n", + "\n", + "\n", + "class FusedUpsample(nn.Module):\n", + " def __init__(self, in_channel, out_channel, kernel_size, padding=0):\n", + " super().__init__()\n", + "\n", + " weight = torch.randn(in_channel, out_channel, kernel_size, kernel_size)\n", + " bias = torch.zeros(out_channel)\n", + "\n", + " fan_in = in_channel * kernel_size * kernel_size\n", + " self.multiplier = sqrt(2 / fan_in)\n", + "\n", + " self.weight = nn.Parameter(weight)\n", + " self.bias = nn.Parameter(bias)\n", + "\n", + " self.pad = padding\n", + "\n", + " def forward(self, input):\n", + " weight = F.pad(self.weight * self.multiplier, [1, 1, 1, 1])\n", + " weight = (\n", + " weight[:, :, 1:, 1:]\n", + " + weight[:, :, :-1, 1:]\n", + " + weight[:, :, 1:, :-1]\n", + " + weight[:, :, :-1, :-1]\n", + " ) / 4\n", + "\n", + " out = F.conv_transpose2d(input, weight, self.bias, stride=2, padding=self.pad)\n", + "\n", + " return out\n", + "\n", + "\n", + "class FusedDownsample(nn.Module):\n", + " def __init__(self, in_channel, out_channel, kernel_size, padding=0):\n", + " super().__init__()\n", + "\n", + " weight = torch.randn(out_channel, in_channel, kernel_size, kernel_size)\n", + " bias = torch.zeros(out_channel)\n", + "\n", + " fan_in = in_channel * kernel_size * kernel_size\n", + " self.multiplier = sqrt(2 / fan_in)\n", + "\n", + " self.weight = nn.Parameter(weight)\n", + " self.bias = nn.Parameter(bias)\n", + "\n", + " self.pad = padding\n", + "\n", + " def forward(self, input):\n", + " weight = F.pad(self.weight * self.multiplier, [1, 1, 1, 1])\n", + " weight = (\n", + " weight[:, :, 1:, 1:]\n", + " + weight[:, :, :-1, 1:]\n", + " + weight[:, :, 1:, :-1]\n", + " + weight[:, :, :-1, :-1]\n", + " ) / 4\n", + "\n", + " out = F.conv2d(input, weight, self.bias, stride=2, padding=self.pad)\n", + "\n", + " return out\n", + "\n", + "\n", + "class PixelNorm(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + " def forward(self, input):\n", + " return input / torch.sqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)\n", + "\n", + "\n", + "class BlurFunctionBackward(Function):\n", + " @staticmethod\n", + " def forward(ctx, grad_output, kernel, kernel_flip):\n", + " ctx.save_for_backward(kernel, kernel_flip)\n", + "\n", + " grad_input = F.conv2d(\n", + " grad_output, kernel_flip, padding=1, groups=grad_output.shape[1]\n", + " )\n", + "\n", + " return grad_input\n", + "\n", + " @staticmethod\n", + " def backward(ctx, gradgrad_output):\n", + " kernel, kernel_flip = ctx.saved_tensors\n", + "\n", + " grad_input = F.conv2d(\n", + " gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1]\n", + " )\n", + "\n", + " return grad_input, None, None\n", + "\n", + "\n", + "class BlurFunction(Function):\n", + " @staticmethod\n", + " def forward(ctx, input, kernel, kernel_flip):\n", + " ctx.save_for_backward(kernel, kernel_flip)\n", + "\n", + " output = F.conv2d(input, kernel, padding=1, groups=input.shape[1])\n", + "\n", + " return output\n", + "\n", + " @staticmethod\n", + " def backward(ctx, grad_output):\n", + " kernel, kernel_flip = ctx.saved_tensors\n", + "\n", + " grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)\n", + "\n", + " return grad_input, None, None\n", + "\n", + "\n", + "blur = BlurFunction.apply\n", + "\n", + "\n", + "class Blur(nn.Module):\n", + " def __init__(self, channel):\n", + " super().__init__()\n", + "\n", + " weight = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)\n", + " weight = weight.view(1, 1, 3, 3)\n", + " weight = weight / weight.sum()\n", + " weight_flip = torch.flip(weight, [2, 3])\n", + "\n", + " self.register_buffer('weight', weight.repeat(channel, 1, 1, 1))\n", + " self.register_buffer('weight_flip', weight_flip.repeat(channel, 1, 1, 1))\n", + "\n", + " def forward(self, input):\n", + " return blur(input, self.weight, self.weight_flip)\n", + " # return F.conv2d(input, self.weight, padding=1, groups=input.shape[1])\n", + "\n", + "\n", + "class EqualConv2d(nn.Module):\n", + " def __init__(self, *args, **kwargs):\n", + " super().__init__()\n", + "\n", + " conv = nn.Conv2d(*args, **kwargs)\n", + " conv.weight.data.normal_()\n", + " conv.bias.data.zero_()\n", + " self.conv = equal_lr(conv)\n", + "\n", + " def forward(self, input):\n", + " return self.conv(input)\n", + "\n", + "\n", + "class EqualLinear(nn.Module):\n", + " def __init__(self, in_dim, out_dim):\n", + " super().__init__()\n", + "\n", + " linear = nn.Linear(in_dim, out_dim)\n", + " linear.weight.data.normal_()\n", + " linear.bias.data.zero_()\n", + "\n", + " self.linear = equal_lr(linear)\n", + "\n", + " def forward(self, input):\n", + " return self.linear(input)\n", + "\n", + "\n", + "class ConvBlock(nn.Module):\n", + " def __init__(\n", + " self,\n", + " in_channel,\n", + " out_channel,\n", + " kernel_size,\n", + " padding,\n", + " kernel_size2=None,\n", + " padding2=None,\n", + " downsample=False,\n", + " fused=False,\n", + " ):\n", + " super().__init__()\n", + "\n", + " pad1 = padding\n", + " pad2 = padding\n", + " if padding2 is not None:\n", + " pad2 = padding2\n", + "\n", + " kernel1 = kernel_size\n", + " kernel2 = kernel_size\n", + " if kernel_size2 is not None:\n", + " kernel2 = kernel_size2\n", + "\n", + " self.conv1 = nn.Sequential(\n", + " EqualConv2d(in_channel, out_channel, kernel1, padding=pad1),\n", + " nn.LeakyReLU(0.2),\n", + " )\n", + "\n", + " if downsample:\n", + " if fused:\n", + " self.conv2 = nn.Sequential(\n", + " Blur(out_channel),\n", + " FusedDownsample(out_channel, out_channel, kernel2, padding=pad2),\n", + " nn.LeakyReLU(0.2),\n", + " )\n", + "\n", + " else:\n", + " self.conv2 = nn.Sequential(\n", + " Blur(out_channel),\n", + " EqualConv2d(out_channel, out_channel, kernel2, padding=pad2),\n", + " nn.AvgPool2d(2),\n", + " nn.LeakyReLU(0.2),\n", + " )\n", + "\n", + " else:\n", + " self.conv2 = nn.Sequential(\n", + " EqualConv2d(out_channel, out_channel, kernel2, padding=pad2),\n", + " nn.LeakyReLU(0.2),\n", + " )\n", + "\n", + " def forward(self, input):\n", + " out = self.conv1(input)\n", + " out = self.conv2(out)\n", + "\n", + " return out\n", + "\n", + "\n", + "class AdaptiveInstanceNorm(nn.Module):\n", + " def __init__(self, in_channel, style_dim):\n", + " super().__init__()\n", + "\n", + " self.norm = nn.InstanceNorm2d(in_channel)\n", + " self.style = EqualLinear(style_dim, in_channel * 2)\n", + "\n", + " self.style.linear.bias.data[:in_channel] = 1\n", + " self.style.linear.bias.data[in_channel:] = 0\n", + "\n", + " def forward(self, input, style):\n", + " style = self.style(style).unsqueeze(2).unsqueeze(3)\n", + " gamma, beta = style.chunk(2, 1)\n", + "\n", + " out = self.norm(input)\n", + " out = gamma * out + beta\n", + "\n", + " return out\n", + "\n", + "\n", + "class NoiseInjection(nn.Module):\n", + " def __init__(self, channel):\n", + " super().__init__()\n", + "\n", + " self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))\n", + "\n", + " def forward(self, image, noise):\n", + " return image + self.weight * noise\n", + "\n", + "\n", + "class ConstantInput(nn.Module):\n", + " def __init__(self, channel, size=4):\n", + " super().__init__()\n", + "\n", + " self.input = nn.Parameter(torch.randn(1, channel, size, size))\n", + "\n", + " def forward(self, input):\n", + " batch = input.shape[0]\n", + " out = self.input.repeat(batch, 1, 1, 1)\n", + "\n", + " return out\n", + "\n", + "\n", + "class StyledConvBlock(nn.Module):\n", + " def __init__(\n", + " self,\n", + " in_channel,\n", + " out_channel,\n", + " kernel_size=3,\n", + " padding=1,\n", + " style_dim=512,\n", + " initial=False,\n", + " upsample=False,\n", + " fused=False,\n", + " ):\n", + " super().__init__()\n", + "\n", + " if initial:\n", + " self.conv1 = ConstantInput(in_channel)\n", + "\n", + " else:\n", + " if upsample:\n", + " if fused:\n", + " self.conv1 = nn.Sequential(\n", + " FusedUpsample(\n", + " in_channel, out_channel, kernel_size, padding=padding\n", + " ),\n", + " Blur(out_channel),\n", + " )\n", + "\n", + " else:\n", + " self.conv1 = nn.Sequential(\n", + " nn.Upsample(scale_factor=2, mode='nearest'),\n", + " EqualConv2d(\n", + " in_channel, out_channel, kernel_size, padding=padding\n", + " ),\n", + " Blur(out_channel),\n", + " )\n", + "\n", + " else:\n", + " self.conv1 = EqualConv2d(\n", + " in_channel, out_channel, kernel_size, padding=padding\n", + " )\n", + "\n", + " self.noise1 = equal_lr(NoiseInjection(out_channel))\n", + " self.adain1 = AdaptiveInstanceNorm(out_channel, style_dim)\n", + " self.lrelu1 = nn.LeakyReLU(0.2)\n", + "\n", + " self.conv2 = EqualConv2d(out_channel, out_channel, kernel_size, padding=padding)\n", + " self.noise2 = equal_lr(NoiseInjection(out_channel))\n", + " self.adain2 = AdaptiveInstanceNorm(out_channel, style_dim)\n", + " self.lrelu2 = nn.LeakyReLU(0.2)\n", + "\n", + " def forward(self, input, style, noise):\n", + " out = self.conv1(input)\n", + " out = self.noise1(out, noise)\n", + " out = self.lrelu1(out)\n", + " out = self.adain1(out, style)\n", + "\n", + " out = self.conv2(out)\n", + " out = self.noise2(out, noise)\n", + " out = self.lrelu2(out)\n", + " out = self.adain2(out, style)\n", + "\n", + " return out\n", + "\n", + "\n", + "class Generator(nn.Module):\n", + " def __init__(self, code_dim, fused=True):\n", + " super().__init__()\n", + "\n", + " self.progression = nn.ModuleList(\n", + " [\n", + " StyledConvBlock(512, 512, 3, 1, initial=True), # 4\n", + " StyledConvBlock(512, 512, 3, 1, upsample=True), # 8\n", + " StyledConvBlock(512, 512, 3, 1, upsample=True), # 16\n", + " StyledConvBlock(512, 512, 3, 1, upsample=True), # 32\n", + " StyledConvBlock(512, 256, 3, 1, upsample=True), # 64\n", + " StyledConvBlock(256, 128, 3, 1, upsample=True, fused=fused), # 128\n", + " StyledConvBlock(128, 64, 3, 1, upsample=True, fused=fused), # 256\n", + " StyledConvBlock(64, 32, 3, 1, upsample=True, fused=fused), # 512\n", + " StyledConvBlock(32, 16, 3, 1, upsample=True, fused=fused), # 1024\n", + " ]\n", + " )\n", + "\n", + " self.to_rgb = nn.ModuleList(\n", + " [\n", + " EqualConv2d(512, 3, 1),\n", + " EqualConv2d(512, 3, 1),\n", + " EqualConv2d(512, 3, 1),\n", + " EqualConv2d(512, 3, 1),\n", + " EqualConv2d(256, 3, 1),\n", + " EqualConv2d(128, 3, 1),\n", + " EqualConv2d(64, 3, 1),\n", + " EqualConv2d(32, 3, 1),\n", + " EqualConv2d(16, 3, 1),\n", + " ]\n", + " )\n", + "\n", + " # self.blur = Blur()\n", + "\n", + " def forward(self, style, noise, step=0, alpha=-1, mixing_range=(-1, -1)):\n", + " out = noise[0]\n", + "\n", + " if len(style) < 2:\n", + " inject_index = [len(self.progression) + 1]\n", + "\n", + " else:\n", + " inject_index = sorted(random.sample(list(range(step)), len(style) - 1))\n", + "\n", + " crossover = 0\n", + "\n", + " for i, (conv, to_rgb) in enumerate(zip(self.progression, self.to_rgb)):\n", + " if mixing_range == (-1, -1):\n", + " if crossover < len(inject_index) and i > inject_index[crossover]:\n", + " crossover = min(crossover + 1, len(style))\n", + "\n", + " style_step = style[crossover]\n", + "\n", + " else:\n", + " if mixing_range[0] <= i <= mixing_range[1]:\n", + " style_step = style[1]\n", + "\n", + " else:\n", + " style_step = style[0]\n", + "\n", + " if i > 0 and step > 0:\n", + " out_prev = out\n", + " \n", + " out = conv(out, style_step, noise[i])\n", + "\n", + " if i == step:\n", + " out = to_rgb(out)\n", + "\n", + " if i > 0 and 0 <= alpha < 1:\n", + " skip_rgb = self.to_rgb[i - 1](out_prev)\n", + " skip_rgb = F.interpolate(skip_rgb, scale_factor=2, mode='nearest')\n", + " out = (1 - alpha) * skip_rgb + alpha * out\n", + "\n", + " break\n", + "\n", + " return out\n", + "\n", + "\n", + "class StyledGenerator(nn.Module):\n", + " def __init__(self, code_dim=512, n_mlp=8):\n", + " super().__init__()\n", + "\n", + " self.generator = Generator(code_dim)\n", + "\n", + " layers = [PixelNorm()]\n", + " for i in range(n_mlp):\n", + " layers.append(EqualLinear(code_dim, code_dim))\n", + " layers.append(nn.LeakyReLU(0.2))\n", + "\n", + " self.style = nn.Sequential(*layers)\n", + "\n", + " def forward(\n", + " self,\n", + " input,\n", + " noise=None,\n", + " step=0,\n", + " alpha=-1,\n", + " mean_style=None,\n", + " style_weight=0,\n", + " mixing_range=(-1, -1),\n", + " ):\n", + " styles = []\n", + " if type(input) not in (list, tuple):\n", + " input = [input]\n", + "\n", + " for i in input:\n", + " styles.append(self.style(i))\n", + "\n", + " batch = input[0].shape[0]\n", + "\n", + " if noise is None:\n", + " noise = []\n", + "\n", + " for i in range(step + 1):\n", + " size = 4 * 2 ** i\n", + " noise.append(torch.randn(batch, 1, size, size, device=input[0].device))\n", + "\n", + " if mean_style is not None:\n", + " styles_norm = []\n", + "\n", + " for style in styles:\n", + " styles_norm.append(mean_style + style_weight * (style - mean_style))\n", + "\n", + " styles = styles_norm\n", + "\n", + " return self.generator(styles, noise, step, alpha, mixing_range=mixing_range)\n", + "\n", + " def mean_style(self, input):\n", + " style = self.style(input).mean(0, keepdim=True)\n", + "\n", + " return style\n", + "\n", + "\n", + "class Discriminator(nn.Module):\n", + " def __init__(self, fused=True, from_rgb_activate=False):\n", + " super().__init__()\n", + "\n", + " self.progression = nn.ModuleList(\n", + " [\n", + " ConvBlock(16, 32, 3, 1, downsample=True, fused=fused), # 512\n", + " ConvBlock(32, 64, 3, 1, downsample=True, fused=fused), # 256\n", + " ConvBlock(64, 128, 3, 1, downsample=True, fused=fused), # 128\n", + " ConvBlock(128, 256, 3, 1, downsample=True, fused=fused), # 64\n", + " ConvBlock(256, 512, 3, 1, downsample=True), # 32\n", + " ConvBlock(512, 512, 3, 1, downsample=True), # 16\n", + " ConvBlock(512, 512, 3, 1, downsample=True), # 8\n", + " ConvBlock(512, 512, 3, 1, downsample=True), # 4\n", + " ConvBlock(513, 512, 3, 1, 4, 0),\n", + " ]\n", + " )\n", + "\n", + " def make_from_rgb(out_channel):\n", + " if from_rgb_activate:\n", + " return nn.Sequential(EqualConv2d(3, out_channel, 1), nn.LeakyReLU(0.2))\n", + "\n", + " else:\n", + " return EqualConv2d(3, out_channel, 1)\n", + "\n", + " self.from_rgb = nn.ModuleList(\n", + " [\n", + " make_from_rgb(16),\n", + " make_from_rgb(32),\n", + " make_from_rgb(64),\n", + " make_from_rgb(128),\n", + " make_from_rgb(256),\n", + " make_from_rgb(512),\n", + " make_from_rgb(512),\n", + " make_from_rgb(512),\n", + " make_from_rgb(512),\n", + " ]\n", + " )\n", + "\n", + " # self.blur = Blur()\n", + "\n", + " self.n_layer = len(self.progression)\n", + "\n", + " self.linear = EqualLinear(512, 1)\n", + "\n", + " def forward(self, input, step=0, alpha=-1):\n", + " for i in range(step, -1, -1):\n", + " index = self.n_layer - i - 1\n", + "\n", + " if i == step:\n", + " out = self.from_rgb[index](input)\n", + "\n", + " if i == 0:\n", + " out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-8)\n", + " mean_std = out_std.mean()\n", + " mean_std = mean_std.expand(out.size(0), 1, 4, 4)\n", + " out = torch.cat([out, mean_std], 1)\n", + "\n", + " out = self.progression[index](out)\n", + "\n", + " if i > 0:\n", + " if i == step and 0 <= alpha < 1:\n", + " skip_rgb = F.avg_pool2d(input, 2)\n", + " skip_rgb = self.from_rgb[index + 1](skip_rgb)\n", + "\n", + " out = (1 - alpha) * skip_rgb + alpha * out\n", + "\n", + " out = out.squeeze(2).squeeze(2)\n", + " # print(input.size(), out.size(), step)\n", + " out = self.linear(out)\n", + "\n", + " return out" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mfTgDaB-rhW5" + }, + "source": [ + "### Checkpoint\n", + "\n", + "Run the cell below to download the StyleGAN checkpoint." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "4qLS71xrbyQm", + "outputId": "e606cbac-a9d0-4cda-8c20-75c320027ade", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 481 + } + }, + "source": [ + "# Download public StyleGAN 256x256 resolution model to `stylegan-256px.pt`\n", + "import os\n", + "if 'stylegan-256px.pt' not in os.listdir(os.getcwd()):\n", + " !wget --load-cookies /tmp/cookies.txt \"https://drive.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://drive.google.com/uc?export=download&id=1QlXFPIOFzsJyjZ1AtfpnVhqW4Z0r8GLZ' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id=1QlXFPIOFzsJyjZ1AtfpnVhqW4Z0r8GLZ\" -O stylegan-256px.pt && rm -rf /tmp/cookies.txt" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "--2020-10-18 19:57:19-- https://drive.google.com/uc?export=download&confirm=rrR8&id=1QlXFPIOFzsJyjZ1AtfpnVhqW4Z0r8GLZ\n", + "Resolving drive.google.com (drive.google.com)... 108.177.125.101, 108.177.125.138, 108.177.125.113, ...\n", + "Connecting to drive.google.com (drive.google.com)|108.177.125.101|:443... connected.\n", + "HTTP request sent, awaiting response... 302 Moved Temporarily\n", + "Location: https://doc-04-70-docs.googleusercontent.com/docs/securesc/v6t3ak7laoki5c26pj521oo38p229egn/3d2c3glguhtc1hc3tkplkkh7kkh41dt5/1603050975000/04891893430065212435/16332041203202668392Z/1QlXFPIOFzsJyjZ1AtfpnVhqW4Z0r8GLZ?e=download [following]\n", + "--2020-10-18 19:57:19-- https://doc-04-70-docs.googleusercontent.com/docs/securesc/v6t3ak7laoki5c26pj521oo38p229egn/3d2c3glguhtc1hc3tkplkkh7kkh41dt5/1603050975000/04891893430065212435/16332041203202668392Z/1QlXFPIOFzsJyjZ1AtfpnVhqW4Z0r8GLZ?e=download\n", + "Resolving doc-04-70-docs.googleusercontent.com (doc-04-70-docs.googleusercontent.com)... 108.177.125.132, 2404:6800:4008:c01::84\n", + "Connecting to doc-04-70-docs.googleusercontent.com (doc-04-70-docs.googleusercontent.com)|108.177.125.132|:443... connected.\n", + "HTTP request sent, awaiting response... 302 Found\n", + "Location: https://docs.google.com/nonceSigner?nonce=6a4gjtfm9t1l6&continue=https://doc-04-70-docs.googleusercontent.com/docs/securesc/v6t3ak7laoki5c26pj521oo38p229egn/3d2c3glguhtc1hc3tkplkkh7kkh41dt5/1603050975000/04891893430065212435/16332041203202668392Z/1QlXFPIOFzsJyjZ1AtfpnVhqW4Z0r8GLZ?e%3Ddownload&hash=31aca3m4gqe3vsv1naoafgn1k314u6hu [following]\n", + "--2020-10-18 19:57:19-- https://docs.google.com/nonceSigner?nonce=6a4gjtfm9t1l6&continue=https://doc-04-70-docs.googleusercontent.com/docs/securesc/v6t3ak7laoki5c26pj521oo38p229egn/3d2c3glguhtc1hc3tkplkkh7kkh41dt5/1603050975000/04891893430065212435/16332041203202668392Z/1QlXFPIOFzsJyjZ1AtfpnVhqW4Z0r8GLZ?e%3Ddownload&hash=31aca3m4gqe3vsv1naoafgn1k314u6hu\n", + "Resolving docs.google.com (docs.google.com)... 74.125.203.139, 74.125.203.113, 74.125.203.102, ...\n", + "Connecting to docs.google.com (docs.google.com)|74.125.203.139|:443... connected.\n", + "HTTP request sent, awaiting response... 302 Found\n", + "Location: https://doc-04-70-docs.googleusercontent.com/docs/securesc/v6t3ak7laoki5c26pj521oo38p229egn/3d2c3glguhtc1hc3tkplkkh7kkh41dt5/1603050975000/04891893430065212435/16332041203202668392Z/1QlXFPIOFzsJyjZ1AtfpnVhqW4Z0r8GLZ?e=download&nonce=6a4gjtfm9t1l6&user=16332041203202668392Z&hash=l11l6ep34pm046032mqa11e850igtcj2 [following]\n", + "--2020-10-18 19:57:20-- https://doc-04-70-docs.googleusercontent.com/docs/securesc/v6t3ak7laoki5c26pj521oo38p229egn/3d2c3glguhtc1hc3tkplkkh7kkh41dt5/1603050975000/04891893430065212435/16332041203202668392Z/1QlXFPIOFzsJyjZ1AtfpnVhqW4Z0r8GLZ?e=download&nonce=6a4gjtfm9t1l6&user=16332041203202668392Z&hash=l11l6ep34pm046032mqa11e850igtcj2\n", + "Connecting to doc-04-70-docs.googleusercontent.com (doc-04-70-docs.googleusercontent.com)|108.177.125.132|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: unspecified [application/octet-stream]\n", + "Saving to: ‘stylegan-256px.pt’\n", + "\n", + "stylegan-256px.pt [<=> ] 691.85M 40.7MB/s in 17s \n", + "\n", + "2020-10-18 19:57:37 (40.2 MB/s) - ‘stylegan-256px.pt’ saved [725461173]\n", + "\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "t5uJfA2eQQn2" + }, + "source": [ + "## Some Useful Functions\n", + "\n", + "Before jumping into the fine-tuning code, let's first implement some useful functions" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZyK7KaxR61ZX" + }, + "source": [ + "### Exponential Moving Average (EMA) of Model Weights\n", + "\n", + "In fine-tuning, the authors keep an exponential moving average of the weights for the generator for inference.\n", + "\n", + "This is relatively common in deep learning since it helps stabilize the converged model (similar to how optimizers like Adam and layers like batch normalization keep exponential moving averages to stabilize their statistics across multiple updates)." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "IxlE5CEcNRv5" + }, + "source": [ + "def init_ema(model):\n", + " '''\n", + " Function that initializes a static model to store exponential moving average.\n", + " '''\n", + " model.eval()\n", + " for p in model.parameters():\n", + " p.requires_grad = False\n", + "\n", + "def accumulate_ema_weights(model_ema, model_tgt, decay=0.999):\n", + " '''\n", + " Function for updating exponential moving average of weights.\n", + " '''\n", + " ema = dict(model_ema.named_parameters())\n", + " tgt = dict(model_tgt.named_parameters())\n", + "\n", + " for p_ema, p_tgt in zip(ema.values(), tgt.values()):\n", + " p_ema.data.mul_(decay).add_(p_tgt.data, alpha=1-decay)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vUh-9tNEQmiG" + }, + "source": [ + "### Mixing Regularization\n", + "\n", + "You've already learned about mixing regularization in the StyleGAN notebook, so this section will just implement the function to be used in fine-tuning." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "CeFbJ60URd_5" + }, + "source": [ + "def sample_noise(batch_size, code_dim, device, p=0.9):\n", + " '''\n", + " Function that samples noise with mixing regularization with probability p.\n", + " '''\n", + " if random.random() < p:\n", + " z11, z12, z21, z22 = torch.randn(4, batch_size, code_dim, device=device).unbind(0)\n", + " z1 = [z11, z12]\n", + " z2 = [z21, z22]\n", + "\n", + " else:\n", + " z1, z2 = torch.randn(2, batch_size, code_dim, device=device).unbind(0)\n", + "\n", + " return z1, z2" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RoMS8idUUNWR" + }, + "source": [ + "### Gradient Penalty\n", + "\n", + "This should also be something that you're already familiar with. The original code repository only applies gradient penalty to the real image that's passed through the discriminator. The implementation is below." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "9c_Wj0uPUThK" + }, + "source": [ + "from torch.autograd import grad\n", + "\n", + "def gradient_penalty(inputs, outputs):\n", + " '''\n", + " Function that computes gradient penalty given inputs and outputs.\n", + " '''\n", + " g = grad(outputs=outputs.sum(), inputs=inputs, create_graph=True)[0]\n", + " gp = (g.flatten(1).norm(2, dim=1) ** 2).mean()\n", + " gp = 10 / 2 * gp\n", + " return gp" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "y5HEmZFeW5lz" + }, + "source": [ + "### Freezing Discriminator Layers\n", + "\n", + "The final helper function you'll need is one that'll freeze the first four layers in the discriminator (and one that unfreezes them). Check it out below." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "sQcIJhPVXFsK" + }, + "source": [ + "def freeze_discriminator_layers(d):\n", + " '''\n", + " Function that freezes the first four discriminator layers.\n", + " '''\n", + " # Naming patterns taken from official code repo\n", + " ls = ['progression.{}'.format(8 - i) for i in range(3)] + ['linear']\n", + "\n", + " for name, p in d.named_parameters():\n", + " if any(l in name for l in ls):\n", + " p.requires_grad = False\n", + "\n", + "def unfreeze_discriminator_layers(d):\n", + " '''\n", + " Function that unfreezes the discriminator layers.\n", + " '''\n", + " for p in d.parameters():\n", + " p.requires_grad = True" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3h5fAoro5VqD" + }, + "source": [ + "## Fine-tuning StyleGAN on Anime Faces\n", + "\n", + "You're now ready to fine-tune StyleGAN on the Anime Faces dataset! The authors fine-tune on 10 classes, but feel free to adjust this number.\n", + "In fine-tuning, the first 4 layers of the discriminator are frozen." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "8kWZdTfq5-2m" + }, + "source": [ + "from tqdm import tqdm\n", + "import math\n", + "import re\n", + "\n", + "# Some training parameters\n", + "finetune_steps = 50000\n", + "display_step = 50\n", + "\n", + "resolution = 256\n", + "step = int(math.log2(resolution / 4))\n", + "lr = 0.002\n", + "betas = (0.0, 0.99)\n", + "\n", + "def finetune(generators, dis, dataloader, code_dim, device):\n", + " gen_ema, gen = generators\n", + " \n", + " gen_optim = torch.optim.Adam([\n", + " {\n", + " 'params': gen.generator.parameters(),\n", + " 'sched': 1.0,\n", + " },\n", + " {\n", + " 'params': gen.style.parameters(),\n", + " 'sched': 0.01,\n", + " },\n", + " ], betas=betas)\n", + " dis_optim = torch.optim.Adam(\n", + " [p for p in dis.parameters() if p.requires_grad],\n", + " lr=lr, betas=betas,\n", + " )\n", + "\n", + " cur_step = 0\n", + " mean_gen_loss, mean_dis_loss = 0.0, 0.0\n", + "\n", + " while cur_step < finetune_steps:\n", + " for x in tqdm(dataloader):\n", + " with torch.cuda.amp.autocast((device=='cuda')):\n", + " # Prep inputs\n", + " x = x.to(device)\n", + " z1, z2 = sample_noise(x.size(0), code_dim, device)\n", + "\n", + " dis.zero_grad()\n", + " gen.zero_grad()\n", + "\n", + " # Forward pass through generator\n", + " x_fake1 = gen(z1, step=step, alpha=1)\n", + "\n", + " # Unfreeze discriminator\n", + " unfreeze_discriminator_layers(dis)\n", + "\n", + " # Update discriminator\n", + " x.requires_grad = True\n", + " fake_pred = dis(x_fake1.detach(), step=step, alpha=1)\n", + " real_pred = dis(x, step=step, alpha=1)\n", + " real_gp = gradient_penalty(x, real_pred)\n", + "\n", + " dis_loss = real_gp + F.softplus(-real_pred).mean() + F.softplus(fake_pred).mean()\n", + " mean_dis_loss += dis_loss.item() / display_step\n", + " dis_optim.zero_grad()\n", + " dis_loss.backward()\n", + " dis_optim.step()\n", + "\n", + " # Freeze discriminator\n", + " freeze_discriminator_layers(dis)\n", + "\n", + " # Update generator\n", + " x_fake2 = gen(z2, step=step, alpha=1)\n", + " fake_pred = dis(x_fake2, step=step, alpha=1)\n", + "\n", + " gen_loss = F.softplus(-fake_pred).mean()\n", + " mean_gen_loss += gen_loss.item() / display_step\n", + " gen_optim.zero_grad()\n", + " gen_loss.backward()\n", + " gen_optim.step()\n", + "\n", + " # Update EMA\n", + " accumulate_ema_weights(gen_ema, gen, decay=0.999)\n", + "\n", + " # Schedule learning rate\n", + " for param_group in gen_optim.param_groups:\n", + " param_group['lr'] *= param_group['sched']\n", + "\n", + " cur_step += 1\n", + " if cur_step % display_step == 0:\n", + " show_tensor_images(x_fake1.to(x.dtype))\n", + " show_tensor_images(x_fake2.to(x.dtype))\n", + " show_tensor_images(x)\n", + "\n", + " print('Step {}. G loss: {:.5f}. \\t D loss: {:.5f}.'.format(cur_step, mean_gen_loss, mean_dis_loss))\n", + " mean_gen_loss = 0.0\n", + " mean_dis_loss = 0.0\n", + "\n", + " # Delete previous checkpoint to reduce disk memory\n", + " if cur_step - display_step > 0:\n", + " os.remove('stylegan-step={}.pt'.format(cur_step - display_step))\n", + " torch.save({\n", + " 'generator': gen.state_dict(),\n", + " 'g_running': gen_ema.state_dict(),\n", + " 'discriminator': dis.state_dict(),\n", + " 'g_optim': gen_optim.state_dict(),\n", + " 'd_optim': dis_optim.state_dict(),\n", + " 'step': cur_step,\n", + " }, 'stylegan-step={}.pt'.format(cur_step))\n", + "\n", + " # End training if reached enough steps\n", + " if cur_step == finetune_steps:\n", + " break" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "fZOvJcTSgzzz" + }, + "source": [ + "code_dim = 512\n", + "n_classes = 10\n", + "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "\n", + "checkpoint = torch.load('stylegan-256px.pt')\n", + "\n", + "gen = StyledGenerator(code_dim=code_dim).to(device)\n", + "gen.load_state_dict(checkpoint['generator'])\n", + "gen_ema = StyledGenerator(code_dim=code_dim).to(device)\n", + "init_ema(gen_ema)\n", + "gen_ema.load_state_dict(checkpoint['g_running'])\n", + "\n", + "dis = Discriminator(from_rgb_activate=True).to(device)\n", + "dis.load_state_dict(checkpoint['discriminator'])\n", + "\n", + "dataloader = torch.utils.data.DataLoader(\n", + " Dataset('data/animeface-character-dataset/thumb', n_classes=n_classes),\n", + " batch_size=16, pin_memory=True, shuffle=True, drop_last=True,\n", + ")\n", + "\n", + "finetune(\n", + " [gen_ema, gen],\n", + " dis,\n", + " dataloader,\n", + " code_dim,\n", + " device,\n", + ")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "l6pbmacaGhq9" + }, + "source": [ + "### Inference\n", + "\n", + "After you're happy with the quality of your FreezeD'd StyleGAN, you can run inference. Simply load the checkpoint from training and feed it some random noise to generate some cool anime faces!" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "fflAgXjAG0zU" + }, + "source": [ + "# Define some helper functions for inference\n", + "@torch.no_grad\n", + "def mean_style(generator, code_dim, device, batch=1024, batches=10):\n", + " '''\n", + " Function that computes the mean style vector.\n", + " '''\n", + " mean_style = []\n", + " for i in range(batches):\n", + " z = torch.randn(batch, code_dim).to(device)\n", + " style = generator.mean_style(z)\n", + " mean_style += [style]\n", + "\n", + " mean_style = torch.stack(mean_style, dim=0).mean(0)\n", + " return mean_style\n", + "\n", + "@torch.no_grad\n", + "def synthesize(generator, n_samples, code_dim, device):\n", + " '''\n", + " Function that samples random noise and generates a fake image.\n", + " '''\n", + " generator.eval()\n", + " z = torch.randn(n_samples, code_dim, device=device)\n", + " mean_style = mean_style(generator, code_dim, device)\n", + " x = generator(z, step=step, alpha=1, mean_style=mean_style, style_weight=0.7)\n", + " return x\n", + "\n", + "n_samples = 16\n", + "\n", + "# Load checkpoint\n", + "checkpoint = torch.load('stylegan-step=50000.pt')\n", + "generator = StyledGenerator(code_dim=code_dim).to(device)\n", + "generator.load_state_dict(checkpoint['generator'])\n", + "\n", + "# Run inference\n", + "x = synthesize(generator, n_samples, code_dim, device)\n", + "show_tensor_images(x)" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file