diff --git a/C3W2_Pix2PixHD_(Optional).ipynb b/C3W2_Pix2PixHD_(Optional).ipynb new file mode 100644 index 0000000..ab56994 --- /dev/null +++ b/C3W2_Pix2PixHD_(Optional).ipynb @@ -0,0 +1,1273 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "C3W2: Pix2PixHD (Optional).ipynb", + "provenance": [], + "collapsed_sections": [], + "toc_visible": true, + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1czVdIlqnImH" + }, + "source": [ + "# Pix2PixHD\n", + "\n", + "*Please note that this is an optional notebook, meant to introduce more advanced concepts if you're up for a challenge, so don't worry if you don't completely follow!*\n", + "\n", + "It is recommended that you should already be familiar with:\n", + " - Residual blocks, from [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) (He et al. 2015)\n", + " - Perceptual loss, from [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155) (Johnson et al. 2016)\n", + " - VGG architecture, from [Very Deep Convolutional Networks for Large-Scale Image Recognition](https://arxiv.org/abs/1409.1556) (Simonyan et al. 2015)\n", + " - Instance normalization (which you should know from StyleGAN), from [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022) (Ulyanov et al. 2017)\n", + " - Reflection padding, which Pytorch has implemented in [torch.nn.ReflectionPad2d](https://pytorch.org/docs/stable/generated/torch.nn.ReflectionPad2d.html)\n", + "\n", + "**Goals**\n", + "\n", + "In this notebook, you will learn about Pix2PixHD, which synthesizes high-resolution images from semantic label maps. Proposed in [High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs](https://arxiv.org/abs/1711.11585) (Wang et al. 2018), Pix2PixHD improves upon Pix2Pix via multiscale architecture, improved adversarial loss, and instance maps." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fB1Vq8ps7Bfd" + }, + "source": [ + "## Residual Blocks\n", + "\n", + "The residual block, which is relevant in many state-of-the-art computer vision models, is used in all parts of Pix2PixHD. If you're not familiar with residual blocks, please take a look [here](https://paperswithcode.com/method/residual-block). Now, you'll start by first implementing a basic residual block." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "GHD_wif07f4b" + }, + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "class ResidualBlock(nn.Module):\n", + " '''\n", + " ResidualBlock Class\n", + " Values\n", + " channels: the number of channels throughout the residual block, a scalar\n", + " '''\n", + "\n", + " def __init__(self, channels):\n", + " super().__init__()\n", + "\n", + " self.layers = nn.Sequential(\n", + " nn.ReflectionPad2d(1),\n", + " nn.Conv2d(channels, channels, kernel_size=3, padding=0),\n", + " nn.InstanceNorm2d(channels, affine=False),\n", + "\n", + " nn.ReLU(inplace=True),\n", + "\n", + " nn.ReflectionPad2d(1),\n", + " nn.Conv2d(channels, channels, kernel_size=3, padding=0),\n", + " nn.InstanceNorm2d(channels, affine=False),\n", + " )\n", + "\n", + " def forward(self, x):\n", + " return x + self.layers(x)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Sqet5X_Sf-KZ" + }, + "source": [ + "## Multiscale Generator: Generating at multiple scales (resolutions)\n", + "\n", + "The Pix2PixHD generator is comprised of two separate subcomponent generators: $G_1$ is called the global generator and operates at low resolution (1024 x 512) to transfer styles. $G_2$ is the local enhancer and operates at high resolution (2048 x 1024) to deal with higher resolution.\n", + "\n", + "The architecture for each network is adapted from [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155) (Johnson et al. 2016) and is comprised of\n", + "\n", + "\\begin{align*}\n", + " G = \\left[G^{(F)}, G^{(R)}, G^{(B)}\\right],\n", + "\\end{align*}\n", + "\n", + "where $G^{(F)}$ is a frontend of convolutional blocks (downsampling), $G^{(R)}$ is a set of residual blocks, and $G^{(B)}$ is a backend of transposed convolutional blocks (upsampling). This is just a type of encoder-decoder generator that you learned about with Pix2Pix!\n", + "\n", + "$G_1$ is trained first on low-resolution images. Then, $G_2$ is added to the pre-trained $G_1$ and both are trained jointly on high-resolution images. Specifically, $G_2^{(F)}$ encodes a high-resolution image, $G_1$ encodes a downsampled, low-resolution image, and the outputs from both are summed and passed sequentially to $G_2^{(R)}$ and $G_2^{(B)}$. This pre-training and fine-tuning scheme works well because the model is able to learn accurate coarser representations before using them to touch up its refined representations, since learning high-fidelity representations is generally a pretty hard task.\n", + "\n", + "> ![Pix2PixHD Generator](https://drive.google.com/uc?export=view&id=1HDGPKupDxD52JSgnuH9pANV7hhHF3BDm)\n", + "*Pix2PixHD Generator, taken from Figure 3 of [High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs](https://arxiv.org/abs/1711.11585) (Wang et al. 2018). Following our notation, $G = \\left[G_2^{(F)}, G_1^{(F)}, G_1^{(R)}, G_1^{(B)}, G_2^{(R)}, G_2^{(B)}\\right]$ from left to right.*" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9QzAlLDpuhZ5" + }, + "source": [ + "### Global Subgenerator ($G_1$)\n", + "\n", + "Let's first start by building the global generator ($G_1$). Even though the global generator is nested inside the local enhancer, you'll still need a separate module for training $G_1$ on its own first." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "rarPcc8xutN5" + }, + "source": [ + "class GlobalGenerator(nn.Module):\n", + " '''\n", + " GlobalGenerator Class:\n", + " Implements the global subgenerator (G1) for transferring styles at lower resolutions.\n", + " Values:\n", + " in_channels: the number of input channels, a scalar\n", + " out_channels: the number of output channels, a scalar\n", + " base_channels: the number of channels in first convolutional layer, a scalar\n", + " fb_blocks: the number of frontend / backend blocks, a scalar\n", + " res_blocks: the number of residual blocks, a scalar\n", + " '''\n", + "\n", + " def __init__(self, in_channels, out_channels,\n", + " base_channels=64, fb_blocks=3, res_blocks=9):\n", + " super().__init__()\n", + "\n", + " # Initial convolutional layer\n", + " g1 = [\n", + " nn.ReflectionPad2d(3),\n", + " nn.Conv2d(in_channels, base_channels, kernel_size=7, padding=0),\n", + " nn.InstanceNorm2d(base_channels, affine=False),\n", + " nn.ReLU(inplace=True),\n", + " ]\n", + "\n", + " channels = base_channels\n", + " # Frontend blocks\n", + " for _ in range(fb_blocks):\n", + " g1 += [\n", + " nn.Conv2d(channels, 2 * channels, kernel_size=3, stride=2, padding=1),\n", + " nn.InstanceNorm2d(2 * channels, affine=False),\n", + " nn.ReLU(inplace=True),\n", + " ]\n", + " channels *= 2\n", + "\n", + " # Residual blocks\n", + " for _ in range(res_blocks):\n", + " g1 += [ResidualBlock(channels)]\n", + "\n", + " # Backend blocks\n", + " for _ in range(fb_blocks):\n", + " g1 += [\n", + " nn.ConvTranspose2d(channels, channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1),\n", + " nn.InstanceNorm2d(channels // 2, affine=False),\n", + " nn.ReLU(inplace=True),\n", + " ]\n", + " channels //= 2\n", + "\n", + " # Output convolutional layer as its own nn.Sequential since it will be omitted in second training phase\n", + " self.out_layers = nn.Sequential(\n", + " nn.ReflectionPad2d(3),\n", + " nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=0),\n", + " nn.Tanh(),\n", + " )\n", + "\n", + " self.g1 = nn.Sequential(*g1)\n", + "\n", + " def forward(self, x):\n", + " x = self.g1(x)\n", + " x = self.out_layers(x)\n", + " return x" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "esIlNlgE7s3h" + }, + "source": [ + "### Local Enhancer Subgenerator ($G_2$)\n", + "\n", + "And now onto the local enhancer ($G_2$)! Recall that the local enhancer uses (a pretrained) $G_1$ as part of its architecture. Following our earlier notation, recall that the residual connections from the last layers of $G_2^{(F)}$ and $G_1^{(B)}$ are added together and passed through $G_2^{(R)}$ and $G_2^{(B)}$ to synthesize a high-resolution image. Because of this, you should reuse the $G_1$ implementation so that the weights are consistent for the second training phase." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "lLkFsvt28T9Z" + }, + "source": [ + "class LocalEnhancer(nn.Module):\n", + " '''\n", + " LocalEnhancer Class: \n", + " Implements the local enhancer subgenerator (G2) for handling larger scale images.\n", + " Values:\n", + " in_channels: the number of input channels, a scalar\n", + " out_channels: the number of output channels, a scalar\n", + " base_channels: the number of channels in first convolutional layer, a scalar\n", + " global_fb_blocks: the number of global generator frontend / backend blocks, a scalar\n", + " global_res_blocks: the number of global generator residual blocks, a scalar\n", + " local_res_blocks: the number of local enhancer residual blocks, a scalar\n", + " '''\n", + "\n", + " def __init__(self, in_channels, out_channels, base_channels=32, global_fb_blocks=3, global_res_blocks=9, local_res_blocks=3):\n", + " super().__init__()\n", + "\n", + " global_base_channels = 2 * base_channels\n", + "\n", + " # Downsampling layer for high-res -> low-res input to g1\n", + " self.downsample = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)\n", + "\n", + " # Initialize global generator without its output layers\n", + " self.g1 = GlobalGenerator(\n", + " in_channels, out_channels, base_channels=global_base_channels, fb_blocks=global_fb_blocks, res_blocks=global_res_blocks,\n", + " ).g1\n", + "\n", + " self.g2 = nn.ModuleList()\n", + "\n", + " # Initialize local frontend block\n", + " self.g2.append(\n", + " nn.Sequential(\n", + " # Initial convolutional layer\n", + " nn.ReflectionPad2d(3),\n", + " nn.Conv2d(in_channels, base_channels, kernel_size=7, padding=0), \n", + " nn.InstanceNorm2d(base_channels, affine=False),\n", + " nn.ReLU(inplace=True),\n", + "\n", + " # Frontend block\n", + " nn.Conv2d(base_channels, 2 * base_channels, kernel_size=3, stride=2, padding=1), \n", + " nn.InstanceNorm2d(2 * base_channels, affine=False),\n", + " nn.ReLU(inplace=True),\n", + " )\n", + " )\n", + "\n", + " # Initialize local residual and backend blocks\n", + " self.g2.append(\n", + " nn.Sequential(\n", + " # Residual blocks\n", + " *[ResidualBlock(2 * base_channels) for _ in range(local_res_blocks)],\n", + "\n", + " # Backend blocks\n", + " nn.ConvTranspose2d(2 * base_channels, base_channels, kernel_size=3, stride=2, padding=1, output_padding=1), \n", + " nn.InstanceNorm2d(base_channels, affine=False),\n", + " nn.ReLU(inplace=True),\n", + "\n", + " # Output convolutional layer\n", + " nn.ReflectionPad2d(3),\n", + " nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=0),\n", + " nn.Tanh(),\n", + " )\n", + " )\n", + "\n", + " def forward(self, x):\n", + " # Get output from g1_B\n", + " x_g1 = self.downsample(x)\n", + " x_g1 = self.g1(x_g1)\n", + "\n", + " # Get output from g2_F\n", + " x_g2 = self.g2[0](x)\n", + "\n", + " # Get final output from g2_B\n", + " return self.g2[1](x_g1 + x_g2)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "r6g5OLPHB-e0" + }, + "source": [ + "And voilà! You now have modules for both the global subgenerator and local enhancer subgenerator!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8HoyIbXajoMC" + }, + "source": [ + "## Multiscale Discriminator: Discriminating at different scales too!\n", + "\n", + "Pix2PixHD uses 3 separate subcomponents (subdiscriminators $D_1$, $D_2$, and $D_3$) to generate predictions. They all have the same architectures but $D_2$ and $D_3$ operate on inputs downsampled by 2x and 4x, respectively. The GAN objective is now modified as\n", + "\n", + "\\begin{align*}\n", + " \\min_G \\max_{D_1,D_2,D_3}\\sum_{k=1,2,3}\\mathcal{L}_{\\text{GAN}}(G, D_k)\n", + "\\end{align*}\n", + "\n", + "Each subdiscriminator is a PatchGAN, which you should be familiar with from Pix2Pix!\n", + "\n", + "Let's first implement a single PatchGAN - this implementation will be slightly different than the one you saw in Pix2Pix since the intermediate feature maps will be needed for computing loss." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "fYgX2B_hDxkA" + }, + "source": [ + "class Discriminator(nn.Module):\n", + " '''\n", + " Discriminator Class\n", + " Implements the discriminator class for a subdiscriminator, \n", + " which can be used for all the different scales, just with different argument values.\n", + " Values:\n", + " in_channels: the number of channels in input, a scalar\n", + " base_channels: the number of channels in first convolutional layer, a scalar\n", + " n_layers: the number of convolutional layers, a scalar\n", + " '''\n", + "\n", + " def __init__(self, in_channels, base_channels=64, n_layers=3):\n", + " super().__init__()\n", + "\n", + " # Use nn.ModuleList so we can output intermediate values for loss.\n", + " self.layers = nn.ModuleList()\n", + "\n", + " # Initial convolutional layer\n", + " self.layers.append(\n", + " nn.Sequential(\n", + " nn.Conv2d(in_channels, base_channels, kernel_size=4, stride=2, padding=2),\n", + " nn.LeakyReLU(0.2, inplace=True),\n", + " )\n", + " )\n", + "\n", + " # Downsampling convolutional layers\n", + " channels = base_channels\n", + " for _ in range(1, n_layers):\n", + " prev_channels = channels\n", + " channels = min(2 * channels, 512)\n", + " self.layers.append(\n", + " nn.Sequential(\n", + " nn.Conv2d(prev_channels, channels, kernel_size=4, stride=2, padding=2),\n", + " nn.InstanceNorm2d(channels, affine=False),\n", + " nn.LeakyReLU(0.2, inplace=True),\n", + " )\n", + " )\n", + "\n", + " # Output convolutional layer\n", + " prev_channels = channels\n", + " channels = min(2 * channels, 512)\n", + " self.layers.append(\n", + " nn.Sequential(\n", + " nn.Conv2d(prev_channels, channels, kernel_size=4, stride=1, padding=2),\n", + " nn.InstanceNorm2d(channels, affine=False),\n", + " nn.LeakyReLU(0.2, inplace=True),\n", + " nn.Conv2d(channels, 1, kernel_size=4, stride=1, padding=2),\n", + " )\n", + " )\n", + "\n", + " def forward(self, x):\n", + " outputs = [] # for feature matching loss\n", + " for layer in self.layers:\n", + " x = layer(x)\n", + " outputs.append(x)\n", + "\n", + " return outputs" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZchSdgY1Jrd5" + }, + "source": [ + "Now you're ready to implement the multiscale discriminator in full! This puts together the different subdiscriminator scales." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "3_kqcsh4Jwjz" + }, + "source": [ + "class MultiscaleDiscriminator(nn.Module):\n", + " '''\n", + " MultiscaleDiscriminator Class\n", + " Values:\n", + " in_channels: number of input channels to each discriminator, a scalar\n", + " base_channels: number of channels in first convolutional layer, a scalar\n", + " n_layers: number of downsampling layers in each discriminator, a scalar\n", + " n_discriminators: number of discriminators at different scales, a scalar\n", + " '''\n", + "\n", + " def __init__(self, in_channels, base_channels=64, n_layers=3, n_discriminators=3):\n", + " super().__init__()\n", + "\n", + " # Initialize all discriminators\n", + " self.discriminators = nn.ModuleList()\n", + " for _ in range(n_discriminators):\n", + " self.discriminators.append(\n", + " Discriminator(in_channels, base_channels=base_channels, n_layers=n_layers)\n", + " )\n", + "\n", + " # Downsampling layer to pass inputs between discriminators at different scales\n", + " self.downsample = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)\n", + "\n", + " def forward(self, x):\n", + " outputs = []\n", + "\n", + " for i, discriminator in enumerate(self.discriminators):\n", + " # Downsample input for subsequent discriminators\n", + " if i != 0:\n", + " x = self.downsample(x)\n", + "\n", + " outputs.append(discriminator(x))\n", + "\n", + " # Return list of multiscale discriminator outputs\n", + " return outputs\n", + "\n", + " @property\n", + " def n_discriminators(self):\n", + " return len(self.discriminators)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oSFpoKLlLdpb" + }, + "source": [ + "## Instance Boundary Map: Learning boundaries between instances\n", + "\n", + "Here's a new method that adds additional information as conditional input!\n", + "\n", + "The authors observed that previous approaches have typically taken in a label map (aka. segmentation map) that labels all the pixels to be of a certain class (i.e. car) but doesn't differentiate between two instances of the same class (i.e. two cars in the image). This is the difference between *semantic label maps*, which have class labels but not instance labels, and *instance label maps*, which represent unique instances with unique numbers.\n", + "\n", + "The authors found that the most important information in the instance lelab map is actually the boundaries between instances (i.e. the outline of each car). You can create boundary maps by mapping each pixel maps to a 1 if it's a different instance from its 4 neighbors, and 0 otherwise.\n", + "\n", + "To include this information, the authors concatenate the boundary map with the semantic label map as input. From the figure below, you can see that including both as input results in much sharper generated images (right) than only inputting the semantic label map (left).\n", + "\n", + "> ![Semantic label map input vs instance boundary map input](https://drive.google.com/uc?export=view&id=18J9HN-_TJMYRHPWhygAbc7EVgxpgrl8H)\n", + "![Semantic label map vs instance boundary map](https://drive.google.com/uc?export=view&id=13_lT1DPUxEwWWyjf-aXGzom_D4BFze2E)\n", + "*Semantic label map input (top left) and its blurry output between instances (bottom left) vs. instance boundary map (top right) and the much clearer output between instances from inputting both the semantic label map and the instance boundary map (bottom right). Taken from Figures 4 and 5 of [High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs](https://arxiv.org/abs/1711.11585) (Wang et al. 2018).*" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JzRC73GwEjY-" + }, + "source": [ + "## Instance-level Feature Encoder: Adding controllable diversity\n", + "\n", + "As you already know, the task of generation has more than one possible realistic output. For example, an object of class `road` could be concrete, cobblestone, dirt, etc. To learn this diversity, the authors introduce an encoder $E$, which takes the original image as input and outputs a feature map (like the feature extractor from Course 2, Week 1). They apply *instance-wise averaging*, averaging the feature vectors across all occurrences of each instance (so that every pixel corresponding to the same instance has the same feature vector). They then concatenate this instance-level feature embedding with the semantic label and instance boundary maps as input to the generator.\n", + "\n", + "What's cool is that the encoder $E$ is trained jointly with $G_1$. One huge backprop! When training $G_2$, $E$ is fed a downsampled image and the corresponding output is upsampled to pass into $G_2$.\n", + "\n", + "To allow for control over different features (e.g. concrete, cobblestone, and dirt) for inference, the authors first use K-means clustering to cluster all the feature vectors for each object class in the training set. You can think of this as a dictionary, mapping each class label to a set of feature vectors (so $K$ centroids, each representing different clusters of features). Now during inference, you can perform a random lookup from this dictionary for each class (e.g. road) in the semantic label map to generate one type of feature (e.g. dirt). To provide greater control, you can select among different feature types for each class to generate diverse feature types and, as a result, multi-modal outputs from the same input. \n", + "\n", + "Higher values of $K$ increase diversity and potentially decrease fidelity. You've seen this tradeoff between diversity and fidelity before with the truncation trick, and this is just another way to trade-off between them.\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "dtxfYI8zR5MF" + }, + "source": [ + "class Encoder(nn.Module):\n", + " '''\n", + " Encoder Class\n", + " Values:\n", + " in_channels: number of input channels to each discriminator, a scalar\n", + " out_channels: number of channels in output feature map, a scalar\n", + " base_channels: number of channels in first convolutional layer, a scalar\n", + " n_layers: number of downsampling layers, a scalar\n", + " '''\n", + "\n", + " def __init__(self, in_channels, out_channels, base_channels=16, n_layers=4):\n", + " super().__init__()\n", + "\n", + " self.out_channels = out_channels\n", + " channels = base_channels\n", + "\n", + " layers = [\n", + " nn.ReflectionPad2d(3),\n", + " nn.Conv2d(in_channels, base_channels, kernel_size=7, padding=0), \n", + " nn.InstanceNorm2d(base_channels),\n", + " nn.ReLU(inplace=True),\n", + " ]\n", + "\n", + " # Downsampling layers\n", + " for i in range(n_layers):\n", + " layers += [\n", + " nn.Conv2d(channels, 2 * channels, kernel_size=3, stride=2, padding=1),\n", + " nn.InstanceNorm2d(2 * channels),\n", + " nn.ReLU(inplace=True),\n", + " ]\n", + " channels *= 2\n", + " \n", + " # Upsampling layers\n", + " for i in range(n_layers):\n", + " layers += [\n", + " nn.ConvTranspose2d(channels, channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1),\n", + " nn.InstanceNorm2d(channels // 2),\n", + " nn.ReLU(inplace=True),\n", + " ]\n", + " channels //= 2\n", + "\n", + " layers += [\n", + " nn.ReflectionPad2d(3),\n", + " nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=0),\n", + " nn.Tanh(),\n", + " ]\n", + "\n", + " self.layers = nn.Sequential(*layers)\n", + "\n", + " def instancewise_average_pooling(self, x, inst):\n", + " '''\n", + " Applies instance-wise average pooling.\n", + "\n", + " Given a feature map of size (b, c, h, w), the mean is computed for each b, c\n", + " across all h, w of the same instance\n", + " '''\n", + " x_mean = torch.zeros_like(x)\n", + " classes = torch.unique(inst, return_inverse=False, return_counts=False) # gather all unique classes present\n", + "\n", + " for i in classes:\n", + " for b in range(x.size(0)):\n", + " indices = torch.nonzero(inst[b:b+1] == i, as_tuple=False) # get indices of all positions equal to class i\n", + " for j in range(self.out_channels):\n", + " x_ins = x[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3]]\n", + " mean_feat = torch.mean(x_ins).expand_as(x_ins)\n", + " x_mean[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3]] = mean_feat\n", + "\n", + " return x_mean\n", + "\n", + " def forward(self, x, inst):\n", + " x = self.layers(x)\n", + " x = self.instancewise_average_pooling(x, inst)\n", + " return x" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eQHD_S0pYx7Z" + }, + "source": [ + "## Additional Loss Functions\n", + "\n", + "In addition to the architectural and feature-map enhancements, the authors also incorporate a feature matching loss based on the discriminator. Essentially, they output intermediate feature maps at different resolutions from the discriminator and try to minimize the difference between the real and fake image features.\n", + "\n", + "The authors found this to stabilize training. In this case, this forces the generator to produce natural statistics at multiple scales. This feature-matching loss is similar to StyleGAN's perceptual loss. For some semantic label map $s$ and corresponding image $x$,\n", + "\n", + "\\begin{align*}\n", + " \\mathcal{L}_{\\text{FM}} = \\mathbb{E}_{s,x}\\left[\\sum_{i=1}^T\\dfrac{1}{N_i}\\left|\\left|D^{(i)}_k(s, x) - D^{(i)}_k(s, G(s))\\right|\\right|_1\\right]\n", + "\\end{align*}\n", + "\n", + "where $T$ is the total number of layers, $N_i$ is the number of elements at layer $i$, and $D^{(i)}_k$ denotes the $i$th layer in discriminator $k$.\n", + "\n", + "The authors also report minor improvements in performance when adding perceptual loss, formulated as\n", + "\n", + "\\begin{align*}\n", + " \\mathcal{L}_{\\text{VGG}} = \\mathbb{E}_{s,x}\\left[\\sum_{i=1}^N\\dfrac{1}{M_i}\\left|\\left|F^i(x) - F^i(G(s))\\right|\\right|_1\\right]\n", + "\\end{align*}\n", + "\n", + "where $F^i$ denotes the $i$th layer with $M_i$ elements of the VGG19 network. `torchvision` provides a pretrained VGG19 network, so you'll just need a simple wrapper for it to get the intermediate outputs.\n", + "\n", + "The overall loss looks like this:\n", + "\n", + "\\begin{align*}\n", + " \\mathcal{L} = \\mathcal{L}_{\\text{GAN}} + \\lambda_1\\mathcal{L}_{\\text{FM}} + \\lambda_2\\mathcal{L}_{\\text{VGG}}\n", + "\\end{align*}\n", + "\n", + "where $\\lambda_1 = \\lambda_2 = 10$." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "QCXhrrsudbB2" + }, + "source": [ + "import torchvision.models as models\n", + "\n", + "class VGG19(nn.Module):\n", + " '''\n", + " VGG19 Class\n", + " Wrapper for pretrained torchvision.models.vgg19 to output intermediate feature maps\n", + " '''\n", + "\n", + " def __init__(self):\n", + " super().__init__()\n", + " vgg_features = models.vgg19(pretrained=True).features\n", + "\n", + " self.f1 = nn.Sequential(*[vgg_features[x] for x in range(2)])\n", + " self.f2 = nn.Sequential(*[vgg_features[x] for x in range(2, 7)])\n", + " self.f3 = nn.Sequential(*[vgg_features[x] for x in range(7, 12)])\n", + " self.f4 = nn.Sequential(*[vgg_features[x] for x in range(12, 21)])\n", + " self.f5 = nn.Sequential(*[vgg_features[x] for x in range(21, 30)])\n", + "\n", + " for param in self.parameters():\n", + " param.requires_grad = False\n", + "\n", + " def forward(self, x):\n", + " h1 = self.f1(x)\n", + " h2 = self.f2(h1)\n", + " h3 = self.f3(h2)\n", + " h4 = self.f4(h3)\n", + " h5 = self.f5(h4)\n", + " return [h1, h2, h3, h4, h5]\n", + "\n", + "class Loss(nn.Module):\n", + " '''\n", + " Loss Class\n", + " Implements composite loss for GauGAN\n", + " Values:\n", + " lambda1: weight for feature matching loss, a float\n", + " lambda2: weight for vgg perceptual loss, a float\n", + " device: 'cuda' or 'cpu' for hardware to use\n", + " norm_weight_to_one: whether to normalize weights to (0, 1], a bool\n", + " '''\n", + "\n", + " def __init__(self, lambda1=10., lambda2=10., device='cuda', norm_weight_to_one=True):\n", + " super().__init__()\n", + " self.vgg = VGG19().to(device)\n", + " self.vgg_weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]\n", + "\n", + " lambda0 = 1.0\n", + " # Keep ratio of composite loss, but scale down max to 1.0\n", + " scale = max(lambda0, lambda1, lambda2) if norm_weight_to_one else 1.0\n", + "\n", + " self.lambda0 = lambda0 / scale\n", + " self.lambda1 = lambda1 / scale\n", + " self.lambda2 = lambda2 / scale\n", + "\n", + " def adv_loss(self, discriminator_preds, is_real):\n", + " '''\n", + " Computes adversarial loss from nested list of fakes outputs from discriminator.\n", + " '''\n", + " target = torch.ones_like if is_real else torch.zeros_like\n", + "\n", + " adv_loss = 0.0\n", + " for preds in discriminator_preds:\n", + " pred = preds[-1]\n", + " adv_loss += F.mse_loss(pred, target(pred))\n", + " return adv_loss\n", + "\n", + " def fm_loss(self, real_preds, fake_preds):\n", + " '''\n", + " Computes feature matching loss from nested lists of fake and real outputs from discriminator.\n", + " '''\n", + " fm_loss = 0.0\n", + " for real_features, fake_features in zip(real_preds, fake_preds):\n", + " for real_feature, fake_feature in zip(real_features, fake_features):\n", + " fm_loss += F.l1_loss(real_feature.detach(), fake_feature)\n", + " return fm_loss\n", + "\n", + " def vgg_loss(self, x_real, x_fake):\n", + " '''\n", + " Computes perceptual loss with VGG network from real and fake images.\n", + " '''\n", + " vgg_real = self.vgg(x_real)\n", + " vgg_fake = self.vgg(x_fake)\n", + "\n", + " vgg_loss = 0.0\n", + " for real, fake, weight in zip(vgg_real, vgg_fake, self.vgg_weights):\n", + " vgg_loss += weight * F.l1_loss(real.detach(), fake)\n", + " return vgg_loss\n", + "\n", + " def forward(self, x_real, label_map, instance_map, boundary_map, encoder, generator, discriminator):\n", + " '''\n", + " Function that computes the forward pass and total loss for generator and discriminator.\n", + " '''\n", + " feature_map = encoder(x_real, instance_map)\n", + " x_fake = generator(torch.cat((label_map, boundary_map, feature_map), dim=1))\n", + "\n", + " # Get necessary outputs for loss/backprop for both generator and discriminator\n", + " fake_preds_for_g = discriminator(torch.cat((label_map, boundary_map, x_fake), dim=1))\n", + " fake_preds_for_d = discriminator(torch.cat((label_map, boundary_map, x_fake.detach()), dim=1))\n", + " real_preds_for_d = discriminator(torch.cat((label_map, boundary_map, x_real.detach()), dim=1))\n", + "\n", + " g_loss = (\n", + " self.lambda0 * self.adv_loss(fake_preds_for_g, False) + \\\n", + " self.lambda1 * self.fm_loss(real_preds_for_d, fake_preds_for_g) / discriminator.n_discriminators + \\\n", + " self.lambda2 * self.vgg_loss(x_fake, x_real)\n", + " )\n", + " d_loss = 0.5 * (\n", + " self.adv_loss(real_preds_for_d, True) + \\\n", + " self.adv_loss(fake_preds_for_d, False)\n", + " )\n", + "\n", + " return g_loss, d_loss, x_fake.detach()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EcyavR5oLmau" + }, + "source": [ + "## Training Pix2PixHD\n", + "\n", + "You now have the Pix2PixHD model coded up! All you have to do now is prepare your dataset. Pix2PixHD is trained on the Cityscapes dataset, which unfortunately requires registration. You'll have to download the dataset and put it in your `data` folder to initialize the dataset code below.\n", + "\n", + "Specifically, you should download the `gtFine_trainvaltest` and `leftImg8bit_trainvaltest` and specify the corresponding data splits into the dataloader." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "mII5L2cZLlpO" + }, + "source": [ + "import os\n", + "\n", + "import numpy as np\n", + "import torchvision.transforms as transforms\n", + "from PIL import Image\n", + "\n", + "def scale_width(img, target_width, method):\n", + " '''\n", + " Function that scales an image to target_width while retaining aspect ratio.\n", + " '''\n", + " w, h = img.size\n", + " if w == target_width: return img\n", + " target_height = target_width * h // w\n", + " return img.resize((target_width, target_height), method)\n", + "\n", + "class CityscapesDataset(torch.utils.data.Dataset):\n", + " '''\n", + " CityscapesDataset Class\n", + " Values:\n", + " paths: (a list of) paths to load examples from, a list or string\n", + " target_width: the size of image widths for resizing, a scalar\n", + " n_classes: the number of object classes, a scalar\n", + " '''\n", + "\n", + " def __init__(self, paths, target_width=1024, n_classes=35):\n", + " super().__init__()\n", + "\n", + " self.n_classes = n_classes\n", + "\n", + " # Collect list of examples\n", + " self.examples = {}\n", + " if type(paths) == str:\n", + " self.load_examples_from_dir(paths)\n", + " elif type(paths) == list:\n", + " for path in paths:\n", + " self.load_examples_from_dir(path)\n", + " else:\n", + " raise ValueError('`paths` should be a single path or list of paths')\n", + "\n", + " self.examples = list(self.examples.values())\n", + " assert all(len(example) == 3 for example in self.examples)\n", + "\n", + " # Initialize transforms for the real color image\n", + " self.img_transforms = transforms.Compose([\n", + " transforms.Lambda(lambda img: scale_width(img, target_width, Image.BICUBIC)),\n", + " transforms.Lambda(lambda img: np.array(img)),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", + " ])\n", + "\n", + " # Initialize transforms for semantic label and instance maps\n", + " self.map_transforms = transforms.Compose([\n", + " transforms.Lambda(lambda img: scale_width(img, target_width, Image.NEAREST)),\n", + " transforms.Lambda(lambda img: np.array(img)),\n", + " transforms.ToTensor(),\n", + " ])\n", + "\n", + " def load_examples_from_dir(self, abs_path):\n", + " '''\n", + " Given a folder of examples, this function returns a list of paired examples.\n", + " '''\n", + " assert os.path.isdir(abs_path)\n", + "\n", + " img_suffix = '_leftImg8bit.png'\n", + " label_suffix = '_gtFine_labelIds.png'\n", + " inst_suffix = '_gtFine_instanceIds.png'\n", + "\n", + " for root, _, files in os.walk(abs_path):\n", + " for f in files:\n", + " if f.endswith(img_suffix):\n", + " prefix = f[:-len(img_suffix)]\n", + " attr = 'orig_img'\n", + " elif f.endswith(label_suffix):\n", + " prefix = f[:-len(label_suffix)]\n", + " attr = 'label_map'\n", + " elif f.endswith(inst_suffix):\n", + " prefix = f[:-len(inst_suffix)]\n", + " attr = 'inst_map'\n", + " else:\n", + " continue\n", + "\n", + " if prefix not in self.examples.keys():\n", + " self.examples[prefix] = {}\n", + " self.examples[prefix][attr] = root + '/' + f\n", + "\n", + " def __getitem__(self, idx):\n", + " example = self.examples[idx]\n", + "\n", + " # Load image and maps\n", + " img = Image.open(example['orig_img']).convert('RGB') # color image: (3, 512, 1024)\n", + " inst = Image.open(example['inst_map']) # instance map: (512, 1024)\n", + " label = Image.open(example['label_map']) # semantic label map: (512, 1024)\n", + "\n", + " # Apply corresponding transforms\n", + " img = self.img_transforms(img)\n", + " inst = self.map_transforms(inst)\n", + " label = self.map_transforms(label).long() * 255\n", + "\n", + " # Convert labels to one-hot vectors\n", + " label = torch.zeros(self.n_classes, img.shape[1], img.shape[2]).scatter_(0, label, 1.0).to(img.dtype)\n", + "\n", + " # Convert instance map to instance boundary map\n", + " bound = torch.ByteTensor(inst.shape).zero_()\n", + " bound[:, :, 1:] = bound[:, :, 1:] | (inst[:, :, 1:] != inst[:, :, :-1])\n", + " bound[:, :, :-1] = bound[:, :, :-1] | (inst[:, :, 1:] != inst[:, :, :-1])\n", + " bound[:, 1:, :] = bound[:, 1:, :] | (inst[:, 1:, :] != inst[:, :-1, :])\n", + " bound[:, :-1, :] = bound[:, :-1, :] | (inst[:, 1:, :] != inst[:, :-1, :])\n", + " bound = bound.to(img.dtype)\n", + "\n", + " return (img, label, inst, bound)\n", + "\n", + " def __len__(self):\n", + " return len(self.examples)\n", + "\n", + " @staticmethod\n", + " def collate_fn(batch):\n", + " imgs, labels, insts, bounds = [], [], [], []\n", + " for (x, l, i, b) in batch:\n", + " imgs.append(x)\n", + " labels.append(l)\n", + " insts.append(i)\n", + " bounds.append(b)\n", + " return (\n", + " torch.stack(imgs, dim=0),\n", + " torch.stack(labels, dim=0),\n", + " torch.stack(insts, dim=0),\n", + " torch.stack(bounds, dim=0),\n", + " )" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "51StN80TCpgE" + }, + "source": [ + "Now initialize everything you'll need for training. Don't be worried if there looks like a lot of random code, it's all stuff you've seen before!" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "KG1ZIAI8DgJX" + }, + "source": [ + "from tqdm import tqdm\n", + "from torch.utils.data import DataLoader\n", + "\n", + "n_classes = 35 # total number of object classes\n", + "rgb_channels = n_features = 3\n", + "device = 'cuda'\n", + "train_dir = ['data']\n", + "epochs = 200 # total number of train epochs\n", + "decay_after = 100 # number of epochs with constant lr\n", + "lr = 0.0002\n", + "betas = (0.5, 0.999)\n", + "\n", + "def lr_lambda(epoch):\n", + " ''' Function for scheduling learning '''\n", + " return 1. if epoch < decay_after else 1 - float(epoch - decay_after) / (epochs - decay_after)\n", + "\n", + "def weights_init(m):\n", + " ''' Function for initializing all model weights '''\n", + " if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):\n", + " nn.init.normal_(m.weight, 0., 0.02)\n", + "\n", + "loss_fn = Loss(device=device)\n", + "\n", + "## Phase 1: Low Resolution (1024 x 512)\n", + "dataloader1 = DataLoader(\n", + " CityscapesDataset(train_dir, target_width=1024, n_classes=n_classes),\n", + " collate_fn=CityscapesDataset.collate_fn, batch_size=1, shuffle=True, drop_last=False, pin_memory=True,\n", + ")\n", + "encoder = Encoder(rgb_channels, n_features).to(device).apply(weights_init)\n", + "generator1 = GlobalGenerator(n_classes + n_features + 1, rgb_channels).to(device).apply(weights_init)\n", + "discriminator1 = MultiscaleDiscriminator(n_classes + 1 + rgb_channels, n_discriminators=2).to(device).apply(weights_init)\n", + "\n", + "g1_optimizer = torch.optim.Adam(list(generator1.parameters()) + list(encoder.parameters()), lr=lr, betas=betas)\n", + "d1_optimizer = torch.optim.Adam(list(discriminator1.parameters()), lr=lr, betas=betas)\n", + "g1_scheduler = torch.optim.lr_scheduler.LambdaLR(g1_optimizer, lr_lambda)\n", + "d1_scheduler = torch.optim.lr_scheduler.LambdaLR(d1_optimizer, lr_lambda)\n", + "\n", + "\n", + "## Phase 2: High Resolution (2048 x 1024)\n", + "dataloader2 = DataLoader(\n", + " CityscapesDataset(train_dir, target_width=2048, n_classes=n_classes),\n", + " collate_fn=CityscapesDataset.collate_fn, batch_size=1, shuffle=True, drop_last=False, pin_memory=True,\n", + ")\n", + "generator2 = LocalEnhancer(n_classes + n_features + 1, rgb_channels).to(device).apply(weights_init)\n", + "discriminator2 = MultiscaleDiscriminator(n_classes + 1 + rgb_channels).to(device).apply(weights_init)\n", + "\n", + "g2_optimizer = torch.optim.Adam(list(generator2.parameters()) + list(encoder.parameters()), lr=lr, betas=betas)\n", + "d2_optimizer = torch.optim.Adam(list(discriminator2.parameters()), lr=lr, betas=betas)\n", + "g2_scheduler = torch.optim.lr_scheduler.LambdaLR(g2_optimizer, lr_lambda)\n", + "d2_scheduler = torch.optim.lr_scheduler.LambdaLR(d2_optimizer, lr_lambda)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VwyO_padKxW5" + }, + "source": [ + "And now the training loop, which is pretty much the same between the two phases:" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "w-2lU4TnK5ik" + }, + "source": [ + "from torchvision.utils import make_grid\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Parse torch version for autocast\n", + "# ######################################################\n", + "version = torch.__version__\n", + "version = tuple(int(n) for n in version.split('.')[:-1])\n", + "has_autocast = version >= (1, 6)\n", + "# ######################################################\n", + "\n", + "def show_tensor_images(image_tensor):\n", + " '''\n", + " Function for visualizing images: Given a tensor of images, number of images, and\n", + " size per image, plots and prints the images in an uniform grid.\n", + " '''\n", + " image_tensor = (image_tensor + 1) / 2\n", + " image_unflat = image_tensor.detach().cpu()\n", + " image_grid = make_grid(image_unflat[:1], nrow=1)\n", + " plt.imshow(image_grid.permute(1, 2, 0).squeeze())\n", + " plt.show()\n", + "\n", + "def train(dataloader, models, optimizers, schedulers, device):\n", + " encoder, generator, discriminator = models\n", + " g_optimizer, d_optimizer = optimizers\n", + " g_scheduler, d_scheduler = schedulers\n", + "\n", + " cur_step = 0\n", + " display_step = 100\n", + "\n", + " mean_g_loss = 0.0\n", + " mean_d_loss = 0.0\n", + "\n", + " for epoch in range(epochs):\n", + " # Training epoch\n", + " for (x_real, labels, insts, bounds) in tqdm(dataloader, position=0):\n", + " x_real = x_real.to(device)\n", + " labels = labels.to(device)\n", + " insts = insts.to(device)\n", + " bounds = bounds.to(device)\n", + "\n", + " # Enable autocast to FP16 tensors (new feature since torch==1.6.0)\n", + " # If you're running older versions of torch, comment this out\n", + " # and use NVIDIA apex for mixed/half precision training\n", + " if has_autocast:\n", + " with torch.cuda.amp.autocast(enabled=(device=='cuda')):\n", + " g_loss, d_loss, x_fake = loss_fn(\n", + " x_real, labels, insts, bounds, encoder, generator, discriminator\n", + " )\n", + " else:\n", + " g_loss, d_loss, x_fake = loss_fn(\n", + " x_real, labels, insts, bounds, encoder, generator, discriminator\n", + " )\n", + "\n", + " g_optimizer.zero_grad()\n", + " g_loss.backward()\n", + " g_optimizer.step()\n", + "\n", + " d_optimizer.zero_grad()\n", + " d_loss.backward()\n", + " d_optimizer.step()\n", + "\n", + " mean_g_loss += g_loss.item() / display_step\n", + " mean_d_loss += d_loss.item() / display_step\n", + "\n", + " if cur_step % display_step == 0 and cur_step > 0:\n", + " print('Step {}: Generator loss: {:.5f}, Discriminator loss: {:.5f}'\n", + " .format(cur_step, mean_g_loss, mean_d_loss))\n", + " show_tensor_images(x_fake.to(x_real.dtype))\n", + " show_tensor_images(x_real)\n", + " mean_g_loss = 0.0\n", + " mean_d_loss = 0.0\n", + " cur_step += 1\n", + "\n", + " g_scheduler.step()\n", + " d_scheduler.step()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wCcV0ttk6_Mv" + }, + "source": [ + "And now you can train your models! Remember to set the local enhancer subgenerator to the global subgenerator that you train in the first phase.\n", + "\n", + "In their official repository, the authors don't continue to train the encoder. Instead, they precompute all feature maps upsample them, and concatenate this to the input to the local enhancer subgenerator. (They also leave a re-train option for it). For simplicity, the script below will just downsample and upsample high-resolution inputs." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "1-41crv17Pcn" + }, + "source": [ + "# Phase 1: Low Resolution\n", + "#######################################################################\n", + "train(\n", + " dataloader1,\n", + " [encoder, generator1, discriminator1],\n", + " [g1_optimizer, d1_optimizer],\n", + " [g1_scheduler, d1_scheduler],\n", + " device,\n", + ")\n", + "\n", + "\n", + "# Phase 2: High Resolution\n", + "#######################################################################\n", + "# Update global generator in local enhancer with trained\n", + "generator2.g1 = generator1.g1\n", + "\n", + "# Freeze encoder and wrap to support high-resolution inputs/outputs\n", + "def freeze(encoder):\n", + " encoder.eval()\n", + " for p in encoder.parameters():\n", + " p.requires_grad = False\n", + "\n", + " @torch.jit.script\n", + " def forward(x, inst):\n", + " x = F.interpolate(x, scale_factor=0.5, recompute_scale_factor=True)\n", + " inst = F.interpolate(inst.float(), scale_factor=0.5, recompute_scale_factor=True)\n", + " feat = encoder(x, inst.int())\n", + " return F.interpolate(feat, scale_factor=2.0, recompute_scale_factor=True)\n", + " return forward\n", + "\n", + "train(\n", + " dataloader2,\n", + " [freeze(encoder), generator2, discriminator2],\n", + " [g2_optimizer, d2_optimizer],\n", + " [g2_scheduler, d2_scheduler],\n", + " device,\n", + ")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "37egKCs2YqEO" + }, + "source": [ + "## Inference with Pix2PixHD\n", + "\n", + "Recall that in inference time, the encoder feature maps from training are saved and clustered with K-means by object class. Again, you'll have to download the Cityscapes dataset into your `data` folder and then run these functions." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "69wfl6vre_cM", + "outputId": "cb08aea5-a28a-421b-e5dd-6441d8ee04c7", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 147 + } + }, + "source": [ + "from sklearn.cluster import KMeans\n", + "\n", + "# Encode features by class label\n", + "features = {}\n", + "for (x, _, inst, _) in tqdm(dataloader2):\n", + " x = x.to(device)\n", + " inst = inst.to(device)\n", + " area = inst.size(2) * inst.size(3)\n", + "\n", + " # Get pooled feature map\n", + " with torch.no_grad():\n", + " feature_map = encoder(x, inst)\n", + "\n", + " for i in torch.unique(inst):\n", + " label = i if i < 1000 else i // 1000\n", + " label = int(label.flatten(0).item())\n", + "\n", + " # All indices should have same feature per class from pooling\n", + " idx = torch.nonzero(inst == i, as_tuple=False)\n", + " n_inst = idx.size(0)\n", + " idx = idx[0, :]\n", + "\n", + " # Retrieve corresponding encoded feature\n", + " feature = feature_map[idx[0], :, idx[2], idx[3]].unsqueeze(0)\n", + "\n", + " # Compute rate of feature appearance (in official code, they compute per block)\n", + " block_size = 32\n", + " rate_per_block = 32 * n_inst / area\n", + " rate = torch.ones((1, 1), device=device).to(feature.dtype) * rate_per_block\n", + "\n", + " feature = torch.cat((feature, rate), dim=1)\n", + " if label in features.keys():\n", + " features[label] = torch.cat((features[label], feature), dim=0)\n", + " else:\n", + " features[label] = feature\n", + "\n", + "\n", + "# Cluster features by class label\n", + "k = 10\n", + "centroids = {}\n", + "for label in range(n_classes):\n", + " if label not in features.keys():\n", + " continue\n", + " feature = features[label]\n", + "\n", + " # Thresholding by 0.5 isn't mentioned in the paper, but is present in the\n", + " # official code repository, probably so that only frequent features are clustered\n", + " feature = feature[feature[:, -1] > 0.5, :-1].cpu().numpy()\n", + "\n", + " if feature.shape[0]:\n", + " n_clusters = min(feature.shape[0], k)\n", + " kmeans = KMeans(n_clusters=n_clusters).fit(feature)\n", + " centroids[label] = kmeans.cluster_centers_" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "100%|██████████| 174/174 [02:07<00:00, 1.36it/s]\n" + ], + "name": "stderr" + }, + { + "output_type": "error", + "ename": "SyntaxError", + "evalue": "ignored", + "traceback": [ + "\u001b[0;36m File \u001b[0;32m\"\"\u001b[0;36m, line \u001b[0;32m54\u001b[0m\n\u001b[0;31m return centroids\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m 'return' outside function\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "f1LVqMkTjn06" + }, + "source": [ + "After getting the encoded feature centroids per class, you can now run inference! Remember that the generator is trained to take in a concatenation of the semantic label map, instance boundary map, and encoded feature map.\n", + "\n", + "Congrats on making it to the end of this complex notebook! Have fun with this powerful model and be responsible of course ;)" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "3BpqE3ngk_3s", + "outputId": "939172eb-c90f-4050-90d4-22175fa5073a", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 421 + } + }, + "source": [ + "def infer(label_map, instance_map, boundary_map):\n", + " # Sample feature vector centroids\n", + " b, _, h, w = label_map.shape\n", + " feature_map = torch.zeros((b, n_features, h, w), device=device).to(label_map.dtype)\n", + "\n", + " for i in torch.unique(instance_map):\n", + " label = i if i < 1000 else i // 1000\n", + " label = int(label.flatten(0).item())\n", + "\n", + " if label in centroids.keys():\n", + " centroid_idx = random.randint(0, centroids[label].shape[0] - 1)\n", + " idx = torch.nonzero(instance_map == int(i), as_tuple=False)\n", + "\n", + " feature = torch.from_numpy(centroids[label][centroid_idx, :]).to(device)\n", + " feature_map[idx[:, 0], :, idx[:, 2], idx[:, 3]] = feature\n", + "\n", + " with torch.no_grad():\n", + " x_fake = generator2(torch.cat((label_map, boundary_map, feature_map), dim=1))\n", + " return x_fake\n", + "\n", + "for x, labels, insts, bounds in dataloader2:\n", + " x_fake = infer(labels.to(device), insts.to(device), bounds.to(device))\n", + " show_tensor_images(x_fake.to(x.dtype))\n", + " show_tensor_images(x)\n", + " break" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + } + ] +} \ No newline at end of file