{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "view-in-github" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "91eWH3ZiMtxH" }, "source": [ "# TP02 - VAE for MNIST clustering and generation\n", "\n", "In this practical, we will explore Variational Auto-Encoders (VAE) with the MNIST digit recognition dataset.\n", "The purpose of a VAE is to generate new samples \"looking like\" the training samples.\n", "For this purpose, we will construct a two-block model: encoder and decoder, one takes a sample and maps it to a latent space, the second takes a latent point and maps it to the sample space. If the latent distribution is roughly gaussian, this gives a way to get new samples: take a latent point at random,\n", "and map it to the sample space with the decoder." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZuhVxmqVsV5W" }, "outputs": [], "source": [ "import os\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torchvision\n", "from torchvision import transforms\n", "from torchvision.utils import save_image\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from IPython.display import clear_output\n", "%matplotlib inline\n", "\n", "from sklearn.metrics.cluster import normalized_mutual_info_score" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "axr0X7JWsV5c" }, "outputs": [], "source": [ "!pip install --quiet \"git+https://gitlab.com/robindar/dl-scaman_checker.git\"\n", "from dl_scaman_checker import TP09\n", "TP09.check_install()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "r7tfgdYCsV5d" }, "outputs": [], "source": [ "N_CLASSES = 10\n", "IMAGE_SIZE = 28 * 28\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "device" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "onVZXPp6MtxP" }, "outputs": [], "source": [ "def show(img):\n", " plt.imshow(np.transpose(img.numpy(), (1,2,0)), interpolation='nearest')\n", "\n", "def plot_reconstruction(model, n=24):\n", " x,_ = next(iter(data_loader))\n", " x = x[:n,:,:,:].to(device)\n", " try:\n", " out, _, _, log_p = model(x.view(-1, IMAGE_SIZE))\n", " except:\n", " out, _, _ = model(x.view(-1, IMAGE_SIZE))\n", " x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)\n", " out_grid = torchvision.utils.make_grid(x_concat).cpu().data\n", " show(out_grid)\n", "\n", "@torch.no_grad()\n", "def plot_generation(model, n=24):\n", " z = torch.randn(n, z_dim).to(device)\n", " out = model.decode(z).view(-1, 1, 28, 28)\n", " out_grid = torchvision.utils.make_grid(out).cpu()\n", " show(out_grid)\n", "\n", "@torch.no_grad()\n", "def plot_conditional_generation(model, n=8, fix_number=None):\n", " matrix = np.zeros((n, N_CLASSES))\n", " matrix[:,0] = 1\n", "\n", " if fix_number is None:\n", " final = matrix[:]\n", " for i in range(1, N_CLASSES):\n", " final = np.vstack((final,np.roll(matrix,i)))\n", " z = torch.randn(n, z_dim)\n", " z = z.repeat(N_CLASSES,1).to(device)\n", " y_onehot = torch.tensor(final).type(torch.FloatTensor).to(device)\n", " out = model.decode(z,y_onehot).view(-1, 1, 28, 28)\n", " else:\n", " z = torch.randn(n, z_dim).to(device)\n", " y_onehot = torch.tensor(np.roll(matrix, fix_number)).type(torch.FloatTensor).to(device)\n", " out = model.decode(z,y_onehot).view(-1, 1, 28, 28)\n", "\n", " out_grid = torchvision.utils.make_grid(out).cpu()\n", " show(out_grid)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pgwvL5SvMtxm" }, "outputs": [], "source": [ "data_dir = './data'\n", "dataset = torchvision.datasets.MNIST(\n", " root=data_dir,\n", " train=True,\n", " transform=transforms.ToTensor(),\n", " download=True)\n", "\n", "data_loader = torch.utils.data.DataLoader(\n", " dataset=dataset,\n", " batch_size=128,\n", " shuffle=True)\n", "\n", "test_loader = torch.utils.data.DataLoader(\n", " torchvision.datasets.MNIST(data_dir, train=False, download=True, transform=transforms.ToTensor()),\n", " batch_size=10, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "UESJllvCMtxu" }, "source": [ "# Part A - Variational Autoencoders\n", "\n", "## A.1 - Autoencoding theory\n", "\n", "Consider a latent variable model with a data variable $x\\in \\mathcal{X}$ and a latent variable $z\\in \\mathcal{Z}$, $p(z,x) = p(z)p_\\theta(x|z)$. Given the data $x_1,\\dots, x_n$, we want to train the model by maximizing the marginal log-likelihood:\n", "\\begin{eqnarray*}\n", "\\mathcal{L} = \\mathbf{E}_{p_d(x)}\\left[\\log p_\\theta(x)\\right]=\\mathbf{E}_{p_d(x)}\\left[\\log \\int_{\\mathcal{Z}}p_{\\theta}(x|z)p(z)dz\\right],\n", " \\end{eqnarray*}\n", " where $p_d$ denotes the empirical distribution of $X$: $p_d(x) =\\frac{1}{n}\\sum_{i=1}^n \\delta_{x_i}(x)$.\n", "\n", " To avoid the (often) difficult computation of the integral above, the idea behind variational methods is to instead maximize a lower bound to the log-likelihood:\n", " \\begin{eqnarray*}\n", "\\mathcal{L} \\geq L(p_\\theta(x|z),q(z|x)) =\\mathbf{E}_{p_d(x)}\\left[\\mathbf{E}_{q(z|x)}\\left[\\log p_\\theta(x|z)\\right]-\\mathrm{KL}\\left( q(z|x)||p(z)\\right)\\right].\n", " \\end{eqnarray*}\n", " Any choice of $q(z|x)$ gives a valid lower bound. Variational autoencoders replace the variational posterior $q(z|x)$ by an inference network $q_{\\phi}(z|x)$ that is trained together with $p_{\\theta}(x|z)$ to jointly maximize $L(p_\\theta,q_\\phi)$.\n", " \n", "The variational posterior $q_{\\phi}(z|x)$ is also called the **encoder** and the generative model $p_{\\theta}(x|z)$, the **decoder** or generator.\n", "\n", "The first term $\\mathbf{E}_{q(z|x)}\\left[\\log p_\\theta(x|z)\\right]$ is the negative reconstruction error. Indeed under a gaussian assumption i.e. $p_{\\theta}(x|z) = \\mathcal{N}(\\mu_{\\theta}(z), I)$ the term $\\log p_\\theta(x|z)$ reduces to $\\propto \\|x-\\mu_\\theta(z)\\|^2$, which is often used in practice. The term $\\mathrm{KL}\\left( q(z|x)||p(z)\\right)$ can be seen as a regularization term, where the variational posterior $q_\\phi(z|x)$ should be matched to the prior $p(z)= \\mathcal{N}(0, I)$.\n", "\n", "Variational Autoencoders were introduced by [Kingma and Welling (2013)](https://arxiv.org/abs/1312.6114), see also [(Doersch, 2016)](https://arxiv.org/abs/1606.05908) for a tutorial.\n", "\n", "## A.2 - A simple autoencoder for MNIST\n", "\n", "There are various examples of VAE in PyTorch available [here](https://github.com/pytorch/examples/tree/master/vae) or [here](https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/variational_autoencoder/main.py#L38-L65). The code below is taken from this last source.\n", "\n", "![A variational autoencoder.](https://github.com/dataflowr/notebooks/blob/master/HW3/vae.png?raw=true)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TWtQVnsAMtxw" }, "outputs": [], "source": [ "# VAE model\n", "class VAE(nn.Module):\n", " def __init__(self, h_dim, z_dim):\n", " super(VAE, self).__init__()\n", " self.fc1 = nn.Linear(IMAGE_SIZE, h_dim)\n", " self.fc2 = nn.Linear(h_dim, z_dim)\n", " self.fc3 = nn.Linear(h_dim, z_dim)\n", " self.fc4 = nn.Linear(z_dim, h_dim)\n", " self.fc5 = nn.Linear(h_dim, IMAGE_SIZE)\n", "\n", " def encode(self, x):\n", " h = F.relu(self.fc1(x))\n", " return self.fc2(h), self.fc3(h)\n", "\n", " def reparameterize(self, mu, log_var):\n", " std = torch.exp(log_var/2)\n", " eps = torch.randn_like(std)\n", " return mu + eps * std\n", "\n", " def decode(self, z):\n", " h = F.relu(self.fc4(z))\n", " return torch.sigmoid(self.fc5(h))\n", "\n", " def forward(self, x):\n", " mu, log_var = self.encode(x)\n", " z = self.reparameterize(mu, log_var)\n", " x_reconst = self.decode(z)\n", " return x_reconst, mu, log_var\n", "\n", "# Hyper-parameters\n", "h_dim = 400\n", "z_dim = 20\n", "num_epochs = 15\n", "learning_rate = 1e-3\n", "\n", "model = VAE(h_dim=h_dim, z_dim=z_dim).to(device)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)" ] }, { "cell_type": "markdown", "metadata": { "id": "2klHMA7TMtx2" }, "source": [ "Here for the loss, instead of MSE for the reconstruction loss, we take Binary Cross-Entropy. The code below is still from the PyTorch tutorial (with minor modifications to avoid warnings!)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ldmo0pTtsV5o" }, "outputs": [], "source": [ "@torch.no_grad()\n", "def live_plot(data_dict, x_key=None, figsize=(7,5), title=''):\n", " clear_output(wait=True)\n", " plt.figure(figsize=figsize)\n", " for label, data in data_dict.items():\n", " if label == x_key or len(data) == 0:\n", " continue\n", " x = data_dict[x_key] if x_key is not None else np.arange(len(data))\n", " plt.plot(x, data, label=label, linewidth=1)\n", " plt.title(title)\n", " plt.grid(alpha=.5, which='both')\n", " plt.xlabel('epoch' if x_key is None else x_key)\n", " plt.legend(loc='center left') # the plot evolves to the right\n", " plt.show();" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_hs3Wnd8Mtx4" }, "outputs": [], "source": [ "verbose = False\n", "data_dict = { \"epoch\": [], \"total loss\": [], \"reconstruction\": [], \"kl_div\": [] }\n", "\n", "for epoch in range(num_epochs):\n", " for i, (x, _) in enumerate(data_loader):\n", " # Forward pass\n", " x = x.to(device).view(-1, IMAGE_SIZE)\n", " x_reconst, mu, log_var = model(x)\n", "\n", " # Compute reconstruction loss and kl divergence\n", " # For KL divergence between Gaussians, see Appendix B in VAE paper or (Doersch, 2016):\n", " # https://arxiv.org/abs/1606.05908\n", " reconst_loss = F.binary_cross_entropy(x_reconst, x, reduction='sum')\n", " kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())\n", "\n", " # Backprop and optimize\n", " loss = reconst_loss + kl_div\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " # Bookkeeping\n", " data_dict[\"total loss\"].append(loss.item() / len(x))\n", " data_dict[\"reconstruction\"].append(reconst_loss.item() / len(x))\n", " data_dict[\"kl_div\"].append(kl_div.item() / len(x))\n", " data_dict[\"epoch\"].append(epoch + float(1+i) / len(data_loader))\n", "\n", "\n", " if (i+1) % 100 == 0:\n", " if verbose:\n", " print (\"Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}\"\n", " .format(epoch+1, num_epochs, i+1, len(data_loader),\n", " reconst_loss.item()/len(x), kl_div.item()/len(x)))\n", " else:\n", " live_plot(data_dict, x_key=\"epoch\", title=\"VAE\")\n", "\n", "live_plot(data_dict, x_key=\"epoch\", title=\"VAE\")" ] }, { "cell_type": "markdown", "metadata": { "id": "XinKgE7AMtx-" }, "source": [ "## A.3 - Evaluating results\n", "\n", "Let us see how our network reconstructs our last batch. We display pairs of original digits and reconstructed version side by side.\n", "\n", "Observe how most reconstructed digits are essentially identical to the original version.\n", "This means that the identity mapping has been learned well.\n", "You should still see some blurry digits very different from the original (resample a couple times if needed)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YLTFvi0SMtyA" }, "outputs": [], "source": [ "plot_reconstruction(model)" ] }, { "cell_type": "markdown", "metadata": { "id": "ZC7nauVyMtyF" }, "source": [ "Let's see now how our network generates new samples." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ehuLHz4NMtyH" }, "outputs": [], "source": [ "plot_generation(model)" ] }, { "cell_type": "markdown", "metadata": { "id": "Hddb3q43MtyN" }, "source": [ "Not great, but we did not train our network for long... That being said, we have no control of the generated digits. In the rest of this notebook, we explore ways to generates zeroes, ones, twos and so on.\n", "\n", "\n", "As a by-product, we show how our VAE will allow us to do clustering thanks to the Gumbel VAE described below. But before that, we start by cheating a little bit..." ] }, { "cell_type": "markdown", "metadata": { "id": "tOj5--9VsV5t" }, "source": [ "# Part B - Cheating with the \"conditional\" VAE\n", "\n", "We would like to generate samples of a given number. We have so far a model for images $X \\in [0,1]^{P}$ with $P = 784$ pixels, generated as $X = f_\\theta(Z)$ where $Z \\sim \\mathcal{N}(0,I)$ is a latent with normal distribution. In other to sample a specific digit, we will use a model\n", "$g_\\theta : \\mathbb{R}^z \\times \\{0,\\ldots,9\\} \\to [0,1]^P$ and sample images according to $X = g_\\theta(Z, Y)$\n", "where $Y \\in \\{0, \\ldots, 9\\}$ is a class label.\n", "\n", "In the context of variational autoencoding, this is considered cheating,\n", "because it uses external information (the class label) instead of learning the modes from the data.\n", "\n", "To build the function $g$, we will simply concatenate the latent representation with\n", "a one-hot encoding of the class. This way, we can use the above architecture with very little modification.\n", "\n", "First, write a function transforming a label into its onehot encoding. This function will be used in the training loop (not in the architecture of the neural network!)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5l5zV2vfsV5t" }, "outputs": [], "source": [ "def l_2_onehot(labels, nb_digits=N_CLASSES):\n", " # take labels (from the dataloader) and return labels onehot-encoded\n", " ### your code here ###" ] }, { "cell_type": "markdown", "metadata": { "id": "Y5Ean5qFsV5u" }, "source": [ "You can test it on a batch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cNL--q4SsV5u" }, "outputs": [], "source": [ "(x,labels) = next(iter(data_loader))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AirJGVqwsV5v" }, "outputs": [], "source": [ "labels" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "b8_vaVGzsV5w" }, "outputs": [], "source": [ "assert l_2_onehot(labels).shape == (*labels.shape, N_CLASSES)\n", "l_2_onehot(labels)" ] }, { "cell_type": "markdown", "metadata": { "id": "V4UNjRzJsV5w" }, "source": [ "Now modify the architecture of the VAE where the decoder takes as input the random code concatenated with the onehot encoding of the label, you can use `torch.cat`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ioSc2-WCsV5x" }, "outputs": [], "source": [ "class VAE_Cond(nn.Module):\n", " def __init__(self, h_dim, z_dim):\n", " super(VAE_Cond, self).__init__()\n", " self.fc1 = nn.Linear(IMAGE_SIZE, h_dim)\n", " self.fc2 = nn.Linear(h_dim, z_dim)\n", " self.fc3 = nn.Linear(h_dim, z_dim)\n", " ### your code here ###\n", "\n", " def encode(self, x):\n", " h = F.relu(self.fc1(x))\n", " return self.fc2(h), self.fc3(h)\n", "\n", " def reparameterize(self, mu, log_var):\n", " std = torch.exp(log_var/2)\n", " eps = torch.randn_like(std)\n", " return mu + eps * std\n", "\n", " def decode(self, z, l_onehot):\n", " ### your code here ###\n", "\n", " def forward(self, x, l_onehot):\n", " ### your code here ###" ] }, { "cell_type": "markdown", "metadata": { "id": "8aeUE3qVsV5x" }, "source": [ "Test your new model on a batch:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Tac5aWqssV5y" }, "outputs": [], "source": [ "model_C = VAE_Cond(h_dim=h_dim, z_dim=z_dim).to(device)\n", "x = x.to(device).view(-1, IMAGE_SIZE)\n", "l_onehot = l_2_onehot(labels).to(device)\n", "x_reconst, mu, log_var = model_C(x, l_onehot)\n", "assert x_reconst.shape == x.shape\n", "assert mu.shape == log_var.shape and mu.shape == (x.shape[0], z_dim)\n", "x_reconst.shape, mu.shape, log_var.shape" ] }, { "cell_type": "markdown", "metadata": { "id": "0-8d19dcsV5y" }, "source": [ "Now you can modify the training loop of your network. The parameter $\\beta$ will allow you to scale the KL term in your loss as explained in the [$\\beta$-VAE paper](https://openreview.net/forum?id=Sy2fzU9gl) see formula (4) in the paper." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IClystPvsV5z" }, "outputs": [], "source": [ "def train_C(model, data_loader=data_loader,num_epochs=num_epochs, beta=10., verbose=True):\n", " nmi_scores = []\n", " model.train(True)\n", " data_dict = { \"epoch\": [], \"total loss\": [], \"reconstruction\": [], \"kl_div\": [] }\n", " title = \"Conditional VAE\"\n", "\n", " for epoch in range(num_epochs):\n", " for i, (x, labels) in enumerate(data_loader):\n", " x = x.to(device).view(-1, IMAGE_SIZE)\n", "\n", " ### your forward code here ###\n", "\n", " reconst_loss = F.binary_cross_entropy(x_reconst, x, reduction='sum')\n", " kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())\n", "\n", " # Backprop and optimize\n", " loss = reconst_loss + beta * kl_div\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " # Bookkeeping\n", " data_dict[\"total loss\"].append(loss.item() / len(x))\n", " data_dict[\"reconstruction\"].append(reconst_loss.item() / len(x))\n", " data_dict[\"kl_div\"].append(kl_div.item() / len(x))\n", " data_dict[\"epoch\"].append(epoch + float(1+i) / len(data_loader))\n", "\n", "\n", " if (i+1) % 100 == 0:\n", " if verbose:\n", " print(\"Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}\"\n", " .format(epoch+1, num_epochs, i+1, len(data_loader),\n", " reconst_loss.item()/len(x),\n", " kl_div.item()/len(x)))\n", " else:\n", " live_plot(data_dict, x_key=\"epoch\", title=title)\n", "\n", " live_plot(data_dict, x_key=\"epoch\", title=title)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Eq5rzx-EsV5z" }, "outputs": [], "source": [ "model_C = VAE_Cond(h_dim=h_dim, z_dim=z_dim).to(device)\n", "optimizer = torch.optim.Adam(model_C.parameters(), lr=learning_rate)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2eISU2zTsV50" }, "outputs": [], "source": [ "train_C(model_C, num_epochs=15, verbose=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5jGgt2UhsV50" }, "outputs": [], "source": [ "plot_conditional_generation(model_C)" ] }, { "cell_type": "markdown", "metadata": { "id": "wZxXSDH9sV51" }, "source": [ "Here you should get nice results. Now we will avoid the use of the labels..." ] }, { "cell_type": "markdown", "metadata": { "id": "ePpthCZXMtyS" }, "source": [ "# Part C - No cheating with Gumbel VAE\n", "\n", "Implement a VAE where you add a categorical variable $c\\in \\{0,\\dots 9\\}$ so that your latent variable model is $p(c,z,x) = p(c)p(z)p_{\\theta}(x|c,z)$ and your variational posterior is $q_{\\phi}(c|x)q_{\\phi}(z|x)$ as described in this NeurIPS paper: [(Dupont, 2018)](https://arxiv.org/abs/1804.00104). Try to make only minimal modifications to previous architecture.\n", "\n", "The idea is to incorporate a categorical variable in your latent space, without cheating by forcing it to take the value of the image's label. You hope that this categorical variable will encode the class of the digit, so that your network can use it for a better reconstruction, but can't force it. Moreover, if things work as planned, you will then be able to generate digits conditionally to the class, i.e. you can choose the class thanks to the latent categorical variable $c$ and then generate digits from this class.\n", "\n", "## C.1 - Gumbel trick for discrete latent variables\n", "\n", "As noticed above, in order to sample random variables while still being able to use backpropagation, we need to use the reparameterization trick which is easy for Gaussian random variables. For categorical random variables, the reparameterization trick is explained in [(Jang et al., 2016)](https://arxiv.org/abs/1611.01144). This is implemented in PyTorch thanks to [F.gumbel_softmax](https://pytorch.org/docs/stable/generated/torch.nn.functional.gumbel_softmax.html).\n", "\n", "Note: there is an instability in the PyTorch `F.gumbel_softmax` implementation. We provide instead `TP09.gumbel_softmax` which takes the same arguments and ensures that the output is not NaN. If you encounter nans in your training or errors in the binary cross-entropy, make sure you are using the stabilized version." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sBIKV-68MtyW" }, "outputs": [], "source": [ "class VAE_Gumbel(nn.Module):\n", " def __init__(self, h_dim, z_dim):\n", " super(VAE_Gumbel, self).__init__()\n", " ### your code here ###\n", "\n", " def encode(self, x):\n", " ### your code here ###\n", "\n", " def reparameterize(self, mu, log_var):\n", " std = torch.exp(log_var/2)\n", " eps = torch.randn_like(std)\n", " return mu + eps * std\n", "\n", " def decode(self, z, y_onehot):\n", " ### your code here ###\n", "\n", " def forward(self, x):\n", " ### your code here ###" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "J2gLw45WsV52" }, "outputs": [], "source": [ "model_G = VAE_Gumbel(h_dim=h_dim, z_dim=z_dim).to(device)\n", "optimizer = torch.optim.Adam(model_G.parameters(), lr=learning_rate)" ] }, { "cell_type": "markdown", "metadata": { "id": "KykEyDKDMtyc" }, "source": [ "You need to modify the loss to take into account the categorical random variable with an uniform prior on $\\{0,\\dots 9\\}$, see Appendix A.2 in [(Dupont, 2018)](https://arxiv.org/abs/1804.00104)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JBV8OdB9Mtyg" }, "outputs": [], "source": [ "def train_G(model, data_loader=data_loader,num_epochs=num_epochs, beta = 1., verbose=True):\n", " nmi_scores = []\n", " model.train(True)\n", " data_dict = { \"epoch\": [], \"total loss\": [], \"reconstruction\": [], \"kl_div\": [], \"entropy\": [] }\n", " title = \"Gumbel VAE\"\n", "\n", " for epoch in range(num_epochs):\n", " all_labels = []\n", " all_labels_est = []\n", " for i, (x, labels) in enumerate(data_loader):\n", " x = x.to(device).view(-1, IMAGE_SIZE)\n", " ### your forward code here ###\n", "\n", " # Backprop and optimize\n", " loss = None ### your loss code here ###\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " # Bookkeeping\n", " data_dict[\"total loss\"].append(loss.item() / len(x))\n", " ### your bookkeeping code here ###\n", " data_dict[\"epoch\"].append(epoch + float(1+i) / len(data_loader))\n", "\n", " if (i+1) % 100 == 0:\n", " if verbose:\n", " print (\"Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}, Entropy: {:.4f}\"\n", " .format(epoch+1, num_epochs, i+1, len(data_loader),\n", " reconst_loss.item()/len(x),\n", " kl_div.item()/len(x),\n", " H_cat.item()/len(x)))\n", " else:\n", " live_plot(data_dict, x_key=\"epoch\", title=title)\n", "\n", " live_plot(data_dict, x_key=\"epoch\", title=title)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KL5mR20KMtyp" }, "outputs": [], "source": [ "train_G(model_G, num_epochs=20, verbose=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NyZCzWt5Mtyw" }, "outputs": [], "source": [ "plot_reconstruction(model_G)" ] }, { "cell_type": "markdown", "metadata": { "id": "9nMSU21iMty7" }, "source": [ "The reconstruction is good, but we care more about the generation. For each category, we generate 8 samples thanks to the `plot_conditional_generation()` function.\n", "Consistently with the previous use, each row is supposed to consist of the same digit sampled 8 times, and each row should correspond to a distinct digit." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XWb_mYgsMty9" }, "outputs": [], "source": [ "plot_conditional_generation(model_G)" ] }, { "cell_type": "markdown", "metadata": { "id": "PilQVNR3MtzL" }, "source": [ "It does not look like our original idea is working...\n", "\n", "What is happening is that our network is not using the categorical variable at all (all \"digits\" look the same, the variation comes from the re-sampling of Z). We can track the [normalized mutual information](https://en.wikipedia.org/wiki/Mutual_information#Normalized_variants) (see [this method in scikit-learn](http://scikit-learn.org/stable/modules/generated/sklearn.metrics.normalized_mutual_info_score.html)) between the true labels and the labels predicted by our network (just by taking the category with maximal probability).\n", "\n", "Change your training loop to return the normalized mutual information (NMI) for each epoch. Plot the curve to check that the NMI is actually decreasing." ] }, { "cell_type": "markdown", "metadata": { "id": "GV8yl33sMtzN" }, "source": [ "## C.2 - Robust disentangling with controlled capacity increase\n", "\n", "This problem is explained in [(Burgess et al., 2018)](https://arxiv.org/abs/1804.03599) and a solution is proposed in Section 5.\n", "\n", "In order to force our network to use the categorical variable, we will change the loss according to [(Dupont, 2018)](https://arxiv.org/abs/1804.00104), Section 3 Equation (7).\n", "\n", "Implement this change in the training loop and plot the new NMI curve. Increase the $C_z$ and $C_c$ constants by a constant value every epoch until it reaches the value passed in argument.\n", "For $\\beta = 20, C_z=100, C_c=100$, you should see that NMI increases." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TxVJLo86MtzO" }, "outputs": [], "source": [ "model_G = VAE_Gumbel(h_dim=h_dim, z_dim=z_dim).to(device)\n", "optimizer = torch.optim.Adam(model_G.parameters(), lr=learning_rate)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1R70kiWVMtzU" }, "outputs": [], "source": [ "def train_G_modified_loss(model, data_loader=data_loader,num_epochs=num_epochs, beta=1. , C_z_fin=0, C_c_fin=0, verbose=True):\n", " ### your code here ###\n", " return nmi_scores" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tPNGg4uwSEGY" }, "outputs": [], "source": [ "# Hyper-parameters\n", "num_epochs = 40\n", "learning_rate = 1e-3\n", "beta = 20\n", "C_z_fin = 100\n", "C_c_fin = 100\n", "\n", "model_G = VAE_Gumbel(h_dim=h_dim, z_dim = z_dim).to(device)\n", "optimizer = torch.optim.Adam(model_G.parameters(), lr=learning_rate)\n", "\n", "nmi = train_G_modified_loss(\n", " model_G, data_loader,\n", " num_epochs=num_epochs,\n", " beta=beta,\n", " C_z_fin=C_z_fin,\n", " C_c_fin=C_c_fin,\n", " verbose=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G7jIqNvNsV6L" }, "outputs": [], "source": [ "plt.plot(nmi, '-o')\n", "plt.xlabel(\"epoch\")\n", "plt.ylabel(\"Normalized Mutual Information\")\n", "plt.grid(alpha=.5, which='both')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hMF3LMn1Mtza" }, "outputs": [], "source": [ "plot_reconstruction(model_G)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tR2sn2l7Mtzg" }, "outputs": [], "source": [ "plot_conditional_generation(model_G, fix_number=None)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KoF28npFMtzt" }, "outputs": [], "source": [ "plot_conditional_generation(model_G, fix_number=2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UZzTEAgTsV6N" }, "outputs": [], "source": [ "plot_conditional_generation(model_G, fix_number=8)" ] }, { "cell_type": "markdown", "metadata": { "id": "4g6uOsTwsV6O" }, "source": [ "## C.3 - Interpretation of learned results\n", "\n", "Compare the generated digits from the Conditional VAE model and the modified Gumbel-VAE.\n", "The conditional VAE produces all digits in order, which makes for a nice picture; on the other hand, the Gumbel VAE has no way of knowing the order of digits, it must just guess modes of pictures that look the same, so the order is meaningless.\n", "Note however the variations of straight/inclined ones, are they present on the same row ?\n", "What about shapes from the same row morphing from one digit to another, does it make sense that this was a considered a single mode by the model ? Are there distinct modes corresponding to the same digit, and if so, is it a reasonable distinction ?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VfIal_ECsV6O" }, "outputs": [], "source": [ "plot_conditional_generation(model_C, fix_number=None)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2lr0g2WUsV6P" }, "outputs": [], "source": [ "plot_conditional_generation(model_G, fix_number=None)" ] }, { "cell_type": "markdown", "metadata": { "id": "ggILRCoisV6Q" }, "source": [ "---" ] } ], "metadata": { "colab": { "include_colab_link": true, "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 4 }