In a previous post, we walked through the theory and implementation of the variational autoencoder, which is a probabilistic generative model that combines variational inference and neural networks to model and sample from complex distributions. In this post, we will walk through another such model: the **denoising diffusion probabilistic model**. Diffusion models were originally proposed by Sohl-Dickstein et al. (2015) and later extended by Ho, Jain, and Abbeel (2020).

At the time of this writing, diffusion models are state-of-the-art models used for image generation and have achieved what are, in my opinion, breathtaking results in generating incredibly detailed, realistic images. Below, is an example image generated by DALL·E 3 (via OpenAI’s ChatGPT), which as far as I understand, uses diffusion models as part of its image-generation machinery.

Diffusion models are also being explored in biomedical research. For example, RFDiffusion and Chroma are two methods that use diffusion models to generate novel protein structures. Diffusion models are also being explored for synthetic biomedical data generation.

Because of these models’ incredible performance in image generation, and their burgeoning use-cases in computational biology, I was curious to understand how they work. While I have a relatively good understanding into the theory behind the variational autoencoder, diffusion models presented a bigger challenge as the mathematics is more involved. In this post, I will step through my newfound understanding of diffusion models regarding both their mathematical theory and practical implementation.

Specifically, I will walk through the denoising diffusion probabilistic model (DDPM) as presented by Ho, Jain, and Abbeel (2020). The mathematical derivations are somewhat lengthy and I present them in the Appendix to the post so that they do not distract from the core ideas behind the model. We will conclude by walking through an implementation of a simple diffusion model in PyTorch and apply it to the MNIST dataset of hand-written digits. Hopefully, this post will serve others who are learning this material as well. Please let me know if you find any errors!

Like all probabilistic generative models, diffusion models can be understood as models that specify a probability distribution, $p(\boldsymbol{x})$, over some set of objects of interest where $\boldsymbol{x}$ is a vector representation of one such object. For example, these objects might be images, text documents, or protein sequences. Generating an image via a diffusion model can be viewed as *sampling* from $p(\boldsymbol{x})$:

In training a diffusion model, we fit $p(\boldsymbol{x})$ by fitting a diffusion process. This diffusion process goes as follows: Given a vector $\boldsymbol{x}$ representing an object (e.g., an image), we iteratively add Gaussian noise to $\boldsymbol{x}$ over a series of $T$ timesteps. Let $\boldsymbol{x}_t$ be the object at time step $t$ and let $\boldsymbol{x}_0$ be the original object before noise was added to it. If $\boldsymbol{x}_0$ is an image of my dog Korra, this diffusion process would look like the following:

Here, $q(\boldsymbol{x})$ represents the hypothetical “real world distribution” of objects (which is distinct from the model’s distribution $p(\boldsymbol{x})$, though our goal is to train the model so that $p(\boldsymbol{x})$ resembles $q(\boldsymbol{x})$). Furthermore, if the total number of timesteps $T$ is large enough, then the corrupted object approaches a sample from a standard normal distribution $N(\boldsymbol{0}, \boldsymbol{I})$ – that is, it approaches pure white noise.

Now, the goal of training a diffusion model is to learn how to reverse this diffusion process by iteratively removing noise in the reverse order it was added:

The main idea behind diffusion models is that if our model can remove noise succesfully, then we have a ready-made method for generating new objects. Specifically, we can generate a new object by first sampling noise from $N(\boldsymbol{0}, \boldsymbol{I})$, and then applying our model iteratively, removing noise step-by-step until a new object is formed:

In a sense, the model is “sculpting” an object out of noise bit by bit. It is like a sculptor who starts from an amorphous block of granite and slowly chips away at the rock until a form appears!

Now that we have some high-level intuition, let’s make this more mathematically rigorous. First, the forward diffusion process works as follows: For each timestep, $t$, we will sample noise, $\epsilon$, from a standard normal distribution, and then add it to $\boldsymbol{x}_t$ in order to form the next, noisier object $\boldsymbol{x}_{t+1}$:

\[\begin{align*}\epsilon &\sim N(\boldsymbol{0}, \boldsymbol{1}) \\ \boldsymbol{x}_{t+1} &:= c_1\boldsymbol{x}_t + c_2\epsilon\end{align*}\]where $c_1$ and $c_2$ are two constants (to be defined in more detail later in the post). Note that the above process can also be described as sampling from a normal distribution with a mean specified by $\boldsymbol{x}_t$:

\[\boldsymbol{x}_{t+1} \sim N\left(c_1\boldsymbol{x}_t, c_2^2 \boldsymbol{I}\right)\]Thus, we can view the formation of $\boldsymbol{x}_{t+1}$ as the act of *sampling* from a normal distribution that is conditioned on $\boldsymbol{x}_t$. We will use the notation $q(\boldsymbol{x}_{t+1} \mid \boldsymbol{x}_t)$ to refer to this conditional distribution.

In a similar manner, we can also view the process of removing noise (i.e., reversing a diffusion step) as sampling. Specifically, we can view it as sampling from the *posterior* distribution, $q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)$. To reverse the diffusion process, we start from pure noise and iteratively sample from these posteriors:

Now, how do we derive these posterior distributions? One idea is to use Bayes Theorem:

\[q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t+1}) = \frac{q(\boldsymbol{x}_{t+1} \mid \boldsymbol{x}_t)q(\boldsymbol{x}_{t})}{q(\boldsymbol{x}_{t+1})}\]Unfortunately, this posterior is intractable to compute. Why? First note that in order to compute $q(\boldsymbol{x}_t)$, we have to marginalize over all of the time steps prior to $t$:

\[\begin{align*} q(\boldsymbol{x}_t) &= \int_{\boldsymbol{x}_{t-1},\dots,\boldsymbol{x}_0} q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}, \dots, \boldsymbol{x}_0) \ d\boldsymbol{x}_{t-1}\dots \boldsymbol{x}_{0} \\ &= \int_{\boldsymbol{x}_{t-1},\dots,\boldsymbol{x}_0} q(\boldsymbol{x}_0)\prod_{i=1}^{t} q(\boldsymbol{x}_i \mid \boldsymbol{x}_{i-1}) \ d\boldsymbol{x}_{t-1}\dots \boldsymbol{x}_{0} \end{align*}\]Notice that this marginalization requires that we define a distribution $q(\boldsymbol{x}_0)$, which is a distribution over noiseless objects (e.g., a distribution over noiseless images). Unfortunately, we don’t know what this is – that is our whole purpose of developing a diffusion model!

To get around this problem, we will employ a similar strategy as used in variational inference: We will *approximate* $q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t+1})$ with a surrogate distribution $p_{\theta}(\boldsymbol{x}_t \mid \boldsymbol{x}_{t+1})$. Here, $\theta$ represent a set of learnable parameters that we will be use to fit these distribution as close to each $q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t+1})$ as possible. As we will see later in the post, $p_{\theta}(\boldsymbol{x}_t \mid \boldsymbol{x}_{t+1})$ can incorporate a neural network so that it can represent a distribution complex enough to sucessfully remove noise.

To be more specific, our goal will be to approximate the full diffusion process, which can be represented as a joint distribution over all intermediate noisy objects:

\[q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0) = \prod_{t=1}^T q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t-1})\]where $\boldsymbol{x}_{0:T} = \boldsymbol{x}_0, \boldsymbol{x}_1, \dots, \boldsymbol{x}_T$. We will approximate this joint distribution using another joint distribution that is instead factored by the posterior distributions (i.e., the reverse diffusion steps):

\[p_\theta(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0) = \prod_{t=1}^T p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)\]From this approximation, we will obtain our approximate posterior distributions given by each $p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)$. In short, by fitting $p_\theta(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)$ to $q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)$, we will obtain a set of posterior distributions, $p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)$.

Once we have these approximate posterior distributions in hand, we can generate an object by first sampling white noise $\boldsymbol{x}_T$ from a standard normal distribution $N(\boldsymbol{0}, \boldsymbol{I})$, and then iteratively sampling $\boldsymbol{x}_{t-1}$ from each learned $p_{\theta}(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t})$ distribution. At the end of this process we will have “transformed” the random white noise into an object. More specifically, we will have “sampled” an object!

Let’s come back to the idea that a diffusion model represents a probability distribution over some objects of interest. Here we see that this distribution defined by our diffusion model, $p_{\theta}(\boldsymbol{x}_0)$, is the marginal distribution over all of the intermediate, noisy objects, $\boldsymbol{x}_t$, at each time step $t$ of the diffusion process:

\[\begin{align*}p_{\theta}(\boldsymbol{x}) = \int_{\boldsymbol{x}_0, \dots, \boldsymbol{x}_T} p_{\theta}(\boldsymbol{x}_T) \prod_{t=1}^T p_{\theta}(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t}) \ d\boldsymbol{x}_0 \dots d\boldsymbol{x}_T\end{align*}\]Note this integral is hard to calculate; however, despite this fact, we can still sample from the distribution, and that sampling process is performed via the iterative denoising process we just described.

As stated previously, the forward model is defined as

\[q(\boldsymbol{x}_{t+1} \mid \boldsymbol{x}_t) := \sim N\left(\boldsymbol{x}_{t+1} ; c_1\boldsymbol{x}_t, c_2^2 \boldsymbol{I}\right)\]where $c_1$ and $c_2$ are constants. Let us now define these constants. First, let us define values $\beta_1, \beta_2, \dots, \beta_T \in [0, 1]$. These are $T$ values between zero and one, each corresponding to a timestep. The constants $c_1$ and $c_2$ are simply:

\[\begin{align*}c_1 &:= \sqrt{1-\beta_t} \\ c_2 &:= \beta_t\end{align*}\]Then, the fully-defined forward model at timestep $t$ is:

\[q(\boldsymbol{x}_{t+1} \mid \boldsymbol{x}_t) := N\left(\boldsymbol{x}_{t+1}; \sqrt{1-\beta_t}\boldsymbol{x}_t, \beta_t \boldsymbol{I}\right)\]Here we see that $c_2 := \beta_t$ sets the variance of the noise at timestep $t$. In diffusion models, it is common to predefine a function that returns $\beta_t$ at each timestep. This function is called the **variance schedule**. For example, one might use a linear variance schedule defined as:

where $\text{max}, \text{min} \in [0,1]$ and $\text{min} < \text{max}$ are two small constants. The function above will compute a sequence of $\beta_1, \dots, \beta_T$ that interpolate linearly between $\text{min}$ and $\text{max}$. Note, the specific variance schedule that one uses is a modeling design choice. Instead of a linear variance schedule, such as the one shown above, one may opt for another one. For example, Nichol and Dhariwal (2021) suggest replacing a linear variance schedule with a cosine variance schedule (which we won’t discuss here).

This begs the question: Why use a different value of $\beta_t$ at each time step? Why not set $\beta_t$ constant across timesteps? The answer is that, empirically, if $\beta_t$ is large, then the object will turn to noise too quickly and wash away the structure of the object too early in the process thereby making it challenging to learn $p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)$. In contrast, if $\beta$ is very small, then each step only removes a very small amount of noise, and thus, to turn an object $\boldsymbol{x}_0$ into white noise (and back via reverse diffusion), we would require many timesteps (which as we will see, would lead to inefficient training of the model).

A solution that balances the need to maintain the object’s structure while keeping the number of timesteps relatively short, is to increase the variance at each timestep according to a set schedule so that at the beginning of the diffusion process, only a little bit of noise is added at a time, but towards the end of the process, more noise is added at a time to ensure that $\boldsymbol{x}_T$ approaches a sample from $N(\boldsymbol{0}, \boldsymbol{I})$ (i.e., it becomes pure noise). This is illustrated in the figure below:

Now that we have a better understanding of the second constant (i.e., $c_2 := \beta_t$), which scales the variance, let’s turn our attention to the first constant, $c_1 := \sqrt{1-\beta_t}$, which scales the mean. Why are we scaling the mean with this constant? Doesn’t it make more sense to simply center the mean of the forward noise distribution at $\boldsymbol{x}_t$?

The reason for this term is that it makes sure that the variance of the noise does not increase, but rather equals one. That is, $\sqrt{1-\beta}$, is precisely the value required to scale the mean of the forward diffusion process distribution, $q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1})$ such that $\text{Var}(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}) = 1$. See Derivation 3 in the Appendix to this post for a proof. Recall, our goal is to transform $\boldsymbol{x}_0$ into white noise distributed by a standard normal distribution (which has a variance of 1), and thus, we cannot have the various grow at each timestep. Below we depict a forward diffusion process on 1-dimensional data using two strategies: the first does not scale the mean and the second does. Notice that the variance continues to grow when we don’t scale the mean, but it remains fixed when we scale the mean by $\sqrt{1-\beta}$:

Before we conclude this section, we will also prove a few convenient properties of the forward model that will be useful for deriving the final objective function used to train diffusion models:

1. **$q(\boldsymbol{x}_t \mid \boldsymbol{x}_0)$ has a closed form.** That is, the distribution over a noisy object at timestep $t$ of the diffusion process has a closed form solution. That solution is specifically the following normal distribution (See Derivation 4 in the Appendix to this post):

where $\alpha_t := 1-\beta$ and $\bar{\alpha}_t := \prod_{i=1}^t \alpha_t$ (this notation is used in the original paper by Ho, Jain, and Abbeel (2020) and makes the equations going forward easier to read). This is depicted schematically below:

Note that because $q(\boldsymbol{x}_t \mid \boldsymbol{x}_0)$ is simply a normal distribution, this enables us to sample noisy images at any arbitrary timestep $t$ without having to run the full diffusion process for $t$ timesteps. That is, instead of having to sample from $t$ normal distributions, which is what would be required to run the forward diffusion process to timestep $t$, we can instead sample from one distribution. As we will show, this will enable us to speed up the training of the model.

2. **$q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)$ has a closed form.** Note that we previously discussed how the conditional distribution, $q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)$ was intractible to compute.

However, it turns out that if instead of only conditioning $\boldsymbol{x}_t$, we also condition on the original, noiseless object, $\boldsymbol{x}_0$, we *can* derive a closed form for this posterior distribution. That distribution is a normal distribution (See Derivations 5 and 6 in the Appendix to this post):

Depicted schematically:

The fact that this posterior has a closed form when conditioning on $\boldsymbol{x}_0$ makes intuitive sense: as we talked about previously, the posterior distribution $q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)$ requires knowing $q(\boldsymbol{x}_0)$ – that is, in order to turn noise into an object, we need to know what real, noiseless objects look like. However, if we condition on $\boldsymbol{x}_0$, this means we are assuming we *know* what $\boldsymbol{x}_0$ looks like and the modified posterior, $q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)$, needs only to take into account subtraction of noise towards this noiseless object.

Diffusion models use variational inference to fit $p_\theta(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)$ to $q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)$. Recall, in variational inference, our goal is to approximate some unknown distribution $q$, with an approximate distribution $p$ by minimizing the KL-divergence from $p$ to $q$:

\[\hat{p} := \text{arg min}_p \ KL(q \ \vert\vert \ p)\]*Note, in accordance with the literature, we use $p$ to denote the approximate distribution and $q$ to denote the exact distribution. However, in my prior blog post on variational inference, I use $q$ to denote the approximate distribution and $p$ to denote the exact distribution. My apologies for this confusion!*

In our case, we wish to learn the reverse diffusion process from the forward diffusion process, so we start with the following objective function:

\[\hat{\theta} := \text{arg min}_\theta \ KL( q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0) \ \vert\vert \ p_\theta(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0))\]Now, let’s derive a more intuitive form of this objective function. Recall from our discussion on variational inference that minimizing this KL-divergence objective can be accomplished by maximizing another quantity called the evidence lower bound (ELBO), which is a function of the parameters $\theta$. For a more in-depth discussion of the ELBO, see my previous blog post. In the case of diffusion models, this ELBO looks as follows (See Derivation 1 in the Appendix to this post):

\[KL( q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0) \ \vert\vert \ p_\theta(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)) = \log p_\theta(\boldsymbol{x}) - \underbrace{E_{\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0 \sim q}\left[ \log\frac{p_\theta (\boldsymbol{x}_{0:T}) }{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0) } \right]}_{\text{ELBO}}\]Thus, we seek:

\[\begin{align*}\hat{\theta} &:= \text{arg max}_\theta \ \text{ELBO}(\theta) \\ &= \text{arg max}_\theta \ E_{\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0 \sim q}\left[ \log\frac{p_\theta (\boldsymbol{x}_{0:T}) }{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0) } \right]\end{align*}\]Let’s now examine the ELBO more closely. It turns out that this ELBO can be further manipulated into a form that has a term for each step of the diffusion process (See Derivation 2 in the Appendix to this post):

\[\begin{align*}\text{ELBO}(\theta) &= E_{\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0 \sim q}\left[ \log \frac{ p_\theta (\boldsymbol{x}_{0:T}) }{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)} \right] \\ &= \underbrace{E_{\boldsymbol{x}_1 \mid \boldsymbol{x}_0 \sim q} \left[ p_\theta(\boldsymbol{x}_0 \mid \boldsymbol{x}_1) \right]}_{L_0} + \underbrace{\sum_{t=2}^T \left[ E_{\boldsymbol{x}_t \mid \boldsymbol{x}_0 \sim q} KL \left( q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0) \ \vert\vert \ p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) \right) \right]}_{L_1, L_2, \dots, L_{T-1}} + \underbrace{KL\left( q(\boldsymbol{x}_T \mid \boldsymbol{x}_0) \ \vert\vert \ p_\theta(\boldsymbol{x}_T) \right)}_{L_T}\end{align*}\]These terms are broken into three cagegories:

- $L_0$ is the probability the model gives the data conditioned on the very first diffusion step. In the reverse diffusion process, this is the last step required to transform the noise into the original image. This term is called the
**reconstruction term**because it provides high probility if the model can succesfully predict the original noiseless image $\boldsymbol{x}_0$ from $\boldsymbol{x}_1$, which is the result of the first iteration of the diffusion process. - $L_1, \dots, L_{T-1}$ are terms that measure how well the model is performing reverse diffusion. That is, it asking how well the posterior probabilities specified by the model, $p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)$, match the posterior probabilities $q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)$.
- $L_T$ simply measures how well the result of the noisy diffusion process, which theoretically approaches a normal distribution, matches the noise distribution from which we seed the reverse diffusion process, which in our case, we define to be a normal distribution.

By breaking up the ELBO into these terms, we can simplify it into a closed form expression. Let’s start with the last term $L_T$. Recall that we define $p_\theta(\boldsymbol{x}_T)$ to be a standard normal distribution that does not incorporate the model parameters. That is,

\[p_\theta(\boldsymbol{x}_T) := N(\boldsymbol{x}_T; \boldsymbol{0}, \boldsymbol{I})\]Thus we see that the last term, $L_T$, does not depend on the model parameters, we can ignore this term when maximizing the ELBO. Thus, our task will be to find:

\[\hat{\theta} := \text{arg max}_\theta \ \underbrace{E_{\boldsymbol{x}_1 \sim q} \left[ p_\theta(\boldsymbol{x}_0 \mid \boldsymbol{x}_1) \right]}_{L_0} + \underbrace{\sum_{t=2}^T \left[ E_{\boldsymbol{x}_t \sim q} KL \left( q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0) \ \vert\vert \ p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) \right) \right]}_{L_1, L_2, \dots, L_{T-1}}\]Now, let’s turn to the middle terms $L_1, \dots, L_{T-1}$. Here we see that these terms require calculating KL-divergences from $p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)$ to $q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)$. Recall from the previous sections that both of these distributions are normal distributions. That is,

\[\begin{align*}L_t := KL\left(N(A; B, C) \ \vert\vert \ N(A; B, C)\right)\end{align*}\]We now use the following fact: Given two normal distributions

\[\begin{align*}P \:= N(\mu_1, \sigma^2_1) \\ Q \:= N(\mu_2, \sigma^2_2) \end{align*}\]it follows that

\[KL(P \ \vert\vert Q) = XXXXXXXXXX\]Applying this fact to $L_t$, we see that,

\[\begin{align*}L_t := KL\left(N(A; B, C) \ \vert\vert \ N(A; B, C)\right)\end{align*}\]While this idea of learning a denoising model that reverses a diffusion process may be intuitive at a high-level, one may be wanting for a more rigorous theoretical justification for this framework. That is, what is the justification for fitting $p_\theta(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)$ to $q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)$? And moreover, if we are interested in generating realistic objects – that is, samples of $\boldsymbol{x}_0$ – then why do these two distributions *condition* on $\boldsymbol{x}_0$?

I’ve found four perspectives from which to understand the theoretical justification behind these models:

- As implicitly learning to fit $q(\boldsymbol{x}_0)$
- As breaking up a difficult problem into many easier problems
- As maximum-likelihood estimation
- As score-matching

The first of two of these perspectives are less rigorous, but provides some high-level intuition. The second two are more rigorous. Let’s dig in.

We can gain some high-level intuition into why this method of learning to reverse diffusion will lead us to a distribution $p_\theta(\boldsymbol{x}_0)$ that resembles $q(\boldsymbol{x}_0)$ by looking again at the posterior distribution:

\[\begin{align*}q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) &= \frac{q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1})q(\boldsymbol{x}_{t-1})}{q(\boldsymbol{x}_t)} \\ &= \frac{q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1})q(\boldsymbol{x}_{t-1})}{\int_{\boldsymbol{x}_{t-1},\dots,\boldsymbol{x}_0} q(\boldsymbol{x}_0)\prod_{i=1}^{t} q(\boldsymbol{x}_i \mid \boldsymbol{x}_{i-1}) \ d\boldsymbol{x}_{t-1}\dots \boldsymbol{x}_{0}}\end{align*}\]Again, notice how this distribution requires knowing $q(\boldsymbol{x}_0)$. This makes intuitive sense: in order to transform pure noise, $\boldsymbol{x}_T$ to a “sharp”, noiseless object $\boldsymbol{x}_0$, we need to know what real objects look like! Now, in an attempt to fit $p_\theta(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)$ to $q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)$, it follows that $p_{\theta}(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t})$ will need to match $q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t})$. This very act of learning to approximate $q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t})$ using a surrogate distribution $p_{\theta}(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t})$ will, in an implicit way, learn about the distribution $q(\boldsymbol{x}_0)$! Said differently, $p_{\theta}(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t})$ *must* learn about $q(\boldsymbol{x}_0)$ in order to approximate $q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t})$ effectively.

Another reason why diffusion models tend to perform better than other methods, such as variational autoencoders, is that diffusion models break up a difficult problem into a series of easier problems. That is, unlike variational autoencoders, where we train a model to produce an object all at once, in diffusion models, we train the model to produce the object step-by-step. Intuitively, we train a model to “sculpt” an object out of noise in a step-wise fashion rather than generate the object in one fell-swoop.

This step-wise approach is advantageous because it enables the model to learn features of objects at different levels of resolution. At the end of the reverse diffusion process (i.e., the sampling process), the model identifies broad, vague features of an object within the noise. At later steps of the reverse diffusion process, it fills in smaller details of the object by removing the last remaining noise.

Recall our goal was to fit $p_\theta(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)$ to $q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)$ by minimizing their KL-divergence, which as we showed, could be accomplished implicitly by maximizing the ELBO:

\[\begin{align*}\hat{\theta} &:= \text{arg max}_\theta \ \text{ELBO}(\theta) \\ &= \text{arg max}_\theta \ E_{\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0 \sim q}\left[ \log\frac{p_\theta (\boldsymbol{x}_{0:T}) }{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0) } \right]\end{align*}\]Notice too that if we maximize the ELBO, we not only minimize the KL-divergence, but we also implicitly maximize a lower bound of the log-likelihood, $\log p_\theta(\boldsymbol{x})$. That is, we see that

\[\begin{align*} \log p_\theta(\boldsymbol{x}) &= KL( q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0) \ \vert\vert \ p_\theta(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)) + \underbrace{E_{\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0 \sim q} \left[ \log\frac{p_\theta (\boldsymbol{x}_{0:T}) }{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0) } \right]}_{\text{ELBO}} \\ &\geq \underbrace{E_{\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0 \sim q}\left[ \log\frac{p_\theta (\boldsymbol{x}_{0:T}) }{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0) } \right]}_{\text{ELBO}} \ \ \text{Because KL-divergence is non-negative} \end{align*}\]This idea is depicted schematically below (this figure is adapted from this blog post by Jakub Tomczak):

Here, $\theta^*$ represents the maximum likelihood estimate of $\theta$ and $\hat{\theta}$ represents the value for $\theta$ that maximizes the ELBO. If this lower-bound is tight, $\hat{\theta}$ will be close to $\hat{\theta}$. Although in most cases, it is difficult to know with certainty how tight this lower bound is, in practice, this strategy of maximizing the ELBO leads to good results at estimating $\theta^*$.

Another motivation for diffusion models lies in their connection to score matching models. While we will not go into great depth in this blog post (we will merely touch upon it), as it turns out, we will work out a form of the ELBO that can be viewed as an objective function that estimates the *score function* of the true, real-world distribution $q(\boldsymbol{x}_0))$.

As a brief review, the *score function*, $s(\boldsymbol{x})$, of the distribution $q(\boldsymbol{x}))$ is simply,

$s_q(\boldsymbol{x}) := \nabla_{\boldsymbol{x}} \log q(\boldsymbol{x})$

That is, it is the gradient of the log-density function, $q(\boldsymbol{x})$, with respect to the data. Below, we depict a hypothetical density function, $q(\boldsymbol{x})$, and the vector field defined by $\nabla_{\boldsymbol{x}} \log q(\boldsymbol{x})$ below it:

Stated more succintly, by maximizing the ELBO with respect to $\theta$ (that is, a lower bound of the log-likelihood), we are also implicitly fitting an estimated score function $s_\theta(\boldsymbol{x})$ to the real score function $s_q(\boldsymbol{x})$. We will make this connection more explicit later in the blog post.

Finally, it will turn out that we can view the process of reversing the diffusion process to sample from $p_\theta(\boldsymbo{x}_0)$ as a variant of [sampling via Langevin dynamics] – a stochastic method that enables one to sample from an arbitrary distribution by following the gradients defined by the score function.

Now that we have previewed the theoretical foundation behind diffusion models, let’s now dig into the specifics of the model and see how diffusion models implement these various strategies of estimation.

In this section, we will walk through a relatively simple implementation of a diffusion model in PyTorch and apply it to the MNIST dataset of hand-written digits. I used the following GitHub repositories as guides:

- https://github.com/hojonathanho/diffusion
- https://github.com/cloneofsimo/minDiffusion
- https://github.com/bot66/MNISTDiffusion/tree/main
- https://github.com/usuyama/pytorch-unet

My goal was to implement a small model (both small in complexity and size) that would generate realistic digits. In the following sections, I will detail each component and show some of the model’s outputs. All code implementing the model can be found on Google Colab.

**Using a U-Net with ResNet blocks to predict the noise**

For the noise-model, I used a U-Net with ResNet-like convolutional blocks – that is, convolutional layers with skip-connection between them. This architecture is depicted below:

Code for my U-Net implementation are found in the Appendix to this blog post as well as on Google Colab.

**Representing the timestep using a time-embedding**

As we discussed, the noise model conditions on the timestep, $t$. Thus, we need a way for the neural network to
take as input, and utilize, the timestep. To do this, Ho, Jain, and Abbeel (2020) borrowed an idea from the transformer model original conceived by Vaswani *et al.* (2023). Specifically, each timestep is mapped to a specific, sinusoidal *embedding* vector and this vector is added, element-wise to certain layers of the neural network. The code for generating these embeddings is presented in the Appendix to this post. A heatmap depicting these embeddings is shown below:

Recall that at every iteration of the training loop, we sample some objects in the training set (a minibatch) and sample a timestep for each object. Below, we depict a single timestep embedding for a given timestep $t$. The U-Net implementation takes this time embedding, passes it through a feed-forward neural network, re-shapes the vector into a tensor, and then adds it to the input of the up-sampling blocks. This process is depicted below:

**Example outputs from the model**

Once we’ve trained the model and implemented the sampling algorithm, we can generate new MNIST digits! (See Appendix for the code used to generate new images). Below, is an example of the model generating a “3”. As we examine the image across timesteps of the reverse diffusion process, we see it being sucessfully transformed from noise into a clear image!

Here is a sample of hand-selected images of digits output by the model:

The model also output many nonsensical images. While this may not be desirable, I find it interesting that the model honed in on patterns that are “digit-like”. These not-quite digits look like symbols from an alien language:

A better model may output fewer of these nonsensical “digits”; however, I think this demonstrates how these generative models can be used for creative tasks. That is, the model succesfully modeled “digit-like patterns”, which in some cases led it to producing nonsensical digits that still look visually interesting (well, interesting to me at least). It did this by assembling these digit-like patterns in new, interesting ways.

Much of my understanding of this material came from the following resources:

- These lecture notes by David I. Inouye
- This blog post by Lilian Weng
- This blog post by Param Hanji
- This blog post by Angus Turner
- This blog post by Yang Song
- This YouTube lecture at UC, Berkeley
- This YouTube lecture by Dominic Rampas

**Note 1:** By the Markov property of the forward diffusion process, it holds that $q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}) = q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}, \boldsymbol{x}_0)$.

**Note 2:** Apply Bayes theorem:

**Note 3:**

**Note 4:**
\(\begin{align*}E_{\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0 \sim q}\left[\log \frac{p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) }{ q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)} \right] &= E_{\boldsymbol{x}_{t}, \boldsymbol{x}_{t-1} \mid \boldsymbol{x}_0 \sim q}\left[\log \frac{p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) }{ q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)} \right] \\ &= \iint q(\boldsymbol{x}_{t-1}, \boldsymbol{x}_t \mid \boldsymbol{x}_0) \log \frac{p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) }{ q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)} \ d\boldsymbol{x}_{t-1} d\boldsymbol{x}_t \\ &= \iint q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0) q(\boldsymbol{x}_t \mid \boldsymbol{x}_0) \log \frac{p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) }{ q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)} \ d\boldsymbol{x}_{t-1} d\boldsymbol{x}_t \\ &= \int q(\boldsymbol{x}_t \mid \boldsymbol{x}_0) \left[ \int q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0) \log \frac{p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) }{ q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)} \ d\boldsymbol{x}_{t-1} \right] d\boldsymbol{x}_t \\ &= \int q(\boldsymbol{x}_t \mid \boldsymbol{x}_0) KL(q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0) \ \vert\vert \ p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)) d\boldsymbol{x}_t \\ &= E_{\boldsymbol{x}_t \mid \boldsymbol{x}_0 \sim q} \left[ KL(q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0) \ \vert\vert \ p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)) \right]\end{align*}\)

Let,

\[\boldsymbol{x}_t \sim N(\boldsymbol{\mu}, \boldsymbol{I})\]for some mean $\boldsymbol{\mu}$. For the next timestep, we have

\[\boldsymbol{x}_{t+1} \sim N(a\boldsymbol{x}_t, \beta \boldsymbol{I})\]where $a$ is some constant that scales the mean given by $\boldsymbol{x}_t$. We seek a value of $a$ such that $\text{Var}(\boldsymbol{x}_{t+1}) = 1$. To find this value, we use the law of total variance:

\[\begin{align*}\text{Var}(\boldsymbol{x}_{t+1}) &= E\left[\text{Var}(\boldsymbol{x}_{t+1} \mid \boldsymbol{x}_t ) \right] + \text{Var}\left( E\left[\boldsymbol{x}_{t+1} \mid \boldsymbol{x}_t \right]\right) \\ &= E[\beta] + \text{Var}(a\boldsymbol{x}_t) \\ &= \beta + a^2\text{Var}(\boldsymbol{x}_t) \\ &= \beta + a^2\text{Var}(\boldsymbol{x}_t) \\ &= \beta + a^2\end{align*}\]Now, if we fix $\text{Var}(\boldsymbol{x}_{t+1}) = 1$, it follows that:

\[\begin{align*}&1 = \beta + a^2 \\ \implies &a = \sqrt{1-\beta}\end{align*}\]We start with $q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1})$. Recall it is given by,

\[q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}) := N(\boldsymbol{x}_t; \sqrt{1-\beta_t}\boldsymbol{x}_{t-1}, \beta \boldsymbol{I})\]Because this is a normal distribution, we can generate a sample

\[\boldsymbol{x}_t \sim q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1})\]by first sample $\epsilon_{t-1}$ from a standard normal, $N(\boldsymbol{0}, \boldsymbol{I})$, and then transforming it into $\boldsymbol{x}_t$ via,

\[\begin{align*}\boldsymbol{x}_t &= \sqrt{1-\beta_t}\boldsymbol{x}_{t-1} + \sqrt{\beta_t}\epsilon_{t-1} \\ &= \sqrt{\alpha_t}\boldsymbol{x}_{t-1} + \sqrt{1-\alpha_t}\epsilon_{t-1} \end{align*}\]where $\alpha_t := 1 - \beta_t$ (which will make the notation easier going forward).

Notice that this transformation relies on $\boldsymbol{x}_{t-1}$, which is a sample from $q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t-2})$. From this observation, we realize there is a way to sample $\boldsymbol{x}_t$ not from $q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1})$, but rather from $q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-2})$. Specifically, we can generate *two* samples from a standard normal distribution,

Then, we can generate a sample

\[\boldsymbol{x}_t \sim q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-2})\]via the following transformation of $\epsilon_{t-1}$ and $\epsilon_{t-2}$:

\[\begin{align*}\boldsymbol{x}_t &= \sqrt{\alpha_t}\boldsymbol{x}_{t-1} + \sqrt{1-\alpha_t}\epsilon_{t-1} \\ &= \sqrt{\alpha_t}\left(\sqrt{\alpha_{t-1}}\boldsymbol{x}_{t-2} + \sqrt{1-\alpha_{t-1}}\epsilon_{t-2}\right) + \sqrt{1-\alpha_t}\epsilon_{t-1} \\ &= \sqrt{\alpha_t}\sqrt{\alpha_{t-1}}\boldsymbol{x}_{t-2} + \sqrt{\alpha_t}\sqrt{1-\alpha_{t-1}} \epsilon_{t-1} + \sqrt{1-\alpha_{t}}\epsilon_t \\ &=\sqrt{\alpha_t}\sqrt{\alpha_{t-1}}\boldsymbol{x}_{t-2} + (\sqrt{\alpha_t}\sqrt{1-\alpha_{t-1}} + \sqrt{1-\alpha_{t}})\epsilon_{t, t-1} \\ &= \sqrt{\alpha_t\alpha_{t-1}}\boldsymbol{x}_{t-2} + (\sqrt{1-\alpha_t \alpha_{t-1}})\epsilon_{t, t-1}\end{align*}\]where $\epsilon_{t, t-1}$ is a sample of $N(\boldsymbol{0}, \boldsymbol{I})$. Here, we used the fact that if we have two random variables $X$ and $Y$ such that,

\[\begin{align*}X &\sim N(0, \sigma_X^2) \\ Y &\sim N(0, \sigma_Y^2) \end{align*}\]Then it follows that,

\[X + Y \sim N(0, \sigma_X^2 + \sigma_Y^2)\]though we won’t prove this fact here.

Now, following the same logic above, we can generate a sample,

\[\boldsymbol{x}_t \sim q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-3})\]via

\[\boldsymbol{x}_t = \sqrt{\alpha_t\alpha_{t-1}\alpha_{t-2}}\boldsymbol{x}_{t-3} + (\sqrt{1-\alpha_t \alpha_{t-1} \alpha_{t-2}})\epsilon_{t, t-1, t-2}\]where $\epsilon_{t, t-1, t-2} \sim N(\boldsymbol{0}, \boldsymbol{I})$. If we follow this pattern all the way down to $t=0$, we see that we can generate a sample,

\[\boldsymbol{x}_t \sim q(\boldsymbol{x}_t \mid \boldsymbol{x}_0)\]via

\[\begin{align*}\boldsymbol{x}_t &= \sqrt{\prod_{i=1}^t \alpha_i}\boldsymbol{x}_0 + \sqrt{1-\prod_{i=1}^t \alpha_i}\epsilon_{t, t-1, \dots, 0} \\ &= \sqrt{\bar{\alpha_t}}\boldsymbol{x}_0 + \sqrt{\bar{\alpha_t}}\epsilon_{t, t-1, \dots, 0}\end{align*}\]where $\bar{\alpha_t} := \prod_{i=1}^t \alpha_i$. Thus, we see that,

\[q(\boldsymbol{x}_t \mid \boldsymbol{x}_0) = N\left(\boldsymbol{x}_t; \sqrt{\bar{\alpha}_t}\boldsymbol{x}_0, \left(1-\bar{\alpha}_t \right) \boldsymbol{I}\right)\]This is the functional form of the density function of a normal distribution. Thus, we see that,

\[q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0) = N\left(\boldsymbol{x}_{t-1}; \frac{\sqrt{\alpha_t} \left( 1 - \bar{\alpha}_{t-1} \right) }{\beta_t} \boldsymbol{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_{1-t}}\boldsymbol{x}_0, \frac{\beta_t\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t}\boldsymbol{I} \right)\]**Note 1:** Apply Bayes Theorem

**Note 2:** Throughout this derivation, we will only consider terms that contain $\boldsymbol{x}_{t-1}$.

**Note 3:** Here we remove terms that do not involve $\boldsymbol{x}_{t-1}$ by using the following fact: given a term, $f(\boldsymbol{x}_{t-1})$, and a constant term, $C$, it follows that:

**Note 4:** Here we complete the square and use the fact that:

In our case,

\[\begin{align*}a &:= \frac{\alpha_t}{\beta_t} + \frac{1}{1-\bar{\alpha}_{t-1}} \\ b &:= \frac{2\sqrt{\alpha_t}\boldsymbol{x}_t }{\beta_t} + \frac{2 \sqrt{\bar{\alpha}_{t-1}}\boldsymbol{x}_0 }{1 - \bar{\alpha}_{t-1}} \end{align*}\]Note that we can disgregard the term, $\left(c - \frac{b^2}{4a}\right)$, since this is a constant with respect to $\boldsymbol{x}_{t-1}$ and it gets “swallowed” by the $\propto$ as described in Note 3.

Moreover, after completing the square, we see that this is the functional form of a normal distribution where we have annotated the mean, $\mu$, and reciprocal of the variance, $1 / \sigma^2$.

**Note 5:**

**Note 6:**

In this section, we will walk through all of the code used to implement a diffusion model. The full code can be run on Google Colab. We will start with importing the required packages:

```
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
```

Next, we implement the U-Net neural network, which implements the noise model – that is, the model used to predict the noise in an image that has undergone diffusion. To implement the U-Net, we define three subclasses: a `UNetDownBlock`

class, which represents a set of layers on the downward portion of the U-Net, a `UNetUpBlock`

class, which represents a set of layers on the upward portion of the U-Net, and a `UNet`

class, which represents the full neural network:

```
class UNetDownBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
pad_maxpool=0,
normgroups=4
):
super(UNetDownBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
padding=1
)
self.groupnorm1 = nn.GroupNorm(
normgroups,
out_channels
)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
padding=1
)
self.groupnorm2 = nn.GroupNorm(
normgroups,
out_channels
)
self.relu2 = nn.ReLU()
self.conv3 = nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
padding=1
)
self.groupnorm3 = nn.GroupNorm(
normgroups,
out_channels
)
self.relu3 = nn.ReLU()
self.maxpool = nn.MaxPool2d(
2, padding=pad_maxpool
)
def forward(self, x):
# First convolution
x = self.conv1(x)
x = self.groupnorm1(x)
x_for_skip = self.relu1(x)
# Second convolution
x = self.conv2(x_for_skip)
x = self.groupnorm2(x)
x = self.relu2(x)
x = self.conv3(x)
x = self.groupnorm3(x)
# Skip connection
x = x + x_for_skip
x = self.relu3(x)
x = self.maxpool(x)
return x
class UNetUpBlock(nn.Module):
def __init__(self, in_channels, out_channels, time_dim, normgroups=4):
super(UNetUpBlock, self).__init__()
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
# Convolution 1
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, padding=1
)
self.groupnorm1 = nn.GroupNorm(normgroups, out_channels)
self.relu1 = nn.ReLU()
# Convolution 2
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, padding=1
)
self.groupnorm2 = nn.GroupNorm(normgroups, out_channels)
self.relu2 = nn.ReLU()
# Convolution 3
self.conv3 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, padding=1
)
self.groupnorm3 = nn.GroupNorm(normgroups, out_channels)
self.relu3 = nn.ReLU()
# Parameters to scale and shift the time embedding
self.time_mlp = nn.Linear(time_dim, time_dim)
self.time_relu = nn.ReLU()
def forward(self, x, x_down, t_embed):
x_up = self.upsample(x)
#print("x_up: ", x_up.shape)
x = torch.cat([x_down, x_up], dim=1)
# Cut embedding to be the size of the current channels
t_embed = t_embed[:,:x.shape[1]]
# Enable the neural network to modify the time-embedding
# as it needs to
t_embed = self.time_mlp(t_embed)
t_embed = self.time_relu(t_embed)
t_embed = t_embed[:,:,None,None].expand(x.shape)
# Add time-embedding to input.
x = x + t_embed
# Convolution 1
x = self.conv1(x)
x = self.groupnorm1(x)
x_for_skip = self.relu1(x)
# Convolution 2
x = self.conv2(x_for_skip)
x = self.groupnorm2(x)
x = self.relu2(x)
# Convolution 3
x = self.conv3(x)
x = self.groupnorm3(x)
# Skip connection
x = x + x_for_skip
x = self.relu3(x)
return x
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
# Down blocks
self.down1 = UNetDownBlock(1, 4, normgroups=1)
self.down2 = UNetDownBlock(4, 10, normgroups=1)
self.down3 = UNetDownBlock(10, 20, normgroups=2)
self.down4 = UNetDownBlock(20, 40, normgroups=4)
# Convolutional layer at the bottom of the U-Net
self.bottom_conv = nn.Conv2d(
40, 40, kernel_size=3, padding=1
)
self.bottom_groupnorm = nn.GroupNorm(4, 40)
self.bottom_relu = nn.ReLU()
# Up blocks
self.up1 = UNetUpBlock(60, 20, 60, normgroups=2) # down4 channels + down3 channels
self.up2 = UNetUpBlock(30, 10, 30, normgroups=1) # down2 channels + up1 channels
self.up3 = UNetUpBlock(14, 5, 14, normgroups=1) # down1 channels + up2 channels
self.up4 = UNetUpBlock(6, 5, 6, normgroups=1) # input channels + up3 channels
# Final convolution to produce output. This layer injects negative
# values into the output.
self.final_conv = nn.Conv2d(
5, 1, kernel_size=3, padding=1
)
def forward(self, x, t_emb):
"""
Parameters
----------
x: Input tensor representing an image
t_embed: The time-embedding vector for the current timestep
"""
# Pad the input so that it is 32x32. This enables downsampling to
# 16x16, then to 8x8, and finally to 4x4 at the bottom of the "U"
x = F.pad(x, (2,2,2,2), 'constant', 0)
# Down-blocks of the U-Net compress the image down to a smaller
# representation
x_d1 = self.down1(x)
x_d2 = self.down2(x_d1)
x_d3 = self.down3(x_d2)
x_d4 = self.down4(x_d3)
# Bottom layer perform final transformation on compressed representation
# before re-inflation
x_bottom = self.bottom_conv(x_d4)
x_bottom = self.bottom_groupnorm(x_d4)
x_bottom = self.bottom_relu(x_d4)
# Up-blocks re-inflate the compressed representation back to the original
# image size while taking as input various representations produced in the
# down-sampling steps
x_u1 = self.up1(x_bottom, x_d3, t_emb)
x_u2 = self.up2(x_u1, x_d2, t_emb)
x_u3 = self.up3(x_u2, x_d1, t_emb)
x_u4 = self.up4(x_u3, x, t_emb)
# Final convolutional layer. Introduces negative values.
x_u4 = self.final_conv(x_u4)
# Remove initial pads to produce a 28x28 MNIST digit
x_u4 = x_u4[:,:,2:-2,2:-2]
return x_u4
```

Next, we will implement a function that will generate the timestep embeddings. Below is an adaptation of the time embedding function by Ho, Jain, and Abbel from their GitHub repository, https://github.com/hojonathanho/diffusion. This code was adapted from TensorFlow to PyTorch:

```
def get_timestep_embedding(timesteps, embedding_dim):
"""
Translated from Tensorflow to PyTorch by the original Diffusion implementation
by Ho et al. in https://github.com/hojonathanho/diffusion
"""
assert len(timesteps.shape) == 1 # and timesteps.dtype == torch.int32
half_dim = embedding_dim // 2
emb = np.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = timesteps[:, None].to(torch.float32) * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb
```

This code is adapted from TensorFlow to PyTorch. The function accepts two integers: the number of timesteps (i.e., $T$) and the embedding dimension. Similar to Ho, Jain, and Abbeel, I used 1,000 timesteps (as we will see in the code that follows). In my model, the largest feature vector associated with each pixel (corresponding to the number of channels in the convolutional layer at the very bottom of the U-Net) is 60, so the embedding dimension would be 60. This function returns a matrix with number of rows equal to the $T$ and number of columns equal to the number of dimensions in the embedding.

Next, we will write a function that will produce a linear variance schedule. Given a minimum variance, maximum variance, and number of timesteps, it will create a linear interpolation between the max and min over the given number of timesteps:

```
def linear_variance_schedule(min: float, max: float, T: int):
"""
min: minimum value for beta
max: maximum value for beta
T: number of timesteps
"""
betas = torch.arange(0, T) / T
betas *= max - min
betas += min
return betas
```

Now that we have defined our UNet model and functions for generating timestep embeddings and the variance schedule, let’s begin to construct and train the model. We will start by setting our parameters for the training process. We train the model for 300 epochs using a minibatch size of 128. We use a linear variance schedule starting spanning from a minimal variance of 1e-4 to a maximum variance of 0.02 as per https://github.com/cloneofsimo/minDiffusion. Specifically, the variables for storing these parameters are shown below:

```
# Parameters
EPOCHS = 300
T = 1000
LEARNING_RATE = 1e-4
BATCH_SIZE = 128
MIN_VARIANCE = 1e-4
MAX_VARIANCE = 0.02
DEVICE = 'cuda'
```

Next, let’s load the data. We will use PyTorch’s built-in functionality for loading the MNIST digits data. Note, this implementation *centers* the pixel values around zero (via the provided `transforms.Normalize((0.5), (0.5))`

transformation to the `DataLoader`

). That is, the raw MNIST data provides pixel values spanning from 0 to 1; however, this code centers the data so that it spans -1 to 1 and is centered at zero. This follows the implementation provided by https://github.com/cloneofsimo/minDiffusion.

```
# Load dataset
dataset = MNIST(
"./data",
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5), (0.5)) # By subtracting 0.5, we center the data
])
)
dataloader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=1
)
```

Finally, let’s put this all together and train a mdoel. The code below instantiates the variance schedule, time embeddings, and UNet model and then implements the training loop. The code is heavily commented for pedagogical purposes:

```
# Compute variance schedule
betas = linear_variance_schedule(MIN_VARIANCE, MAX_VARIANCE, T).to(device)
# Compute constants based on variance schedule
alphas = 1 - betas
onemalphas = 1 - alphas
alpha_bar = torch.exp(torch.cumsum(torch.log(alphas), dim=0))
sqrt_alphabar = torch.sqrt(alpha_bar)
onemalphabar = 1-alpha_bar
sqrt_1malphabar = torch.sqrt(1-alpha_bar)
# Instantiate the noise model, loss function, and optimizer
noise_model = UNet().to(device)
optimizer = optim.Adam(noise_model.parameters(), lr=LEARNING_RATE)
mse_loss = nn.MSELoss().to(device)
# Generate timestep embeddings. Note, the embedding dimension is hardcoded
# and based on the number of channels at the bottom layer of the U-Net
# noise model
time_embeddings = get_timestep_embedding(
torch.arange(0,T),
embedding_dim=60
).to(device)
# The training loop
epoch_losses = []
for epoch in range(EPOCHS):
loss_sum = 0
n_batchs = 0
for b_i, (X_batch, _) in enumerate(dataloader):
n_batchs += 1
# Move batch to device
X_batch = X_batch.to(device)
# Sample noise for each pixel and image in this batch
# B x M x N matrix where B is minibatch size, M is number
# of rows in each image and N is number of columns in the
# each image
eps = torch.randn_like(X_batch).to(device)
# Get a random timepoint for each item in this batch
# B x 1 matrix
ts = torch.randint(
1, T+1, size=(X_batch.shape[0],)
).to(device)
# Grab the time-embeddings for each of these sampled timesteps
# B x D matrix where B is minibatch size and D is time embedding
# dimension
t_embs = time_embeddings[ts-1].to(device)
# Compute X_batch after adding noise via the diffusion process for each of
# the items in the batch (at the sampled per-item timepoints, `ts`)
# B x M x N matrix
sqrt_alphabar_ts = sqrt_alphabar[ts-1]
sqrt_1malphabar_ts = sqrt_1malphabar[ts-1]
X_t = sqrt_alphabar_ts[:, None, None, None] * X_batch \
+ sqrt_1malphabar_ts[:, None, None, None] * eps
# Predict the noise from our sample using the UNet
# B x M x N matrix
pred_eps = noise_model(X_t, t_embs)
# Compute the loss between the real noise and predicted noise
loss = mse_loss(eps, pred_eps)
loss_sum += float(loss)
# Update the weights in the U-Net via a step of gradient descent
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch: {epoch}. Mean loss: {loss_sum/n_batchs}")
epoch_losses.append(loss_sum/n_batchs)
```

After this process finishes (it took a couple of hours to train in Google Colab running on an NVIDIA T4 GPU), we will have a trained model that we can use to generate new MNIST digits. To generate a new MNIST digit, we first sample white noise and then run the reverse diffusion process by iteratively applying our trained model. A function for generating images in this manner is shown below:

```
def sample_from_model(T=999, show_img_mod=None, cmap='viridis'):
# Initialize the image to white noise
X_t = torch.randn(1, 1, 28, 28).to(DEVICE)
# This samples accordingly to Algorithm 2. It is exactly the same logic.
for t in range(T, -1, -1):
# Sample noise
if t > 1:
Z = torch.randn(1, 1, 28, 28).to(DEVICE)
else:
Z = torch.zeros(1, 1, 28, 28).to(DEVICE)
# Get current time embedding
t_emb = time_embeddings[t][None,:]
# Predict the noise from the current image
pred_eps = noise_model(X_t, t_emb)
# Compute constants
one_over_sqrt_alpha_t = 1 / torch.sqrt(alphas[t])
pred_noise_scale = betas[t] / sqrt_1malphabar[t]
sqrt_beta_t = torch.sqrt(betas[t])
# Generate next image in the Markov chain
X_t = (one_over_sqrt_alpha_t * (X_t - (pred_eps * pred_noise_scale))) \
+ (sqrt_beta_t * Z)
# Show current image
if show_img_mod is not None:
if t % show_img_mod == 0:
print(f"t = {t}")
plt.imshow(
(X_t.detach().cpu().numpy().squeeze() + 1.) / 2.,
cmap=cmap
)
plt.xticks([])
plt.yticks([])
plt.show()
if t ==0:
print(f"t = {t}")
plt.imshow(
(X_t.detach().cpu().numpy().squeeze() + 1.) / 2.,
cmap=cmap
)
plt.xticks([])
plt.yticks([])
plt.show()
return X_t
```

The advancement of technology has brought with it the ability to generate ever larger and more complex collections of data. This is especially true in biomedical research, where new technologies can produce thousands, or even millions, of biomolecular measurements at a time. Because we human beings use our vision as our chief sense for understanding the world, when we are confronted with data, we try to understand that data through visualization. Moreover, because we evolved in a three-dimensional world, we can only ever visualize up to three dimensions of an object at a time. This limitation poses a fundamental problem when it comes to high-dimensional data; high-dimensional data cannot, without loss of information, be visualized in their totality at once. But this does not mean we have not tried! The field of dimensionality reduction algorithms studies and develops algorithms that map high dimensional data to two or three dimensions where we can visualize it with minimal loss of information. For example, the classical principal components analysis (PCA) uses a linear mapping to project data down to a space that preserves as much variance as possible. More recently, the t-SNE and UMAP algorithms use nonlinear mappings that attempts to preserve the “topology” of the data – that is, that attempts to preserve neighborhoods of nearby data points while preventing overlapping dense regions of data in the output figure. An example of single-cell RNA-seq data from peripheral blood mononuclear cells (PBMCs) visualized by PCA, t-SNE, and UMAP are shown below (Data was downloaded via Scanpy’s pbmc3k function. Code to reproduce this figure can be found on Google Colab. Note, t-SNE and UMAP are being used to visualize the top 50 principal components from PCA.):

Unfortunately, because it is mathematically impossible to avoid losing information when mapping data from high to low dimensions, these algorithms inevitably lose some aspect of the data, either by distortion or ommision, when plotting it in lower dimensions. This limitation makes the figures generated by these methods easy to misinterpret. Because of this, dimensionality reduction algorithms, especially t-SNE and UMAP, are facing new scrutiny by those who argue that nonlinear dimension reduction algorithms distort the data so heavily that their output is at best useless and at worst harmful. On the other hand, proponents of these methods argue that although distortion is inevitable, these methods can and do reveal aspects of the data’s structure that would be difficult to convey by other means.

In this blog post, I will attempt to provide my views on the matter which lie somewhere between those held by the critics and proponents. I will start with a review of dimensionality reduction and describe how it inevitably entails a loss of information. I will then argue that dimensionality reduction methods require a different kind of mentality to use them correctly than traditional data visualizations (i.e., those that do not compress high dimensional data into few dimensions). As a brief preview, I will argue that dimensionality reduction requires a “probabilistic” framework of interpretation rather than a “deterministic” one wherein conclusions one draws from a dimensionality reduction plot have some probability of not actually being true of the data. I will propose that this does not mean these plots are not useful! To evaluate their utility, I will argue that empirical user studies of these methods are required to evaluate them. That is, we must empirically assess whether or not the conclusions practitioners draw from these figures are more often true than not, and when not true, how consequential are they.

For much of this blog, I will use data generated by single-cell RNA-sequencing (scRNA-seq) as the primary example of high-dimensional data which I will use in a case study addressing the risks and merits of using dimension reduction for data visualization. As a brief review, scRNA-seq data is structured as a data table/matrix where rows represent individual cells and columns represent genes. Each entry of the matrix stores a measurement of the relative abundance of mRNA molecules transcribed from a given gene in a given cell. scRNA-seq studies routinely generate data for hundreds of thousands of cells and provide gene expression measurements for tens of thousands of genomic features such as genes or isoforms. Thus these data are very high-dimensional. For a comprehensive review on RNA-seq, please see my previous blog post.

In this section we will review the task of dimensionality reduction and describe why it inevitably entails a loss of information. Before moving forward, let’s formalize what we mean by the “dimensionality” of data. For the purposes of our discussion, we will refer to data as being $d$-dimensional if that data can be represented as a set of coordinate vectors in $\mathbb{R}^d$. That is, the dataset can be represented as $\boldsymbol{x}_1, \dots \boldsymbol{x}_n \in \mathbb{R}^d$. Collectively, we can represent the data as a matrix $\boldsymbol{X}^{n \times d}$ where each row represents a datapoint. This description thus covers all tabular data. (For a more philsophical treatment on the notion of “dimensionality”, see my previous blog post).

The task of dimensionality reduction is to find a new set of vectors $\boldsymbol{x}’_1, \dots, \boldsymbol{x}’_n$ in a $d’$ dimensional space where $d’ < d$ such that these lower dimensional points preserve some aspect of the original data’s structure. Said more succintly, the task is to convert the high dimensional data $\boldsymbol{X} \in \mathbb{R}^{n \times d}$ to $\boldsymbol{X}’ \in \mathbb{R}^{n \times d’}$ where $d’ < d$. This is often cast as an optimization problem of the form:

\[\max_{\boldsymbol{X}' \in \mathbb{R}^{n \times d'}} \text{Similarity}(\boldsymbol{X}, \boldsymbol{X}')\]where the function $\text{Similarity}(\boldsymbol{X}, \boldsymbol{X}’)$ outputs a value that tells us “how well” the pairwise relationships between data points in $\boldsymbol{X}’$ reflect those in $\boldsymbol{X}$. The exact form of $\text{Similarity}(\boldsymbol{X}, \boldsymbol{X}’)$ depends on the dimensionality reduction method.

Note that if $d > 3$ then we cannot easily visualize our data as a scatterplot to see the global structure between datapoints. Thus, to visualize data it is common to set $d’$ to either 2 or 3 thereby mapping each datapoint $\boldsymbol{x}_i$ to a new, 2 or 3 dimensional data point $\boldsymbol{x}’_i$ that can be visualized in a scatterplot.

However, there is a crucial problem to visualizing data in this manner: it is not possible (in general) to compress data down to a lower dimensional space and preserve all of the relative pairwise distances between data points. As an illustrative example, consider three points in 2-dimensional space that are equidistant from one another. If we were to compress these down into one dimension then *inevitably* at least one pair of data points will have a larger distance from one another than the other two pairs. This is shown below:

Notice how the distance between the blue and green data point is neither equal to the distance between the blue and red data points nor to the distance between the red and green data points as was the case in the original two dimensional space. Thus, the distances between this set of three data points have been distorted!

This problem presents itself in the more familiar territory of creating maps. It is mathematically impossible to project the globe onto a 2D surface without distorting distances and/or shapes of continents. Different map projections have been devised to reflect certain aspects of the 3D configuration of the Earth’s surface (e.g., shapes of continents), but that comes at the expense of some other aspect of that 3D configuration (e.g., distances between continents). A few examples of map projections are illustrated below (These images were created by Daniel R. Strebe and were pulled from Wikipedia):

In fact, you can demonstrate this problem for yourself in your kitchen! Just peel an orange and attempt to lay the pieces of the orange peel on the table to reconstruct the original surface of the orange… it’s impossible!

It is almost always the case that some information is lost following dimensionality reduction. To visualize high dimensional data in two or three dimensions, one must either throw away dimensions and plot the remaining two/three or devise a more sophisticated approach that maps high dimensional data points to low dimensional data points to preserve *some aspect* of the high dimensional data’s structure (with respect to the Euclidean distances between data points) at the expense of other aspects. Exactly which aspect of the data’s structure you wish to preserve depends on how you define your $\text{Similarity}(\boldsymbol{X}, \boldsymbol{X}’)$ function described above! (Note, there are scenarios where it is possible to preserve pairwise distances between data points following dimensionality reduction, but those scenarios tend to be uninteresting. An uninteresting example would be a case in which all of your data points lie in a 2-dimensional plane embedded in a higher-dimensional space).

Now, I am going to build a framework for thinking about data visualization that will draw what I view as an important distinction between “traditional” data visualizations, such as heatmaps or barcharts, and data visualizations based on dimensionality reduction such as UMAP scatterplots. As a very brief preview, I will argue that traditional data visualizations enable one to make claims about the underlying data with 100% certainty whereas dimensionality reduction visualizations do not provide the same certainty.

Before we get going, I am going to make a statement that may appear obvious, but is important for laying the foundation for my views on dimensionality reduction: **The primary outputs of a data visualization are a set of statements about the data**. Take the following table and associated bar chart as an example to describe what I mean by this:

One statement about the data being conveyed by this plot is that Label A is associated with the value 9. Another statement about the data is that Label A is associated with a larger value than Label B. Below are a set of example statements about the data being conveyed by this plot:

Note, I am describing statements *about the data*, not statements about the *world*. That is, the barchart enables one to make claims about the literal values stored in the data table from which this figure was generated. (The task of drawing conclusions about the world based on data and/or data visualizations is the task of science and statistics). Data visualizations make facts about the data easier to understand than the raw data by itself (i.e., large tables of numbers) because we human beings are visual animals.

For traditional data visualizations, statements about the data being described by the visualization are 100% certain to be true. For example, when one looks at the barchart above they *know* that Label A is associated with a larger value than Label B’s value (unless of course, there was an error in the generation of the visualization, but we will assume no errors were made here). That is because in traditional data visualizations, there is either a one-to-one or linear mapping between some aspect of the data and some visual or spatial element to the visualization. In the bar chart above, the mappings are as follows:

- Magnitude of Value $\rightarrow$ Height of bar (linear mapping)
- Distinct Label $\rightarrow$ Distinct bar (one-to-one mapping)

Because these mappings are invertible and deterministic, we can draw conclusions about the raw data with 100% certainty based on the visual elements in the figure.

Now, let’s turn our attention to data visualizations produced by dimensionality reduction methods. Let’s use a UMAP plot of the PBMCs shown above, but now let’s not color the cells by their respective cell type and let’s pretend we don’t know much of anything about these cells.

What are some (possibly incorrect) statements we might make about the data from this figure? Below are a few examples:

The problem is that many of the above statements are probably not true! For UMAP in particular, statements regarding distances are especially unlikely to be true. This is because, as we discussed before, dimensionality reduction methods distort the data. Statements we make about the data from this figure do not have the same gaurantee of truth as statements we made from the bar chart!

This brings us to a probablistic framework for thinking about data visualizations. Any statement about a dataset, $S$, drawn from a traditional data visualization, like the bar chart, has the following property:

\[P(S = \text{True}) = 1\]On the other hand, this is not the case for dimensionality reduction plots. Rather,

\[P(S = \text{True}) < 1\]Of course, the probability $P(S = \text{True})$ cannot really be defined in a frequentist sense, since $S$ either is or is not true. Rather, this is more of a Bayesian probability (i.e., a degree of certainty) that can be informed by looking across many plots with a similar feature as the feature being described by $S$ and asking: for what fraction of the datasets described by those plots is the statement $S$ true?

I concede that this “probabilistic framework” is a bit hand-wavey and not very rigorous, but to me illustrates an important distinction between plots based on dimensionality reduction and “traditional” data visualizations. Specifically, for dimensionality reduction plots, the statements one might draw from them about the data may be wrong. Said differently, dimensionality reduction plots help users draw *uncertain inferences* about the structure of the data whereas there is no uncertainty in a traditional data visualization.

Expert users of dimensionality reduction plots know this, of course, but use them anyway. Why? Well, because they claim that certain “classes” of statements are more often true than not and because of this fact, these methods are useful. For example, when it comes to UMAP, it is *often* true (but not always) that when you see distinct clusters in the scatterplot, those clusters are real characteristics of the data. Thus, one might say that a statement on clusters, $S_{\text{cluster}}$, is associated with a proability, $P(S_{\text{cluster}} = \text{True})$, that is high enough to be useful. On the other hand, a statement on distances between data points (especially long distances), $S_{\text{distance}}$, is associated with a probability, $P(S_{\text{distance}} = \text{True})$, that is far too low to be useful. The problem then is to determine which statements are more likely to be true than others for a given dimensionality reduction method.

As I mentioned in the introduction, dimensionality reduction methods are receiving new scrutiny. Given my “probabilistic mindset” for interpreting data visualizations, I will attempt to summarize my understanding of certain criticisms of dimensionality reduction methods, especially non-linear methods like t-SNE and UMAP:

**Concern 1: Distortion caused by popular methods like t-SNE and UMAP are too severe to be useful:** That is, distortion of the data is so severe that practically any interesting statement $S$ that one might make from such a visualization is associated with too low of a probability of actually being true about the data to be useful for anything.

To this concern, I am undecided and I will address it more thoroughly in the following section. To preview, I believe that the best way to assess dimensionality reduction methods is to employ user studies. Do these plots lead to new insights *in practice* despite the fact that they distort the data? Do they lead to more good than harm?

**Concern 2: We don’t know what inferences can be drawn reliably from these plots:** That is, we do not have a deep enough understanding into the classes of statements that have high or low probability of being true. Without this understanding, we cannot use these plots effectively.

I mostly agree with this statement. The objective function of t-SNE, for example, is mostly built upon heuristics designed to generate a figure with certain properties. For example, the use of a t-distribution in the underlying model is motivated by the fact that it pushes data points apart in the resultant low-dimensional space and avoids overcrowding of data points. In my opinion, this is not really based on sound statistical theory. Similarly, UMAP assumes certain characteristics of the high-dimensional data that, as far as I know, are difficult or impossible to test. For example, UMAP assumes that “The data is uniformly distributed on Riemannian manifold”. I’m not sure how, in general, one can know this without a very sophisticated understanding of the underlying data-generating process.

All of that said, there is ongoing research to either develop new dimensionality reduction methods that are easier to interpret or to help users more accurately interpret plots generated by existing methods. For example, the recent method, PHATE, claims to better preserve continuums of data points in the high-dimensional space. DensMAP claims to better preserve regions of high or low density. Suprisal Components Analysis (SCA) claims to better preserve small clusters. scDEED identifies features in dimensionality reduction plots that are misleading.

**Concern 3: Any visualization in which there is uncertainty around what it says about the data should be avoided:** The argument here is that it is too easy to misinterpret and misuse *any* data visualization technique in which one can make a reasonable statement about the data, $S$, but that $P(S = \text{True}) < 1$.

The concern here is that if a plot does not provide certainty into the data that it describes, then it is too easy to fall victim to confirmation bias when interpreting that figure. I admit I fell prey to this myself. In a paper I led presenting a cell type classification algorithm, called CellO, I made the following statement based on a UMAP plot (referencing Figure 7): “CellO annotated many of these cells as pancreatic A cells (a.k.a. pancreatic alpha cells), which is plausible owing to both their close position to annotated A cells according to UMAP, which is known to preserve some level of global structure in high dimensional data (Becht et al., 2018)…” Granted, this statement is not very strong, I nonetheless ask myself whether what I saw in that UMAP plot is what I wanted to see? Indeed, because these figures may make us more prone to confirmation bias, I am sympathetic to the argument that we should avoid them altogether. At the very least, one should use extreme caution when using them and make sure to confirm any hypotheses generated from these figures using orthogonal techniques. I know I will pay more attention going forward.

While recent studies, such as studies by Chari and Pachter (2023) and Huang *et al.* (2022) evaluate dimensionality reduction algorithms quantitatively (and are very valuable studies), I argue that these studies don’t directly address the fundamental question regarding whether these methods lead to more harm or benefit. Because *statements* about data are the primary output of a data visualization, it is those statements that we should be evaluating. That is, even though it is established that dimensionality reduction methods distort the data, the question remains (in my mind) whether or not the statements that practicioners in the field draw from these plots have a high or low probability of being true. Do these plots lead to new insights *in practice* despite the fact that they distort the data? What alternative visualizations would provide the same insights with more certainty?

I am not an expert in how to conduct these kinds of studies and I am not sure what the best strategy would be, but I envision something like the following: Gather a group of scientist volunteers in some specific field and present them with a dimensionality reduction plot (e.g., a UMAP plot) for a dataset from a domain that they are unfamiliar with. Next, ask each volunteer to list statements/hypotheses they have about the data from that figure. Finally, evaluate how many of those statements were actually true within the data or were not true? What alternative visualization methods would have led the user to the same correct hypotheses, but avoided the incorrect ones? Were there certain categories of statements (e.g., related to clusters) that tended to be true and others (e.g., related to distances) that tended not to be true? Of course, this would be a fairly qualitative study, but perhaps it would shed light on how these plots are being used in the field.

I propose that if one seeks to visualize their data with dimensionality reduction, they should use multiple methods in parallel. Because any statement, $S$, that one draws from these figures has a probability of not being true, it helps to assess whether other dimensionality reduction methods lead to the same statement. If many different methods all support $S$, then perhaps it is more likely to be true than if only one method supports $S$. That is because, as long as the methods are “orthogonal” to one another (i.e., are grounded in different theory or approach), then it would be quite a coincidence that $S$ is supported by multiple methods, but not actually true. Viewing these plots requires one to have a “probabilistic mindset” that is not needed for traditional data visualizations.

As an example, let’s look at another single-cell dataset from differentiating myeloid cells published by Paul *et al.* (2015). Below, I visualize these cells using six different dimensionality reduction methods: PCA, t-SNE, UMAP, Force-directed layout of the k-nearest neighbors graph , PHATE, and Surprisal Components Analysis (SCA) (Data was downloaded via Scanpy’s paul15 function. Code to reproduce this figure can be found on Google Colab. Note, t-SNE, UMAP, and force-directed layout are being used to visualize the top 50 principal components from PCA.):

Note that all of the figures here present a continuum of cells originating at megakaryocyte/erythrocyte progenitors (MEP) and extend outward along two “branches”. Because this is featured by *all* of the plots, I think it is a reasonable hypothesis that there is indeed a continuum of cells starting from this cell type in the high-dimensional gene expression space. But of course, this may not be true. In my analysis, t-SNE, UMAP, and force-directed layout are all operating on the top 50 principal components from PCA, so they are not perfectly orthogonal. Similarly, UMAP, PHATE, and force-directed layout are all operating on a k-nearest neighbors graph. While t-SNE does not explicitly operate on a k-nearest neighbors graph, its use of centering a unimodal distribution around each point to capture a certain density of neighbors is effectively operating on a k-nearest neighbors graph. Thus, these methods in particular are even more similar to one another.

In conclusion, no statement can be made with absolute certainty from dimensionality reduction plots. We must be dilligent in confirming any hypotheses generated by these methods using alternative, statistically grounded approaches. Lastly, when using these methods, we must remain self-aware enough to avoid the confirmation bias that these methods may promote.

**Final note:** Please let me know if I mischaracterized any work cited above.

Throughout many of my prior linear algebra posts, we have seen theorems proving various properties of invertible matrices. In this post, we will bring these theorems into one location and form a set of equivalent statements that all can be used to define an invertible matrix. These statements form what is called the **invertible matrix theorem**.

Importantly, any single one of the statements listed in the invertible matrix theorem imply all of the rest of the statements and really, any single statement can be used as the fundamental definition of an invertible matrix. Thus, this theorem not only provides a multi-angled perspective into the nature of invertible matrices, it is also practically useful because if one has some matrix, $\boldsymbol{A}$, then one needs only to prove a single one of the statements in the invertible matrix theorem in order to learn that *all of the remaining statements* of the invertible matrix theorem are also true of the matrix.

The invertible matrix theorem is stated as follows:

**Theorem 1 (invertible matrix theorem)**: For a given square matrix $\boldsymbol{A} \in \mathbb{R}^n$, if any of the following statements are true of that matrix, then all the remaining statements are also true.

1. There exists a square matrix $\boldsymbol{C} \in \mathbb{R}^{n \times n}$ such that $\boldsymbol{AC} = \boldsymbol{CA} = \boldsymbol{I}$

2. The columns of $\boldsymbol{A}$ are linearly independent

3. The rows of $\boldsymbol{A}$ are linearly independent

4. The columns of $\boldsymbol{A}$ span all of $\mathbb{R}^n$

5. The rows of $\boldsymbol{A}$ span all of $\mathbb{R}^n$

6. The rank of $\boldsymbol{A}$ is $n$

7. The nullity of $\boldsymbol{A}$ is $0$

8. The linear transformation $T(\boldsymbol{x}) := \boldsymbol{Ax}$ is one-to-one and onto.

9. The equation $\boldsymbol{Ax} = \boldsymbol{0}$ has only the trivial solution $\boldsymbol{x} = \boldsymbol{0}$

10. It is possible to use the row reduction algorithm to transform $\boldsymbol{A}$ into $\boldsymbol{I}$

11. There exists a sequence of elementary matrices $\boldsymbol{E}_1, \boldsymbol{E}_2, \dots, \boldsymbol{E}_m$ such that $\boldsymbol{E}_1\boldsymbol{E}_2 \dots \boldsymbol{E}_m\boldsymbol{A} = \boldsymbol{I}$

12. $\vert \text{Det}(\boldsymbol{A}) \vert > 0$

In different texts, the invertible matrix theorem can be written somewhat differently with some texts including some statements that others don’t. The *essence* of the invertible matrix theorem is that there are many seemingly different statements that all define an invertible matrix. Any of these statements imply all of the rest.

To prove the invertible matrix theorem, we will prove the following implications between these statements:

Notice that there is a path from every statement to every other statement through these implications. Given two statements $X$ and $Y$ from from the invertible matrix theorem it holds that “$X$ if and only if $Y$”. Note, that the specific implications proven here are somewhat arbitrary; other texts might prove a different set of direct implications. The important point is that there exists an “implication path” between every statement and every other statement.

The proofs of each of these implications are described below:

**1 $\implies$ 2**: This was proven by Theorem 4 from my post on invertible matrices.

**1 $\implies$ 8**: This was proven by Theorems 2 and 3 from my post on invertible matrices

**1 $\implies$ 12**: By Theorem 2 in the Appendix to this post.

**2 $\iff$ 4**: By Theorem 3 in the Appendix to this post.

**2 $\implies$ 8**: By Theorem 4 in the Appendix to this post.

**2 $\iff$ 7**: By Definition 3 from my post on spaces induced by matrices, the column rank of a matrix is defined to be the maximum number of linearly independent vectors that span the column space of the matrix. By Theorem 2 (row rank equals column rank) from this same post the column rank of a matrix equals the row rank and we refer to either as simply the “rank”.

**3 $\iff$ 7**: By Definition 3 from my post on spaces induced by matrices, the row rank of a matrix is defined to be the maximum number of linearly independent vectors that span the row space of the matrix. By Theorem 2 (row rank equals column rank) from this same post the row rank of a matrix equals the column rank and we refer to either as simply the “rank”.

**3 $\iff$ 5**: This follows by the same logic described in Theorem 3 in the Appendix to this post.

**6 $\iff$ 7**: By Theorem 3 (Rank-Nullity Theorem) from my post on spaces induced by matrices.

**7 $\iff$ 9**: By Definition 5 (nullity) from my post on spaces induced by matrices.

**8 $\iff$ 10**: By the discussion presented in my post on row reduction.

**10 $\iff$ 11**: By the discussion presented in my post on row reduction.

**11 $\implies$ 1**: By Theorem 5 in the Appendix to this post.

**12 $\implies$ 2**: By Theorem 6 in the Appendix to this post.

Recall in our first post on invertible matrices, we defined an invertible matrix as follows:

**Definition 1 (Inverse matrix):** Given a square matrix $\boldsymbol{A} \in \mathbb{R}^{n \times n}$, it’s **inverse matrix** is the matrix $\boldsymbol{C}$ that when either left or right multiplied by $\boldsymbol{A}$, yields the identity matrix. That is, if for a matrix $\boldsymbol{C}$ it holds that \(\boldsymbol{AC} = \boldsymbol{CA} = \boldsymbol{I}\), then $\boldsymbol{C}$ is the inverse of $\boldsymbol{A}$. This inverse matrix, $\boldsymbol{C}$ is commonly denoted as $\boldsymbol{A}^{-1}$.

This definition follows Statement 1 of the invertible matrix theorem. However, in light of the invertible matrix theorem, *any of the statements* about invertible matrices could have been chosen as the definition of an invertible matrix. While we chose Statement 1, we could have chosen another and the rest of the statements would then follow from that definition.

**Theorem 2**: Given a square matrix $\boldsymbol{A} \in \mathbb{R}^{n \times n}$, if there exists an inverse matrix $\boldsymbol{A}^{-1} \in \mathbb{R}^{n \times n}$ such that $\boldsymbol{A}\boldsymbol{A}^{-1} = \boldsymbol{A}^{-1}\boldsymbol{A} = \boldsymbol{I}$, then this implies that $\vert \text{Det}(\boldsymbol{A}) \vert > 0$.

**Proof:**

We will use a proof by contradiction. Assume for the sake of contradiction that $\vert \text{Det}(\boldsymbol{A})\vert = 0$. Then we see that

\[\begin{align*} &\boldsymbol{A}\boldsymbol{A}^{-1} = \boldsymbol{I} \\ \implies & \text{Det}(\boldsymbol{A}\boldsymbol{A}^{-1}) = \text{Det}(\boldsymbol{I}) \\ \implies & \text{Det}(\boldsymbol{A}) \text{Det}(\boldsymbol{A}^{-1}) = \text{Det}(\boldsymbol{I}) \\ \implies & 0 \text{Det}(\boldsymbol{A}^{-1}) = 1 \\ \implies & 0 = 1 \end{align*}\]Clearly zero does not equal one. Thus, our assumption is wrong. It must be the case that if $\boldsymbol{A}\boldsymbol{A}^{-1} = \boldsymbol{I}$, then this implies that $\vert \text{Det}(\boldsymbol{A}) \vert > 0$. This proof can be repeated trivially flipping the order of $\boldsymbol{A}$ and $\boldsymbol{A}^{-1}$ in the matrix product. Note, lines 3 and 4 above follow from Thoerem 8 Axiom 1 from my post on determinant respectively.

$\square$

**Theorem 3**: Given a square matrix $\boldsymbol{A} \in \mathbb{R}^n$, the columns of $\boldsymbol{A}$ are linearly independent if and only if they span all of $\mathbb{R}^n$.

Let us prove the $\implies$ direction: If $\boldsymbol{A}$’s columns are linearly independent, then they span all of $\mathbb{R}^n$.

We will apply a proof by contradiction. Let us assume that there exists a vector $\boldsymbol{b} \in \mathbb{R}^n$ that does not lie in the column space of $\boldsymbol{A}$. This would imply that we could form a matrix by “appending” $\boldsymbol{b}$ to $\boldsymbol{A}$ by making $\boldsymbol{b}$ the last column of $\boldsymbol{A}$:

\[\boldsymbol{A}':= \begin{bmatrix} \boldsymbol{a}_{∗, 1} & \dots & \boldsymbol{a}_{∗, n} & \boldsymbol{b} \end{bmatrix}\]Because all of the columns of this new matrix are linearly independent, its column rank is $n+1$. However, the matrix still only has $n$ rows and thus, the maximum possible row rank of this matrix is $n$. This is in contradiction to Theorem 2 (row rank equals column rank) from my post on spaces induced by matrices, which states that the row rank is equal to the column rank. Thus, it must be the case that our assumption is wrong. There does not exist a vector $\boldsymbol{b} \in \mathbb{R}^n$ that lies outside $\boldsymbol{A}$’s column space. Thus, $\boldsymbol{A}$’s column space is all of $\mathbb{R}^n$.

Let us prove the $\impliedby$ direction: If $\boldsymbol{A}$’s columns span all of $\mathbb{R}^n$, then they are linearly independent.

We will use the Steinitz Exchange Lemma. The Steinitz Exchange Lemma states the following: Given a vector space $\mathcal{V}$ and two finite sets of vectors $U$ and $W$ such that $U$ is linearly independent and $W$ spans $\mathcal{V}$, it must be the case that $\vert U \vert \leq \vert W \vert$.

Now, for the sake of contradiction, let us assume that $\boldsymbol{A}$’s columns are not linearly independent. This implies that there exists at least one column in $\boldsymbol{A}$ that can be formed by the remaining vectors such that if we removed this column, the column space of $\boldsymbol{A}$ would still span $\mathbb{R}^n$. Let $S$ be the set of columns of $\boldsymbol{A}$ after removing this vector. Note that $\vert S \vert = n-1$. Now, let $I := { \boldsymbol{e}_1, \dots, \boldsymbol{e}_n } $ be the set of standard basis vectors in $\mathbb{R}^n$. Note that $\vert I \vert = n$. Now $I$ is a set of linearly independent vectors and $S$ spans $\mathbb{R}^n$; however, $\vert I \vert > \vert S \vert$. This contradicts the Steinitz Exchange Lemma. Thus, it must be the case that the columns of $\boldsymbol{A}$ are linearly independent.

**Theorem 4**: Given a matrix $\boldsymbol{A} \in \mathbb{R}^{n \times n}$ whose columns are linearly independent, the linear transformation defined as $T(\boldsymbol{x}) := \boldsymbol{Ax}$ is onto and one-to-one.

**Proof:**

We first prove that $T(\boldsymbol{x})$ is onto. To do so, we must prove that every vector in $\mathbb{R}^n$ is in the range of $T(\boldsymbol{x})$. Recall from our previous discussion on column spaces, that the range of $T(\boldsymbol{x})$ is the columns space of $\boldsymbol{A}$. By Theorem 3 above, since the columns of $\boldsymbol{A}$ are linearly independent, the column space spans all of $\boldsymbol{R}^n$. Thus, $T(x)$ is capable of mapping vectors to *every* vector in $\mathbb{R}^n$ and is thus onto.

Now we prove that $T(\boldsymbol{x})$ is one-to-one. First, because the columns of $\boldsymbol{A}$ are linearly independent, then by Theorem 2 (row rank equals column rank) from my post on spaces induced by matrices, the rank of $\boldsymbol{A}$ is $n$. By the [Theorem 3 (Rank-Nullity Theorem) from my post on spaces induced by matrices] the nullity of $\boldsymbol{A}$ is 0. This means that the only vector in $\boldsymbol{A}$’s null space is the zero vector $\boldsymbol{0}$. This means that the only solution to $T(\boldsymbol{x}) = 0$ is $\boldsymbol{x} := \boldsymbol{0}$.

With this in mind, we employ a nearly identical proof to that used in Theorem 3 (Invertible matrices characterize one-to-one functions) from my post on invertible matrices. For the sake of contradiction assume that there exists two vectors $\boldsymbol{x}$ and $\boldsymbol{x}’$ such that $\boldsymbol{x} \neq \boldsymbol{x}’$ and that

\(T(\boldsymbol{x}) = \boldsymbol{Ax} = \boldsymbol{b}\) and \(T(\boldsymbol{x}) = \boldsymbol{Ax}' = \boldsymbol{b}\) where $b \neq \boldsymbol{0}$. Then,

\[\begin{align*} \boldsymbol{Ax} - \boldsymbol{Ax}' &= \boldsymbol{0} \\ \implies \boldsymbol{A}(\boldsymbol{x} - \boldsymbol{x}') = \boldsymbol{0}\end{align*}\]By Theorem 1, it must hold that

\[\boldsymbol{x} - \boldsymbol{x}' = \boldsymbol{0}\]which implies that $\boldsymbol{x} = \boldsymbol{x}’$. This contradicts our original assumption. Therefore, it must hold that there does not exist two vectors $\boldsymbol{x}$ and $\boldsymbol{x}’$ that map to the same vector via the matrix $\boldsymbol{A}$. Therefore, $T(\boldsymbol{x})$ is a one-to-one function.

$\square$

**Theorem 5**: If there exists a sequence of elementary matrices $\boldsymbol{E}_1, \boldsymbol{E}_2, \dots, \boldsymbol{E}_m$ such that $\boldsymbol{E}_1\boldsymbol{E}_2 \dots \boldsymbol{E}_m\boldsymbol{A} = \boldsymbol{I}$, then there exists a square matrix $\boldsymbol{C} \in \mathbb{R}^{n \times n}$ such that $\boldsymbol{AC} = \boldsymbol{CA} = \boldsymbol{I}$

**Proof:**

As evident by the premise of the theorem, if $\boldsymbol{E}_1\boldsymbol{E}_2 \dots \boldsymbol{E}_m\boldsymbol{A} = \boldsymbol{I}$, then clearly $\boldsymbol{E}_1\boldsymbol{E}_2 \dots \boldsymbol{E}_m$ is the matrix $\boldsymbol{C}$ for which $\boldsymbol{CA} = \boldsymbol{I}$. So all there is left to prove is that it is also the case that $\boldsymbol{E}_1\boldsymbol{E}_2 \dots \boldsymbol{E}_m$ is the matrix $\boldsymbol{C}$ for which $\boldsymbol{AC} = \boldsymbol{I}$.

First, though not proven formally, it is evident that elementary row matrices are invertible. That is, you can always “undo” the transformation imposed by an elementary row matrix (e.g. for an elementary row matrix that swaps rows, you can always swap them back). Furthermore, since the product of invertible matrices is also invertible we know that $(\boldsymbol{E}_1\dots\boldsymbol{E}_k)$ is invertible. Thus,

\[\begin{align*} & (\boldsymbol{E}_1\dots\boldsymbol{E}_k)\boldsymbol{A} = \boldsymbol{I} \\ \implies & (\boldsymbol{E}_1\dots\boldsymbol{E}_k)^{-1} (\boldsymbol{E}_1 \dots \boldsymbol{E}_k)\boldsymbol{A} = (\boldsymbol{E}_1 \dots \boldsymbol{E}_k)^{-1}\boldsymbol{I} \\ \implies & \boldsymbol{A} = (\boldsymbol{E}_1 \dots \boldsymbol{E}_k)^{-1} \boldsymbol{I} \\ \implies & \boldsymbol{A} = \boldsymbol{I}(\boldsymbol{E}_1 \dots \boldsymbol{E}_k)^{-1} \\ \implies & \boldsymbol{A}(\boldsymbol{E}_1 \dots \boldsymbol{E}_k) = \boldsymbol{I}(\boldsymbol{E}_1 \dots \boldsymbol{E}_k)^{-1}(\boldsymbol{E}_1 \dots \boldsymbol{E}_k) \\ \implies & \boldsymbol{A}(\boldsymbol{E}_1 \dots \boldsymbol{E}_k) = \boldsymbol{I}\end{align*}\]Thus, $\boldsymbol{C} := (\boldsymbol{E}_1 \dots \boldsymbol{E}_k)$ is the matrix for which $\boldsymbol{AC} = \boldsymbol{CA} = \boldsymbol{I}$ and is thus $\boldsymbol{A}$’s inverse.

$\square$

**Theorem 6**: Given a square matrix $\boldsymbol{A} \in \mathbb{R}^{n \times n}$, if $\vert \text{Det}(\boldsymbol{A}) \vert > 0$, this implies that the columns are linearly independent.

**Proof:**

We will use a proof by contrapositive. We know from Theorem 2 in my post on determinants that if $\boldsymbol{A}$’s columns are linearly dependent, then its determinant is zero. Using the contrapositive, it holds that if the determinant is *not* zero, then the columns are *not* linearly dependent. The statement “the determinant is *not* zero” implies that $\vert \text{Det}(\boldsymbol{A}) \vert > 0$. Moreover, if the columns are not linearly dependent, then they can only be independent. Thus, it follows that if $\vert \text{Det}(\boldsymbol{A}) \vert > 0$, the columns of $\boldsymbol{A}$ are linearly independent.

$\square$

]]>In my previous post, we discussed the basics of RNA-sequencing and how to intuit the units of gene expression that this technology generates. To do so, we described a mathematical/statistical abstraction of the protocol that involves sampling mRNA transcripts from the pool of transcripts in the sample and then sampling locations along those selected transcripts. Through this process, we obtain a set of *reads*, which are short sequences from these sampled locations. To obtain a measure of gene expression, we count the number of reads that align to each gene/isoform in the genome.

As we discussed, RNA-seq provides *relative* expression values between genes. This is because we lose the information on how much total RNA was in each sample. The size of the pool of reads that we sequence is a technical parameter that we can adjust in our protocol, and thus, the more reads we sequence, the more counts we will obtain, on average, for each gene. Therefore, after accounting for the length of each gene/isoform, one can obtain an estimate for the relative amounts of each gene within a sample. This is what the units *transcripts per million* (TPM) describes. If you sample a million transcripts from the sample, the TPM value for a given gene will tell you how many transcripts on average will have originated from the given gene of interest.

Of course, this begs the question: how can we compare the expression of a given gene *between* samples? One approach is to simply compare the TPMs of that gene with the understanding that one is not really comparing the absolute amount of expression of that gene, but rather is comparing two relative amounts of expression. That is, if you see that your gene’s TPM is higher in one condition compared to another, it might not be the case that there is actually more mRNA from that gene in the first condition; it might just be that the amount of mRNA from the *other* genes in the sample is lower.

Is it possible to compare the absolute expression of a given gene between two samples? In this post, we will describe one normalization strategy that seeks to enable such analyses: median-ratio normalization. This approach was first introduced by Anders and Huber (2010) as a preprocessing step by the DESeq method for estimating differential expression. Median-ratio normalization is also the normalization approach used by the popular DESeq2 method. If you have ever used DESeq2, median-ratio normalization is the approach that the tool uses by default to calculate the “size factors” corresponding to each sample.

In this post, we will discuss the intuition behind median-ratio normalization and the key assumptions that this method makes about the data. We will also discuss why this method only applies to bulk RNA-seq data, but is not appropriate for most single-cell RNA-seq datasets.

RNA-seq does not provide measurements of absolute expression of each gene because in the RNA-seq protocol, we lose the key information telling us how much total RNA was in each sample to begin with. Thus, it should come as no surprise that in order to compare absolute expression values between samples, we need to make some strong assumptions about our samples. For median-ratio normalization this key assumption is as follows: *most genes are expressed at equal levels across all the samples in our dataset*. That is, each sample has a small number of genes that are expressed differentially from other samples (these genes may be the genes that are biologically interesting), but most genes are expressed at the same absolute level.

With this key assumption, median-ratio normalization uses all the samples in the dataset to compute a “reference sample.” This reference sample represents a baseline level of expression for each gene. We depict this schematically using a toy scenario where we have just two samples and three genes. The reference sample is depicted on the right:

Given our previously stated assumption, for each sample, *most* of its genes are expressed at the baseline level described by the reference sample. To make the samples comparable, we must identify one gene that is expressed at the baseline level. If that gene deviates from this baseline level, then this means that there is a library-size effect that must be accounted for by scaling *all* of the read counts in that sample to ensure this identified gene matches the reference sample. In the toy example below, for Sample 1, we identify Gene C as being a gene that should match the reference sample’s expression. For Sample 2, we identify Gene B as being a gene that should match the reference sample’s expression:

To normalize the samples, we scale the read counts in each sample so that the identified genes that should match the reference sample do match the reference sample. This process is depicted below:

In the end, the differences in read counts between the samples better reflect differences in absolute expression between them.

Note, this procedure will only work if we know that most genes between the different samples should show similar levels of absolute expression. This assumption *may* be met in scenarios where similar biological specimens are being compared. For example, if we are comparing blood samples between patients, then one may assume that most genes are not differently expressed if we assume a similar composition of cell types between the samples and conditions between the patients. In contrast, if one is comparing drastically different biological samples together (say different tissue types), then this may not be a safe assumption.

Let us start by defining some notation. Let,

\[\begin{align*}n &:= \text{Total number of samples} \\ g &:= \text{Total number of genes} \\ c_{i,j} &:= \text{Count of reads from gene $j$ in sample $i$}\end{align*}\]We start by calculating the “reference sample” expression values which represent baseline expression for each gene. We do so by computing the geometric mean of each genes’ counts across all samples. The geometric mean is used instead of the arithmetic mean because it is more robust to outlier values. For gene $j$, this is computed as:

\[m_j := \left(\prod_{i=1}^n c_{i,j} \right)^{\frac{1}{n}}\]Now, we must identify which gene in each sample should match the reference sample’s expression. For each sample $i$, for each gene $j$, we compute the ratio of the counts of gene $j$ in sample $i$ (i.e., $c_{i,j}$), to the baseline expression value for gene $j$:

\[r_{i,j} := \frac{c_{i,j}}{m_j}\]Intuitively, $r_{i,j}$ describes the deviation (more specifically the fold-change) between $c_{i,j}$ and the reference sample’s expression for this gene.

As we stated previously, most of any given sample’s genes should not be over or under expressed relative to the other samples in the dataset. Thus, *most* genes’ expression values in each sample should match the reference sample’s expression. With this assumption in mind, we *rank* all of the ratios for all the genes in a given sample, $r_{i,1}, r_{i,2}, \dots, r_{i,g}$. Intuitively, if most genes are not changing significantly from baseline, then the genes that fall in the middle of this ranking represent those genes that are unchanging. An idealized scenario is illustrated in the schematic below where only a few genes are higher than baseline (red), a few genes are lower than baseline (blue), but most are unchanged (grey):

Any deviation we see from baseline within the middle of this ranking is assumed to be driven by the library size. Thus, we can treat the median ratio as the “size factor” that we can use to re-scale the counts in the sample so that the ratios in the middle of the list are closer to the baseline. That is, we define the size factor for sample $i$ as,

\[s_i := \text{median}\left(r_{i,1}, r_{i,2}, \dots, r_{i,g}\right)\]and then we re-scale all of the counts in this sample by dividing by $s_i$. That is, the normalized count for gene $j$ in sample $i$ would be computed as,

\[\tilde{c}_{i,j} := \frac{c_{i,j}}{s_i}\]This step is illustrated in the schematic below:

Note, in practice, when computing the median ratio, we use only those genes whose expression is non-zero across all samples. This is because we want to only use genes whose expression is high enough to be reliably detected. Intuitively, if a gene was failed to be detected (had zero counts) in some sample, then it cannot tell us about the effect of library size, so it is excluded. Stated mathematically, $s_i$ is computed as

\[s_i := \text{median}\left(\left\{ r_{i,j} \mid \forall k \in [n], \ \ c_{k,j} > 0 \right\}\right)\]where $[n]$ are the set of integers from 1 to $n$.

Let’s examine the effect of median-ratio normalization on a publicly available RNA-seq dataset. We will look at a dataset from PBMC samples taken from patients hospitalized with COVID-19 published by Overmyer et al. (2021). Below, we look at the ranked ratios in two patients and see that for Patient C1 (left), the median ratio of the raw counts (blue) is above the baseline (grey dotted line) indicating a larger library size. In contrast, for Patient C23, we observe the median ratio below the baseline indicating a lower library size. After dividing each samples’ gene counts by the median ratio, the two plots (orange) become more centered about the baseline.

We can observe that running median-ratio normalization helps to normalize this data by examining the distribution of the $\log_2$ read counts in each sample. We observe that running median-ratio normalization effectively centers the data so that the medians are more closely matched between samples:

Median-ratio normalization is only approprioate for bulk RNA-seq data and is almost never appropriate for single-cell RNA-seq data. There are two reasons why median-ratio normalization is innaproporiate for single-cell RNA-seq data. The first reason is theoretical and the second is practical:

- There are usually diverse cell types in single-cell RNA-seq datasets and the key assumption that most genes are expressed at equal level across all cells likely does not hold because the cellular states between cells are often quite different. Differences between individual cells get averaged out in bulk RNA-seq when many cells are aggregated together.
- Current single-cell RNA-seq technologies do not sequence each cell’s transcriptome as deeply as can be sequenced in bulk RNA-seq. Thus, for the vast majority of genes, there is at least one cell with zero read counts originating from that gene. In median-ratio normalization, we only use genes that have non-zero expression across all samples. For single-cell RNA-seq this would end up throwing away most of the genes!

When aggregating single-cells together to form “pseudo-bulk” samples, median-ratio normalization *may* be appropriate; however, I am currently unaware of guidelines around how many cells should be aggregated for median-ratio normalization to be an appropriate procedure.

The **Binomial Theorem** is a basic theorem in mathematics that states that a binomial expression of the form $(x+y)^n$ can be represented as a summation involving the binomial coefficient:

This statement always appeared mysterious to me. What is the intuition for the summation? Moreover, why does the binomial coefficient, “$n$ choose $k$”, appear? In this post I will provide a proof that helped me better intuit this theorem.

Let’s start with a polynomial of the form

\[(a + b)(c + d)(e + f)\]We can apply the distributive property as follows:

\[(a + b)(c + d)(e + f) = a(c+d)(e+f) + b(c+d)(e+f)\]Notice that the first and second terms are two separate polynomials involving $a$ and $b$ where the first *only* involves $a$, but not $b$, whereas the second polynomial *only* involves $b$, but not $a$. In fact, for each binomial factor (i.e. for each of
$(a + b)$, $(c + d)$, and $(e + f)$), only one of the terms of the binomial will appear in a given term of the fully expanded polynomial.

To form a term in the fully expanded polynomial, we imagine the process of iterating over each binomial factor and choosing *one* of the two terms to include. For example, from $(a+b)$, we choose either $a$ or $b$ to include in the term, but never both. This is because of how the distributive property works: as we expanded the expression, we separated the two terms in each binomial factor so that they never could appear in the same term of the expansion. This process is depicted in the following diagram:

We also see that *every* combination of terms from each binomial factor will be used to form a term in the expanded polynomial. Again, this occurs from the process of carrying out the distributive property repeatedly when carrying out the expansion: everytime we perform the distributive property, we create two batches of terms in the expansion that will include either the first or second term from that binomial factor. Each batch is created in a split in the tree diagram above.

With this insight, let’s look at the following polynomial:

\[\begin{align*}(x + y)^3 &= (x + y)(x + y)(x + y) \\ &= x(x+y)(x+y) + y(x+y)(x+y) \\ &= x(x(x+y) + y(x+y)) + y(x(x+y) + y(x+y)) \\ &= xxx + xyx + yxx + yyx + xxy + xyy + yxy + yyy \\ &= x^3 + 3x^2y + 3y^2x + y^3\end{align*}\]Let’s say we’re interested in all terms in the expanded polynomial that have $k$ of the $x$ terms. By the previous observation, any term that has $k$ of the $x$ terms must have $n − k$ of the $y$ terms because we only pick a single term from each of the original binomial factors. Now, how many of the terms in the expanded polynomial will have $k$ of the $x$ terms and $n-k$ of the $y$ terms? Recall that every combination of ways of picking $k$ of the $x$ terms from the binomial factors will result in a term in the expansion of the form $x^ky^{n-k}$. We can think of this as computing all possible ways of choosing $k$ $x$ terms from the $n$ binomial factors. Thus, there will be ${n \choose k}$ such terms.

Finally, there are terms in the polynomial with $k$ $x$ terms for every value of $k$ between $0$ and $n$. Again, this follows from from the fact that every combination of terms from each of the binomial factor will be used to form a term in the expanded polynomial, and thus, there *must* be at least one term with $k$ of the $x$ terms and $n-k$ of the $y$ terms. This final observation leads to the Binomial Theorem:

$\square$

]]>Graphs are ubiqitous mathematical objects that describe a set of relationships between entities; however, they are challenging to model with traditional machine learning methods, which require that the input be represented as a tensor. Graphs break this paradigm due to the fact that the order of edges and nodes are arbitrary and the model must be capable of accomodating this feature. In this post, we will discuss graph convolutional networks (GCNs) as presented by Kipf and Welling (2017): a class of neural network designed to operate on graphs. As their name suggestions, graph convolutional neural networks can be understood as performing a convolution in the same way that traditional convolutional neural networks (CNNs) perform a convolution-like operation (i.e., cross correlation) when operating on images. This analogy is depicted below:

In this post, we will discuss the intution behind the GCN and how it is similar and different to the CNN. We will conclude by presenting a case-study training a GCN to classify molecule toxicity.

Fundamentally, a GCN takes as input a graph together with a set of feature vectors where each node is associated with its own feature vector. The GCN is then composed of a series of graph convolutional layers (to be discussed in the next section) that iteratively transform the feature vectors at each node. The output is then the graph associated with output vectors associated with each node. These output vectors can be (and often are) of different dimension than the input vectors. This is depicted below:

If the task at hand is a “node-level” task, such as performing classification on the nodes, then these per-node vectors can be treated as the model’s final outputs. For node-level classification, these output vectors could, for example, encode the probabilities that each node is associated with each class.

Alternatively, we may be interested in performing a “graph-level” task, where instead of building a model that produces an output per node, we are interested in task that requires an output over the graph as a whole. For example, we may be interested in classifying whole graphs rather than individual nodes. In this scenario, the per-node vectors could be fed, collectively, into another neural network (such as a simple multilayer perceptron), that operates on all them to produce a single output vector. This scenario is depicted below:

Note, GCNs can also perform “edge-level” tasks, but we will not discuss this here. See this article by Sanchez-Lengeling *et al*. (2021) for a discussion on how GCNs can perform various types of tasks with graphs.

In the next sections we will dig deeper into the graph convolutional layer.

GCNs are composed of stacked **graph convolutional layers** in a similar way that traditional CNNs are composed of convolutional layers. Each convolutional layer takes as input the nodes’ vectors from the previous layer (for the first layer this would be the input feature vectors) and produces corresponding output vectors for each node. To do so, the graph convolutional layer pools the vectors from each node’s neighbors as depicted below:

In the schematic above, node A’s vector, denoted $\boldsymbol{x}_A$ is pooled/aggregated with the vectors of its neighbors, $\boldsymbol{x}_B$ and $\boldsymbol{x}_C$. This pooled vector is then transformed/updated to form node A’s vector in the next layer, denoted $\boldsymbol{h}_A$. This same procedure is carried out over every node. Below we show this same procedure, but on performed on node D, which entails aggregating $\boldsymbol{x}_D$ with its neighbor’s vector $\boldsymbol{x}_B$:

This procedure is often called **message passing** since each node is “passing” its vector to its neighbors in order to update their vectors. Each node’s “message” is the vector associated with it.

Now, how exactly does a GCN perform the aggregation and update? To answer this, we will dig into the mathematics of the graph convolutional layer. Let $\boldsymbol{X} \in \mathbb{R}^{n \times d}$ be the features corresponding to the nodes where $n$ is the number of nodes and $d$ is the number of features. That is, row $i$ of $\boldsymbol{X}$ stores the features of node $i$. Let $\boldsymbol{A}$ be the adjacency matrix of this graph where

\[\boldsymbol{A}_{i,j} := \begin{cases} 1,& \text{if there is an edge between node} \ i \ \text{and} \ j \\ 0, & \text{otherwise}\end{cases}\]Note the matrices $\boldsymbol{X}$ and $\boldsymbol{A}$ are the two pieces of data required as input to a GCN for a given graph. The graph convolutional layer can thus be expressed as function that accepts these two inputs and outputs a matrix representing the updated vectors associated with each node. This function is given by:

\[f(\boldsymbol{X}, \boldsymbol{A}) := \sigma\left(\boldsymbol{D}^{-1/2}(\boldsymbol{A}+\boldsymbol{I})\boldsymbol{D}^{-1/2} \boldsymbol{X}\boldsymbol{W}\right)\]where,

\[\begin{align*}\boldsymbol{A} \in \mathbb{R}^{n \times n} &:= \text{The adjacency matrix} \\ \boldsymbol{I} \in \mathbb{R}^{n \times n} &:= \text{The identity matrix} \\ \boldsymbol{D} \in \mathbb{R}^{n \times n} &:= \text{The degree matrix of } \ \boldsymbol{A}+\boldsymbol{I} \\ \boldsymbol{X} \in \mathbb{R}^{n \times d} &:= \text{The input data (i.e., the per-node feature vectors)} \\ \boldsymbol{W} \in \mathbb{R}^{d \times w} &:= \text{The layer's weights} \\ \sigma(.) &:= \text{The activation function (e.g., ReLU)}\end{align*}\]When I first saw this equation I found it to be quite confusing. To break it down, here is what each matrix multiplication is doing in this function:

Let’s examine the first operation: $\boldsymbol{A}+\boldsymbol{I}$. This operation is simply adding ones along the diagonal entries of the adjacency matrix. This is the equivalent of adding self-loops to the graph where each node has an edge pointing to itself. The reason we need this is because when we perform message passing, each node should pass its vector to itself (since each node aggregates its own vector together with its neighbors).

The matrix $\boldsymbol{D}$ is the degree matrix of $\boldsymbol{A}+\boldsymbol{I}$. This is a diagonal matrix where element $i,i$ stores the total number of neighboring nodes to node $i$ (including itself). That is,

\[\boldsymbol{D} := \begin{bmatrix}d_{1,1} & 0 & 0 & \dots & 0 \\ 0 & d_{2,2} & 0 & \dots & 0 \\ 0 & 0 & d_{3,3} & \dots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & 0 & \dots & d_{n,n}\end{bmatrix}\]where $d_{i,i}$ is the number of adjacent nodes (i.e., direct neighbors) to node $i$.

The matrix $\boldsymbol{D}^{-1/2}$ is the matrix formed by taking the reciprocal of the square root of each entry in $\boldsymbol{D}$. That is,

\[\boldsymbol{D}^{-1/2} := \begin{bmatrix}\frac{1}{\sqrt{d_{1,1}}} & 0 & 0 & \dots & 0 \\ 0 & \frac{1}{\sqrt{d_{2,2}}} & 0 & \dots & 0 \\ 0 & 0 & \frac{1}{\sqrt{d_{3,3}}} & \dots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & 0 & \dots & \frac{1}{\sqrt{d_{n,n}}}\end{bmatrix}\]As we will discuss in the next section, left and right multiplying $\boldsymbol{A}+\boldsymbol{I}$ by $\boldsymbol{D}^{-1/2}$ can be viewed as “normalizing” the adjacency matrix. We will discuss what we mean by “normalizing” in the next section and why this is an important step; however, for now, the important point to relealize is that, like $\boldsymbol{A}$, the matrix $\boldsymbol{D}^{-1/2}(\boldsymbol{A}+\boldsymbol{I})\boldsymbol{D}^{-1/2}$ will also only have a non-zero entry at element $i,j$ only if nodes $i$ and $j$ are adjacent. For ease of ntoation, let’s let $\tilde{\boldsymbol{A}}$ denote this normalized matrix. That is,

\[\tilde{\boldsymbol{A}} := \boldsymbol{D}^{-1/2}(\boldsymbol{A}+\boldsymbol{I})\boldsymbol{D}^{-1/2}\]Then,

\[\tilde{\boldsymbol{A}}_{i,j} := \begin{cases} \frac{1}{\sqrt{d_{i,i} d_{j,j}}} ,& \text{if there is an edge between node} \ i \ \text{and} \ j \\ 0, & \text{otherwise}\end{cases}\]With this notation, we can simplify the graph convolutional layer function as follows:

\[f(\boldsymbol{X}, \boldsymbol{A}) := \sigma\left(\tilde{\boldsymbol{A}}\boldsymbol{X}\boldsymbol{W}\right)\]Next, let’s turn to the matrix $\tilde{\boldsymbol{A}}\boldsymbol{X}$. This matrix-product is performing the aggregation function/message passing that we described previously. That is, for every feature, we take a weighted sum of the features of the adjacent nodes where the weights are determined by $\tilde{\boldsymbol{A}}$.

Let $\bar{\boldsymbol{x}}_i$ denote the vector at node $i$ representing the aggregated features. We see that this vector is given by:

\[\begin{align*}\bar{\boldsymbol{x}}_i &= \sum_{j=1}^n \tilde{a}_{i,j} \boldsymbol{x}_j \\ &= \sum_{j \in \text{Neigh}(i)} \tilde{a}_{i,j} \boldsymbol{x}_j \\ &= \sum_{j \in \text{Neigh}(i)} \frac{1}{\sqrt{d_{i,i} d_{j,j}}} \boldsymbol{x}_j\end{align*}\]That is, it is simply computed by taking a weighted sum of the neighboring vectors where the weights are stored in the normalized adjacency matrix. We will discuss these neighbor-weights in more detail in the next section, but it is important to note that these weights are *not learned* weights – that is, they are not parameters to the model. Rather they are determined based only on the input graph itself.

So where are the learned weights/parameters to the model? They are stored in the matrix $\boldsymbol{W}$. In the next matrix multiplication, we “update” the aggregated feature vectors according to these weights via $\left(\tilde{\boldsymbol{A}}\boldsymbol{X}\right)\boldsymbol{W}$:

These vectors are then passed to the activation function, $\sigma$, before being output by the layer. This activation function injects non-linearity into the model.

One key point to note is that the dimensionality of the weights vector, $\boldsymbol{W}$, does not depend on the number of nodes in the graph. Thus, we see that the graph convolutional layer can operate on graphs of any size so long as the feature vectors at each node are of the same dimension!

We can visualize the graph convolutional layer at a given node using a network diagram highlighting the neural network architecture:

Now, so far we have discussed only a single graph convolutional layer. We can create a multi-layer GCN by stacking graph convolutional layers together where the output of one layer is fed as input to the next layer! That is, the embedded vector at each node, $\boldsymbol{h}_i$, that is output by a graph convolutional layer can treated as input to the next layer! Mathematically, this would be described as

\[\begin{align*} \boldsymbol{H}_1 &:= f_{\boldsymbol{W}_1}(\boldsymbol{X}, \boldsymbol{A}) \\ \boldsymbol{H}_2 &:= f_{\boldsymbol{W}_2}(\boldsymbol{H}_1, \boldsymbol{A}) \\ \boldsymbol{H}_3 &:= f_{\boldsymbol{W}_3}(\boldsymbol{H}_2, \boldsymbol{A})\end{align*}\]where $\boldsymbol{H}_1$, $\boldsymbol{H}_2$, and $\boldsymbol{H}_3$ are the embedded node vectors at layers 1, 2 and 3 respectively. The matrices $\boldsymbol{W}_1$, $\boldsymbol{W}_2$, and $\boldsymbol{W}_3$ are the weight matrices that parameterize each layer. A schematic illustration of stacked graph convolutional layers is depicted below:

Let’s take a closer look at the normalized adjacency matrix $ \tilde{\boldsymbol{A}} := \boldsymbol{D}^{-1/2}(\boldsymbol{A}+\boldsymbol{I})\boldsymbol{D}^{-1/2}$. What is the intuition behind this matrix and what do we mean by “normalized”.

To understand this normalized matrix, let us first consider what happens in the convolutional layer if we don’t perform any normalization and instead naively use the raw adjacency matrix (with ones along the diagonal), $\boldsymbol{A}+\boldsymbol{I}$. For ease of notation let

\[\hat{\boldsymbol{A}} := \boldsymbol{A} + \boldsymbol{I}\]Then, the graph convolutional layer function without normalization would be:

\[f_{\text{unnormalized}}(\boldsymbol{X}, \boldsymbol{A}) := \sigma(\hat{\boldsymbol{A}}\boldsymbol{X}\boldsymbol{W})\]In the aggregation step, the aggregated features for node $i$, again denoted as $\bar{\boldsymbol{x}}_i$, will be given by

\[\begin{align*}\bar{\boldsymbol{x}}_i &= \sum_{j=1}^n \hat{a}_{i,j} \boldsymbol{x}_j \\ &= \sum_{j=1}^n \mathbb{I}(j \in \text{Neigh}(i)) \boldsymbol{x}_j \\ &= \sum_{j \in \text{Neigh}(i)} \boldsymbol{x}_j\end{align*}\]where $\mathbb{I}$ is the indicator function and $\text{Neigh}(i)$ is the set of neighbors of node $i$.

We see that this aggregation step simply adds together all of the feature vectors of $i$’s adjacent nodes with its own feature vector. A problem becomes apparent: for nodes that have many neighbors, this sum will be large and we will get vectors with large magnitudes. Conversely, for nodes with few neighbors, this sum will result in vectors with small magnitudes. This is not a desirable property! When attempting to train our neural network, each node’s vector will be highly dependent on the number of neighbors that surround it and it will be challenging to optimize weights that look for signals in the neighboring nodes that are independent of the number of neighbors. Another problem is that if we have multiple layers, the vector associated with a given node may blow up in magnitude the deeper into the layers they go, which can lead to numerical stability issues. Thus, we need a way to perform this aggregation step so that the aggregated vector for each node is of similar magnitude and is not dependent on each node’s number of neighbors.

One idea to mitigate this issue would be to take the *mean* of the neighboring vectors rather than the sum. That is to compute,

Here, for node $i$, we simply divide the sum by the number of neighbors of node $i$. We can accomplish this averaging operation across all nodes in the graph at once if we normalize the adjacency matrix as follows:

\[\boldsymbol{D}^{-1}\hat{\boldsymbol{A}}\]Using this version of a normalized matrix for our convolutional layer, we would have:

\[f_{\text{mean}}(\boldsymbol{X}, \boldsymbol{A}) := \sigma(\boldsymbol{D}^{-1}\hat{\boldsymbol{A}}\boldsymbol{X}\boldsymbol{W})\]We can confirm that the aggregation step at node $i$ would be taking a mean of the vectors of the neighboring nodes:

\[\begin{align*}\bar{\boldsymbol{x}}_i &= \sum_{j=1}^n \hat{a}_{i,j} \boldsymbol{x}_j \\ &= \sum_{j=1}^n \frac{1}{d_{i,i}} \boldsymbol{x}_j \\ &= \frac{1}{d_{i,i}}\sum_{j=1}^n \mathbb{I}(j \in \text{Neigh}(i)) \boldsymbol{x}_j \\ &= \frac{1}{\left\vert \text{Neigh}(i) \right\vert}\sum_{j \in \text{Neigh}(i)} \boldsymbol{x}_j \end{align*}\]This normalization is a reasonable approach, but Kipf and Welling (2017) propose a slightly different normalization method that goes a step further than simple averaging. Their normalization is given by

\[\tilde{\boldsymbol{A}} := \boldsymbol{D}^{-1/2}\hat{\boldsymbol{A}}\boldsymbol{D}^{-1/2}\]which results in each element of this normalized matrix being

\[\tilde{\boldsymbol{A}}_{i,j} := \begin{cases} \frac{1}{\sqrt{d_{i,i} d_{j,j}}} ,& \text{if there is an edge between node} \ i \ \text{and} \ j \\ 0, & \text{otherwise}\end{cases}\]We note that this normalization is performing a similar correction as mean normalization (i.e., $\boldsymbol{D}^{-1}\hat{\boldsymbol{A}}$) because the edge weight between adjacent nodes $i$ and $j$ will be smaller if node $i$ is connected to many nodes, and larger if it is connected to few nodes.

However, ths begs the question, why use this alternative normalization approach rather than the more straightforward mean normalization? It turns out that this alternative normalization approach normalizes for something beyond how many neighbors each node has, *it also normalizes for how many neighbors each neighbor has*.

Let us say we have some node, Node $i$, with two neighbors: Neighbor 1 and Neighbor 2. Neighbor 1’s *only neighbor* is $i$. In contrast, Neighbor 2 is neighbors with many nodes in the graph (including $i$). Intuitively, because Neighbor 2 has so many neighbors, it has the opportunity to pass its message to more nodes in the graph. In contrast, Neighbor 1 can only pass its message to Node $i$ and thus, its influence on the rest of the graph is dictated by how Node $i$ is passing along its message. We see that Neighbor 1 is sort of “disempowered” relative to Neighbor 2 just based on its location in the graph. Is there some way to compensate for this imbalance? It turns out that the alternative normalization approach (i.e., $\boldsymbol{D}^{-1/2}\hat{\boldsymbol{A}}\boldsymbol{D}^{-1/2}$), does just that!

Recall that the $i, j$ element of $\boldsymbol{D}^{-1/2}\hat{\boldsymbol{A}}\boldsymbol{D}^{-1/2}$ is given by $\frac{1}{\sqrt{d_{i,i} d_{j,j}}}$. We see that this value will not only be lower if node $i$ has many neighbors, it will also be lower if node $j$, its neighbor, has many neighbors! This helps to boost the signal propogated by nodes with few neighbors relative to nodes with many neighbors.

We illustrate how this works in the schematic below where we color the nodes in a graph according to the weights associated with neighbors of Node 4 according to both mean normalization, $\boldsymbol{D}^{-1}\hat{\boldsymbol{A}}$, and the alternative normalization, $\boldsymbol{D}^{-1/2}\hat{\boldsymbol{A}}\boldsymbol{D}^{-1/2}$ (i.e., the 4th row of these two matrices):

We see that the mean normalization provides equal weight to all of the neighbors of Node 4 (including itself). In contrast, the alternative normalization gives the highest weight to Node 5, because it has few neighbors, and gives a lower weight to Node 1 because it has so many neighbors.

In the prior section, we described the graph convolutional layer as performing a message passing procedure. However, there is another perspective from which we can view this process: as performing a convolution operation similar to the convolution-like operation that is performed by CNNs on images (hence the name graph *convolutional* neural network). Specifically, we can view the message passing procedure instead as the process of passing a **filter** (also called a **kernel**) over each node such that when the filter is centered over a given node, it combines data from the nearby nodes to produce the output vector for that node.

Let’s start with a CNN on images and recall how the filter is passed over each pixel and the values of the neighboring pixels are combined to form the output value at the next layer:

In a similar manner, for GCNs, a filter is passed over each node and the values of the neighboring nodes are combined to form the output value at the next layer:

To perform a graph-level task, such as classifying graphs, we need to aggregate information across nodes. A simple way to do this is to perform a simple aggregation step where we aggregate all of the vectors associated with each node into a single vector associated with the entire graph. This pooling can be done by taking the mean each feature (mean pooling) or the maximum of each feature (max pooling). This aggregated vector can then be used as input to a fully connected, multi-layer perceptron that produces the final output:

In computational biology, graph neural networks are commonly applied to computational tasks operating on molecular structures. A graph is a natural data structure for encoding a molecule; each node represents an atom and each edge connects two atoms that are bonded together. An example is depicted below:

As a case-study for applying GCNs to a real task, we will implement and train a GCN to classify molecular toxicity. This is a binary-classification task where we are provided a molecule and our goal is to classify it as either toxic or not toxic. More specifically, we will use the Tox21 dataset downloaded from MoleculeNet. Applying GCNs to this dataset has been explored by the scientific community (e.g., Chen *et al*. (2021)), but we will reproduce such efforts here. Specifically, we focus on the task of predicting aryl hydrocarbon receptor activation encoded by the “NR-AhR” column of the Tox21 dataset.

Each molecule is represented as a SMILES string. To decode these molecules into graphs, we use the pysmiles Python package, which converts each string to a NetworkX graph. We then split the molecules/graphs into a random training set (85% of the data) and test set (the remaining 15%).

For node features, we use 1) the element of the atom, 2) the number of implicit hydrogens, 3) the charge of the atom, and lastly, 4) the aromaticity of each atom. After normalizing each adjacency matrix, we train the model using binary cross-entropy loss.

Finally, we then apply the model to the test set and generate an ROC curve and precision-recall curve:

The AUROC is 0.86, which isn’t too bad!

The code for this analysis is shown in the Appendix to this post and can be executed on Google Colab.

Note, this is not a very efficient implementation. Notably, we are storing each graph’s adjacency matrix and node feature vectors in a list rather than a tensor. A more efficient strategy would be to store each training batch as a tensor in order to vectorize the forward and backward passes. Advanced graph neural network software packages have tricks for overcoming such inefficiencies such as those described here.

**A Gentle Introduction to Graph Neural Networks**https://distill.pub/2021/gnn-intro/**Graph Convolutional Networks**https://tkipf.github.io/graph-convolutional-networks/

The full code for the toxicity analysis is shown below (and can be run on Google Colab.

First, download the dataset via the following commands:

```
curl -O https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz
gunzip -f tox21.csv.gz
```

The Python code is then:

```
import torch
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pandas as pd
from sklearn.preprocessing import OneHotEncoder
from pysmiles import read_smiles
import networkx as nx
# Load the dataset
df_tox21 = pd.read_csv('tox21.csv').set_index('mol_id')
# The column of the table encoding the toxicity labels
TASK = 'NR-AhR'
# Select only columns relevant to the task
df_tox21_task = df_tox21.loc[df_tox21[TASK].isin([1.,0.])].loc[:,[TASK, 'smiles']]
# Train-test split ratio
TRAIN_FRACTION = 0.85
# Shuffle the data
df_tox21_na_ar = shuffle(df_tox21_task, random_state=123)
# Make the split
split_ind = int(len(df_tox21_task) * TRAIN_FRACTION)
df_tox21_train = df_tox21_task[:split_ind]
df_tox21_test = df_tox21_task[split_ind:]
# Draw the first training molecule as a sanity check
net = read_smiles(df_tox21_train.iloc[0]['smiles'])
nx.draw(net, labels=dict(net.nodes(data='element')))
# Normalize the adjacency matrices
def normalize_adj(A):
# Fill diagonal with one (i.e., add self-edge)
A_mod = A + torch.eye(A.shape[1])
# Create degree matrix for each graph
diag = torch.sum(A_mod, axis=1)
D = torch.diag(diag)
# Create the normalizing matrix
# (i.e., the inverse square root of the degree matrix)
D_mod = torch.linalg.inv(torch.sqrt(D))
# Create the normalized adjacency matrix
A_hat = torch.matmul(D_mod, torch.matmul(A_mod, D_mod))
A_hat = torch.tensor(A_hat, dtype=torch.float64)
return A_hat
print("Loading network graph models...")
mol_nets_train = [
read_smiles(s)
for s in df_tox21_train['smiles']
]
A_train = [
torch.tensor(nx.adjacency_matrix(net).todense())
for net in mol_nets_train
]
A_train = [
normalize_adj(adj)
for adj in A_train
]
mol_nets_test = [
read_smiles(s)
for s in df_tox21_test['smiles']
]
A_test = [
torch.tensor(nx.adjacency_matrix(net).todense())
for net in mol_nets_test
]
A_test = [
normalize_adj(adj)
for adj in A_test
]
# Generate node-level features
def generate_features(net, attrs, one_hot_encoders):
feat = None
# For each attribute, compute the features for this attribute
# per node
for attr, enc in zip(attrs, one_hot_encoders):
node_to_val = nx.get_node_attributes(net, attr)
vals = []
for node in sorted(node_to_val.keys()):
val = node_to_val[node]
vals.append([val])
# Encode the values for this attribute as a feature vector
if enc is not None:
attr_feat = torch.tensor(enc.transform(vals).todense(), dtype=torch.float64)
else:
attr_feat = torch.tensor(vals, dtype=torch.float64)
# Concatenate the feature vector for this feature to the full
# feature vector for each node
if feat is None:
feat = attr_feat
else:
feat = torch.cat((feat, attr_feat), dim=1)
return feat
# Features to encode
ATTRIBUTES = ['element', 'charge', 'aromatic', 'hcount']
IS_ONE_HOT = [True, False, True, False]
# Create the element one hot encoder
encoders = []
for attr, is_one_hot in zip(ATTRIBUTES, IS_ONE_HOT):
if is_one_hot:
all_vals = set()
for net in mol_nets_train:
all_vals.update(nx.get_node_attributes(net, attr).values())
all_vals = sorted(all_vals)
print(f"All values of '{attr}' in training set: ", all_vals)
enc = OneHotEncoder(handle_unknown='ignore')
enc.fit([[x] for x in all_vals])
encoders.append(enc)
else:
encoders.append(None)
# Build training tensors
X_train = []
for net in mol_nets_train:
feats = generate_features(net, ATTRIBUTES, encoders)
X_train.append(feats)
y_train = torch.tensor(df_tox21_train[TASK])
# Build test tensors
X_test = []
for net in mol_nets_test:
feats = generate_features(net, ATTRIBUTES, encoders)
X_test.append(feats)
y_test = torch.tensor(df_tox21_test[TASK])
# Implement the GCN
class GCNLayer(torch.nn.Module):
def __init__(self, dim, hidden_dim):
super(GCNLayer, self).__init__()
self.W = torch.zeros(
dim,
hidden_dim,
requires_grad=True,
dtype=torch.float64
)
torch.nn.init.xavier_uniform_(self.W, gain=1.0)
def forward(self, A_hat, h):
# Aggregate
h = torch.matmul(A_hat, h)
# Update
h = torch.matmul(h, self.W)
h = F.relu(h)
return h
def parameters(self):
return [self.W]
class GCN(torch.nn.Module):
def __init__(self, x_dim, hidden_dim1, hidden_dim2, hidden_dim3):
super(GCN, self).__init__()
# Convolutional layers
self.layer1 = GCNLayer(x_dim, hidden_dim1)
self.layer2 = GCNLayer(hidden_dim1, hidden_dim2)
self.layer3 = GCNLayer(hidden_dim2, hidden_dim3)
# Output layer linear layer
self.linear = torch.nn.Linear(hidden_dim3, 1, dtype=torch.float64)
def forward(self, A_hat, x):
# Aggregate
#x = torch.matmul(A_hat, x)
# GCN layers
x = self.layer1(A_hat, x)
x = self.layer2(A_hat, x)
x = self.layer3(A_hat, x)
# Global average pooling
x = torch.mean(x, axis=0)
x = self.linear(x)
return F.sigmoid(x)
def parameters(self):
params = self.layer1.parameters() \
+ self.layer2.parameters() \
+ self.layer3.parameters() \
+ list(self.linear.parameters())
return params
def train_gcn(A, X, y, batch_size=100, n_epochs=10, lr=0.1):
# Input validation
assert len(A) == len(X)
assert len(X) == len(y)
# Instantiate model, optimizer, and loss function
model = GCN(
X[0].shape[1], # Input dimensions
20, # Layer 1 dimensions
20, # Layer 2 dimensions
5 # Final layer dimensions
)
optimizer = optim.Adam(model.parameters(), lr=lr)
bce = torch.nn.BCELoss()
# Training loop
for epoch in range(n_epochs):
# Shuffle the dataset upon each epoch
inds = list(np.arange(len(X)))
random.shuffle(inds)
X = [X[i] for i in inds]
A = [A[i] for i in inds]
y = torch.tensor(y[inds])
loss_sum = 0
for start in range(0,len(A),batch_size):
# Compute the start and end indices for the batch
end = start + min(batch_size, len(X)-start)
# Forward pass
pred = torch.concat([
model.forward(A[i], X[i])
for i in range(start, end)
])
# Compute loss on the batch
loss = bce(pred, y[start:end])
loss_sum = loss_sum + float(loss)
# Take gradient step
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch: {epoch}. Mean loss: {loss_sum/len(A)}")
return model
# Train the model
model = train_gcn(
A_train,
X_train,
y_train,
batch_size=200,
n_epochs=100,
lr=0.001
)
# Run the model on the test set
y_pred = torch.concat([
model.forward(A_i, X_i)
for A_i, X_i in zip(A_test, X_test)
])
# Create and display the ROC curve
fpr, tpr, _ = roc_curve(
y_test.numpy(),
y_pred.detach().numpy(),
pos_label=1
)
roc_display = RocCurveDisplay(fpr=fpr, tpr=tpr).plot()
# Create and display the PR-curve
precision, recall, thresholds = precision_recall_curve(
y_test.numpy(), y_pred.detach().numpy(), pos_label=1
)
PrecisionRecallDisplay(
recall=recall,
precision=precision
).plot()
Compute AUROC
print(auc(fpr, tpr))
```

In the previous post, we derived the formula for the determinant by showing that the determinant describes the geometric volume of the high dimensional parallelepiped formed by the columns of a matrix. But that is not the full story!

As with most topics, it helps to view determinants from multiple perspectives, which we will attempt to also do here. To understand determinants from multiple perspectives, we will also need to view matrices from multiple perspectives. Recall from a previous post that there are three perpectives for viewing matrices:

**Perspective 1:**As a table of values**Perspective 2:**As a list of column vectors (or row vectors)**Perspective 3:**As a linear transformation between vector spaces

In our last post, we explored the determinant by viewing matrices from Perspective 2. That is, the determinant of a matrix describes the geometric volume of the parallelepiped formed by the column vectors of the matrix. In this post, we will explore determinants by viewing matrices from Perspective 3 and explore what the determinant tells us about the linear transformation characterized by a given matrix. To preview, the determinant tells us two things about the linear transformation:

- How much a matrix’s linear transformation grows or shrinks space
- Whether the matrix’s linear transformation inverts space

We’ll conclude by putting these two pieces together and describe how the determinant can be thought about as describing space scaled by a “signed volume”.

To review, one can view a matrix as a characterizing a linear transformation between vector spaces. That is, given a matrix $\boldsymbol{A} \in \mathbb{R}^{m \times n}$, we can form a function $T$ that maps vectors in $\mathbb{R}^n$ to $\mathbb{R}^m$ using matrix-vector multirplication:

\[T(\boldsymbol{x}) := \boldsymbol{Ax}\]With this in mind, let’s think about what a matrix, $\boldsymbol{A} \in \mathbb{R}^{2 \times 2}$ will do to the standard basis vectors in $\mathbb{R}^{2 \times 2}$. Specifically, we see that the first standard basis vector will be transformed to the first column-vector of $\boldsymbol{A}$:

\[\begin{bmatrix}a & b \\ c & d\end{bmatrix}\begin{bmatrix}1 \\ 0\end{bmatrix} = \begin{bmatrix}a \\ c\end{bmatrix}\]Similarly, the second standard basis vector will be transformed to the second column of $\boldsymbol{A}$:

\[\begin{bmatrix}a & b \\ c & d\end{bmatrix}\begin{bmatrix}0 \\ 1\end{bmatrix} = \begin{bmatrix}b \\ d\end{bmatrix}\]Thus, if we multiply $\boldsymbol{A}$ by the matrix we that is formed by using the two standard basis vectors as columns (which is just the identity matrix), we get back $\boldsymbol{A}$:

\[\boldsymbol{AI} = \boldsymbol{A}\]Here, we are viewing the matrix $\boldsymbol{A}$ as a function and are viewing $\boldsymbol{I}$ as a list of vectors. We see that $\boldsymbol{A}$ transforms the column vectors in $\boldsymbol{I}$ into $\boldsymbol{A}$ itself. Moreover, we see that the column vectors of $\boldsymbol{I}$ form the unit cube and $\boldsymbol{A}$ transforms this unit cube into a parallelogram with an area equal to $\text{Det}(\boldsymbol{A})$. Thus we see that the matrix $\boldsymbol{A}$ has, in a sense, blown up the area of the original cube to an object that has a size equal to $\lvert \text{Det}(\boldsymbol{A}) \rvert$.

This pattern does not just hold for the unit cube alone nor does it hold for just $\mathbb{R}^2$. In fact, any hypercube in $m$ dimensional space that is transformed by some matrix $\boldsymbol{A}$ will become a new hypercube with an area that is grown or shrunk by a factor equal $\lvert \text{Det}(\boldsymbol{A}) \rvert$. To see why, examine what happens to a hypercube with sides of length $c$, which we can represent as the matrix $c\boldsymbol{I}$:

\[c\boldsymbol{I} = \begin{bmatrix}c & 0 & \dots & 0 \\ 0 & c & \dots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \dots & c \end{bmatrix}\]where

\[\text{Volume}(c\boldsymbol{I}) = c^{m}\]because each of $m$ sides is of length $m$. If we transform this hypercube by $\boldsymbol{A}$, we get a parallelepiped represented by the matrix $\boldsymbol{A}c\boldsymbol{I} = c\boldsymbol{A}$. It’s volume is given by the determinant $\lvert\text{Det}(c\boldsymbol{A})\rvert$:

\[\text{Volume}(c\boldsymbol{A}) = \lvert \text{Det}(c\boldsymbol{a}_{*,1}, \dots, c\boldsymbol{a}_{*,m})\rvert\]where $\boldsymbol{a}_{*,1}, \dots, \boldsymbol{a}_{*,m}$ are the column-vectors of $\boldsymbol{A}$. Notice that $c$ is multiplying each column-vector. From the previous post, recall that the determinant is linear with respect to each column-vector so we can “pull out” each $c$ coefficient:

\[\begin{align*}\text{Volume}(c\boldsymbol{A}) &= \lvert \text{Det}(c\boldsymbol{a}_{*,1}, \dots, c\boldsymbol{a}_{*,m}) \rvert \\ &= \lvert c^m \rvert \lvert \text{Det}(\boldsymbol{a}_{*,1}, \dots, \boldsymbol{a}_{*,m}) \rvert \\ &= \text{Volume}(c\boldsymbol{I}) \lvert \text{Det}(\boldsymbol{A}) \rvert \end{align*}\]Thus we see that the volume of our cube was scaled by a factor $\lvert \text{Det}(\boldsymbol{A}) \rvert$.

Without proving it formally here, we can now intuitively see that *any* area/object’s volume will be scaled by the factor $\lvert \text{Det}(\boldsymbol{A}) \rvert$ when transformed by $\boldsymbol{A}$. This is because we can always approximate the volume of an object by filling the object with small hypercubes and summing the volumes of those hypercubes together. As we shrink the hypercubes ever smaller, we get a more accurate approximation of the volume. Under transformation by a matrix, $\boldsymbol{A}$, all of those tiny hypercubes will be scaled by $\lvert \text{Det}(\boldsymbol{A})\rvert$ and thus, the full volume of the object will be scaled by this value as well. This idea can be visualized below where we see the volume of a circle scaled under transformation of a matrix $\boldsymbol{A}$:

**Quick aside: Intuiting why $\text{Det}(\boldsymbol{AB}) = \text{Det}(\boldsymbol{A})\text{Det}(\boldsymbol{B})$**

We can now gain a much better intuition for Theorem 8 presented in the previous post, which states that for two square matrices, $\boldsymbol{A}, \boldsymbol{B} \in \mathbb{R}^{n \times n}$ it holds that

\[\text{Det}(\boldsymbol{AB}) = \text{Det}(\boldsymbol{A})\text{Det}(\boldsymbol{B})\]First, recall that a matrix product $\boldsymbol{AB}$ can be interpreted as a composition of linear transformations. That is, the transformation carried out by $\boldsymbol{AB}$ is equivalent to the transformation carried out by $\boldsymbol{B}$ followed consecutively by $\boldsymbol{A}$. Let’s now think about how the area of an object will change as we first transform it by $\boldsymbol{B}$ followed by $\boldsymbol{A}$. First, transforming it by $\boldsymbol{B}$ will scale its area by a factor of $\lvert \text{Det}(\boldsymbol{B}) \rvert$. Then, transforming it by $\boldsymbol{A}$ will scale its area by a factor of $\lvert \text{Det}(\boldsymbol{A}) \rvert$. The total change of its area is thus $\lvert \text{Det}(\boldsymbol{B}) \rvert \lvert \text{Det}(\boldsymbol{A}) \rvert$. This can ve visualized below:

Above we see a unit cube first transformed into a parallelogram by $\boldsymbol{B}$. It’s area grows by a factor of $\lvert \text{Det}(\boldsymbol{B}) \rvert$. This parallelogram is then transformed into another paralellogram by $\boldsymbol{A}$. It’s transformation grows by an additional factor of $\lvert \text{Det}(\boldsymbol{A}) \rvert$. Thus, the final scaling factor of the unit cube’s area under $\boldsymbol{AB}$ is $\lvert \text{Det}(\boldsymbol{A}) \rvert \lvert \text{Det}(\boldsymbol{B})\rvert$. Equivalently, because the unit cube was transformed by $\boldsymbol{AB}$, its area grew by a factor of $\lvert \text{Det}(\boldsymbol{AB}) \rvert$.

So far, our discussion of the determinant has focused on volume, but we have glossed over the fact that this interpretation of the determinant requires taking its absolute value. What does the sign of the determinant capture? If determinants capture volume, then how can it be negative (intuitively, volume is only a positive quantity)?

It turns out that the sign of the determinant captures something else about a matrix’s linear transformation other than how much it grows or shrinks space: it captures whether or not a matrix “inverts” space. That is, a matrix with a positive determinant will maintain the orientation of vectors in the original space relative to one another, but a matrix with a negative determinant will invert their orientation.

As an example, let us consider the matrix:

\[\boldsymbol{A} := \begin{bmatrix}0 & 1 \\ 1 & 0\end{bmatrix}\]This matrix represents the identity matrix, but with its two columns flipped. The determinant of $\boldsymbol{A}$ is -1. Why? By Axiom 1, the determinant of the identity matrix is 1. By Theorem 1 in my previous post, swapping two columns will make the determinant negative. Thus, the determinant of $\boldsymbol{A}$ is simply -1. (Note, if you perform *two* swaps, the matrix no longer inverts space though this is a bit hard to visualize in high dimensions).

Below is an illustration of what happens to a set of vectors that form the outline of a hand when transformed by the matrix $\boldsymbol{A}$.

Here we see that this matrix simply flipped the orientation of vectors across the thick dotted line (you can see this by tracing the location of the thumb outlined by the thin dotted lines).

This same phenomenon occurs in higher dimensions too. Here is an example in three dimensions where a 3D hand is transformed by a matrix $\boldsymbol{A}$ that again represents the identity matrix, but with the first and third columns flipped. Notice how the hand went from being a right hand to a left hand by the transformation:

In some explanations, the determinant is explained as describing a “signed volume”. What is meant by signed volume? For me, it helps to think about determinants in a similar way that we think about integrals. Integrals express the “signed” area under a curve where the sign tells you whether there is more area above versus below zero. Consider a sequence of univariate functions where each function’s curve approaches zero until it cross zero and becomes more negative:

We see that the integral starts out as positive, shrinks to zero, and then becomes more negative.

Analagously, we can see that as two vectors are rotated towards one another, the determinant is positive but decreases until the vectors are aligned. If the vectors are aligned the determinant is zero. As they cross one another further, the determinant becomes more and more negative. This is visualized below:

Thus, the “sign” of the determinant can be thought about in a similar way as the sign of an integral. A negative integral tells you that the function has more area below zero than above zero. A negative determinant tells you that two columns vectors, in a sense, “crossed” with one another thus inverting space across those two column vectors.

]]>The **determinant** is a function that maps square matrices to real numbers:

where the absolute value of the determinant describes the volume of the parallelepided formed by the matrix’s columns. This is illustrated below:

While this idea is fairly straightforward conceptually, the formula for the determinant is quite confusing:

\[\text{Det}(\boldsymbol{A}) := \begin{cases} a_{1,1}a_{2,2} - a_{1,2}a_{2,1} & \text{if $m = 2$} \\ \sum_{i=1}^m (-1)^{i+1} a_{i,1} \text{Det}(\boldsymbol{A}_{-i, -1}) & \text{if $m > 2$}\end{cases}\]Here, $\boldsymbol{A}_{-i, -1}$ denotes the matrix formed by deleting the $i$th row and first column of $\boldsymbol{A}$. Note that this is a recursive definition where the base case is a $2 \times 2$ matrix.

When one is usually first taught determinants, they are supposed to take it as a given that this formula calculates the volume of an $m$-dimensional parallelepided; however, if you’re like me, this is not at all obvious. How on earth does this formula calculate volume? Moreover, why is it recursive?

In this post, I am going to attempt to derive this formula from first principles. We will start with the base case of a 2×2 matrix, verify that it indeed computes the volume of the parallelogram formed by the columns of the matrix, and then move on to the determinant for larger matrices.

Let’s first only look at the $m = 2$ case and verify that this equation computes the area of the parallelogram formed by the matrix’s columns. Let’s say we have a matrix

\[\boldsymbol{A} := \begin{bmatrix}a & b \\ c & d\end{bmatrix}\]Then we see that the area can be obtained by computing the area of the rectangle that encompasses the parallelogram and subtracting the areas of the triangles around it:

Simplifying the equation above we get

\[\text{Det}(\boldsymbol{A}) = ad - bc\]This is exactly the definition for the $2 \times 2$ determinant. So far so good.

Moving on to $m > 2$, the definition of the determinant is

\[\text{Det}(\boldsymbol{A}) := \sum_{i=1}^m (-1)^{i+1} a_{i,1} \text{Det}(\boldsymbol{A}_{-1,-i})\]Before understanding this equation, we must first ask ourselves what we really mean by “volume” in $m$-dimensional space. In fact, it is through the process of answering this very question that we bring us to the equation above. Specifically, we will formulate a set of three axioms that attempt to capture the notion of “geometric volume” in a very abstract way that applies to higher dimensions. Then, we will show that the only formula that satisfies these axioms is the formula for the determinant shown above!

These axioms are as follows:

**1. The determinant of the identity matrix is one**

The first axiom states that

\[\text{Det}(\boldsymbol{I}) := 1\]Why do we want this to be an axiom? First, we note that the parallelepided formed by the columns of the identity matrix, $\boldsymbol{I}$, is a hypercube in $m$-dimensional space:

We would like our notion of “geometric volume” to match our common intuition that the volume of a cube is simply the product of the sides of the cube. In this case, they’re all of length one so the volume, and thus the determinant, should be one.

**2. If two columns of a matrix are equal, then its determinant is zero**

For a given matrix $\boldsymbol{A}$, if any two columns $\boldsymbol{a}_{*,i}$ and $\boldsymbol{a}_{*,j}$ are equal, then the determinant of $\boldsymbol{A}$ should be zero.

Why do we want this to be an axiom? We first note that if two columns of a matrix are equal, then the parallelapipde formed by their columns is flat. For example, here’s a depiction of a parallelepided formed by the columns of a $3 \times 3$ matrix with two columns that are equal:

We see that the parallelepided is flat and lies within a hyperplane.

We would like our notion of “geometric volume” to match our common intuition that the volume of a flat object is zero. Thus, when any two columns of a matrix are equal, we would like the determinant to be zero.

**3. The determinant of a matrix is linear with respect to each column vector**

Before digging into this final axiom, let us define some notation to make our discussion easier. Specifically, for the remainder of this post, we will often represent the determinant of a matrix as a function with either a single matrix argument, $\text{Det}(\boldsymbol{A})$, or with multiple vector arguments $\text{Det}(\boldsymbol{a}_{*,1}, \dots, \boldsymbol{a}_{*,n})$ where $\boldsymbol{a}_{*,1}, \dots, \boldsymbol{a}_{*,n}$ are the $n$ columns of $\boldsymbol{A}$.

Now, the final axiom for the determinant is that $\text{Det}$ is a linear function with respect to each argument vector. For $\text{Det}$ to be linear with respect to each argument is to imply two conditions. First, for a given constant $k$, it holds that,

\[\forall j \in [n], \ \text{Det}(\boldsymbol{a}_{*,1}, \dots, k\boldsymbol{a}_{*,j}, \dots \boldsymbol{a}_{*,n}) = k\text{Det}(\boldsymbol{a}_{*,1}, \dots, \boldsymbol{a}_{*,j}, \dots \boldsymbol{a}_{*,n})\]and second, that

\[\forall j \in [n], \ \text{Det}(\boldsymbol{a}_{*,1}, \dots, \boldsymbol{a}_{*,j} + \boldsymbol{v}, \dots \boldsymbol{a}_{*,n}) = \text{Det}(\boldsymbol{a}_{*,1}, \dots, \boldsymbol{a}_{*,j}, \dots \boldsymbol{a}_{*,n}) + \text{Det}(\boldsymbol{a}_{*,1}, \dots, \boldsymbol{v}, \dots \boldsymbol{a}_{*,n})\]Why do we wish the linearity of $\text{Det}$ to be an axiom? Because it turns out that the volume of a two-dimensional parallelogram is linear with respect to the vectors that form its sides. We can prove this both algebraically as well as geometrically. Let’s start with the algebraic proof starting with a parallelogram defined by the columns of the following matrix:

\[\boldsymbol{A} := \begin{bmatrix}a & b \\ c & d\end{bmatrix}\]Let’s say we multiply one of the column vectors by k to form $\boldsymbol{A}’$:

\[\boldsymbol{A}' := \begin{bmatrix}ka & b \\ kc & d\end{bmatrix}\]Its determinant is

\[\begin{align*}\text{Det}(\boldsymbol{A}') &:= kad - bkc \\ &= k(ad - bc) \\ &= k\text{Det}(\boldsymbol{A})\end{align*}\]Now let’s consider another matrix formed by taking the first column or $\boldsymbol{A}$ and adding a vector $\boldsymbol{v}$:

\[\boldsymbol{A}' := \begin{bmatrix}a + v_1 & b \\ c + v_2 & d\end{bmatrix}\]Its determinant is

\[\begin{align*}\text{Det}(\boldsymbol{A}') &:= (a + v_1)d - b(c + v_2) \\ &= ad + v_1d - bc - bv_2 \\ &= (ad - bc) + (v_1d - bv_2) \\ &= \text{Det}\left(\begin{bmatrix}a & b \\ c & d\end{bmatrix}\right) + \text{Det}\left(\begin{bmatrix}v_1 & b \\ v_2 & d\end{bmatrix} \right) \end{align*}\]To provide more intuition about why this linearity property holds, let’s look at it geometrically. As a preliminary observation, notice how if we skew one of the edges of a parallelogram along the axis of the other edge, then the area remains the same. We can see this in the figure below by noticing that the area of the yellow triangule is subtracted from the first paralellogram, but is added to the second:

With this observation in mind, we can now show why the determinant is linear from a geometric perspective. Let’s start with the first axiom that says if we scale one of the sides of a parallelogram by $k$, then the area of the parallelogram is scaled by $k$. Below, we show a parallelogram where we scale one of the vectors, $\boldsymbol{v}$, by $k$.

We see that we can skew both sides of the parallelogram to be orthogonal to one another forming a rectangle that preserves the area of the parallelogram. This rectangle has sides of length $a$ and $b$ and thus an area of $ab$. When we skew the enlarged parallelogram in the same way, we form a rectangle with sides of length $a$ and $kb$ and thus an area of $kab$ We know the sides of the enlarged parallelogram are of length $a$ and $kb$ by observing that the two shaded triangles shown below are similar:

The second axiom of linearity states that if we break apart one of the vectors that forms an edge of the parallelogram into two vectors, we can show that they form two “sub-parallelograms” whose total area equals the area of the original parallelogram. This is shown in the following “visual proof”:

In the previous section, we outlined three axioms that define fundamental ways in which the volume of a parallelogram is related to the vectors that form its sides. It turns out that the *only* formula that satisfies these axioms is the following:

We will start by assuming that there exists a function $\text{Det}: \mathbb{R}^{m \times m} \rightarrow \mathbb{R}$ that satisfies our three axioms and will subsequently prove a series of theorems that will build up to this final formula. Many of these theorems make heavy use of the fact that invertible matrices can be decomposed into the product of elementary matrices. For an in-depth discussion of elementary matrices, see my previous post.

The Theorems required to derive this formula are outlined below. and their proofs are given in the Appendix to this post.

**Theorem 1:** Given a matrix $\boldsymbol{A} \in \mathbb{R}^{m \times m}$, if we exchange any two column-vectors of $\boldsymbol{A}$ to form a new matrix $\boldsymbol{A}’$, then $\text{Det}(\boldsymbol{A}’) = -\text{Det}(\boldsymbol{A})$

**Theorem 2:** Given a matrix $\boldsymbol{A} \in \mathbb{R}^{m \times m}$, if if it’s column-vectors are linearly dependent, then its determinant is zero.

**Theorem 3:** Given a triangular matrix, $\boldsymbol{A} \in \mathbb{R}^{m \times m}$, its determinant can be computed by multiplying its diagonal entries.

**Theorem 4:** Given a matrix $\boldsymbol{A} \in \mathbb{R}^{m \times m}$, adding a multiple of one column-vector of $\boldsymbol{A}$ to another column-vector does not change the determinant of $\boldsymbol{A}$.

**Theorem 5:** Given an elementary matrix that represents row-scaling, $\boldsymbol{E} \in \mathbb{R}^{m \times m}$, where $\boldsymbol{E}$ scales the $j$th row of a system of linear equations by $k$, its determinant is simply $k$.

**Theorem 6:** Given an elementary matrix that represents row-swapping, $\boldsymbol{E} \in \mathbb{R}^{m \times m}$ that swaps the $i$th and $j$th rows of a system of linear equations, its determinant is simply -1.

**Theorem 7:** Given an elementary matrix that represents a row-sum, $\boldsymbol{E} \in \mathbb{R}^{m \times m}$ that multiplies row $j$ by $k$ times row $i$, its determinant is simply 1.

**Theorem 8:** Given matrices $\boldsymbol{A}, \boldsymbol{B} \in \mathbb{R}^{m \times m}$, it holds that $\text{Det}(\boldsymbol{AB}) = \text{Det}(\boldsymbol{A})\text{Det}(\boldsymbol{B})$

**Theorem 9:** Given a square matrix $\boldsymbol{A}$, it holds that $\text{Det}(\boldsymbol{A}) = \text{Det}(\boldsymbol{A}^T)$.

**Theorem 10:** The determinant of matrix is linear with respect to the row vectors of the matrix.

With these theorems in hand we can derive the final formula for the determinant:

**Theorem 11:** Let $\text{Det} : \mathbb{R}^{n \times n} \rightarrow \mathbb{R}$ be a function that satisfies the following three properties:

1. $\text{Det}(\boldsymbol{I}) = 1$

2. Given $\boldsymbol{A} \in \mathbb{R}^{n \times n}$, if any two columns of $\boldsymbol{A}$ are equal, then $\text{Det}(\boldsymbol{A}) = 0$

3. $\text{Det}$ is linear with respect to the column-vectors of its input.

Then $\text{Det}$ is given by

\(\text{Det}(\boldsymbol{A}) := \begin{cases} a_{1,1}a_{2,2} - a_{1,2}a_{2,1} & \text{if $m = 2$} \\ \sum_{i=1}^m (-1)^{i+1} a_{i,1} \text{Det}(\boldsymbol{A}_{-i,-1}) & \text{if $m > 2$}\end{cases}\)

Below is an illustration of how each theorem depends on the other theorems. Note, they all flow downward until we can prove the final formula for the determinant in Theorem 11:

All of the proofs are left to the Appendix below this blog post.

- Lecture notes by Mark Demers: http://faculty.fairfield.edu/mdemers/linearalgebra/documents/2019.03.25.detalt.pdf
- Lecture notes by Dan Margalit, Joseph Rabinoff, Ben Williams: https://personal.math.ubc.ca/~tbjw/ila/determinants-volumes.html
- Explanation by 3Blue1Brown: https://www.3blue1brown.com/lessons/determinant

**Theorem 1:** Given a matrix $\boldsymbol{A} \in \mathbb{R}^{m \times m}$, if we exchange any two column-vectors of $\boldsymbol{A}$ to form a new matrix $\boldsymbol{A}’$, then $\text{Det}(\boldsymbol{A}’) = -\text{Det}(\boldsymbol{A})$

**Proof:**

Let columns $i$ and $j$ be the columns that we exchange within $\boldsymbol{A}$. For ease of notation, let us define

\[\text{Det}_{i,j}(\boldsymbol{a}_{*,i}, \boldsymbol{a}_{*,j}) := \text{Det}_{i,j}(\boldsymbol{a}_{*,1} \dots, \boldsymbol{a}_{*,i}, \dots, \boldsymbol{a}_{*,j}, \dots, \boldsymbol{a}_{*,m})\]to be the determinant of $\boldsymbol{A}$ as a function of only the $i$th and $j$th column-vectors of $\boldsymbol{A}$ where the other column-vectors are held fixed. Then, we see that

\[\begin{align*} \text{Det}_{i,j}(\boldsymbol{a}_{*,i}, \boldsymbol{a}_{*,j}) &= \text{Det}_{i,j}(\boldsymbol{a}_{*,i}, \boldsymbol{a}_{*,j}) + \text{Det}_{i,j}(\boldsymbol{a}_{*,i}, \boldsymbol{a}_{*,i}) && \text{Axiom 2} \\ &= \text{Det}_{i,j}(\boldsymbol{a}_{*,i}, \boldsymbol{a}_{*,i} + \boldsymbol{a}_{*,j}) && \text{Axiom 3} \\ &= \text{Det}_{i,j}(\boldsymbol{a}_{*,i}, \boldsymbol{a}_{*,i} + \boldsymbol{a}_{*,j}) - \text{Det}_{i,j}(\boldsymbol{a}_{*,i} + \boldsymbol{a}_{*,j}, \boldsymbol{a}_{*,i} + \boldsymbol{a}_{*,j}) && \text{Axiom 2} \\ &= \text{Det}_{i,j}(-\boldsymbol{a}_{*,j}, \boldsymbol{a}_{*,i} + \boldsymbol{a}_{*,j}) && \text{Axiom 3} \\ &= -\text{Det}_{i,j}(\boldsymbol{a}_{*,j}, \boldsymbol{a}_{*,i} + \boldsymbol{a}_{*,j}) && \text{Axiom 3} \\ &= -(\text{Det}_{i,j}(\boldsymbol{a}_{*,j}, \boldsymbol{a}_{*,i}) - \text{Det}_{i,j}(\boldsymbol{a}_{*,j},\boldsymbol{a}_{*,j})) && \text{Axiom 3} \\ &= -\text{Det}_{i,j}(\boldsymbol{a}_{*,j}, \boldsymbol{a}_{*,i}) && \text{Axiom 2}\end{align*}\]$\square$

**Theorem 2:** Given a matrix $\boldsymbol{A} \in \mathbb{R}^{m \times m}$, if it’s column-vectors are linearly dependent, then its determinant is zero.

**Proof:**

Given a matrix $\boldsymbol{A}$ with columns,

\[\boldsymbol{A} := \begin{bmatrix}\boldsymbol{a}_{*,1}, \boldsymbol{a}_{*,2}, \dots, \boldsymbol{a}_{*,m}\end{bmatrix}\]if the column-vectors of $\boldsymbol{A}$ are linearly dependent, then there exists a vector $\boldsymbol{a}_{*,j}$ that can be expressed as a linear combination of the remaining vectors:

\[\boldsymbol{a}_{*,j} = \sum_{i \neq j} c_i\boldsymbol{a}_{*,i}\]for some set of constants. Thus we can write the determinant as:

\[\begin{align*}\text{Det}(\boldsymbol{A}) &= \text{Det}(\boldsymbol{a}_{*,1}, \boldsymbol{a}_{*,2}, \dots, \boldsymbol{a}_{*,j}, \dots, \boldsymbol{a}_{*,m}) \\ &= \text{Det}\left(\boldsymbol{a}_{*,1}, \boldsymbol{a}_{*,2}, \dots, \sum_{i \neq j} c_i\boldsymbol{a}_{*,i}, \dots, \boldsymbol{a}_{*,m}\right) \\ &= \sum_{i \neq j} \text{Det}\left(\boldsymbol{a}_{*,1}, \boldsymbol{a}_{*,2}, \dots, c_i\boldsymbol{a}_{*,i}, \dots, \boldsymbol{a}_{*,m}\right) && \text{Axiom 3} \\ &= \sum_{i \neq j} c_i \text{Det}\left(\boldsymbol{a}_{*,1}, \boldsymbol{a}_{*,2}, \dots, \boldsymbol{a}_{*,i}, \dots, \boldsymbol{a}_{*,m}\right) && \text{Axiom 3} \\ &= 0 && \text{Axiom 2} \end{align*}\]In the last line, we see that all of the determinants in the summation are zero because each term is the determinant of a matrix that has a duplicate column vector.

$\square$

**Theorem 3:** Given a triangular matrix, $\boldsymbol{A} \in \mathbb{R}^{m \times m}$, its determinant can be computed by multiplying its diagonal entries.

**Proof:**

We will start with an upper triangular $3 \times 3$ matrix:

\[\boldsymbol{A} := \begin{bmatrix} a_{1,1} & a_{1,2} & a_{1,3} \\ 0 & a_{2,2} & a_{2,3} \\ 0 & 0 & a_{3,3}\end{bmatrix}\]Now, we will take the second column-vector and decompose it into the sum of two vectors. Because the determinant is linear by Axiom 3, we can rewrite the determinant as follows:

\[\text{Det}(\boldsymbol{A}) = \text{Det}\left(\begin{bmatrix} a_{1,1} & a_{1,2} & a_{1,3} \\ 0 & 0 & a_{2,3} \\ 0 & 0 & a_{3,3}\end{bmatrix} \right) + \text{Det}\left(\begin{bmatrix} a_{1,1} & 0 & a_{1,3} \\ 0 & a_{2,2} & a_{2,3} \\ 0 & 0 & a_{3,3}\end{bmatrix}\right)\]Note that in the first term of this sum, the column vectors are *linearly dependent* because the second column-vector can be re-written as a multiple of the first. Thus, according to Theorem 2, its determinant is zero. Hence, the entire first term is zero. Thus, we have:

We can repeat this process with the third column vector by decomposing it into the sum of two vectors and then utilizing the fact that the determinant is linear:

\[\text{Det}(\boldsymbol{A}) = \text{Det}\left(\begin{bmatrix} a_{1,1} & 0 & a_{1,3} \\ 0 & a_{2,2} & 0 \\ 0 & 0 & 0\end{bmatrix}\right) + \text{Det}\left(\begin{bmatrix} a_{1,1} & 0 & 0 \\ 0 & a_{2,2} & a_{2,3} \\ 0 & 0 & 0\end{bmatrix}\right) + \text{Det}\left(\begin{bmatrix} a_{1,1} & 0 & 0 \\ 0 & a_{2,2} & 0 \\ 0 & 0 & a_{3,3}\end{bmatrix}\right)\]Again, the first and second terms are zero because the columns of each matrix are linearly dependent. This leaves only the third term, which is the determinant of a diagonal matrix. Finally, we see that

\[\begin{align*}\text{Det}(\boldsymbol{A}) &= \text{Det}\left(\begin{bmatrix} a_{1,1} & 0 & 0 \\ 0 & a_{2,2} & 0 \\ 0 & 0 & a_{3,3}\end{bmatrix}\right) \\ &= a_{1,1}a_{2,2}a_{3,3}\text{Det}\left(\begin{bmatrix} 1 & 0 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1\end{bmatrix}\right) && \text{Axiom 3} \\ &= a_{1,1}a_{2,2}a_{3,3} && \text{Axiom 1}\end{align*}\]Thus, we see that the determinant of the diagonal matrix can be computed by multiplying the entries along the diagonal.

$\square$

**Theorem 4:** Given a matrix $\boldsymbol{A} \in \mathbb{R}^{m \times m}$, adding a multiple of one column-vector of $\boldsymbol{A}$ to another column-vector does not change the determinant of $\boldsymbol{A}$.

**Proof:**

Say we add $k$ times column $j$ to column $i$. First,

\[\begin{align*}\text{Det}(\boldsymbol{a}_{*,1} \dots, k\boldsymbol{a}_{*,j} + \boldsymbol{a}_{*,i}, \dots, \boldsymbol{a}_{*,j}, \dots, \boldsymbol{a}_{*,m}) &= \text{Det}(\boldsymbol{a}_{*,1} \dots, k\boldsymbol{a}_{*,j}, \dots, \boldsymbol{a}_{*,j}, \dots, \boldsymbol{a}_{*,m}) + \text{Det}(\boldsymbol{a}_{*,1} \dots, \boldsymbol{a}_{*,i}, \dots, \boldsymbol{a}_{*,j}, \dots, \boldsymbol{a}_{*,m}) && \text{Axiom 3} \\ &= k\text{Det}(\boldsymbol{a}_{*,1} \dots, \boldsymbol{a}_{*,j}, \dots, \boldsymbol{a}_{*,j}, \dots, \boldsymbol{a}_{*,m}) + \text{Det}(\boldsymbol{a}_{*,1} \dots, \boldsymbol{a}_{*,i}, \dots, \boldsymbol{a}_{*,j}, \dots, \boldsymbol{a}_{*,m}) && \text{Axiom 3} \\ &= \text{Det}(\boldsymbol{a}_{*,1} \dots, \boldsymbol{a}_{*,i}, \dots, \boldsymbol{a}_{*,j}, \dots, \boldsymbol{a}_{*,m}) && \text{Axiom 2}\end{align*}\]The last line follows from the fact that the first term is computing the determinant of a matrix that has duplicate column-vectors. By Axiom 2, its determinant is zero.

$\square$

**Theorem 5:** Given an elementary matrix that represents row-scaling, $\boldsymbol{E} \in \mathbb{R}^{m \times m}$, where $\boldsymbol{E}$ scales the $j$th row of a system of linear equations by $k$, its determinant is simply $k$.

**Proof:**

Such a matrix would be a diagonal matrix with all ones along the diagonal except for the $j$th entry, which would be $k$. For example, a $4 \times 4$ row-scaling matrix that scales the second row by $4$ would look as follows:

\[\boldsymbol{A} := \begin{bmatrix}1 & 0 & 0 & 0 \\ 0 & k & 0 & 0 \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1\end{bmatrix}\]Note that this is a triangular matrix and thus, by Theorem 3, its determinant is given by the product along its diagonals, which is simply $k$.

$\square$

**Theorem 6:** Given an elementary matrix that represents row-swapping, $\boldsymbol{E} \in \mathbb{R}^{m \times m}$ that swaps the $i$th and $j$th rows of a system of linear equations, its determinant is simply -1.

**Proof:**

A row-swapping matrix that swaps the $i$th and $j$th rows of a system of linear equations can be formed by simply swapping the $i$th and $j$th column vectors of the identity matrix. For example, a $4 \times 4$ row-scaling matrix that swaps the second and third rows would look as follows:

\[\boldsymbol{A} := \begin{bmatrix}1 & 0 & 0 & 0 \\ 0 & 0 & 1 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & 1\end{bmatrix}\]Axiom 1 for the definition of the determinant states that the determinant of the identity matrix is 1. According to Theorem 1, if we swap two column-vectors of a matrix, its determinant is multiplied by -1. Here we are swapping two column-vectors of the identity matrix yielding a determinant of -1.

$\square$

**Theorem 7:** Given an elementary matrix that represents a row-sum, $\boldsymbol{E} \in \mathbb{R}^{m \times m}$, that adds row $j$ multiplied by $k$ to row $i$, its determinant is simply 1.

**Proof:**

An elementary matrix representing a row-sum, $\boldsymbol{E} \in \mathbb{R}^{m \times m}$ that adds row $j$ multiplied by $k$ to row $i$, is simply the identity matrix, but with element $(i, j)$ equal to $k$. For example, a $4 \times 4$ row-scaling matrix that adds three times the first row to the third would be given by:

\[\boldsymbol{A} := \begin{bmatrix}1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 3 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1\end{bmatrix}\]This matrix is a lower-triangular matrix. By Theorem 3, the determinant of a triangular matrix is the product of the diagonal entries. In this case, all of the diagonal entries are 1. Thus, the determinant is 1.

$\square$

**Theorem 8:** Given matrices $\boldsymbol{A}, \boldsymbol{B} \in \mathbb{R}^{m \times m}$, it holds that $\text{Det}(\boldsymbol{AB}) = \text{Det}(\boldsymbol{A})\text{Det}(\boldsymbol{B})$

**Proof:**

First, if $\boldsymbol{A}$ is singular, then $\text{Det}(\boldsymbol{AB})$ is also singular. By Theorem 2, we know that the determinant of a singular matrix is zero and thus it trivially holds that $\text{Det}(\boldsymbol{AB}) = \text{Det}(\boldsymbol{A})\text{Det}(\boldsymbol{B})$ since both $\text{Det}(\boldsymbol{AB}) = 0$ and also $\text{Det}(\boldsymbol{A})\text{Det}(\boldsymbol{B}) = 0$ (since $\text{Det}(\boldsymbol{A})=0$). The same is true if $\boldsymbol{B}$ is singular. Thus, our proof will focus only on the case in which $\boldsymbol{A}$ and $\boldsymbol{B}$ are both invertible.

We first begin by proving that given an elementary matrix $\boldsymbol{E}$, it holds that

\[\text{Det}(\boldsymbol{AE}) = \text{Det}(\boldsymbol{A})\text{Det}(\boldsymbol{E})\]To show this, let us consider each of the three types of elementary matrices individually. First, if $\boldsymbol{E}$ is a scaling matrix where the $i$th diagonal entry is a scalar $k$, then $\boldsymbol{AE}$ will scale the $i$th column of $\boldsymbol{A}$ by $k$. By Axiom 3 of the determinant, scaling a single column will scale the full deterimant. Moreover, by Theorem 5, the determinant of $\boldsymbol{E}$ is $k$. Thus,

\[\begin{align*}\text{Det}(\boldsymbol{AE}) &= \text{Det}(\boldsymbol{A})k && \text{Axiom 3} \\ &= \text{Det}(\boldsymbol{A})\text{Det}(\boldsymbol{E}) && \text{Theorem 5}\end{align*}\]Now let’s consider the case where $\boldsymbol{E}$ is a row-swapping matrix that swaps rows $i$ and $j$. Here, $\boldsymbol{AE}$ will swap the $i$th and $j$th columns of $\boldsymbol{A}$. By Theorem 1, swapping any two columns of a matrix flips the sign of the determinant. Moreover, by Theorem 6, the determinant of $\boldsymbol{E}$ is $-1$. Thus,

\[\begin{align*}\text{Det}(\boldsymbol{AE}) &= -\text{Det}(\boldsymbol{A}) && \text{Theorem 1} \\ &= \text{Det}(\boldsymbol{A})\text{Det}(\boldsymbol{E}) && \text{Theorem 6}\end{align*}\]Finally let’s consider the case where $\boldsymbol{E}$ is an elementary matrix that adds a multiple of one row to another. Here, $\boldsymbol{AE}$ will add a multiple of one column of $\boldsymbol{A}$ to another column. By Theorem 4, this does not change the determinant. Moreover, by Theorem 7, the determinant of $\boldsymbol{E}$ is simply 1. Thus,

\[\begin{align*}\text{Det}(\boldsymbol{AE}) &= -\text{Det}(\boldsymbol{A}) && \text{Theorem 4} \\ &= \text{Det}(\boldsymbol{A})\text{Det}(\boldsymbol{E}) && \text{Theorem 7}\end{align*}\]Now we have proven that for an invertible matrix $\boldsymbol{A}$ and elementary matrix $\boldsymbol{E}$ it holds that $\text{Det}(\boldsymbol{AE}) = \text{Det}(\boldsymbol{A})\text{Det}(\boldsymbol{E})$. Let’s now turn to the general case where we are multiplying two invertible matrices $\boldsymbol{A}$ and $\boldsymbol{B}$.

As we have shown, any invertible matrix can be decomposed as the product of some sequence of elementary matrices. Thus, we can write,

\[\boldsymbol{AB} = \boldsymbol{A}\boldsymbol{E}_1\boldsymbol{E}_2 \dots \boldsymbol{E}_k\]Then we apply our newly proven fact that for an elementary matrix $\boldsymbol{E}$ it holds that $\text{Det}(\boldsymbol{AE}) = \text{Det}(\boldsymbol{A})\text{Det}(\boldsymbol{E})$. We apply this fact in an iterative way from right to left as shown below:

\[\begin{align*}\text{Det}(\boldsymbol{AB}) &= \text{Det}(\boldsymbol{A}\boldsymbol{E}_1\boldsymbol{E}_2 \dots \boldsymbol{E}_{k-1}\boldsymbol{E}_k) \\ &= \text{Det}(\boldsymbol{A}\boldsymbol{E}_1\boldsymbol{E}_2 \dots \boldsymbol{E}_{k-1})\text{Det}(\boldsymbol{E}_k) \\ &= \text{Det}(\boldsymbol{A}\boldsymbol{E}_1\boldsymbol{E}_2 \dots \boldsymbol{E}_{k-2})\text{Det}(\boldsymbol{E}_{k-1})\text{Det}(\boldsymbol{E}_k) \\ &= \text{Det}(\boldsymbol{A})\prod_{i=1}^k \text{Det}(\boldsymbol{E}_i) \end{align*}\]Now we reverse the rule that $\text{Det}(\boldsymbol{AE}) = \text{Det}(\boldsymbol{A})\text{Det}(\boldsymbol{E})$ again moving going from right to left:

\[\begin{align*}\text{Det}(\boldsymbol{AB}) &= \text{Det}(\boldsymbol{A})\prod_{i=1}^k \text{Det}(\boldsymbol{E}_i) \\ &= \text{Det}(\boldsymbol{A})\left( \prod_{i=1}^{k-2}\text{Det}(\boldsymbol{E}_i)\right) \text{Det}(\boldsymbol{E}_{k-1}\boldsymbol{E}_{k}) \\ &= \text{Det}(\boldsymbol{A})\left( \prod_{i=1}^{k-3}\text{Det}(\boldsymbol{E}_i)\right) \text{Det}(\boldsymbol{E}_{k-2} \boldsymbol{E}_{k-1}\boldsymbol{E}_{k}) \\ &= \text{Det}(\boldsymbol{A}) \text{Det}(\boldsymbol{E}_1\boldsymbol{E}_2 \dots \boldsymbol{E}_{k-1}\boldsymbol{E}_k) \\ &= \text{Det}(\boldsymbol{AB})\end{align*}\]$\square$

**Theorem 9:** Given a square matrix $\boldsymbol{A}$, it holds that $\text{Det}(\boldsymbol{A}) = \text{Det}(\boldsymbol{A}^T)$.

**Proof:**

First, if $\boldsymbol{A}$ is singular, then $\boldsymbol{A}^T$ is also singular. By Theorem 2, the determinant of a singular matrix is zero and thus, $\text{Det}(\boldsymbol{A}) = \text{Det}(\boldsymbol{A}^T) = 0$.

If $\boldsymbol{A}$ is invertible, then we can express $\boldsymbol{A}$ as the product of some sequence of elementary matrices:

\[\boldsymbol{A} = \boldsymbol{E}_1\boldsymbol{E}_2 \dots \boldsymbol{E}_k\]Then,

\[\boldsymbol{A}^T = \boldsymbol{E}_k^T \boldsymbol{E}_{k-1}^T \dots \boldsymbol{E}_1^T\]We note that the determinant of every elementary matrix is equal to its transpose. Both scaling elementary matrices and row-swapping matrices are symmetric and thus, their transposes are equal to themselves. Thus the determinant of their transpose is equal to themselves. For a row-sum elementary matrix, the transpose is still a diagonal matrix and thus, its determinant also equals the determinant of its transpose (since by Theorem 3, the determinant of a triangular matrix can be computed by summing the diagonal entries).

Then, we apply Theorem 8 and see that

\[\begin{align*}\text{Det}(\boldsymbol{A}) &= \text{Det}(\boldsymbol{E}_1\boldsymbol{E}_2 \dots \boldsymbol{E}_k) \\ &= \text{Det}(\boldsymbol{E}_1) \text{Det}(\boldsymbol{E}_2) \dots \text{Det}(\boldsymbol{E}_k) \\ &= \text{Det}(\boldsymbol{E}_1^T) \text{Det}(\boldsymbol{E}_2^T) \dots \text{Det}(\boldsymbol{E}_k^T) \\ &= \text{Det}(\boldsymbol{E}_{k}^T) \text{Det}(\boldsymbol{E}_{k-1}^T) \dots \text{Det}(\boldsymbol{E}_1^T) \\ &= \text{Det}(\boldsymbol{E}_k^T \boldsymbol{E}_{k-1}^T \dots \boldsymbol{E}_1^T) \\ &= \text{Det}(\boldsymbol{A}^T)\end{align*}\]$\square$

**Theorem 10:** The determinant of a matrix is linear with respect to the row vectors of the matrix.

**Proof:**

This proof follows from Theorem 9 and Axiom 3 of the determinant. Specifically, given a square matrix $\boldsymbol{A} \in \mathbb{R}^{m \times m}$. Let $\boldsymbol{a}_{1,*}, \dots, \boldsymbol{a}_{m,*}$ be the row vectors of $\boldsymbol{A}$. Also, let the $j$th row be represented as the sum of two vectors, $\boldsymbol{a}_{j,*} = \boldsymbol{u} + \boldsymbol{v}$:

\[\boldsymbol{A} := \begin{bmatrix}\boldsymbol{a}_{1, *} \\ \vdots \\ \boldsymbol{u} + \boldsymbol{v} \\ \vdots \\ \boldsymbol{a}_{m, *}\end{bmatrix}\]The determinant of $\boldsymbol{A}$ is then:

\[\begin{align*}\text{Det}(\boldsymbol{A}) &= \text{Det}(\boldsymbol{A}^T) && \text{By Theorem 9} \\ &= \text{Det}\left(\boldsymbol{a}_{1,*}, \dots, \boldsymbol{v} + \boldsymbol{u}, \dots, \boldsymbol{a}_{m,*} \right) \\ &= \text{Det}\left(\boldsymbol{a}_{1,*}, \dots, \boldsymbol{v}, \dots, \boldsymbol{a}_{m,*} \right) + \text{Det}\left(\boldsymbol{a}_{1,*}, \dots, \boldsymbol{u}, \dots, \boldsymbol{a}_{m,*} \right) && \text{By Axiom 3} \end{align*}\]Next, let the $j$th row be scaled by some constant $c$. That is,

\[\boldsymbol{A} := \begin{bmatrix}\boldsymbol{a}_{1, *} \\ \vdots \\ c\boldsymbol{a}_{j,*} \\ \vdots \\ \boldsymbol{a}_{m, *}\end{bmatrix}\]Then,

\[\begin{align*} \text{Det}(\boldsymbol{A}) &= \text{Det}(\boldsymbol{A}^T) && \text{By Theorem 9} \\ &= \text{Det}\left( \boldsymbol{a}_{1,*}, \dots, c\boldsymbol{a}_{j,*}, \dots, \boldsymbol{a}_{m,*} \right) \\ &= c\text{Det}\left(\boldsymbol{a}_{1,*}, \dots, \boldsymbol{a}_{j,*}, \dots, \boldsymbol{a}_{m,*} \right) && \text{By Axiom 3} \end{align*}\]$\square$

**Theorem 11:** Let $\text{Det} : \mathbb{R}^{n \times n} \rightarrow \mathbb{R}$ be a function that satisfies the following three properties:

1. $\text{Det}(\boldsymbol{I}) = 1$

2. Given $\boldsymbol{A} \in \mathbb{R}^{n \times n}$, if any two columns of $\boldsymbol{A}$ are equal, then $\text{Det}(\boldsymbol{A}) = 0$

3. $\text{Det}$ is linear with respect to the column-vectors of its input.

Then $\text{Det}$ is given by

\(\text{Det}(\boldsymbol{A}) := \begin{cases} a_{1,1}a_{2,2} - a_{1,2}a_{2,1} & \text{if $m = 2$} \\ \sum_{i=1}^m (-1)^{i+1} a_{i,1} \text{Det}(\boldsymbol{A}_{-i,-1}) & \text{if $m > 2$}\end{cases}\)

**Proof:**

Given a matrix $\boldsymbol{A} \in \mathbb{R}^{n \times n}$, let $\boldsymbol{A}_{-i,-j}$ be the sub-matrix of $\boldsymbol{A}$ where the $i$th and $j$th rows are deleted. For example, for a $3 \times 3$ matrix

\[\boldsymbol{A} := \begin{bmatrix} a & b & c \\ d & e & f \\ g & h & i \end{bmatrix}\]$\boldsymbol{A}_{-1,-1}$ would be

\[\boldsymbol{A}_{-1,-1} = \begin{bmatrix} e & f \\ h & i \end{bmatrix}\]Now, consider an elementary matrix $\boldsymbol{E} \in \mathbb{R}^{m \times m}$. Let us define $\boldsymbol{E}’$ to be an elementary matrix in $\mathbb{R}^{(m+1) \times (m+1)}$ that is formed by taking $\boldsymbol{E}$, but adding a new row and column where the first element is 1. That is,

\[\boldsymbol{E}' := \begin{bmatrix}1 & 0 & \dots & 0 \\ 0 & & & & \\ \vdots & & \boldsymbol{E} & & \\ 0 & & & &\end{bmatrix}\]Notice that $\boldsymbol{E}’$ is an elementary row matrix that represents the same operation as $\boldsymbol{E}$, but performs this operation on a matrix in $(m+1) \times (m+1)$ instead of a matrix $m \times m$ and leaves the first row alone. Thus, by Theorems 5, 6, and 7 it follows that:

\[\text{Det}(\boldsymbol{E}') = \text{Det}(\boldsymbol{E})\]Let’s keep this fact in the back of our mind, but now turn our attention towards $\boldsymbol{A}$. Let us say that $\boldsymbol{A}$ is a matrix where the first column-vector only has a non-zero entry in the first row. That is, let’s say $\boldsymbol{A}$ looks as follows

\[\boldsymbol{A} = \begin{bmatrix}a_{1,1} & a_{1,2} & \dots & a_{1,m} \\ 0 & & & & \\ \vdots & & \boldsymbol{A}_{-1,-1} & & \\ 0 & & & &\end{bmatrix}\]Then we can show that $\text{Det}(\boldsymbol{A}) = a_{1,1}\text{Det}(\boldsymbol{A}_{-1,-1})$ via the following (see notes below the derivation for more details on some of the key steps):

\[\begin{align*}\text{Det}(\boldsymbol{A}) &= \text{Det}\left( \begin{bmatrix}a_{1,1} & 0 & \dots & 0 \\ 0 & & & & \\ \vdots & & \boldsymbol{A}_{-1,-1} & & \\ 0 & & & &\end{bmatrix} \right) + \text{Det}\left( \begin{bmatrix}0 & a_{1,2} & \dots & 0 \\ 0 & & & & \\ \vdots & & \boldsymbol{A}_{-1,-1} & & \\ 0 & & & &\end{bmatrix} \right) + \dots + \text{Det}\left( \begin{bmatrix}0 & 0 & \dots & a_{1,m} \\ 0 & & & & \\ \vdots & & \boldsymbol{A}_{-1,-1} & & \\ 0 & & & &\end{bmatrix} \right) \ \text{Theorem 10} \\ &= \text{Det}\left( \begin{bmatrix}a_{1,1} & 0 & \dots & 0 \\ 0 & & & & \\ \vdots & & \boldsymbol{A}_{-1,-1} & & \\ 0 & & & &\end{bmatrix} \right) \ \text{by Theorem 2 (see Note 1)} \\ &= a_{1,1}\text{Det}\left( \begin{bmatrix}1 & 0 & \dots & 0 \\ 0 & & & & \\ \vdots & & \boldsymbol{A}_{-1,-1} & & \\ 0 & & & &\end{bmatrix}\right) \ \text{by Axiom 3} \\ &= a_{1,1}\text{Det}\left( \begin{bmatrix}1 & 0 & \dots & 0 \\ 0 & & & & \\ \vdots & & \boldsymbol{A}_{-1,-1} & & \\ 0 & & & &\end{bmatrix}\right) \\ &= a_{1,1}\text{Det}\left( \begin{bmatrix}1 & 0 & \dots & 0 \\ 0 & & & & \\ \vdots & & \boldsymbol{E}_1 \boldsymbol{E}_2 \dots \boldsymbol{E}_k & & \\ 0 & & & &\end{bmatrix}\right) \ \text{see Note 2} \\ &= a_{1,1}\text{Det}(\boldsymbol{E}'_1 \boldsymbol{E}'_2 \dots \boldsymbol{E}'_k) \ \text{see Note 3} \\ &= a_{1,1}\text{Det}(\boldsymbol{E}'_1) \text{Det}(\boldsymbol{E}'_2) \dots \text{Det}(\boldsymbol{E}'_k) \ \text{Theorem 8} \\ &= a_{1,1}\text{Det}(\boldsymbol{E}_1) \text{Det}(\boldsymbol{E}_2) \dots \text{Det}(\boldsymbol{E}_k) \\ &= a_{1,1}\text{Det}(\boldsymbol{E}_1\boldsymbol{E}_2 \dots \boldsymbol{E}_k) \ \text{Theorem 8} \\ &= a_{1,1}\text{Det}(\boldsymbol{A}_{1,1}) \end{align*}\]**Note 1:** Notice in the previous line, all of the determinants except the first are zero since the first column vector of each of their matrix arguments is the zero vector. Thus, these are all singular matrices and by Theorem 2, their determinants are zero.

**Note 2:** If $\boldsymbol{A}_{1,1}$ is invertible, then we can factor it into a product of elementary matrices:

where $k$ is some constant.

**Note 3:** Here, we use the fact that for some elementary matrix $\boldsymbol{E}$ and some matrix $\boldsymbol{B} \in \mathbb{R}^{m \times m}$, it holds that

We then apply this iteratively to

\[\begin{bmatrix}1 & 0 & \dots & 0 \\ 0 & & & & \\ \vdots & & \boldsymbol{E}_1 \boldsymbol{E}_2 \dots \boldsymbol{E}_k & & \\ 0 & & & &\end{bmatrix}\]Finally, at along last, we can derive the formula for the determinant. Let us consider a general matrix $\boldsymbol{A}$:

\[\boldsymbol{A} = \begin{bmatrix} a_{1,1} & a_{1,2} & a_{1,3} & \dots & a_{1,m} \\ a_{2,1} & & & & \\ a_{3,1} & & & & \\ \vdots & & \boldsymbol{A}_{-1,-1} & & \\ a_{m,1} & & & &\end{bmatrix}\]Then,

\[\begin{align*} \text{Det}(\boldsymbol{A}) &= \text{Det}\left(\begin{bmatrix} a_{1,1} & a_{1,2} & a_{1,3} & \dots & a_{1,m} \\ a_{2,1} & & & & \\ a_{3,1} & & & & \\ \vdots & & \boldsymbol{A}_{-1,-1} & & \\ a_{m,1} & & & &\end{bmatrix}\right) \\ &= \text{Det}\left(\begin{bmatrix} a_{1,1} & a_{1,2} & a_{1,3} & \dots & a_{1,m} \\ 0 & & & & \\ 0 & & & & \\ \vdots & & \boldsymbol{A}_{-1,-1} & & \\ 0 & & & &\end{bmatrix}\right)+ \text{Det}\left(\begin{bmatrix}0 & a_{1,2} & a_{1,3} & \dots & a_{1,m} \\ a_{2,1} & & & & \\ 0 & & & & \\ \vdots & & \boldsymbol{A}_{-1,-1} & & \\ 0 & & & &\end{bmatrix}\right) + \text{Det}\left(\begin{bmatrix} 0 & a_{1,2} & a_{1,3} & \dots & a_{1,m} \\ 0 & & & & \\ a_{3,1} & & & & \\ \vdots & & \boldsymbol{A}_{-1,-1} & & \\ 0 & & & &\end{bmatrix}\right) + \dots + \text{Det}\left(\begin{bmatrix} 0 & a_{1,2} & a_{1,3} & \dots & a_{1,m} \\ 0 & & & & \\ 0 & & & & \\ \vdots & & \boldsymbol{A}_{-1,-1} & & \\ a_{m,1} & & & &\end{bmatrix}\right)\end{align*}\]For each term, we can move the row with a non-zero element in the first column to the top-row and maintain the relative order of the remaining $m-1$ rows. Performing this operation on each term in the summation will result in an alternation of addition and subtraction. The reason for this is that if the row we moving to the first row is even-numbered, this procedure will require an odd number of row swaps. On the other hand, if the row is odd-numbered, this procedure will require even number of swaps. This is illustrated by the following schematic:

Thus, we have

\[\begin{align*}\text{Det}(\boldsymbol{A}) \\ \ \ &= \text{Det}\left(\begin{bmatrix} a_{1,1} & a_{1,2} & a_{1,3} & \dots & a_{1,m} \\ 0 & & & & \\ 0 & & & & \\ \vdots & & \boldsymbol{A}_{-1,-1} & & \\ 0 & & & &\end{bmatrix}\right) - \text{Det}\left(\begin{bmatrix}a_{2,1} & a_{2,2} & a_{2,3} & \dots & a_{2,m} \\ 0 & & & & \\ 0 & & & & \\ \vdots & & \boldsymbol{A}_{-2,-1} & & \\ 0 & & & &\end{bmatrix}\right) + \text{Det}\left(\begin{bmatrix} a_{3,1} & a_{3,2} & a_{3,3} & \dots & a_{3,m} \\ 0 & & & & \\ 0 & & & & \\ \vdots & & \boldsymbol{A}_{-3,-1} & & \\ 0 & & & &\end{bmatrix}\right) - \dots +/- \text{Det}\left(\begin{bmatrix} a_{m,1} & a_{m,2} & a_{m,3} & \dots & a_{m,m} \\ 0 & & & & \\ 0 & & & & \\ \vdots & & \boldsymbol{A}_{-m,1} & & \\ 0 & & & &\end{bmatrix}\right) \\ &= \text{Det}\left(\begin{bmatrix} a_{1,1} & 0 & 0 & \dots & 0 \\ 0 & & & & \\ 0 & & & & \\ \vdots & & \boldsymbol{A}_{-1,-1} & & \\ 0 & & & &\end{bmatrix}\right) - \text{Det}\left(\begin{bmatrix}a_{2,1} & 0 & 0 & \dots & 0 \\ 0 & & & & \\ 0 & & & & \\ \vdots & & \boldsymbol{A}_{-2,-1} & & \\ 0 & & & &\end{bmatrix}\right) + \text{Det}\left(\begin{bmatrix} a_{3,1} & 0 & 0 & \dots & 0 \\ 0 & & & & \\ 0 & & & & \\ \vdots & & \boldsymbol{A}_{-3,-1} & & \\ 0 & & & &\end{bmatrix}\right) - \dots +/- \ \text{Det}\left(\begin{bmatrix} a_{m,1} & 0 & 0 & \dots & 0 \\ 0 & & & & \\ 0 & & & & \\ \vdots & & \boldsymbol{A}_{-m,-1} & & \\ 0 & & & &\end{bmatrix}\right) \\ &= a_{1,1}\text{Det}(\boldsymbol{A}_{-1,-1}) - a_{2,1}\text{Det}(\boldsymbol{A}_{-2,-1}) + a_{3,1}\text{Det}(\boldsymbol{A}_{-3,-1}) - \dots +/- a_{-m,1}\text{Det}(\boldsymbol{A}_{-m,-1})\end{align*}\]Thus we have arrived at our recursive formula where, for each term (corresponding to each row), we compute the determinant of a sub-matrix. This proceeds all the way down until we reach the $2 \times 2$ matrix that is defined as $a_{1,1}a_{2,2} - a_{1,2}a_{2,1}$. That is, putting it all together we arrive at the formula for the determinant:

\[\text{Det}(\boldsymbol{A}) := \begin{cases} a_{1,1}a_{2,2} - a_{1,2}a_{2,1} & \text{if $m = 2$} \\ \sum_{i=1}^m (-1)^{i+1} a_{i,1} \text{Det}(\boldsymbol{A}_{-i,-1}) & \text{if $m > 2$}\end{cases}\]$\square$

]]>Matrices are one of the fundamental objects studied in linear algebra. While on their surface they appear like simple tables of numbers, as we have previously described, this simplicity hides deeper mathematical structures that they contain. In this post, we will dive into the deeper structures within matrices by showing three vector spaces that are implicitly defined by every matrix:

- A column space
- A row space
- A null space

Not only will we discuss the definition for these spaces and how they relate to one another, we will also discuss how to best intuit these spaces and what their properties tell us about the matrix itself.

To understand these spaces, we will need to look at matrices from different perspectives. In a previous discussion on matrices, we discussed how there are three complementary perspectives for viewing matrices:

**Perspective 1:**A matrix as a table of numbers**Perspective 2:**A matrix as a list of vectors (both row and column vectors)**Perspective 3:**A matrix as a function mapping vectors from one space to another

By viewing matrices through these perspectives we can gain a better intuition for the vector spaces induced by matrices. Let’s get started.

The **column space** of a matrix is simply the vector space spanned by its column-vectors:

**Definition 1 (column space):** Given a matrix $\boldsymbol{A}$, the **column space** of $\boldsymbol{A}$, is the vector space that spans the column-vectors of $\boldsymbol{A}$

To understand the column space of a matrix $\boldsymbol{A}$, we will consider the matrix from Perspectives 2 and 3 – that is, $\boldsymbol{A}$ as a list of column vectors and as a function mapping vectors from one space to another.

**Understanding the column space when viewing matrices as lists of column vectors**

The least abstract way to view the column space of a matrix is when considering a matrix to be a simple list of column-vectors. For example:

The column space is then the vector space that is spanned by these three vectors. We see that in the example above, the column space is all of $\mathbb{R}^2$ since we can form *any* two-dimensional vector using a linear combination of these three vectors:

**Understanding the column space when viewing matrices as functions**

To gain a deeper understanding into the significance of the column space of a matrix, we will now consider matrices from the perspective of seeing them as functions between vector spaces. That is, recall for a given matrix $\boldsymbol{A} \in \mathbb{R}^{m \times n}$, we can view this matrix as a function that maps vectors from $\mathbb{R}^n$ to vectors in $\mathbb{R}^m$. This mapping is implemented by matrix-vector multiplication. A vector $\boldsymbol{x} \in \mathbb{R}^n$ is mapped to vector $\boldsymbol{b} \in \mathbb{R}^m$ via

\[\boldsymbol{Ax} = \boldsymbol{b}\]Stated more explicitly, we can define a function $T: \mathbb{R}^n \rightarrow \mathbb{R}^m$ as:

\[T(\boldsymbol{x}) := \boldsymbol{Ax}\]It turns out that the column space is simply the range of this function $T$! That is, it is the set of all vectors that $\boldsymbol{A}$ is capable of mapping to. To see why this is the case, recall that we can view matrix-vector multiplication between $\boldsymbol{A}$ and $\boldsymbol{x}$ as the act of taking a linear combination of the columns of $\boldsymbol{A}$ using the coefficients of $\boldsymbol{x}$ as coefficients:

Here we see that the output of this matrix-defined function will always be contained to the span of the column vectors of $\boldsymbol{A}$.

The **row space** of a matrix is the vector space spanned by its row-vectors:

**Definition 2 (row space):** Given a matrix $\boldsymbol{A}$, the **column space** of $\boldsymbol{A}$, is the vector space that spans the row-vectors of $\boldsymbol{A}$

To understand the row space of a matrix $\boldsymbol{A}$, we will consider the matrix from Perspective 2 – that is, we will view $\boldsymbol{A}$ as a list of row vectors. For example:

The row space is then the vector space that is spanned by these vectors. We see that in the example above, the row space is a hyperplane:

Unlike the column space, the row space cannot be interpreted as either the domain or range of the function defined by the matrix. So what is the geometric significance of the row space in the context of Perspective 3 (viewing matrices as functions)? Unfortunately, this does not become evident until we discuss the *null space*, which we will discuss in the next section!

The **null space** of a matrix is the third vector space that is induced by matrices. To understand the null space, we will need to view matrices from Perspective 3: matrices as functions between vector space.

Specifically, the null space of a matrix $\boldsymbol{A}$ is the set of all vectors that $\boldsymbol{A}$ maps to the zero vector. That is, the null space is all vectors, $\boldsymbol{x} \in \mathbb{R}^n$ for which $\boldsymbol{Ax} = \boldsymbol{0}$:

**Definition 3 (null space):** Given a matrix $\boldsymbol{A} \in \mathbb{R}^{m \times n}$, the **null space** of $\boldsymbol{A}$ is the set of vectors, $\{\boldsymbol{x} \in \mathbb{R}^n \mid \boldsymbol{Ax} = \boldsymbol{0}\}$

It turns out that there is a key relationship between the null space and the row space of a matrix: the null space is the **orthogonal complement** to the row space (Theorem 1 in the Appendix to this post). Before going further, let us define the orthogonal complement. Given a vector space $(\mathcal{V}, \mathcal{F})$, the orthogonal complement to this vector space is another vector space, $(\mathcal{V}’, \mathcal{F})$, where all vectors in $\mathcal{V}’$ are orthogonal to all vectors in $\mathcal{V}$:

**Definition 4 (orthogonal complement):** Given two vector spaces $(\mathcal{V}, \mathcal{F})$ and $(\mathcal{V}’, \mathcal{F})$ that share the same scalar field, each is an **orthogonal complement** to the other if $\forall \boldsymbol{v} \in \mathcal{V}, \ \forall \boldsymbol{v}’ \in \mathcal{V}’ \ \langle \boldsymbol{v}, \boldsymbol{v}’ \rangle = 0$

Stated more formally:

**Theorem 1 (null space is orthogonal complement of row space):** Given a matrix $\boldsymbol{A}$, the null space of $\boldsymbol{A}$ is the orthogonal complement to the row space of $\boldsymbol{A}$.

To see why the null space and row space are orthogonal complements, recall that we can view matrix-vector multiplication between a matrix $\boldsymbol{A}$ and a vector $\boldsymbol{x}$ as the process of taking a dot product of each row of $\boldsymbol{A}$ with $\boldsymbol{x}$:

\[\boldsymbol{Ax} := \begin{bmatrix} \boldsymbol{a}_{1,*} \cdot \boldsymbol{x} \\ \boldsymbol{a}_{2,*} \cdot \boldsymbol{x} \\ \vdots \\ \boldsymbol{a}_{m,*} \cdot \boldsymbol{x} \end{bmatrix}\]If $\boldsymbol{x}$ is in the null space of $\boldsymbol{A}$ then this means that $\boldsymbol{Ax} = \boldsymbol{0}$, which means that every dot product shown above is zero. That is,

\[\begin{align*}\boldsymbol{Ax} &= \begin{bmatrix} \boldsymbol{a}_{1,*} \cdot \boldsymbol{x} \\ \boldsymbol{a}_{2,*} \cdot \boldsymbol{x} \\ \vdots \\ \boldsymbol{a}_{m,*} \cdot \boldsymbol{x} \end{bmatrix} \\ &= \begin{bmatrix} 0 \\ 0 \\ \vdots \\ 0 \end{bmatrix} \\ &= \boldsymbol{0} \end{align*}\]Recall, if the dot product between a pair of vectors is zero, then the two vectors are orthogonal. Thus we see that if $\boldsymbol{x}$ is in the null space of $\boldsymbol{A}$ it *has* to be orthogonal to every row-vector of $\boldsymbol{A}$. This means that the null space is the orthogonal complement to the row space!

We can visualize the relationship between the row space and null space using our example matrix:

\[\begin{bmatrix}1 & 2 & 1 \\ 0 & 1 & -1\end{bmatrix}\]The null space for this matrix is comprised of all of the vectors that point along the red vector shown below:

Notice that this red vector is orthogonal to the hyperplane that represents the row space of $\boldsymbol{A}$.

The intrinsic dimensionality of the row space and column space are also related to one another and tell us alot about the matrix itself. Recall, the intrinsic dimensionality of a set of vectors is given by the maximal number of linearly independent vectors in the set. With this in mind, we can form the following definitions that describe the intrinsic dimensionalities of the row space and column space:

**Definition 3 (column rank):** Given a matrix $\boldsymbol{A} \in \mathbb{R}^{m \times n}$, the **column rank** of $\boldsymbol{A}$ is the maximum sized subset of the columns of $\boldsymbol{A}$ that are linearly independent.

**Definition 4 (row rank):** Given a matrix $\boldsymbol{A} \in \mathbb{R}^{m \times n}$, the **row rank** of $\boldsymbol{A}$ is the maximum sized subset of the rows of $\boldsymbol{A}$ that are linearly independent.

It turns out that intrinsic dimensionality of the row space and column space are always equal and thus the column rank will always equal the row rank:

**Theorem 2 (row rank equals column rank):** Given a matrix $\boldsymbol{A} \in \mathbb{R}^{m \times n}$, its row rank equals its column rank.

Because of the row rank and column rank are equal, one can simply talk about the **rank** of a matrix without the need to delineate whether we mean the row rank or the column rank.

Moreover, because the row rank equals the column rank of a matrix, a matrix of shape $m \times n$ can *at most* have a rank that is the minimum of $m$ and $n$. For example, a matrix with 3 rows and 5 columns can *at most* be of rank 3 (but it might be less!). In fact, we observed this phenomenon in our previous example matrix, which has a rank of 2:

As we can see, the column space spans all of $\mathbb{R}^2$ and thus, it’s intrinsic dimensionality is two. The row space spans a hyperplane in $\mathbb{R}^3$ and thus, it’s intrinsic dimensionality is also two.

Where the rank of a matrix describes the intrinsic dimensionality of the row and column spaces of a matrix, the **nullity** describes the intrinsic dimensionality of the null space:

**Definition 5 (nullity):** Given a matrix $\boldsymbol{A} \in \mathbb{R}^{m \times n}$, the **nullity** of $\boldsymbol{A}$ is the maximum number of linearly independent vectors that span the null space of $\boldsymbol{A}$.

There is a key relationship between nullity and rank: they sum to the number of columns of $\boldsymbol{A}$! This is proven in the rank-nullity theorem (proof provided in the Appendix to this post):

**Theorem 3 (rank-nullity theorem):** Given a matrix $\boldsymbol{A} \in \mathbb{R}^{m \times n}$, it holds that $\text{rank} + \text{nullity} = n$.

Below we illustrate this theorem with two examples:

On the left, we have a matrix whose rows span a hyperplane in $\mathbb{R}^3$, which is of dimension 2. The null space is thus a line, which has dimension 1. In contrast, on the right we have a matrix whose rows span a line in $\mathbb{R}^3$, which is of dimension 1. The null space here is a hyperplane that is orthogonal to this line. In both examples, the dimensionality of the row space and null space sum to 3, which is the number of columns of both matrices!

We can summarize the properties of the column space, row space, and null space with the following table organized around Perspective 2 (matrices as lists of vectors) and Perspective 3 (matrices as functions):

Moreover, we can summarize the relationships between these spaces with the following figure:

We conclude by discussing the vector spaces induced by invertible matrices. Recall, that a square matrix $\boldsymbol{A} \in \mathbb{R}^{n \times n}$ is invertible if and only if its columns are linearly independent (see Theorem 4 in the Appendix to my previous blog post). This implies that for invertible matrices, it holds that:

- The column space spans all of $n$ since they are linearly independent. This implies that the column rank is $n$
- The row space spans all of $n$, since by Theorem 2 the row rank equals the column rank
- The nullity is zero, since by Theorem 3 the nullity plus the rank must equal the number of columns

We call an invertible matrix **full rank** since the rank equals the number of rows and columns. The rank is “full” because it cannot be increased any further past the number of its columns/rows!

Moreover, we see that there is only *one* vector in the null space of an invertible matrix since its nullity is zero (a dimensionality of zero corresponds to a single point). If we think back on our discussion of invertible matrices as characterizing invertible functions, then this fact makes sense. For a function to be invertible, it must be one-to-one and onto. So if we use an invertible matrix $\boldsymbol{A}$ to define the function

Then it holds that every vector, $\boldsymbol{b}$, in the range of the function $T$ has exactly one vector, $\boldsymbol{x}$, in the domain of $T$ for which $T(\boldsymbol{x}) = \boldsymbol{b}$. This must also hold for the zero vector. Thus, there must be only one vector, $\boldsymbol{x}$, for which $\boldsymbol{Ax} = \boldsymbol{0}$. Hence, the null space comprises just a single vector.

Now we may ask, what vector is this singular member of the null space. It turns out, it’s the zero vector! We see this by applying Theorem 1 from this previous blog post.

**Theorem 1 (null space is orthogonal complement of row space):** Given a matrix $\boldsymbol{A}$, the null space of $\boldsymbol{A}$ is the orthogonal complement to the row space of $\boldsymbol{A}$.

**Proof:**

To prove that the null space of $\boldsymbol{A}$ is the orthogonal complement of the row space, we must show that every vector in the null space is orthogonal to every vector in the row space. Consider vector $\boldsymbol{x}$ in the null space of $\boldsymbol{A}$. By the definition of the null space (Definition 5), this means that $\boldsymbol{Ax} = \boldsymbol{0}$. That is,

\[\begin{align*}\boldsymbol{Ax} &= \begin{bmatrix} \boldsymbol{a}_{1,*} \cdot \boldsymbol{x} \\ \boldsymbol{a}_{2,*} \cdot \boldsymbol{x} \\ \vdots \\ \boldsymbol{a}_{m,*} \cdot \boldsymbol{x} \end{bmatrix} \\ &= \begin{bmatrix} 0 \\ 0 \\ \vdots \\ 0 \end{bmatrix} \\ &= \boldsymbol{0} \end{align*}\]We note that for each row $i$, we see that $\boldsymbol{a}_{i,*} \cdot \boldsymbol{x} = 0$ implies that each row vector of $\boldsymbol{A}$ is orthogonal to $\boldsymbol{x}$.

$\square$

**Theorem 2 (row rank equals column rank):** Given a matrix $\boldsymbol{A} \in \mathbb{R}^{m \times n}$, the row rank equals the column rank

**Proof:**

This proof is described on Wikipedia, provided here in my own words.

Let $r$ be the row rank of $\boldsymbol{A}$ and let $\boldsymbol{b}_1, \dots, \boldsymbol{b}_r \in \mathbb{R}^n$ be a set of basis vectors for the row space of $\boldsymbol{A}$. Now, let $c_1, c_2, \dots, c_r$ be coefficients such that

\[\sum_{i=1}^r c_i \boldsymbol{Ab}_i = \boldsymbol{0}\]Furthermore, let

\[\boldsymbol{v} := \sum_{i=1}^r c_i\boldsymbol{b}_i\]We see that

\[\begin{align*} \sum_{i=1}^r c_i \boldsymbol{Ab}_i &= \boldsymbol{0} \\ \implies \sum_{i=1}^r \boldsymbol{A}c_i \boldsymbol{b}_i &= \boldsymbol{0} \\ \implies \boldsymbol{A} \sum_{i=1}^r c_i\boldsymbol{b}_i &= \boldsymbol{0} \\ \boldsymbol{Av} &= \boldsymbol{0} \end{align*}\]With this in mind, we can prove that $\boldsymbol{v}$ must be the zero vector. To do so, we first note that $\boldsymbol{v}$ is in both the row space of $\boldsymbol{A}$ and the null space of $\boldsymbol{A}$. It is in the row space $\boldsymbol{A}$ because it is a linear combination of the basis vectors of the row space of $ \boldsymbol{A}$. It is in the null space of $\boldsymbol{A}$, because $\boldsymbol{Av} = \boldsymbol{0}$. From Theorem 1, $\boldsymbol{v}$ must be orthogonal to all vectors in the row space of $\boldsymbol{A}$, which includes itself. The only vector that is orthogonal to itself is the zero vector and thus, $\boldsymbol{v}$ must be the zero vector.

This in turn implies that $c_1, \dots, c_r$ must be zero. We know this because $\boldsymbol{b}_1, \dots, \boldsymbol{b}_r \in \mathbb{R}^n$ are basis vectors, which by definition cannot include the zero vector. Thus we have proven that the only assignment of values for $c_1, \dots, c_r$ for which $\sum_{i=1}^r c_i \boldsymbol{Ab}_i = \boldsymbol{0}$ is the assignment for which they are all zero. By Theorem 1 in a previous post, this implies that $\boldsymbol{Ab}_1, \dots, \boldsymbol{Ab}_r$ must be linearly independent.

Moreover, by the definition of matrix-vector multiplication, we know that $\boldsymbol{Ab}_1, \dots, \boldsymbol{Ab}_r$ are in the column space of $\boldsymbol{A}$. Thus, we have proven that there exist *at least* $r$ independent vectors in the column space of $\boldsymbol{A}$. This means that the column rank of $\boldsymbol{A}$ is *at least* $r$. That is,

We can repeat this exercise on the transpose of $\boldsymbol{A}$, which tells us that

\[\text{row rank of} \ \boldsymbol{A} \geq \text{column rank of} \ \boldsymbol{A}\]These statements together imply that the column rank and row rank of $\boldsymbol{A}$ are equal.

$\square$

**Theorem 3 (rank-nullity theorem):** Given a matrix $\boldsymbol{A} \in \mathbb{R}^{m \times n}$, it holds that $\text{rank} + \text{nullity} = n$.

**Proof:**

This proof is described on Wikipedia, provided here in my own words along with supplemental schematics of the matrices used in the proof.

Let $r$ be the rank of the matrix. This means that there are $r$ linearly independent column vectors in $\boldsymbol{A}$. Without loss of generality, we can arrange $\boldsymbol{A}$ so that the first $r$ columns are linearly independent, and the remaining $n - r$ columns can be written as a linear combination of the first $r$ columns. That is, we can write:

\[\boldsymbol{A} = \begin{pmatrix} \boldsymbol{A}_1 & \boldsymbol{A}_2 \end{pmatrix}\]where $\boldsymbol{A}_1$ and $\boldsymbol{A}_2$ are the two partitions of the matrix as shown below:

because the columns of $\boldsymbol{A}_2$ are linear combinations of the columns of $\boldsymbol{A}_1$, there exists a matrix $\boldsymbol{B} \in \mathbb{R}^{r \times n-r}$ for which

\[\boldsymbol{A}_2 = \boldsymbol{A}_1 \boldsymbol{B}\]This is depicted below:

Now, consider a matrix

\[\boldsymbol{X} := \begin{pmatrix} -\boldsymbol{B} \\ \boldsymbol{I}_{n-r} \end{pmatrix}\]That is, $\boldsymbol{X}$ is formed by concatenating the $n-r \times n-r$ identity matrix below the $-\boldsymbol{B}$ matrix. Now, we see that $\boldsymbol{AX} = \boldsymbol{0}$:

\[\begin{align*}\boldsymbol{AX} &= \begin{pmatrix} \boldsymbol{A}_1 & \boldsymbol{A}_1\boldsymbol{B} \end{pmatrix} \begin{pmatrix} -\boldsymbol{B} \\ \boldsymbol{I}_{n-r} \end{pmatrix} \\ &= -\boldsymbol{A}_1\boldsymbol{B} + \boldsymbol{A}_1\boldsymbol{B} \\ &= \boldsymbol{0} \end{align*}\]Depicted schematically,

Thus, we see that every column of $\boldsymbol{X}$ is in the null space of $\boldsymbol{A}$.

We now show that these column vectors are linearly independent. To do so, we will consider a vector $\boldsymbol{u} \in \mathbb{R}^{n-r}$ such that

\[\boldsymbol{Xu} = \boldsymbol{0}\]For this to hold, we see that $\boldsymbol{u}$ must be zero:

\[\begin{align*}\boldsymbol{Xu} &= \boldsymbol{0} \\ \implies \begin{pmatrix} -\boldsymbol{B} \\ \boldsymbol{I}_{n-r} \end{pmatrix}\boldsymbol{u} &= \begin{pmatrix} \boldsymbol{0}_r \\ \boldsymbol{0}_{n-r} \end{pmatrix} \\ \\ \implies \begin{pmatrix} -\boldsymbol{Bu} \\ \boldsymbol{u} \end{pmatrix} &= \begin{pmatrix} \boldsymbol{0}_r \\ \boldsymbol{0}_{n-r} \end{pmatrix} \end{align*}\]By Theorem 1 in a previous post, this proves that the columns of $\boldsymbol{X}$ are linearly independent. So we have shown that there exists $n-r$ linearly independent vectors in the null space of $\boldsymbol{A}$, which means the nullity is *at least* $n-r$.

We now show that *any* other vector in the null space of $\boldsymbol{A}$ that is not a column of $\boldsymbol{X}$ can be written as a linear combination of the columns of $\boldsymbol{X}$. If we can prove this fact, we will have proven that the nullity is exactly equal to $n-r$ and is not greater.

We start by again considering a vector $\boldsymbol{u} \in \mathbb{R}^n$ that we assume is in the null space of $\boldsymbol{A}$. We partition this vector into two segments: one segment, $\boldsymbol{u}_1$, comprising the first $r$ elements and a second segment, $\boldsymbol{u}_2$, comprising the remaining $n-r$ elements:

\[\boldsymbol{u} = \begin{pmatrix}\boldsymbol{u}_1 \\ \boldsymbol{u}_2 \end{pmatrix}\]Because we assume that $\boldsymbol{u}$ is in the null space, it must hold that $\boldsymbol{Au} = \boldsymbol{0}$. Depicted schematically:

Solving for $\boldsymbol{u}$, we see that

\[\begin{align*} \boldsymbol{Au} &= \boldsymbol{0} \\ \begin{pmatrix} \boldsymbol{A}_1 & \boldsymbol{A}_2 \end{pmatrix} \begin{pmatrix}\boldsymbol{u}_1 \\ \boldsymbol{u}_2 \end{pmatrix} &= \boldsymbol{0} \\\begin{pmatrix} \boldsymbol{A}_1 & \boldsymbol{A}_1\boldsymbol{B} \end{pmatrix} \begin{pmatrix}\boldsymbol{u}_1 \\ \boldsymbol{u}_2 \end{pmatrix} &= \boldsymbol{0} \\ \implies \boldsymbol{A}_1\boldsymbol{u}_1 + \boldsymbol{A}_1\boldsymbol{B}\boldsymbol{u}_2 &= \boldsymbol{0} \\ \implies \boldsymbol{A}_1 (\boldsymbol{u}_1 + \boldsymbol{Bu}_2) &= \boldsymbol{0} \\ \implies \boldsymbol{u}_1 + \boldsymbol{Bu}_2 &= \boldsymbol{0} \\ \implies \boldsymbol{u}_1 = -\boldsymbol{Bu}_2 \end{align*}\]Thus,

\[\begin{align*}\boldsymbol{u} &= \begin{pmatrix}\boldsymbol{u}_1 \\ \boldsymbol{u}_2 \end{pmatrix} \\ &= \begin{pmatrix} -\boldsymbol{Bu}_2 \\ \boldsymbol{u}_2 \end{pmatrix} \\ &= \begin{pmatrix} -\boldsymbol{B} \\ \boldsymbol{I}_{n-r} \end{pmatrix}\boldsymbol{u}_2 \\ &= \boldsymbol{X}\boldsymbol{u}_2 \end{align*}\]Thus, we see that $\boldsymbol{u}$ must be the linear combination of the columns of $\boldsymbol{X}$! Thus we have shown that:

- There exists $n-r$ linearly independent vectors in the null space of $\boldsymbol{A}$
- Any vector in the null space can be expressed as a linear combination of these linearly independent vectors

This proves that the nullity is $n-r$, and thus, the nullity $n-r$ plus the rank $r$, equals $n$.

$\square$

]]>Variational autoencoders (VAEs), introduced by Kingma and Welling (2013), are a class of probabilistic models that find latent, low-dimensional representations of data. VAEs are thus a method for performing dimensionality reduction to reduce data down to their intrinsic dimensionality.

As their name suggests, VAEs are a type of **autoencoder**. An autoencoder is a model that takes a vector, $\boldsymbol{x}$, compress it into a lower-dimensional vector, $\boldsymbol{z}$, and then decompress $\boldsymbol{z}$ back into $\boldsymbol{x}$. The architecture of an autoencoder can can be visualized as follows:

Here we see one function (usually a neural network), $h_\phi$, compresses $\boldsymbol{x}$ into a low-dimensional data point, $\boldsymbol{z}$, and then another function (also a neural network), $f_\theta$, decompresses it back into an approximation of $\boldsymbol{x}$, here denoted as $\boldsymbol{x}’$. The variables $\phi$ and $\theta$ denote the parameters to the two neural networks.

VAEs can be understood as a type of autoencoder like the one shown above, but with some important differences: Unlike standard autoencoders, VAEs are probablistic models and as we will see in this post, their “autoencoding” ability emerges from how the probabilistic model is defined and fit.

In summary, it helps to view VAEs from two angles:

**Probabilistic generative model:**VAEs are probabilistic generative models of independent, identically distributed samples, $\boldsymbol{x}_1, \dots, \boldsymbol{x}_n$. In this model, each sample, $\boldsymbol{x}_i$, is associated with a latent (i.e. unobserved), lower-dimensional variable $\boldsymbol{z}_i$. Variational autoencoders are a generative model in that they describe a joint distribution over samples and their associated latent variable, $p(\boldsymbol{x}, \boldsymbol{z})$.**Autoencoder:**VAEs are a form of autoencoders. Unlike traditional autoencoders, VAEs can be veiwed as*probabilistic*rather than deterministic; Given an input sample, $\boldsymbol{x}_i$, the compressed representation of $\boldsymbol{x}_i$, $\boldsymbol{z}_i$, is randomly generated.

In this blog post we will show how VAEs can be viewed through both of these lenses. We will then provide an example implementation of a VAE and apply it to the MNIST dataset of hand-written digits.

At their foundation, a VAE defines a probabilistic generative process for “generating” data points that reside in some $D$-dimensional vector space. This generative process goes as follows: we first sample a latent variable $\boldsymbol{z} \in \mathbb{R}^{J}$ where $J < D$ from some distribution such as a standard normal distribution:

\[\boldsymbol{z} \sim N(\boldsymbol{0}, \boldsymbol{I})\]Then, we use a determinstic function to map $\boldsymbol{z}$ to the parameters, $\boldsymbol{\psi}$, of another distribution used to sample $\boldsymbol{x} \in \mathbb{R}^D$. Most commonly, we construct $\psi$ from $\boldsymbol{z}$ using neural networks:

\[\begin{align*} \boldsymbol{\psi} &:= f_{\theta}(\boldsymbol{z}) \\ \boldsymbol{x} &\sim \mathcal{D}(\boldsymbol{\psi}) \end{align*}\]where $\mathcal{D}$ is a parametric distribution and $f$ is a neural network parameterized by a set of parameters $\theta$. Here’s a schematic illustration of the generative process:

This generative process can be visualized graphically below:

Interestingly, this model enables us to fit very complicated distributions. That’s because although the distribution of $\boldsymbol{z}$ and the conditional distribution of $\boldsymbol{x}$ given $\boldsymbol{z}$ may both be simple (e.g., both normal distributions), the non-linear mapping between $\boldsymbol{z}$ and $\psi$ via the neural network leads to the marginal distribution of $\boldsymbol{x}$ becoming complex:

Now, let’s say we are given a dataset consisting of data points $\boldsymbol{x}_1, \dots, \boldsymbol{x}_n \in \mathbb{R}^D$ that were generated by a VAE. We may be interested in two tasks:

- For fixed $\theta$, for each $\boldsymbol{x}_i$, compute the posterior distribution $p_{\theta}(\boldsymbol{z}_i \mid \boldsymbol{x}_i)$
- Find the maximum likelihood estimates of $\theta$

Unfortunately, for a fixed $\theta$, solving for the posterior $p_{\theta}(\boldsymbol{z}_i \mid \boldsymbol{x}_i)$ using Bayes Theorem is intractible due to the fact that the denominator in the formula for Bayes Theorem requires marginalizing over $\boldsymbol{z}_i$:

\[p_\theta(\boldsymbol{z}_i \mid \boldsymbol{x}_i) = \frac{p_\theta(\boldsymbol{x}_i \mid \boldsymbol{z}_i)p(\boldsymbol{z}_i)}{\int p_\theta(\boldsymbol{x}_i \mid \boldsymbol{z}_i)p(\boldsymbol{z}_i) \ d\boldsymbol{z}_i }\]This marginalization requires solving an integral over all of the dimensions of the latent space! This is not feasible to calculate. Estimating $\theta$ via maximum likelihood estimation also requires solving this integral:

\[\begin{align*}\hat{\theta} &:= \text{argmax}_\theta \prod_{i=1}^n p_\theta(\boldsymbol{x}_i) \\ &= \text{argmax}_\theta \prod_{i=1}^n \int p_\theta(\boldsymbol{x}_i \mid \boldsymbol{z}_i)p(\boldsymbol{z}_i) \ d\boldsymbol{z}_i \end{align*}\]Variational autoencoders find approximate solutions to both of these intractible inference problems using variational inference. First, let’s assume that $\theta$ is fixed and attempt to approximate $p_\theta(\boldsymbol{z}_i \mid \boldsymbol{x}_i)$. Variational inference is a method for performing such approximations by first choosing a set of probability distributions, $\mathcal{Q}$, called the *variational family*, and then finding the distribution $q(\boldsymbol{z}_i) \in \mathcal{Q}$ that is “closest to” $p_\theta(\boldsymbol{z}_i \mid \boldsymbol{x}_i)$.

Variational inference uses the KL-divergence between $q(\boldsymbol{z}_i)$ and $p_\theta(\boldsymbol{z}_i \mid \boldsymbol{x}_i)$ as its measure of “closeness”. Thus, the goal of variational inference is to minimize the KL-divergence. It turns out that the task of minimizing the KL-divergence is equivalent to the task of maximizing a quantity called the evidence lower bound (ELBO), which is defined as

\[\begin{align*} \text{ELBO}(q) &:= E_{\boldsymbol{z}_1, \dots, \boldsymbol{z}_n \overset{\text{i.i.d.}}{\sim} q}\left[ \sum_{i=1}^n \log p_\theta(\boldsymbol{x}_i, \boldsymbol{z}_i) - \sum_{i=1}^n \log q(\boldsymbol{z}_i) \right] \\ &= \sum_{i=1}^n E_{z_i \sim q} \left[\log p_\theta(\boldsymbol{x}_i, \boldsymbol{z}_i) - \log q(\boldsymbol{z}_i) \right] \end{align*}\]Thus, variational inference entails finding

\[\hat{q} := \text{arg max}_{q \in \mathcal{Q}} \ \text{ELBO}(q)\]Now, so far we have assumed that $\theta$ is fixed. Is it possible to find both $q$ and $\theta$ jointly? As we discuss in a previous post on variational inference, it is perfectly reasonable to define the ELBO as a function of *both* $q$ and $\theta$ and then to maximize the ELBO jointly with respect to both of these parameters:

Why is this a reasonable thing to do? Recall the ELBO is a *lower bound* on the marginal log-likelihood $p_\theta(x_1, \dots, x_n)$. Thus, optimizing the ELBO with respect to $\theta$ increases the lower bound of the log-likelihood. Below, we depict this process where $\hat{\theta}$ maximizes the ELBO and $\theta^*$ is the true maximum of the log-likelihood (This figure is adapted from this blog post by Jakub Tomczak):

VAEs use a variational family with the following form:

\[\mathcal{Q} := \left\{ N(h^{(1)}_\phi(\boldsymbol{x}), \text{diag}(\exp(h^{(2)}\phi(\boldsymbol{x})))) \mid \phi \in \mathbb{R}^R \right\}\]where $h^{(1)}_\phi$ and $h^{(2)}_\phi$ are two neural networks that map the original object, $\boldsymbol{x}$, to the mean, $\boldsymbol{\mu}$, and the logarithm of the variance, $\log \boldsymbol{\sigma}^2$, of the approximate posterior distribution. $R$ is the number of parameters to these neural networks.

Said a different way, we define $q_\phi(\boldsymbol{z} \mid \boldsymbol{x})$ as

\[q_\phi(\boldsymbol{z} \mid \boldsymbol{x}) := N(h^{(1)}_\phi(\boldsymbol{x}), \text{diag}(\exp(h^{(2)}_\phi(\boldsymbol{x}))))\]Said a third way, the approximate posterior distribution can be sampled via the following process:

\[\begin{align*}\boldsymbol{\mu} &:= h^{(1)}_\phi(\boldsymbol{x}) \\ \log \boldsymbol{\sigma}^2 &:= h^{(2)}_\phi(\boldsymbol{x}) \\ \boldsymbol{z} &\sim N(\boldsymbol{\mu}, \text{diag}(\boldsymbol{\sigma^2})) \end{align*}\]This can be visualized as follows:

Note that $h^{(1)}_\phi$ and $h^{(2)}_\phi$ may either be two entirely separate neural networks or may share some subset of parameters. We use $h_\phi$ to refer to the full neural network (or union of two separate neural networks) comprising both $h^{(1)}_\phi$ and $h^{(2)}_\phi$ as shown below:

Thus, maximizing the ELBO over $\mathcal{Q}$ reduces to maximizing the ELBO over the neural network parameters $\phi$ (in addition to $\theta$ as discussed previously):

\[\begin{align*}\hat{\phi}, \hat{\theta} &= \text{arg max}_{\phi, \theta} \ \text{ELBO}(\phi, \theta) \\ &:= \text{arg max}_{\phi, \theta} \ \sum_{i=1}^n E_{\boldsymbol{z}_i \sim q_\phi(\boldsymbol{z}_i \mid \boldsymbol{x}_i)}\left[ \log p_\theta(\boldsymbol{x}_i, \boldsymbol{z}_i) - \log q_\phi(\boldsymbol{z}_i \mid \boldsymbol{x}_i) \right] \end{align*}\]One detail to point out here is that the approximation of the posterior over each $\boldsymbol{z}_i$ is defined by a set of parameters $\phi$ that are shared accross all samples $\boldsymbol{z}_1, \dots, \boldsymbol{z}_n$. That is, we use a single set of neural network parameters $\phi$ to encode the posterior distribution $q_\phi(\boldsymbol{z}_i \mid \boldsymbol{x}_i)$. Note, we *could* have gone a different route and defined a *separate* variational distribution $q_i$ for each $\boldsymbol{z}_i$ that is not conditioned on $\boldsymbol{x}_i$. That is, to define the variational posterior as $q_{\phi_i}(\boldsymbol{z}_i)$, where each $\boldsymbol{z}_i$ has its own set of parameters $\phi_i$. Here, $q_{\phi_i}(\boldsymbol{z}_i)$ does not condition on $\boldsymbol{x}_i$. Why don’t we do this instead? The answer is that for extremely large datasets it’s easier to perform VI when $\phi$ are shared across all data points because it reduces the number of parameters we need to search over in our optimization. This act of defining a common set of parameters shared across all of the independent posteriors is called **amortized variational inference**.

Now that we’ve set up the optimization problem, we need to solve it. Unfortunately, the expectation present in the ELBO makes this difficult as it requires integrating over all possible values for $\boldsymbol{z}_i$:

\[\begin{align*}\text{ELBO}(\phi, \theta) &= \sum_{i=1}^n E_{\boldsymbol{z} \sim q_\phi(\boldsymbol{z} \mid \boldsymbol{x})}\left[ \log p_\theta(\boldsymbol{x}_i, \boldsymbol{z}_i) - \log q_\phi(\boldsymbol{z}_i \mid \boldsymbol{x}_i) \right] \\ &= \sum_{i=1}^n \int_{\boldsymbol{z}_i} q_\phi(\boldsymbol{z}_i \mid \boldsymbol{x}_i) \left[ \log p_\theta(\boldsymbol{x}_i, \boldsymbol{z}_i) - \log q_\phi(\boldsymbol{z}_i \mid \boldsymbol{x}_i) \right] \ d\boldsymbol{z}_i \end{align*}\]We address this challenge by using the **reparameterization gradient** method, which we discussed in a previous blog post. We will review this method here; however, see my previous post for a more detailed explanation.

In brief, the reparameterization method maximizes the ELBO via stochastic gradient ascent in which stochastic gradients are formulated by first performing the **reparameterization trick** followed by Monte Carlo sampling. The reparameterization trick works as follows: we “reparameterize” the distribution $q_\phi(z_i \mid x_i)$ in terms of a surrogate random variable $\epsilon_i \sim \mathcal{J}$ and a determinstic function $g$ in such a way that sampling $z_i$ from $q_\phi(z_i \mid x_i)$ is performed as follows:

One way to think about this is that instead of sampling $\boldsymbol{z}_i$ directly from our variational posterior $q_\phi(\boldsymbol{z}_i \mid \boldsymbol{x}_i)$, we “re-design” the generative process of $\boldsymbol{z}_i$ such that we first sample a surrogate random variable $\boldsymbol{\epsilon}_i$ and then transform $\boldsymbol{\epsilon}_i$ into $\boldsymbol{z}_i$ all while ensuring that in the end, the distribution of $\boldsymbol{z}_i$ still follows $q_\phi(\boldsymbol{z}_i \mid \boldsymbol{x}_i)$. Following the reparameterization trick, we can re-write the ELBO as follows:

\[\text{ELBO}(\phi, \theta) := \sum_{i=1}^n E_{\epsilon_i \sim \mathcal{D}}\left[ \log p_\theta(\boldsymbol{x}_i, g_\phi(\boldsymbol{\epsilon}_i, \boldsymbol{x}_i)) - \log q_\phi(g_\phi(\boldsymbol{\epsilon}_i, \boldsymbol{x}_i) \mid \boldsymbol{x}_i) \right]\]We then approximate the ELBO via Monte Carlo sampling. That is, for each sample, $i$, we first sample random variables from our surrogate distribution $\mathcal{D}$:

\[\boldsymbol{\epsilon}'_{i,1}, \dots, \boldsymbol{\epsilon}'_{i,L} \sim \mathcal{D}\]Then we can compute a Monte Carlo approximation to the ELBO:

\[\tilde{\text{ELBO}}(\phi, \theta) := \frac{1}{n} \sum_{i=1}^n \frac{1}{L} \sum_{l=1}^L \left[ \log p_\theta(\boldsymbol{x}_i, g_\phi(\boldsymbol{\epsilon}'_{i,l}, \boldsymbol{x}_i)) - \log q_\phi(g_\phi(\boldsymbol{\epsilon}'_{i,l}, \boldsymbol{x}_i) \mid \boldsymbol{x}_i) \right]\]Now the question becomes, what reparameterization can we use? Recall that for the VAEs discussed here $q_\phi(\boldsymbol{z}_i \mid \boldsymbol{x}_i)$ is a normal distribution:

\[q_\phi(\boldsymbol{z} \mid \boldsymbol{x}) := N(h^{(1)}_\phi(\boldsymbol{x}), \exp(h^{(2)}_\phi(\boldsymbol{x})) \boldsymbol{I})\]This naturally can be reparameterized as:

\[\begin{align*}\boldsymbol{\epsilon}_i &\sim N(\boldsymbol{0}, \boldsymbol{I}) \\ z_i &:= h^{(1)}_\phi(\boldsymbol{x}) + \sqrt{\exp(h^{(2)}_\phi(\boldsymbol{x}))}\boldsymbol{\epsilon}_i \end{align*}\]Thus, our function $g$ is simply the function that shifts $\boldsymbol{\epsilon}_i$ by $h^{(1)}_\phi(\boldsymbol{x})$ and scales it by $\exp(h^{(2)}_\phi(\boldsymbol{x}))$. That is,

\[g(\boldsymbol{\epsilon}_i, \boldsymbol{x}_i) := h^{(1)}_\phi(\boldsymbol{x}) + \sqrt{\exp(h^{(2)}_\phi(\boldsymbol{x}))}\boldsymbol{\epsilon}_i\]Because $\tilde{\text{ELBO}}(\phi, \theta)$ is differentiable with respect to both $\phi$ and $\theta$ (notice that $f_\phi(\boldsymbol{x}_i)$ and $h_\phi(\boldsymbol{x}_i)$ are neural networks which are differentiable), we can form the gradient:

\[\nabla_{\phi, \theta} \tilde{\text{ELBO}}(\phi, \theta) = \frac{1}{n} \sum_{i=1}^n \frac{1}{L} \sum_{l=1}^L \nabla_{\phi, \theta} \left[ \log p_\theta(\boldsymbol{x}_i, g_\phi(\epsilon'_{i,l}, \boldsymbol{x}_i)) - \log q_\phi(g_\phi(\epsilon'_{i,l}, \boldsymbol{x}_i) \mid \boldsymbol{x}_i) \right]\]This gradient can then be used to perform gradient ascent. To compute this gradient, we can apply automatic differentiation. Then we an use these gradients to perform gradient descent-based optimization. Thus, we can utilize the extensive toolkit developed for training deep learning models!

For the VAE model there is a modification that we can make to reduce the variance of the Monte Carlo gradients. We first re-write the original ELBO in a different form:

\[\begin{align*}\text{ELBO}(\phi, \theta) &= \sum_{i=1}^n E_{z_i \sim q} \left[ \log p_\theta(\boldsymbol{x}_i, \boldsymbol{z}_i) - \log q( \boldsymbol{z}_i \mid \boldsymbol{x}_i) \right] \\ &= \sum_{i=1}^n \int q(\boldsymbol{z}_i \mid \boldsymbol{x}_i) \left[\log p_\theta(\boldsymbol{x}_i, \boldsymbol{z}_i) - \log q(\boldsymbol{z}_i \mid \boldsymbol{x}_i) \right] \ d\boldsymbol{z}_i \\ &= \sum_{i=1}^n \int q(\boldsymbol{z}_i \mid \boldsymbol{x}_i) \left[\log p_\theta(\boldsymbol{x}_i \mid \boldsymbol{z}_i) + \log p(\boldsymbol{z}_i) - \log q(\boldsymbol{z}_i \mid \boldsymbol{x}_i) \right] \ d\boldsymbol{z}_i \\ &= \sum_{i=1}^n E_{\boldsymbol{z}_i \sim q} \left[ \log p_\theta(\boldsymbol{x}_i \mid \boldsymbol{z}_i) \right] + \sum_{i=1}^n \int q(\boldsymbol{z}_i \mid \boldsymbol{x}_i) \left[\log p(\boldsymbol{z}_i) - \log q(\boldsymbol{z}_i \mid \boldsymbol{x}_i) \right] \ d\boldsymbol{z}_i \\ &= \sum_{i=1}^n E_{\boldsymbol{z}_i \sim q} \left[ \log p_\theta(\boldsymbol{x}_i \mid \boldsymbol{z}_i) \right] + \sum_{i=1}^n E_{\boldsymbol{z}_i \sim q}\left[ \log \frac{ p(\boldsymbol{z}_i)}{q(\boldsymbol{z}_i \mid \boldsymbol{x}_i)} \right] \\ &= \sum_{i=1}^n E_{\boldsymbol{z}_i \sim q} \left[ \log p_\theta(\boldsymbol{x}_i \mid \boldsymbol{z}_i) \right] - KL(q(\boldsymbol{z}_i \mid \boldsymbol{x}_i) \ || \ p(\boldsymbol{z}_i)) \end{align*}\]Recall the VAEs we have considered in this blog post have defined $p(\boldsymbol{z})$ to be the standard normal distribution $N(\boldsymbol{0}, \boldsymbol{I})$. In this particular case, it turns out that the KL-divergence term above can be expressed analytically (See the Appendix to this post):

\[KL(q_\phi(\boldsymbol{z}_i \mid \boldsymbol{x}_i) \mid\mid p(\boldsymbol{z}_i)) = -\frac{1}{2} \sum_{j=1}^J \left(1 + h^{(2)}_\phi(\boldsymbol{x}_i)_j - \left(h^{(1)}_\phi(\boldsymbol{x}_i)\right)_j^2 - \exp(h^{2}_\phi(\boldsymbol{x}_i)_j) \right)\]Note above the KL-divergence is calculated by summing over each dimension in the latent space. The full ELBO is:

\[\begin{align*} \text{ELBO}(\phi, \theta) &= \frac{1}{n} \sum_{i=1}^n \left[\frac{1}{2} \sum_{j=1}^J \left(1 + h^{(2)}_\phi(\boldsymbol{x}_i)_j - \left(h^{(1)}_\phi(\boldsymbol{x}_i)\right)_j^2 - \exp(h^{2}_\phi(\boldsymbol{x}_i)_j) \right) + E_{\boldsymbol{z}_i \sim q_\phi(\boldsymbol{z}_i \mid \boldsymbol{x}_i)} \left[\log p_\theta(\boldsymbol{x}_i \mid \boldsymbol{z}_i) \right]\right] \end{align*}\]Then, we can apply the reparameterization trick to this formulation of the ELBO and derive the following Monte Carlo approximation:

\[\begin{align*}\text{ELBO}(\phi, \theta) &\approx \frac{1}{n} \sum_{i=1}^n \left[\frac{1}{2} \sum_{j=1}^J \left(1 + h^{(2)}_\phi(\boldsymbol{x}_i)_j - \left(h^{(1)}_\phi(\boldsymbol{x}_i)\right)_j^2 - \exp(h^{2}_\phi(\boldsymbol{x}_i)_j) \right) + \frac{1}{L} \sum_{l=1}^L \left[\log p_\theta(\boldsymbol{x}_i \mid h^{(1)}_\phi(\boldsymbol{x}) + \sqrt{\exp(h^{(2)}_\phi(\boldsymbol{x}_i))}\boldsymbol{\epsilon}_{i,l}) \right] \right]\end{align*}\]Though this equation looks daunting, the feature to notice is that it is differentiable with respect to both $\phi$ and $\theta$. Therefore, we can apply automatic differentation to derive the gradients that are needed to perform stochastic gradient ascent!

Lastly, one may ask: why does this stochastic gradient have reduced variance than the version discussed previously? Intuitively, terms within the ELBO’s expectation are being “pulled out” and computed analytically (i.e., the KL-divergence). Since these terms are analytical, less of this quantity is determined by the variability from sampling each $\boldsymbol{\epsilon}’_{i,l}$ and thus, there will be less overall variability.

So far, we have described VAEs in the context of probabilistic modeling. That is, we have described how the VAE is a probabilistic model that describes each high-dimensional datapoint, $\boldsymbol{x}_i$, as being “generated” from a lower dimensional data point $\boldsymbol{z}_i$. This generating procedure utilizes a neural network to map $\boldsymbol{z}_i$ to the parameters of the distribution $\mathcal{D}$ required to sample $\boldsymbol{x}_i$. Moreover, we can infer the parameters and latent variables to this model via VI. To do so, we solve a sort of inverse problem in which use a neural network to map each $\boldsymbol{x}_i$ into parameters of the variational posterior distribution $q$ required to sample $\boldsymbol{z}_i$.

Now, what happens if we tie the variational posterior $q_\phi(\boldsymbol{z} \mid \boldsymbol{x})$ to the data generating distribution $p_\theta(\boldsymbol{x} \mid \boldsymbol{z})$? That is, given a data point $\boldsymbol{x}$, we first sample $\boldsymbol{z}$ from the variational posterior distribution,

\[\boldsymbol{z} \sim q_\phi(\boldsymbol{z} \mid \boldsymbol{x})\]then we generate a new data point, $\boldsymbol{x}’$, from $p(\boldsymbol{x} \mid \boldsymbol{z})$:

\[\boldsymbol{x}' \sim p_\theta(\boldsymbol{x} \mid \boldsymbol{z})\]We can visualize this process schematically below:

Notice the similarity of the above process to the standard autoencoder:

We see that a VAE performs the same sort of compression and decompression as a standard autoencoder! One can view the VAE as a “probabilistic” autoencoder. Instead of mapping each $\boldsymbol{x}_i$ directly to $\boldsymbol{z}_i$, the VAE maps $\boldsymbol{x}_i$ to a *distribution* over $\boldsymbol{z}_i$ from which $\boldsymbol{z}_i$ is *sampled*. This randomly sampled $\boldsymbol{z}_i$ is then used to parameterize the distribution from which $\boldsymbol{x}_i$ is sampled.

Let’s take a closer look at the loss function for the VAE with our new perspective of VAEs as being probabilistic autoencoders. The (exact) loss function is the negative of the ELBO:

\[\begin{align*} \text{loss}_{\text{VAE}}(\phi, \theta) &= -\sum_{i=1}^n E_{\boldsymbol{z}_i \sim q_\phi(\boldsymbol{z}_i \mid \boldsymbol{x}_i)} \left[\log p_\theta(\boldsymbol{x}_i \mid \boldsymbol{z}_i) \right] - KL(q_\phi(\boldsymbol{z}_i \mid \boldsymbol{x}_i) \ || \ p(\boldsymbol{z}_i)) \end{align*}\]Notice there are two terms with opposite signs. The first term, $\log p_\theta(\boldsymbol{x}_i \mid \boldsymbol{z}_i)$, can be seen as a **reconstruction loss** because it will push the model towards reconstructing the original $\boldsymbol{x}_i$ from its compressed representation, $\boldsymbol{z}_i$.

This can be made especially evident if our model assumes that $p_\theta(\boldsymbol{x}_i \mid \boldsymbol{z}_i)$ is a normal distribution. That is,

\[p_\theta(\boldsymbol{x}_i \mid \boldsymbol{z}_i) := N(f_\theta(\boldsymbol{z}_i), \sigma_{\text{decoder}}\boldsymbol{I})\]where $\sigma_{\text{decoder}}$ describes the amount of Gaussian noise around $f_\theta(\boldsymbol{z}_i)$. In this scenario, the VAE will attempt to minimize the squared error between the decoded $\boldsymbol{x}’_i$ and the original input data point $\boldsymbol{x}_i$. We can see this by writing out the analytical form of $\log p_\theta(\boldsymbol{x}_i \mid \boldsymbol{z}_i)$ and highlighting the squared error in red:

\[\begin{align*} \text{loss}_{\text{VAE}}(\phi, \theta) &= -\sum_{i=1}^n E_{\boldsymbol{z}_i \sim q_\phi(\boldsymbol{z}_i \mid \boldsymbol{x}_i)} \left[\log p_\theta(\boldsymbol{x}_i \mid \boldsymbol{z}_i) \right] - KL(q_\phi(\boldsymbol{z}_i \mid \boldsymbol{x}_i) \ || \ p(\boldsymbol{z}_i)) \\ &= -\sum_{i=1}^n E_{\boldsymbol{z}_i \sim q_\phi(\boldsymbol{z}_i \mid \boldsymbol{x}_i)} \left[\log \frac{1}{\sqrt{2 \pi \sigma_{\text{decoder}}^2}} - \frac{ \color{red}{||\boldsymbol{x}_i - f_\theta(\boldsymbol{z}_i) ||_2^2}}{2 \sigma_{\text{decoder}}^2} \right] - KL(q_\phi(\boldsymbol{z}_i \mid \boldsymbol{x}_i) \ || \ p(\boldsymbol{z}_i) \end{align*}\]Recall, in the loss function for the standard autoencoder, we are also minimizing this squared error as seen below:

\[\text{loss}_{AE} := \frac{1}{n} \sum_{i=1}^n \color{red}{||\boldsymbol{x}_i - f_\theta(h_\phi(\boldsymbol{x}_i)) ||_2^2}\]where $h_\phi$ is the encoding neural network and $f_\theta$ is the decoding neural network.

Thus, both the VAE and standard autoencoder will seek to minimize the squared error between the decoded data point $\boldsymbol{x}’_i$ and the original data point $\boldsymbol{x}_i$. In this regard, the two models are quite similar!

In contrast to standard autoencoders, the VAE also has a KL-divergence term with opposite sign to the reconstruction loss term. Notice, how this term will push the model to generate latent variables from $q_\phi(\boldsymbol{z}_i \mid \boldsymbol{x}_i)$ that follow the prior distribution, $p(\boldsymbol{z}_i)$, which in our case is a standard normal. We can think of this KL-term as a **regularization term** on the reconstruction loss. That is, the model seeks to reconstruct each $\boldsymbol{x}_i$; however, it also seeks to ensure that the latent $\boldsymbol{z}_i$’s are distributed according to a standard normal distribution!

If our generative model assumes that $p_\theta(\boldsymbol{x}_i \mid \boldsymbol{z}_i)$ is the normal distribution $N(f_\theta(\boldsymbol{z}), \sigma_{\text{decoder}}\boldsymbol{I})$, then the implementation of a standard autoencoder and a VAE are quite similar. To see this similarity, let’s examine the computation graph of the loss function that we would use to train each model. For the standard autoencoder, the computation graph looks like:

Below, we show Python code that defines and trains a simple autoencoder using PyTorch. This autoencoder has one fully connected hidden layer in the encoder and decoder. The function `train_model`

accepts a numpy array, `X`

, that stores the data matrix $X \in \mathbb{R}^{n \times J}$:

```
import torch
import torchvision.transforms as transforms
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
import torch.nn.functional as F
# Define autoencoder model
class autoencoder(nn.Module):
def __init__(
self,
x_dim,
hidden_dim,
z_dim=10
):
super(autoencoder, self).__init__()
# Define autoencoding layers
self.enc_layer1 = nn.Linear(x_dim, hidden_dim)
self.enc_layer2 = nn.Linear(hidden_dim, z_dim)
# Define autoencoding layers
self.dec_layer1 = nn.Linear(z_dim, hidden_dim)
self.dec_layer2 = nn.Linear(hidden_dim, x_dim)
def encoder(self, x):
# Define encoder network
x = F.relu(self.enc_layer1(x))
z = F.relu(self.enc_layer2(x))
return z
def decoder(self, z):
# Define decoder network
output = F.relu(self.dec_layer1(z))
output = F.relu(self.dec_layer2(output))
return output
def forward(self, x):
# Define the full network
z = self.encoder(x)
output = self.decoder(z)
return output
def train_model(X, learning_rate=1e-3, batch_size=128, num_epochs=15):
# Create DataLoader object to generate minibatches
X = torch.tensor(X).float()
dataset = TensorDataset(X)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Instantiate model and optimizer
model = autoencoder(x_dim=X.shape[1], hidden_dim=256, z_dim=50)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Define the loss function
def loss_function(output, x):
recon_loss = F.mse_loss(output, x, reduction='sum')
return recon_loss
# Train the model
for epoch in range(num_epochs):
epoch_loss = 0
for batch in dataloader:
# Zero the gradients
optimizer.zero_grad()
# Get batch
x = batch[0]
# Forward pass
output = model(x)
# Calculate loss
loss = loss_function(output, x)
# Backward pass
loss.backward()
# Update parameters
optimizer.step()
# Add batch loss to epoch loss
epoch_loss += loss.item()
# Print epoch loss
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(X)}")
```

For a VAE that assumes $p_\theta(\boldsymbol{x}_i \mid \boldsymbol{z}_i)$ to be a normal distribution, the computation graph looks like:

One note: we’ve made a slight change to the notation that we’ve used prior to this point; Here, the output of the decoder, $\boldsymbol{x}’$, can be interpreted to be the *mean* of the normal distribution, $p_\theta(\boldsymbol{x}_i \mid \boldsymbol{z}_i)$, rather than as a sample from this distribution.

The PyTorch code for the autoencoder would then be slightly altered in the following ways:

- The loss function between the two is modified to use the approximated ELBO rather than mean-squared error
- In the forward pass for the VAE, there is an added step for randomly sample $\boldsymbol{\epsilon}_i$ in order to generate $\boldsymbol{z}_i$

Aside from those two differences, the two implementations are quite similar. Below, we show code implementing a simple VAE. Note that here we sample only one value of $\epsilon_{i}$ per data point (that is, $L := 1$). In practice, if the training set is large enough, a single Monte Carlo sample per data sample often suffices to achieve good performance.

```
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
import torch.optim as optim
import torch.nn.functional as F
class VAE(nn.Module):
def __init__(
self,
x_dim,
hidden_dim,
z_dim=10
):
super(VAE, self).__init__()
# Define autoencoding layers
self.enc_layer1 = nn.Linear(x_dim, hidden_dim)
self.enc_layer2_mu = nn.Linear(hidden_dim, z_dim)
self.enc_layer2_logvar = nn.Linear(hidden_dim, z_dim)
# Define autoencoding layers
self.dec_layer1 = nn.Linear(z_dim, hidden_dim)
self.dec_layer2 = nn.Linear(hidden_dim, x_dim)
def encoder(self, x):
x = F.relu(self.enc_layer1(x))
mu = F.relu(self.enc_layer2_mu(x))
logvar = F.relu(self.enc_layer2_logvar(x))
return mu, logVar
def reparameterize(self, mu, logvar):
std = torch.exp(logvar/2)
eps = torch.randn_like(std)
z = mu + std * eps
return z
def decoder(self, z):
# Define decoder network
output = F.relu(self.dec_layer1(z))
output = F.relu(self.dec_layer2(output))
return x
def forward(self, x):
mu, logvar = self.encoder(x)
z = self.reparameterize(mu, logVar)
output = self.decoder(z)
return output, z, mu, logvar
# Define the loss function
def loss_function(output, x, mu, logvar):
recon_loss = F.mse_loss(output, x, reduction='sum') / batch_size
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + 0.002 * kl_loss
def train_model(
X,
learning_rate=1e-4,
batch_size=128,
num_epochs=15,
hidden_dim=256,
latent_dim=50
):
# Define the VAE model
model = VAE_simple(x_dim=X.shape[1], hidden_dim=hidden_dim, z_dim=latent_dim)
# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Convert X to a PyTorch tensor
X = torch.tensor(X).float()
# Create DataLoader object to generate minibatches
dataset = TensorDataset(X)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Train the model
for epoch in range(num_epochs):
epoch_loss = 0
for batch in dataloader:
# Zero the gradients
optimizer.zero_grad()
# Get batch
x = batch[0]
# Forward pass
output, z, mu, logvar = model(x)
# Calculate loss
loss = loss_function(output, x, mu, logvar)
# Backward pass
loss.backward()
# Update parameters
optimizer.step()
# Add batch loss to epoch loss
epoch_loss += loss.item()
# Print epoch loss
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(X)}")
return model
```

Another quick point to note about the above implementation is that we weighted the KL-divergence term in the loss function by 0.002. This weighting can be interpreted as hard-coding a larger variance in the normal distribution $p_\theta(\boldsymbol{x}_i \mid \boldsymbol{z}_i)$.

Let’s run the code shown above on MNIST. MNIST is a dataset consisting of 28x28 pixel images of hand-written digits. We will use a latent representation of length 50 (that is, $\boldsymbol{z} \in \mathbb{R}^{50}$). Note, code displayed previously implements a model that flattens each image into a vector and uses fully connected layers in both the encoder and decoder. For improved performance, one may instead want to use a convolutional neural network architecture, which have been shown to better model imaging data. Regardless, after enough training, the algorithm was able to reconstruct images that were not included in the training data. Here’s an example of the VAE reconstructing an image of the digit “7” that it was not trained on:

As VAEs are generative models, we can use them to generate new data! To do so, we first sample $\boldsymbol{z} \sim N(\boldsymbol{0}, \boldsymbol{I})$ and then output $f_\theta(\boldsymbol{z})$. Here are a few examples of generated samples:

Lastly, let’s explore the latent space learned by the model. First, let’s take the images of a “3” and “7”, and encode them into the latent space by sampling from $q_\phi(\boldsymbol{z} \mid \boldsymbol{x})$. Let’s let $\boldsymbol{z}_1$ and $\boldsymbol{z}_2$ be the latent vectors for “3” and “7” respectively:

Then, let’s interpolate between $\boldsymbol{z}_1$ and $\boldsymbol{z}_2$ and for each interpolated vector, $\boldsymbol{z}’$, we’ll compute $f_\theta(\boldsymbol{z’})$. Interestingly, we see a smooth transition between these digits as the 3 sort of morphs into the 7:

There are two key advantages that VAEs provide over standard autoencoders that make VAEs the better choice for certain types of problems:

1. **Data generation:** As a generative model, a VAE provides a method for *generating* new samples. To generate a new sample, one first samples a latent variable from the prior: $\boldsymbol{z} \sim p(\boldsymbol{z})$. Then one samples a new data sample, $\boldsymbol{x}$, from $p_\theta(\boldsymbol{x} \mid \boldsymbol{z})$. We demonstrated this ability in the prior section when we used our VAE to generate new MNIST digits.

2. **Control of the latent space:** VAEs enable tighter control over the structure of the latent space. As we saw previously, the ELBO’s KL-divergence term will push the model’s encoder towards encoding samples such that their latent random variables are distributed like the prior distribution. In this post, we discussed models that use a standard normal distribution as a prior. In this case, the latent random variables will tend to be distributed like a standard normal distribution and thus, will group in a sort of spherical pattern in the latent space. Below, we show the distribution of the first latent variable for each of 1,000 MNIST test digits using the VAE described in the previous section (right). The orange line shows the density function of the standard normal distribution. We also show the joint distribution of the first and second latent variables (left) for each of these 1,000 MNIST test digits:

3. **Modeling complex distributions:** VAEs provide a principled way to learn low dimensional representations of data that are distributed according to more complicated distributions. That is, they can be used when $p_\theta(\boldsymbol{x}_i \mid \boldsymbol{z}_i)$ is more complicated than the Gaussian distribution-based model we have discussed so far. One example of an application of VAEs using non-Gaussian distributions comes from single-cell RNA-seq analysis in the field of genomics. As described by Lopez et al. (2018), scVI is a tool that uses a VAE to model vectors of counts that are assumed to be distributed according to a zero-inflated negative binomial distribution. That is, the distribution $p_\theta(\boldsymbol{x}_i \mid \boldsymbol{z}_i)$ is a zero-inflated negative binomial.

Although we will not go into depth into this model here (perhaps for a future blog post), it provides an example of how VAEs can be easily extended to model data with specific distributional assumptions. In a sense, VAEs are “modular”. You can pick and choose your distributions. As long as the likelihood function is differentiable with respect to $\theta$ and the variational distribution is differentiable with respect to $\phi$, then you can fit the model using stochastic gradient descent of the ELBO using the reparameterization trick!

If we choose the variational distribution $q_\phi(\boldsymbol{z} \mid \boldsymbol{x})$ to be defined by the following generative process:

\[\begin{align*}\boldsymbol{\mu} &:= h^{(1)}_\phi(\boldsymbol{x}) \\ \log \boldsymbol{\sigma}^2 &:= h^{(2)}_\phi(\boldsymbol{x}) \\ \boldsymbol{z} &\sim N(\boldsymbol{\mu}, \boldsymbol{\sigma^2}\boldsymbol{I}) \end{align*}\]where $h^{(1)}$ and $h^{(2)}$ are the two encoding neural networks, and we choose the prior distribution over $\boldsymbol{z}$ to be a standard normal distribution:

\[p(\boldsymbol{z}) := N(\boldsymbol{0}, \boldsymbol{I})\]then the KL-divergence from $p(\boldsymbol{z})$ to $q_\phi(\boldsymbol{z} \mid \boldsymbol{x})$ is given by the following formula:

\[KL(q_\phi(\boldsymbol{z} \mid \boldsymbol{x}) \mid\mid p(\boldsymbol{z})) = -\frac{1}{2} \sum_{j=1}^J \left(1 + h^{(2)}_\phi(\boldsymbol{x})_j - \left(h^{(1)}_\phi(\boldsymbol{x})\right)_j^2 - \exp(h^{2}_\phi(\boldsymbol{x})_j) \right)\]where $J$ is the dimensionality of $\boldsymbol{z}$.

**Proof:**

First, let’s re-write the KL-divergence as follows:

\[\begin{align*}KL(q_\phi(\boldsymbol{z} \mid \boldsymbol{x}) || p(\boldsymbol{z})) &= \int q_\phi(\boldsymbol{z} \mid \boldsymbol{x}) \log \frac{q_\phi(\boldsymbol{z} \mid \boldsymbol{x})}{ p(\boldsymbol{z}))} \\ &= \int q_\phi(\boldsymbol{z} \mid \boldsymbol{x}) \log q_\phi(\boldsymbol{z} \mid \boldsymbol{x}) \ d\boldsymbol{z} - \int q_\phi(\boldsymbol{z} \mid \boldsymbol{x}) \log p(\boldsymbol{z}) \ d\boldsymbol{z}\end{align*}\]First, let’s compute the first term, $\int q_\phi(\boldsymbol{z} \mid \boldsymbol{x}) \log q_\phi(\boldsymbol{z} \mid \boldsymbol{x}) \ d\boldsymbol{z}$:

\[\begin{align*}\int q_\phi(\boldsymbol{z} \mid \boldsymbol{x}) \log q_\phi(\boldsymbol{z} \mid \boldsymbol{x}) \ d\boldsymbol{z} &= \int N(\boldsymbol{z}; \boldsymbol{\mu}, \text{diag}(\boldsymbol{\sigma^2})) \log N(\boldsymbol{z}; \boldsymbol{\mu}, \text{diag}(\boldsymbol{\sigma^2})) \ d\boldsymbol{z} \\ &= \int N(\boldsymbol{z}; \boldsymbol{\mu}, \text{diag}(\boldsymbol{\sigma^2})) \sum_{j=1}^J \log N(z_j; \mu_j, \sigma^2_j) \ d\boldsymbol{z} \\ &= \int N(\boldsymbol{z}; \boldsymbol{\mu}, \text{diag}(\boldsymbol{\sigma^2})) \sum_{j=1}^J \left[\log\left(\frac{1}{\sqrt{\sigma^2 2 \pi}}\right) - \frac{1}{2} \frac{(z_j - \mu_j)^2}{\sigma_j^2} \right] \ d\boldsymbol{z} \\ &= -\frac{J}{2} \log(2 \pi) - \frac{1}{2}\sum_{j=1}^J \log \sigma_j^2 - \frac{1}{2} \sum_{j=1}^J \int N(\boldsymbol{z}; \boldsymbol{\mu}, \text{diag}(\boldsymbol{\sigma^2}))\frac{(z_j - \mu_j)^2}{\sigma_j^2} \ d\boldsymbol{z} \\ &= -\frac{J}{2} \log(2 \pi) - \frac{1}{2}\sum_{j=1}^J \log \sigma_j^2 - \frac{1}{2} \sum_{j=1}^J \int_{z_j} N(z_j; \mu_j, \sigma^2_j)\frac{(z_j - \mu_j)^2}{\sigma_j^2} \int_{\boldsymbol{z}_{i \neq j}} \prod_{i \neq j} N(z_i; \mu_i, \sigma^2_i) \ d\boldsymbol{z}_{i \neq j} \ dz_j \\ &= -\frac{J}{2} \log(2 \pi) - \frac{1}{2}\sum_{j=1}^J \log \sigma_j^2 - \frac{1}{2} \sum_{j=1}^J \int N(z_j; \mu_j, \sigma^2_j)\frac{(z_j - \mu_j)^2}{\sigma_j^2} \ dz_j && \text{Note 1} \\ &= -\frac{J}{2} \log(2 \pi) - \frac{1}{2}\sum_{j=1}^J \log \sigma_j^2 - \frac{1}{2} \sum_{j=1}^J \frac{1}{\sigma_j^2} \int N(z_j; \mu_j, \sigma^2_j)(z_j^2 - 2z_j\mu_j + \mu_j^2) \ dz_j \\ &= -\frac{J}{2} \log(2 \pi) - \frac{1}{2}\sum_{j=1}^J \log \sigma_j^2 - \frac{1}{2} \sum_{j=1}^J \frac{1}{\sigma_j^2} \left(E[z_j^2] - E[2z_j\mu_j] + \mu_j^2 \right) && \text{Note 2} \\ &= -\frac{J}{2} \log(2 \pi) - \frac{1}{2}\sum_{j=1}^J \log \sigma_j^2 - \frac{1}{2} \sum_{j=1}^J \frac{1}{\sigma_j^2} \left(\mu_j^2 + \sigma^2 - 2\mu_j^2 + \mu_j^2 \right) && \text{Note 3} \\ &= -\frac{J}{2} \log(2 \pi) - \frac{1}{2}\sum_{j=1}^J \log \sigma_j^2 - \frac{1}{2} \sum_{j=1}^J 1 \\ &= -\frac{J}{2} \log(2 \pi) - \frac{1}{2} \sum_{j=1}^J (1 + \log \sigma_j^2) \end{align*}\]Now, let us compute the second term, $\int q_\phi(\boldsymbol{z} \mid \boldsymbol{x}) \log p(\boldsymbol{z}) \ d\boldsymbol{z}$:

\[\begin{align*}\int q_\phi(\boldsymbol{z} \mid \boldsymbol{x}) \log p(\boldsymbol{z}) \ d\boldsymbol{z} &= \int N(\boldsymbol{z}; \boldsymbol{\mu}, \text{diag}(\boldsymbol{\sigma}^2)) \log N(\boldsymbol{z}; \boldsymbol{0}, \boldsymbol{I})) \ d\boldsymbol{z} \\ &= \int N(\boldsymbol{z}; \boldsymbol{\mu}, \text{diag}(\boldsymbol{\sigma}^2)) \sum_{j=1}^J N(z_i; 0, 1) \ d\boldsymbol{z} \\ &= \int N(\boldsymbol{z}; \boldsymbol{\mu}, \text{diag}(\boldsymbol{\sigma}^2)) \sum_{j=1}^J \left[\log \frac{1}{\sqrt{2\pi}} - \frac{1}{2} z_i^2\right] \ d\boldsymbol{z} \\ &= J \log \frac{1}{\sqrt{2\pi}} \int N(\boldsymbol{z}; \boldsymbol{\mu}, \text{diag}(\boldsymbol{\sigma}^2)) \ d\boldsymbol{z} - \frac{1}{2}\int N(\boldsymbol{z}; \boldsymbol{\mu}, \text{diag}(\boldsymbol{\sigma}^2)) \sum_{j=1}^J z_i^2 \ d\boldsymbol{z} \\ &= -\frac{J}{2} \log (2 \pi) - \frac{1}{2}\int N(\boldsymbol{z}; \boldsymbol{\mu}, \text{diag}(\boldsymbol{\sigma}^2)) \sum_{j=1}^J z_i^2 \ d\boldsymbol{z} \\ &= -\frac{J}{2} \log (2 \pi) - \frac{1}{2} \int \sum_{j=1}^J z_j^2 \prod_{j'=1}^J N(z_{j'}; \mu_{j'}, \sigma^2_{j'}) \ d\boldsymbol{z} \\ &= -\frac{J}{2} \log (2 \pi) - \frac{1}{2} \sum_{j=1}^J \int_{z_j} z_j^2 N(z_j; \mu_j, \sigma^2_j) \int_{\boldsymbol{z}_{i \neq j}} \prod_{i \neq j} N(z_i; \mu_i, \sigma^2_i) \ d\boldsymbol{z}_{i \neq j} \ dz_i \\ &= -\frac{J}{2} \log (2 \pi) - \frac{1}{2} \sum_{j=1}^J \int z_j^2 N(z_j; \mu_j, \sigma^2_j) \ dz_j && \text{Note 4} \\ &= -\frac{J}{2} \log (2 \pi) - \frac{1}{2} \sum_{j=1}^J \left( \mu_j^2 + \sigma_j^2) \right) && \text{Note 5} \end{align*}\]Combining the two terms, we arrive at at the formula:

\[KL(q_\phi(\boldsymbol{z} \mid \boldsymbol{x}) || p(\boldsymbol{z})) = -\frac{1}{2} \sum_{j=1}^J \left(1 + \log \sigma_j^2 - \mu_j^2 - \sigma^2_j\right)\]**Note 1:** We see that $\int_{\boldsymbol{z}_{i \neq j}} \prod_{i \neq j} N(z_i; \mu_i, \sigma^2_i) \ d\boldsymbol{z}_{i \neq j}$ integrates to 1 since this is simply integrating the density function of a multivariate normal distribution.

**Note 2:** We see that $\int z_j^2 N(z_j; \mu_j, \sigma^2_j) \ dz_j$ is simply $E[z_j^2]$ where $z_j \sim N(z_j; \mu_j, \sigma^2_j)$. Similarly, $\int 2z_j \mu_j N(z_j; \mu_j, \sigma^2_j) \ dz_j$ is simply $E[2 z_j \mu_j]$.

**Note 3:** We use the equation for the variance of random variable $X$:

to see that

\[E[z_j^2] = \mu_j^2 + \sigma_j^2\]**Note 4:** See Note 1.

**Note 5:** See Notes 2 and 3.

$\square$

]]>