<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.10.0">Jekyll</generator><link href="https://mbernste.github.io/feed.xml" rel="self" type="application/atom+xml" /><link href="https://mbernste.github.io/" rel="alternate" type="text/html" /><updated>2026-02-11T04:55:56-08:00</updated><id>https://mbernste.github.io/feed.xml</id><title type="html">Matthew N. Bernstein</title><subtitle>Personal website</subtitle><author><name>Matthew N. Bernstein</name></author><entry><title type="html">A framework for making sense of metrics in technical organizations</title><link href="https://mbernste.github.io/posts/metrics/" rel="alternate" type="text/html" title="A framework for making sense of metrics in technical organizations" /><published>2026-02-10T00:00:00-08:00</published><updated>2026-02-10T00:00:00-08:00</updated><id>https://mbernste.github.io/posts/metrics</id><content type="html" xml:base="https://mbernste.github.io/posts/metrics/"><![CDATA[<p><em>If you work in a quantitative or technical field, there is little doubt that you or your team has worked long and hard to define which metrics to measure and track. Using data-driven metrics is a critical practice for making rational decisions and deciphering truth in a complex and noisy world. However, as others have pointed out, an over-reliance on metrics can lead to poor outcomes. In an effort to better articulate the value and risks inherent in metrics, in this blog post, I will present a mental framework for thinking about metrics that has helped me reason about their value and risks with a bit more clarity.</em></p>

<h2 id="introduction">Introduction</h2>

<p>For good reason, the importance of data-driven reasoning is deeply ingrained in the culture of quantitative and technical disciplines. Moreover, the systems we both build and operate within are too complex to be fully grasped by the human mind. To understand them, we must measure them. The combined consequences of our secular cultural traditions, and the simple need to understand complex systems, lead technical organizations to fixate on <strong>metrics</strong>. This reliance on metrics is deeply ingrained. Managers repeat quotes like, “You can’t manage what you can’t measure.” Software engineers create dashboards and databases to track metrics. Meetings often begin with an overview of where the project or business stands in terms of the metrics.</p>

<p>I would by no means be the first to point out that an over-reliance on metrics can lead to poor outcomes. For example, <a href="https://en.wikipedia.org/wiki/Goodhart%27s_law">Goodhart’s Law</a> states, “When a measure becomes a target, it ceases to be a good measure.” <a href="https://www.youtube.com/watch?v=8ij964FCQiw">In a recent interview</a>, Jeff Bezos describes how organizations often end up mistaking metrics meant to measure some ground truth as the actual thing. This mistaking a “proxy for truth” for the actual truth leads organizations astray.</p>

<p>Metrics are therefore tricky: On the one hand, metrics are critical for navigating a complex world. On the other hand, they can lead organizations astray. Because of this, in my personal experience, I have found organizations struggle to really wield metrics well and think about them clearly. In an effort to better articulate the value and risks inherent in metrics, in this blog post, I will present a mental framework for thinking about metrics that has helped me reason about them with a bit more clarity.</p>

<p>This framework entails first treating systems as high-dimensional objects in some abstract space. A metric then, is a function that maps those systems to numbers. With this framing, we can categorize metrics into two groups: exploratory metrics and those that are used to approximate value. Because metrics serve merely as an approximation, one must take care to understand clearly the behavior of that approximation. In some extreme cases, it might be better to admit that no metric can serve as a good approximator of value.</p>

<h2 id="systems-as-high-dimensional-objects">Systems as high-dimensional objects</h2>

<p>Before we get to discussing metrics, we will first generalize the system being measured as a <a href="https://mbernste.github.io/posts/intrinsic_dimensionality/">high-dimensional</a> object in some abstract space (akin to a <a href="https://mbernste.github.io/posts/vector_spaces/">vector space</a>). By “system”, I mean any complicated thing that a given technical organization seeks to understand or improve.</p>

<p>For example, such a system under consideration might be an entire business; Businesses are complicated “high-dimensional” objects in that they have many components: employees, processes, capital, debt, revenue, and so on. A piece of technology is also such a system. For example, an algorithm for automating medical diagnoses has many aspects: model complexity, latency, lines of code, bias, etc.</p>

<p>In a very abstract way, one can imagine that any given system resides in a “space” comprising other similar systems. For simplicity, let’s take a business: We can summarize a business in terms of a large list of numbers like number of employees, sales per month, cost of goods sold, cash on hand, debt… (the list can go on and on). Given such a list (which could be extremely long), we can place the business in a coordinate vector space where each location in the vector space is some (possibly non-existent) business. A schematic with three dimensions is shown below:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/business_as_vectors_metrics.png" alt="drawing" width="550" /></center>

<p>We will denote this space of possible systems (e.g., businesses) as $\mathcal{X}$. A given system $x$ is a member of $\mathcal{X}$, denoted $x \in \mathcal{X}$.</p>

<h2 id="value-functions-tell-you-how-good-a-system-is">Value functions tell you how “good” a system is</h2>

<p>We consider cases in which the goal of an organization to either improve some system under measurement, $x \in \mathcal{X}$, or,at the very least, to assess how “good” the system is in terms of some subjective or economic measure of value.</p>

<p>Let’s define a <strong>value function</strong> to be a function $V$ that maps systems in $\mathcal{X}$ to real numbers that quantify the value of those systems:</p>

\[V : \mathcal{X} \rightarrow \mathbb{R}\]

<p>That is, $V(x)$ tells us how much to value system $x$. If we have two systems, $x_1$ and $x_2$, then $V(x_1) &gt; V(x_2)$ tells us we should prefer $x_1$ to $x_2$. We can depict this schematically in a small, two-dimensional space of systems with a heatmap. Each dot in the figure below is a system. The color tells us how much value that system has:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/value_function_example.png" alt="drawing" width="475" /></center>

<p><br /></p>

<p>Organizations that seek to maximize $V$ are in a sense performing a form of “gradient descent” over $V$:</p>

\[x' \leftarrow x + \nabla V(x)\]

<p>where the goal is to make iterative progress along $V$. Below is a toy example of a system undergoing iteration as it progresses along the surface of the value function:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/optimizing_value_function_metric.png" alt="drawing" width="400" /></center>

<p><br /></p>

<p>In some situations, the value function is obvious. For example, when considering a business, the value function might simply be its <em>profitability</em>. However, in other cases, the value function is not so easy to define. A premier example in this situation is judging the aesthetic value of artistic works.</p>

<p>I would also argue that even in technical fields, where the system under study <em>should</em> admit an easy-to-define value function, it is often not as clear-cut as one would hope. Take for example an algorithm for detecting tumors in medical imaging data. As <a href="https://laurenoakdenrayner.com/2019/01/21/medical-ai-safety-doing-it-wrong/">some in this field would note</a>, metrics like accuracy are not the true value function. In truth, the value function should track something like, “expected benefit versus harm” when employing the algorithm in a real clinical context, but that is not trivial to define!</p>

<h2 id="metrics-are-functions-that-can-be-either-exploratory-or-value-approximating">Metrics are functions that can be either exploratory or value-approximating</h2>

<p>Using this same framework, we can define a <strong>metric</strong> to be some function that, like $V$, projects a system, $x \in \mathcal{X}$, to a number. That is, we can define a metric, $f$, to be a function,</p>

\[f: \mathcal{X} \rightarrow \mathbb{R}\]

<p>Metrics fall into two fundamentally different categories: those that are intended to approximate the value function $V$, and those that are intended purely to explore and describe the structure of $\mathcal{X}$. These two uses impose very different requirements. A value-approximating metric makes an implicit claim about the relationship between $f$ and $V$. To rely on such a metric, one must understand where this approximation holds. Exploratory metrics, by contrast, make no claim about value at all. I believe that making a clear distinction between these kinds of metrics can bring clarity to discussions around them. We describe these two categories in the following sections.</p>

<h3 id="exploratory-metrics-those-that-seek-to-describe-mathcalx">Exploratory metrics: Those that seek to describe $\mathcal{X}$</h3>

<p>In many cases, organizations do not seek to necessarily approximate the value function, but instead simply seek to understand their system. To gain this understanding, they will create a collection of metrics, $f_1, f_2, …, f_M$, each describing some specific aspect of the system. In this way, these metrics project the complex, high-dimensional system  into the space of a few dimensions (In this sense, the metrics act as a form of <a href="https://mbernste.github.io/posts/dim_reduc/">dimensionality reduction</a>).</p>

<p>The goal here is to gain a mechanistic understanding of the system. That is, the goal is to gain a <a href="https://mbernste.github.io/posts/understanding_3d/">holistic understanding</a> of the system, which may lead to new insights into how to improve the system downstream of these metrics. The goal is not to track whether the system is improving over time.</p>

<p>Exploratory metrics are almost always critical for understanding a system, though I would note that numbers alone may not be sufficient. Sometimes, one must also understand the geometric relationships between these metrics, which are better grasped via visualization. Said differently, tracking the metrics may not alone be sufficient to gain adequete understanding. Rather, understanding may only come from synthesizing these metrics into a comprehensible visual format.</p>

<p>One final point on this topic: I believe it is important to clearly delineate whether a metric under consideration is an exploratory metric or a value approximating metric. As soon as one modifies the system under consideration to optimize some exploratory metric, $f(x)$, that implicitly moves the metric from the exploratory category to the value-approximating category.</p>

<h3 id="value-approximating-metrics-those-that-seek-to-approximate-v">Value-approximating metrics: Those that seek to approximate $V$</h3>

<p>A value-approximating metric is a metric, $f(x)$, that is treated as a proxy for the value function, $V(x)$. Unlike exploratory metrics, which are merely observed, value approximating metrics are <em>acted upon</em>. More specifically, an organization <em>changes</em> the system, x, in the direction of the gradient of $f$. That is, the gradient of $f$ is used as a proxy for the gradient of $V$:</p>

\[\nabla f(x) \approx \nabla V(x)\]

<p>Thus, it is important that one clearly knows whether or not $f$ is a good approximation for $V$. In the following sections, I will briefly describe two common ways in which $f$ may deviate from $V$.</p>

<p><strong>Locally accurate, but globally innacurate</strong></p>

<p>In this scenario, the metric $f$ is very close to $V$ in some local neighborhood around $x$; however, as the organization optimizes for $f$, it pushes the system into regions where $f$ is no longer a good approximation. This is a form of Goodhart’s Law; once one pushes $x$ too far along $f$, then $f$ ceases to be a good measure for $V$.</p>

<p>This is illustrated schematically in the figure below. In the top left figure, we show the true value function, $V$, and a trajectory we would take if we were optimizing with respect to it. In the top right figure, we show a metric function, $f$, and the trajectory we would take if we were optimizing $f$ instead. In the bottom figure, we superimpose the two trajectories (blue = $f$ and orange = $V$). As you can see, the two trajectories start off very close to one another, but then diverge as $x$ is optimized towards using $f$:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/goodharts_law_plot_metrics.png" alt="drawing" width="550" /></center>

<p>A premier example of this scenario occured recently in the <a href="https://virtualcellchallenge.org">Virtual Cell Challenge</a> held by the <a href="https://arcinstitute.org">Arc Institute</a>. In this challenge, research groups competed to develop an AI model for predicting the effects that would result from either genetically or chemically perturbing biological cells. That is, the goal is to predict how a biological cell will react to a chemical or genetic alteration. In this challenge, groups found that they could <a href="https://gmdbioinformatics.substack.com/p/arc-virtual-cell-challenge-has-the">game the metrics</a> by applying absurd transformations to the data. These metrics, while perhaps a good proxy for the  value function within the regime of $\mathcal{X}$ for which the Arc Institute designed the challenge, were a poor proxy of in distal regions of the space.</p>

<p><strong>Monotonic with respect to $V$</strong></p>

<p>In this situation, the metric $f$ increases “monotonically” with respect to $V$. By this I mean that if $f$ is increasing so too is $V$ and if $f$ is decreasing so too is $V$. An example is depicted below:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/monotonic_metric_example1.png" alt="drawing" width="350" /></center>

<p><br /></p>

<p>Notably, however, the two functions may not have the same shape. Rather, the only thing that matters is whether $f$ follows the same direction as $V$. For example, one function may be flattening while the other is accelerating upward, but both are increasing:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/monotonic_metric.png" alt="drawing" width="350" /></center>

<p><br /></p>

<p>When one has such a metric $f$ that is monotonic with respect to $V$, then one can use $f$ to reliably to determine whether or not they are improving the system, $x$. That is, if one is considering two systems, $x_1$ and $x_2$, for which $f(x_1) &gt; f(x_2)$, then one can be sure that $V(x_1) &gt; V(x_2)$ and therefore one should prefer $x_1$.</p>

<p>However, critically, there are two tasks that one <strong>should not</strong> rely on $f$ for:</p>

<ul>
  <li>One should not use $f$ to determine <em>how much</em> the given system improved. That is, it may be that $f(x_1) &gt; f(x_2)$, however, the difference $f(x_1) - f(x_2)$ may not closely track with $V(x_1) - V(x_2)$. One may have improved the system a great deal with respect to $f$, however, when it comes to real value, $V$, the needly may have barely moved.</li>
  <li>One should not use $f$ to determine success criteria. That is because while $f$ tracks $V$ directionally, it does not tell you anything about the absolute magnitude of $V$.  One may have improved the system quite a lot with respect to $f$, but it still is not yet good enough to release into the world!</li>
</ul>

<p>I would argue that this regime of monotonicity is the very common. Take our usual example of an algorithm for detecting tumors in medical imaging data. In this scenario, one might use simple <em>accuracy</em> as their metric of choice. Clearly, improving the algorithm in terms of accuracy will lead to a more valuable algorithm. However, it is not clear how well accuracy tracks with real benefit in a clinical setting. The two, while correlated, may not be correlated <em>enough</em> to be relied upon for determining whether the algorithm is ready for real-world use.</p>

<h2 id="it-is-important-to-sample-the-value-function">It is important to sample the value function</h2>

<p>Because a value-approximating metric is merely an approximation, I assert that it is important to sample the true value function regularly – that is, to evaluate $V(x)$ for the current system, $x$. In practice, sampling the value function entails making an earnest, holistic assessment of the system’s value independently of any metric.</p>

<p>Sampling the value function might entail having experts perform subjective assessments of the system under consideration. This can be expensive since subjective assessment requires time and energy and is difficult to scale. Moreover, subjective assessments are harder to track and communicate than simple numbers. Nonetheless, sometimes this is the only way.</p>

<p>In other cases, sampling the value function means employing the system in real environments to see how it performs. This can be extremely expensive as it may require conducting real-world studies. In our example of an algorithm for detecting tumors in medical images, <a href="https://www.nber.org/papers/w31422">one study</a> found that, “human-AI collaboration using an information experiment with professional radiologists… show that providing AI predictions does not improve performance on average.” There is little doubt that the original scientific articles reporting on these algorithms showed many metrics demonstrating good performance; however, in the context of real-world use, they demonstrated less benefit.</p>

<p>Lastly, with the rise of capable <a href="https://en.wikipedia.org/wiki/AI_agent">agentic AI systems</a>, one can possibly test systems, like new algorithms, in the hands of AI agents within sandboxed environments. For example, a research group at Stanford is exploring using AI agents to <a href="https://hai.stanford.edu/assets/files/hai-policy-brief-simulating-human-behavior-with-ai-agents.pdf">simulate human behavior</a> in order to serve as a tool for conducting things like market or policy research. While this line of thinking is only in its infancy, I think agentic AI will emerge as a new tool that can be creatively employed for more cheaply sampling value functions.</p>

<h2 id="not-all-systems-admit-a-value-approximating-metric-that-can-be-used-for-defining-success-criteria">Not all systems admit a value-approximating metric that can be used for defining success criteria</h2>

<p>Sometimes, because the value function is so complex, it may just not be possible to develop a metric that can be relied upon for things like devising success criteria. That is, it is not possible to know, based on the metric alone, whether the system under consideration is ready for real-world deployment. I do believe that this situation is not uncommon. After all, many “functions” or “mappings” in the real world are incredibly noisy, unintuitive, or non-linear; Why would value functions be any different?</p>

<p>When one is confronted with an intractible value function, it is often wiser to admit this outright than to spend valuable time and energy chasing it. <strong>Sometimes, one just has to admit defeat: We can’t quantify “good” even though we know it when we see it.</strong></p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/metric_meme.png" alt="drawing" width="500" /></center>

<p><br /></p>

<p>This is indeed a challenging situation to find oneself. It means that one cannot rely upon an easy, automated way to assess the state of the system and to make decisions. One may feel lost at sea without a compass! But, I would argue that this situation <em>can</em> be navigated, but to do so, one must first acknowledge that one is lost! Once the limitations are acknowledged, one can devise a plan and allocate a budget for sampling the value function.</p>]]></content><author><name>Matthew N. Bernstein</name></author><category term="metrics" /><category term="leadership" /><summary type="html"><![CDATA[If you work in a quantitative or technical field, there is little doubt that you or your team has worked long and hard to define which metrics to measure and track. Using data-driven metrics is a critical practice for making rational decisions and deciphering truth in a complex and noisy world. However, as others have pointed out, an over-reliance on metrics can lead to poor outcomes. In an effort to better articulate the value and risks inherent in metrics, in this blog post, I will present a mental framework for thinking about metrics that has helped me reason about their value and risks with a bit more clarity.]]></summary></entry><entry><title type="html">Understanding attention</title><link href="https://mbernste.github.io/posts/attention/" rel="alternate" type="text/html" title="Understanding attention" /><published>2025-12-21T00:00:00-08:00</published><updated>2025-12-21T00:00:00-08:00</updated><id>https://mbernste.github.io/posts/attention</id><content type="html" xml:base="https://mbernste.github.io/posts/attention/"><![CDATA[<p><em>Attention is a type of layer in a neural network that is widely regarded to be one of the most important breakthroughs that enabled the development of modern AI systems and large language models. At its heart, attention is a mechanism for explicitly drawing relationships between items in a set. In natural language processing, the set being processed are words (or tokens) and attention enables the model to relate those words to one another even when those words lie far away from eachother in the body of text. In this blog post, we will step through the attention mechanism both mathematically and intuitively. We then present a minimal example of a neural network that uses attention to perform binary classification in a task that is not solveable using a naïve bag-of-words model.</em></p>

<h2 id="introduction">Introduction</h2>

<p>Attention is a type of layer in a neural network, originally introduced in its modern form by Vaswani <em>et al.</em> (2017) in their landmark paper, <em><a href="https://papers.nips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf">Attention Is All You Need</a></em> that has powered the development of modern AI.</p>

<p>Attention was developed in the context of <a href="https://en.wikipedia.org/wiki/Large_language_model">language modeling</a> and is often introduced as a mechanism for a neural network to identify how different words of a sentence relate to one another. For example, take the sentence, “I like sushi because it makes me happy.” Attention may enable the model to explicitly and dynamically recognize that the word “it” in this sentence is referring to “sushi”. Similarly it may enable the model to recognize that the words “me” and “I” are related in that they both are referring to the same entity (i.e., the speaker of the sentence).</p>

<p>While attention is most often explained in the context of language modeling, the idea is far more general: It is simply a way to explicitly draw relationships between items in a set. In this blog post, we will step through the attention mechanism both mathematically and intuitively with a focus on how attention is, at its a core, a way to relate items of a set together. We will then present a minimal example of a neural network that uses attention to perform binary classification in a task that is not solveable using a naïve <a href="https://en.wikipedia.org/wiki/Bag-of-words_model">bag of words</a> model.</p>

<h2 id="inputs-and-outputs-of-the-attention-layer">Inputs and outputs of the attention layer</h2>

<p>At its core, an attention layer is a layer of a neural network that transforms an input set of vectors to a set of output vectors. This contrasts with a traditional fully-connected neural layer which transforms a single input vector to an output vector:</p>

<p><br /></p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/attention_input_output.png" alt="drawing" width="700" /></center>

<p><br /></p>

<p>In most contexts in which attention is employed, the input vectors represent items in a sequence such as words in natural language text or a sequence of neucleic acids in a DNA sequence.  Each element of the input set is often referred to as a <strong>token</strong>. In this post, we will use natural language text as the primary example; however the input set of vectors can extend beyond sequences; nothing in the attention layer assumes an ordering over the tokens.</p>

<p>A powerful feature of the attention layer is that the size of the input set of vectors does not need to be fixed; it can be variable! This enables the attention layer to operate on arbitrary-lengthed sets. This is similar to how a <a href="https://mbernste.github.io/posts/gcn/">graph convolutional neural network</a> can operate on arbitrary-sized graphs.</p>

<p>The idea behind attention is that when we consider the output vector associated with a given token, we intuitively want the model to pay greater “attention” to some other input tokens and less attention to others (“attention” used here in the colloquial sense). For example, let’s say we are generating output vectors for input vectors associated with the sentence, “I like sushi because it makes me happy.” Let us consider the case in which we are generating the output token for “delicious”. Intuitively, we know that “delicious” is referring to “sushi”. It makes sense that when the model is generating the output token for “delicous” it should consider the word “sushi” more heavily, than say, “because”. The word “delicious” is referring directly to “sushi” whereas “because” is a conjunction playing a more complicated role in the sentence joining multiple ideas together. This is depicted in the schematic below:</p>

<p><br /></p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/transformer_attention_sushi_example1.png" alt="drawing" width="600" /></center>

<p><br /></p>

<p>In contrast when generating the output vector for “happy”, intuitively, we might want to place more weight on the word, “I”, because “happy” is referring directly to the subject, “I”. This is depicted in the schematic below:</p>

<p><br /></p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/transformer_attention_sushi_example2.png" alt="drawing" width="600" /></center>

<p><br /></p>

<p>In summary, when generating each output vector, the attention mechanism considers <em>all</em> of the input vectors and weights them according to how much “attention” to pay them when computing the output vector. We depict this process in the smaller example sentence, “I am hungry”, below:</p>

<p><br /></p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/transformers_input_output_attention.png" alt="drawing" width="800" /></center>

<h2 id="the-nuts-and-bolts-of-the-attention-layer">The nuts and bolts of the attention layer</h2>

<p>We will now dig into the details of the attention mechanism by building our understanding step-by-step. We will use the example sentence, “I am hungry”, going forward.</p>

<p>Let’s let $\boldsymbol{x}_\text{I}$, $\boldsymbol{x}_\text{am}$, $\boldsymbol{x}_\text{hungry} \in \mathbb{R}^D_{\text{in}}$ denote our input vectors (of dimension $D_{\text{in}}$) associated with each token. In the first step, the model generates a vector associated with each input vector called the <strong>values</strong> (or “value vectors”) by multiplying each input vector by a weights matrix $\boldsymbol{W}_V \in \mathbb{R}^{D_{\text{in}} \times D_{\text{out}}}$.  and $\boldsymbol{v}_\text{I}$, $\boldsymbol{v}_\text{am}$, $\boldsymbol{v}_\text{hungry}$ denote the value vectors. Then, the value vectors are generated via:</p>

\[\begin{align*}\boldsymbol{v}_\text{I} &amp;:= \boldsymbol{W}_V\boldsymbol{x}_\text{I}  \\ \boldsymbol{v}_\text{am} &amp;:= \boldsymbol{W}_V\boldsymbol{x}_\text{am} \\ \boldsymbol{v}_\text{hungry} &amp;:= \boldsymbol{W}_V\boldsymbol{x}_\text{hungry}\end{align*}\]

<p>This is depicted schematically below:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/transformers_intermediate_vectors_only_values.png" alt="drawing" width="500" /></center>

<p><br /></p>

<p>To spoil the punchline, the output vector associated with each input vector will be constructed as a weighted sum of these values vectors. The weights here represent the amount of attention we pay to each input vector (for now take these weights as given, we will show how they are generated soon). For example, to generate the output vector for the word “I”, which we will denote as $\boldsymbol{h}_\text{I}$, we will take a weighted sum of the value vectors associated with all the other words in the sentence:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/transformers_attention_weighted_sum_one_word.png" alt="drawing" width="400" /></center>

<p>Here, the weights $a_{\text{I},\text{I}}$, $a_{\text{I},\text{am}}$, and $a_{\text{I},\text{hungry}}$ are the attention weights! They are used to weight the other words in the sentence according to how much we should use that words information (i.e., their value vectors) when constructing the output for “I”.</p>

<p>We repeat this for every token in the input sequence where, for each token, the attention weights are different and thus, we compute a different weighted sum of the value vectors:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/transformers_attention_weighted_sum_all_words.png" alt="drawing" width="800" /></center>

<p><br /></p>

<p>Now, how are these attention weights calculated? This is really the meat of the transformer layer and can appear a bit complicated at first as it requires additional vectors to be generated for each input token. That is, not only will we generate value vectors associated with each token, as described previously, but we will also generate two additional vectors associated with each input token: <strong>queries</strong> and <strong>keys</strong>. Like the value vectors, these will be generated using two matrices, denoted $\boldsymbol{W}_Q$ and $\boldsymbol{W}_K$ respectively. Let $\boldsymbol{q}_\text{I}$, $\boldsymbol{q}_\text{am}$, $\boldsymbol{q}_\text{hungry}$ denote the query vectors and $\boldsymbol{k}_\text{I}$, $\boldsymbol{k}_\text{am}$, $\boldsymbol{k}_\text{hungry}$ denote the key vectors. They are then generated as follows:</p>

\[\begin{align*}\boldsymbol{q}_\text{I} &amp;:= \boldsymbol{W}_Q\boldsymbol{x}_\text{I}  \\ \boldsymbol{q}_\text{am} &amp;:= \boldsymbol{W}_Q\boldsymbol{x}_\text{am} \\ \boldsymbol{q}_\text{hungry} &amp;:= \boldsymbol{W}_Q\boldsymbol{x}_\text{hungry}\end{align*}\]

\[\begin{align*}\boldsymbol{k}_\text{I} &amp;:= \boldsymbol{W}_K\boldsymbol{x}_\text{I}  \\ \boldsymbol{k}_\text{am} &amp;:= \boldsymbol{W}_K\boldsymbol{x}_\text{am} \\ \boldsymbol{k}_\text{hungry} &amp;:= \boldsymbol{W}_K\boldsymbol{x}_\text{hungry}\end{align*}\]

<p>This process is depicted below:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/transformers_intermediate_vectors_only_queries_keys.png" alt="drawing" width="500" /></center>

<p><br /></p>

<p>The queries and keys are then used to construct the attention weights. Let us start by generating the single attention weight, $a_{\text{I}, \text{am}}$ that tells the model how much to weight “am” when generating the word “I”. We start by taking the <a href="https://en.wikipedia.org/wiki/Dot_product">dot product</a> between the query vector for $I$, $\boldsymbol{q}_{\text{I}}$, and the key vector for “am”, $\boldsymbol{k}_{\text{am}}$:</p>

\[s_{\text{I}, {am}} := \boldsymbol{q}\_{\text{I}} \boldsymbol{k}\_{\text{I}}^T\]

<p>We’ll call this value the <strong>attention score</strong> between word “I” and word “am” and it will be used to form the attention weight. This is depicted schematically below:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/transformers_score_dot_product.png" alt="drawing" width="300" /></center>

<p><br /></p>

<p>Intuitively, if a given pair of words have a high score (i.e., high dot product between the first’s query and the second’s key) then this means that the given query and key are similar vectors, and consequently we should pay attention to the second word (associated with the key) when forming the output vector associated with the first word (associated with query). The neural network learns how to map tokens to queries and keys, such that two given words will have a large dot product if one should pay attention to the other.</p>

<p>With this reasoning, it seems that the score alone would serve as a good attention weight; however there is practical problem with using the score directly: there is no upper bound for the value of the dot product and thus, if we stack many self-attention layers together, we can encounter numerical instability as these values blow up. We thus need a way to normalize the score. To do so, we transform the score by first scaling by a constant value, usually $\sqrt{D_{\text{out}}}$ where $d$ is the dimensions of the queries and key vectors, and then computing the softmax using <em>all</em> of the scores when examining the other words in the input sequence. This forms the final attention weight. For example, for words “I” and “am”, the attention weight is given by:</p>

\[\begin{align*}a_{\text{I}, \text{am}} &amp;:= \text{softmax}\left( \frac{ s_{\text{I},\text{I}}}{\sqrt{d}}, \frac{s_{\text{I},\text{am}}}{\sqrt{d}}, \frac{s_{\text{I},\text{hungry}}}{\sqrt{d}} \right) \\ &amp;= \frac{\exp{\frac{s_{\text{I},\text{am}}}{\sqrt{d}}}}{\exp{\frac{s_{\text{I},\text{I}}}{\sqrt{d}}} + \exp{\frac{s_{\text{I},\text{am}}}{\sqrt{d}}} + \exp{\frac{s_{\text{I},\text{hungry}}}{\sqrt{d}}}}\end{align*}\]

<p>This is depicted in the schematic below for all of the attention weights when generating the output vector for “I”:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/transformer_generate_attention_weights.png" alt="drawing" width="450" /></center>

<p>The intuition behind this normalization procedure is that the first scaling operation that scales each score by $\sqrt{d}$ normalizes for the number of terms in the summation used to compute the dot product. The softmax then performs a final normalization that forces the sum of the attention weights to equal one!</p>

<p>Putting it all together: Given a set of tokens, associated with input feature vectors, we map each token to a value vector, query vector, and key vector via the matrices $W_V$, $W_Q$, and $W_K$, respectively:</p>

<p><br /></p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/transformers_intermediate_vectors.png" alt="drawing" width="500" /></center>

<p><br /></p>

<p>The query and key vectors are used to form attention weights. These attention weights are used to compute a weighted sum of the value-vectors that then form each output token’s final vector representation:</p>

<p><br /></p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/transformer_attention_mechanism.png" alt="drawing" width="850" /></center>

<p><br /></p>

<h2 id="computing-attention-via-matrix-multiplication">Computing attention via matrix multiplication</h2>

<p>The attention layer can expressed and computed more succintly using <a href="https://mbernste.github.io/posts/matrix_multiplication/">matrix multiplication</a>. First, let $X \in \mathbb{R}^{N \times D_{\text{in}}}$ represent the matrix of $N$ token-vectors, each of $D_{\text{in}}$ dimensions.  Then, the query, key, and value vectors can be computed by multiplying $X$ by $W_Q$, $W_K$, and $W_V$ to form queries, keys, and values that can are then represented as matrices, $Q, K, V \in \mathbb{R}^{N \times D_{\text{out}}}$:</p>

\[\begin{align*}Q &amp;:= X^TW_Q \\ K &amp;:= X^TW_K \\ V &amp;:= X^TW_V\end{align*}\]

<p>Represented schematically:</p>

<p><br /></p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/attention_matrix_mult_KQV.png" alt="drawing" width="650" /></center>

<p><br /></p>

<p>Then, the pairwise dot products between the tokens’ keys and queries can be computed via matrix multiplication between Q and K:</p>

\[\text{Scores} := QK^T\]

<p>This produces an $N \times N$ matrix storing all of the pairwise attention scores:</p>

<p><br /></p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/attention_matrix_mult_scores.png" alt="drawing" width="450" /></center>

<p><br /></p>

<p>The final output matrix of $N \times D_{\text{out}}$ transformed token vectors is then computed by taking a linear combination of the value vectors using the normalized scores (i.e., the attention weights). This can also be expressed as a matrix multiplication:</p>

\[H := \text{Softmax}\left( \text{Scores} / \sqrt{D_{\text{out}}}\right)V\]

<p>Depicted schematically,</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/attention_matrix_mult_final_out.png" alt="drawing" width="550" /></center>

<p><br /></p>

<p>Thus, the final form of the attention layer is,</p>

\[H := \text{Softmax}\left(\frac{QK^T}{\sqrt{D_{\text{out}}}}\right)V\]

<h2 id="the-fully-connected-layer">The fully connected layer</h2>

<p>The attention layer is usually followed by a fully connected layer. This layer is quite simple: we simply take the vectors that were produced by the attention layer and pass them through a fully connected neural network – i.e., a <a href="https://en.wikipedia.org/wiki/Multilayer_perceptron">multilayer perceptron (MLP)</a> – that is shared for all tokens. The sequence of an attention layer followed by a fully-connected layer is often referred to as a <strong>transformer layer</strong> as it forms the basis for the <a href="https://en.wikipedia.org/wiki/Transformer_(deep_learning)">transformer neural network</a>, which is an architecture built on attention used for mapping sequences to sequences proposed by Vaswani <em>et al.</em> (2017) in <em><a href="https://papers.nips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf">Attention Is All You Need</a></em>.</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/transformers_attention_and_fully_connected.png" alt="drawing" width="550" /></center>

<p><br /></p>

<p>Thus, we perform a non-linear transformation of these attention-derived vectors. This steps injects more non-linearity into the model so that, when we stack transformer layers together, we can form very complex iterations of attention where each subsequent layer is computing attention between the tokens in different ways.</p>

<h2 id="keys-queries-values-a-note-on-terminology">“Keys”, “Queries”, “Values”? A note on terminology</h2>

<p>A natural question when first learning this topic is: Why are the $Q$, $K$, and $V$ matrices referred to as “queries”, “keys”, and “values”? Where do these terms come from? The answer is that these terms were introduced by Vaswani <em>et al.</em> (2017) in their original paper based on an analogy they made between the attention layer and database systems.</p>

<p>To make this analogy concrete, let’s say we have a database of music files (say .mp3 files) where each file is associated with a title encoded as a string. Here we’ll call the titles “keys” and the sound files “values”. Each key is associated with a value.</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/attention_database_key_values.png" alt="drawing" width="350" /></center>

<p>To retrieve a given song, we form a query, which is also a string, and attempt to match this query against all the existing titles (keys) in the database. If we find a match, the database will return the corresponding music file.</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/attention_database_retrieval.png" alt="drawing" width="700" /></center>

<p>This is very similar to the roles that the keys, queries, and values play in the attention layer; however, instead of each query being binary – we either match a key or we don’t – the queries in the attention layer are “soft” – that is, a query may <em>somewhat</em> match to <em>multiple</em> keys. This soft matching and retrieval is carried out by computing dot products between keys and queries and then using these dot products to perform a weighted sum over the value vectors. That is, each weight denotes how much each query matches each key and the “retrieval” is carried out as a weighted sum!</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/attention_kqv_vectors_as_database.png" alt="drawing" width="450" /></center>

<h2 id="positional-encodings">Positional encodings</h2>

<p>As described above, the attention layer maps a set of vectors to a new set of vectors in such a way that each vector can “attend” to some set of other vectors within the set. Noteably, there is nothing explicitly built into the attention layer that specifies any distinction between the order of these input and output vectors. That is, the attention layer operates on <em>unordered</em> sets.</p>

<p>However, we mentioned in this very post that one of the most common areas of application for attention is in modeling natural language text, which <em>is</em> intrinsically ordered. Moreover, we expect that a model would benefit from having access to this order. The sentence, “The shark bit the person” has quite a different meaning from the sentence, “The person bit the shark” even though both sentences use the same set of words.</p>

<p>The standard method for which to provide the model information on the order, or <em>position</em>, of input tokens relative to one another is to use <strong>positional encodings</strong>. More specifically, we associate with each position, $1, 2, \dots, M$, a vector that encodes that position of dimension $D_{\text{in}}$. Then, that positional encoding vector is <em>added</em> to the given input token vector at that position. That is, the input vector at position $i$, denoted $\boldsymbol{x}_i$, is modified via</p>

\[\boldsymbol{x}_i' := \boldsymbol{x}_i + \boldsymbol{p}_i\]

<p>where $\boldsymbol{p}_i$ is the positional encoding vector for position $i$. The end result is that each modified input token vector contains both information regarding the token as well as the position of that token.</p>

<p>These positional encodings can either be learned during training (i.e., each position integer is mapped to a learned encoding vector), or more commonly, can be fixed <em>a priori</em>. For example, <a href="https://papers.nips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf">Vaswani <em>et al.</em> (2017)</a>, use positional encodings built from sine and cosine functions of different frequencies for each dimension. A heatmap of such positional encodings are showing below where the rows are positions and the columns are dimensions of the input token vectors:</p>

<p><br /></p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/attention_positional_encoding.png" alt="drawing" width="450" /></center>

<p><br /></p>

<h2 id="multiheaded-attention">Multiheaded attention</h2>

<p>In order to expand the learning capacity of a model, one can also <em>parallelize</em> attention in each layer using an extension of attention called <strong>multiheaded attention</strong>. In multiheaded attention, one performs the attention operation <em>multiple times</em> using multiple sets of queries, keys, and values. That is, at a given layer, the nerual network learns multiple $W_Q$, $W_K$, and $W_V$ vectors and performs attention multiple times. Each attention mechanism is called a “head” and the full layer is called “multi-headed attention”. The final output vectors from multiheaded attention are formed by concatenating the outputs of the indivudal heads as shown below:</p>

<p><br /></p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/attention_multiheaded.png" alt="drawing" width="650" /></center>

<p><br /></p>

<p>If the multiheaded attention layer is followed by a fully connected layer, these concatenated vectors would be each fed into the downstream multilayer perceptron.</p>

<p>Multiheaded attention enables an attention layer to learn different kinds of relationships between entities. In text, for example, one head might learn possessive relationships between entities. For example, in the sentence, “Joe’s dog ran into Hannah’s yard”, one head might relate “Joe” and “dog” as well as “Hannah” and “yard”. Another head might learn relationships between objects and places. In this sentence the second head might relate “dog” to “yard” because the dog ran into the yard.</p>

<h2 id="what-makes-attention-so-powerful">What makes attention so powerful?</h2>

<p>As eluded to in the introduction to this post, the attention layer is widely regarded to be one of the most important breakthroughs that enabled the development of modern AI systems and large language models. But what exactly makes attention so powerful?</p>

<p>Before attention, models had difficulty relating together distant pieces of input data together whose joint consideration would be critical for accomplishing the model’s task at hand. For example, when trained on long sequences of text, models would have trouble relating words that were far away from each other in the document. For example, <a href="https://en.wikipedia.org/wiki/Recurrent_neural_network">recurrent neural networks</a> would often “forget” about text that appeared early in the document. While <a href="https://en.wikipedia.org/wiki/Long_short-term_memory">long term short term memory networks</a> helped mitigate this problem, they did not fundamentally solve it. Attention provides a solution to this challenge because it  explicitly enables a neural network to relate data together regardless of their distance in the dataset.</p>

<p>Attention has also been applied to computer vision where, for example, it also has been challenging for models to relate regions of an image that are far way from eachother. In order for a <a href="https://en.wikipedia.org/wiki/Convolutional_neural_network">convolutional neural network (CNN)</a> to jointly “see” and operate over two distant regions of an image, the CNN must consist of many layers. This is because the size of the <a href="https://en.wikipedia.org/wiki/Receptive_field">receptive field</a> of a given neuron in a CNN is determined by the number of layers that precede that neuron in the model’s architecture. The <a href="https://en.wikipedia.org/wiki/Vision_transformer">vision transformer</a> is a neural network architecture that uses attention over image patches to explicitly link regions of the image together no matter how distant they are.</p>

<p>Perhaps most importantly, <em>attention scales</em>. As researchers have scaled attention-based models to ever larger sizes, it appears that the models continue to improve. In fact, in recent years, the community has discovererd empirical <a href="https://en.wikipedia.org/wiki/Neural_scaling_law">scaling laws</a> over training set and model sizes – that is, as models and datasets grow, performance seems to improve at a predictable pace. At the time of this writing, frontier large language models are built with <a href="https://www.cometapi.com/how-many-parameters-does-gpt-5-have/"><em>trillions</em> of parameters</a> and trained on nearly the entire internet’s worth of data.</p>

<h2 id="applying-attention-to-a-simple-classification-problem">Applying attention to a simple classification problem</h2>

<p>To demonstrate an implementation of attention, we train a simple attention-based model to perform binary classification on a task that is not solveable using a naïve <a href="https://en.wikipedia.org/wiki/Bag-of-words_model">bag of words</a> model. All code for this experiment can be found on <a href="https://colab.research.google.com/drive/1wEu844liSYWoPv6MiUpx8IwNwB8FddEM?usp=sharing">Google Colab</a> and in the Appendix to this blog post.</p>

<p>Specifically, we develop a problem setting in which our task is to classify sentences into one of two categories:</p>

<ol>
  <li>Sentences that describe a white car sitting to the left of a black car (positive class)</li>
  <li>Sentences that describe a black car sitting to the left of a white car (negative class)</li>
</ol>

<p>Critically, the dataset consists of pairs of sentences, a positive and negative example, that both share the same set of word frequencies and thus, these sentences are indistinguishable using bag-of-words representations alone. For example, two sentences in this training set are:</p>

<ul>
  <li><code class="language-plaintext highlighter-rouge">Within a white frame and black border the white car remains on the left while the black car remains on the right</code> (Positive)</li>
  <li><code class="language-plaintext highlighter-rouge">Within a white frame and black border the black car remains on the left while the white car remains on the right</code> (Negative)</li>
</ul>

<p>The training set consists of 528 sentences (264 pairs). The file can be found <a href="https://github.com/mbernste/blog_posts/blob/main/attention_dataset/training_set3.json">here on GitHub</a>.</p>

<p>We train two very simple models on this task:</p>
<ol>
  <li>An attention-based model consisting of four layers, each consisting of attention (just one head) followed by a fully connected layer. To classify a sentence, the attention based model averages the token-representations coming out of the last layer and passes them through one final linear layer that outputs a <a href="https://en.wikipedia.org/wiki/Logit">logit</a>. Both the token embeddings and positional embeddings are learned during training.</li>
  <li>A simple multilayer perceptron trained on <a href="https://en.wikipedia.org/wiki/Bag-of-words_model">bag of words</a> representations of each sentence</li>
</ol>

<p>The architecture of the attention-based classifier is shown below:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/attention_simple_classifier.png" alt="drawing" width="400" /></center>

<p>When training on 90% of the pairs and testing on the remaining 10%, the attention-based model achieves <strong>96%</strong> whereas the bag-of-words-based model achieves <strong>50%</strong> (as expected, since there is no signal in the word frequencies).</p>

<p>We can also investigate the attention scores between words. For example, below are the attention scores in the second layer of the network when given the sentence, “Listed left to right is a white car then black car”:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/attention_example_scores.png" alt="drawing" width="400" /></center>

<p>Here, element (i,j) denotes how much word i is “attending to” word j – that is, how much word j’s value is being weighted in word i’s weighted sum. It is interesting to note in the above example how the word “left” is attending most to “white”. While difficult to ascertain exactly how this is effecting the model’s predictions, it intuitively makes sense that the model is relating these key words to one another.</p>

<h2 id="further-reading">Further Reading</h2>

<ul>
  <li>Much of my understanding of this material came from the excellent blog post, <em><a href="https://jalammar.github.io/illustrated-transformer/">The Illustrated Transformer</a></em> by Jay Allamar.</li>
</ul>

<h2 id="appendix">Appendix</h2>

<p>Here, we display all code for training the small attention-based classifier.</p>

<p>We start with implementations of Python functions for tokenization, batch-padding, and loading the data.</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>import re
import json
import numpy as np
import string
import math
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn


def tokenize(sentence):
    """
    Tokenize a sentence by:
    - lowercasing
    - separating punctuation into standalone tokens
    - splitting on whitespace

    Example:
    "From left to right, the lineup is: a red cone"
    -&gt;
    ["from", "left", "to", "right", ",", "the", "lineup", "is", ":", "a", "red", "cone"]
    """
    sentence = sentence.lower()

    # Put spaces around punctuation we care about
    #sentence = re.sub(r"([.,;:()])", r" \1 ", sentence)

    # Remove punctation
    sentence = sentence.translate(str.maketrans('', '', string.punctuation))

    # Collapse multiple spaces and split
    tokens = sentence.split()

    return tokens


def build_vocab(sentences, min_freq=1):
    """
    Build a token -&gt; index vocabulary from a list of sentences.

    Returns:
    - token_to_id: dict
    - id_to_token: list
    """
    freq = {}

    for sent in sentences:
        for tok in tokenize(sent):
            freq[tok] = freq.get(tok, 0) + 1

    # Special tokens
    tokens = ["&lt;pad&gt;", "&lt;unk&gt;"]

    # Add real tokens
    for tok, count in freq.items():
        if count &gt;= min_freq:
            tokens.append(tok)

    token_to_id = {tok: i for i, tok in enumerate(tokens)}
    id_to_token = tokens

    return token_to_id, id_to_token


def encode(sentence, token_to_id):
    """
    Convert a sentence into a list of token IDs.
    """
    unk_id = token_to_id["&lt;unk&gt;"]

    return [
        token_to_id.get(tok, unk_id)
        for tok in tokenize(sentence)
    ]


def pad_batch(encoded_sentences, pad_id):
    """
    Pad a list of encoded sentences to the same length.

    Returns:
    - input_ids: LongTensor (batch_size, max_len)
    - attention_mask: BoolTensor (batch_size, max_len)
    """
    max_len = max(len(s) for s in encoded_sentences)
    batch_size = len(encoded_sentences)

    input_ids = torch.full(
        (batch_size, max_len),
        pad_id,
        dtype=torch.long
    )

    attention_mask = torch.zeros(
        (batch_size, max_len),
        dtype=torch.bool
    )

    for i, seq in enumerate(encoded_sentences):
        input_ids[i, :len(seq)] = torch.tensor(seq)
        attention_mask[i, :len(seq)] = True

    return input_ids, attention_mask


class SentenceBinaryDataset(Dataset):
    """
    Stores (sentence, label) pairs and encodes sentences into token IDs on-the-fly.
    """
    def __init__(self, data, token_to_id):
        """
        data: list of (sentence: str, label: int) tuples
        token_to_id: dict mapping tokens -&gt; ids
        """
        self.data = data
        self.token_to_id = token_to_id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sentence, label = self.data[idx]
        token_ids = encode(sentence, self.token_to_id)  # list[int]
        return token_ids, int(label)


def make_collate_fn_attention_model(pad_id: int):
    """
    Returns a collate_fn that:
      - pads sequences in the batch
      - builds an attention_mask (True for real tokens)
      - returns tensors ready for the model
    """
    def collate_fn(batch):
        # batch is a list of (token_ids_list, label_int)
        token_id_lists, labels = zip(*batch)

        input_ids, attention_mask = pad_batch(token_id_lists, pad_id=pad_id)
        labels = torch.tensor(labels, dtype=torch.float32)  # for BCEWithLogitsLoss

        return {
            "input_ids": input_ids,               # (B, T) LongTensor
            "attention_mask": attention_mask,     # (B, T) BoolTensor
            "labels": labels                      # (B,) FloatTensor
        }

    return collate_fn


def create_dataloader_attention(
      data,
      token_to_id,
      batch_size=32,
      shuffle=True,
      num_workers=0
    ):
    """
    Convenience wrapper to create a DataLoader for (sentence, label) data.
    """
    pad_id = token_to_id["&lt;pad&gt;"]
    ds = SentenceBinaryDataset(data, token_to_id)
    dl = DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        collate_fn=make_collate_fn_attention_model(pad_id),
        drop_last=False
    )
    return dl


def train_pair_indices(num_rows, train_frac, seed=None):
    """
    Assumes that rows of the dataset are paired as positive, negative examples
    (e.g., row 0 is a positive example and row 1 is a paired negative example).
    This function outputs training indices from the dataset ensuring that
    pairs are either both included or excluded from the training set.
    """
    assert num_rows % 2 == 0, "Number of rows must be even (paired rows)."

    num_pairs = num_rows // 2
    rng = np.random.default_rng(seed)

    # shuffle pair indices
    pair_ids = rng.permutation(num_pairs)

    # number of training pairs
    n_train_pairs = int(train_frac * num_pairs)

    train_pairs = pair_ids[:n_train_pairs]

    # convert pair ids -&gt; row indices
    train_indices = np.concatenate([
        np.array([2*p, 2*p + 1]) for p in train_pairs
    ])

    return train_indices
</code></pre></div></div>

<p>Below is the code that implements the small attention-based classifier:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>class SingleHeadSelfAttention(nn.Module):
    """
    Minimal single-head self-attention:
      Q = X Wq, K = X Wk, V = X Wv
      Attn(X) = softmax(QK^T / sqrt(d)) V
    """
    def __init__(self, d_model: int):
        super().__init__()
        self.d_model = d_model
        self.Wq = nn.Linear(d_model, d_model, bias=False)
        self.Wk = nn.Linear(d_model, d_model, bias=False)
        self.Wv = nn.Linear(d_model, d_model, bias=False)

    def compute_attention_scores(self, x: torch.Tensor):
        """
        x: (B, T, D) -- i.e., batch-size, number of tokens, dimensions
        """
        B, T, D = x.shape
        q = self.Wq(x)
        k = self.Wk(x)

        # For each sentence in the batch, compute the attention score between
        # each pair of tokens.
        #
        # q: (B, T, D)
        # k.transpose(-2, -1): (B, D, T)
        # scores: (B, T, T)
        #
        # For a given batch, b, element i,j of scores[b] denotes how much
        # token i should weight token j when updating the representation of i
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(D)  # (B, T, T)

        return scores

    def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None):
        """
        x: (B, T, D) -- i.e., batch-size, number of tokens, dimensions
        attn_mask: (B, T) boolean, True for real tokens, False for padding
        returns: (B, T, D)
        """
        # Compute attention scores (B, T, T)
        scores = self.compute_attention_scores(x)

        # Compute value vectors
        v = self.Wv(x)

        # Mask out padding tokens
        # Wherever the mask is False, the scores are replaced with
        # -inf. When pushed through the softmax function, these then become
        # zero.
        if attn_mask is not None:
            key_mask = attn_mask.unsqueeze(1)  # (B, 1, T)
            scores = scores.masked_fill(~key_mask, float("-inf"))

        weights = torch.softmax(scores, dim=-1)
        out = weights @ v
        return out


class AttentionBlock(nn.Module):
    """
    A tiny Transformer-like block:
      - single-head self-attention
      - residual + layernorm
      - 2-layer MLP (feed-forward)
      - residual + layernorm

    Still minimal, but stacking these makes training much more stable than stacking
    raw attention layers without normalization/residuals.
    """
    def __init__(self, d_model: int, mlp_ratio: int = 4):
        super().__init__()
        self.attn = SingleHeadSelfAttention(d_model)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

        hidden = mlp_ratio * d_model
        self.mlp = nn.Sequential(
            nn.Linear(d_model, hidden),
            nn.ReLU(),
            nn.Linear(hidden, d_model),
        )

    def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None):
        # Attention sublayer
        x = self.ln1(x + self.attn(x, attn_mask=attn_mask))

        # Feed-forward sublayer
        x = self.ln2(x + self.mlp(x))
        return x


class TinyAttentionBinaryClassifier(nn.Module):
    """
    Configurable attention-based binary classifier.

    tokens -&gt; embedding (+ positional embedding)
          -&gt; N attention blocks (single-head)
          -&gt; mean pool over tokens (masked)
          -&gt; linear -&gt; logit
    """
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 64,
        max_len: int = 128,
        pad_id: int = 0,
        num_layers: int = 1,
        mlp_ratio: int = 4
    ):
        super().__init__()
        self.pad_id = pad_id
        self.max_len = max_len

        # Learnable token embeddings
        self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)

        # Learnable positional encodings
        self.pos_emb = nn.Embedding(max_len, d_model)

        # Sequence of self-attention blocks
        self.blocks = nn.ModuleList(
            [AttentionBlock(d_model, mlp_ratio=mlp_ratio) for _ in range(num_layers)]
        )

        # Map the mean of the embeddings from the last layer to a single
        # logit output
        self.fc = nn.Linear(d_model, 1)

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None):
        """
        input_ids: (B, T)
        attention_mask: (B, T) boolean, True for non-padding tokens
        returns: logits (B,)
        """
        B, T = input_ids.shape
        if T &gt; self.max_len:
            raise ValueError(f"Sequence length {T} exceeds max_len={self.max_len}")

        # Token embeddings. Project to (B, T, D)
        tok = self.token_emb(input_ids)

        # Positional embeddings. Project integers to (B, T, D)
        pos_ids = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, T)
        pos = self.pos_emb(pos_ids)

        # Final embedding is token embedding plus positional embedding
        x = tok + pos

        # Pass data through attention blocks
        for blk in self.blocks:
            x = blk(x, attn_mask=attention_mask)

        # Mean pooling over last layer's non-padding tokens
        if attention_mask is None:
            pooled = x.mean(dim=1)
        else:
            mask = attention_mask.unsqueeze(-1).float()  # (B, T, 1)
            pooled = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)

        # Map the mean of the embeddings from the last layer to a single
        # logit output
        logits = self.fc(pooled).squeeze(-1)
        return logits

    def compute_attention_scores(self, input_ids):
        """
        Compute attention scores at each layer
        """
        B, T = input_ids.shape

        # Token embeddings. Project to (B, T, D)
        tok = self.token_emb(input_ids)

        # Positional embeddings. Project integers to (B, T, D)
        pos_ids = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, T)
        pos = self.pos_emb(pos_ids)

        # Final embedding is token embedding plus positional embedding
        x = tok + pos

        # Pass data through attention blocks
        scores = []
        for blk in self.blocks:
            scores.append(blk.attn.compute_attention_scores(x))
        return scores
</code></pre></div></div>

<p>Now a function for training the model:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>def train_attention_classifier(
    model,
    train_loader,
    epochs=10,
    lr=1e-3,
    weight_decay=0.0,
    device=None,
):
    """
    Super simple binary classification training loop for the attention model.

    Assumes each batch is a dict with:
      - batch["input_ids"]       LongTensor (B, T)
      - batch["attention_mask"]  BoolTensor (B, T)
      - batch["labels"]          FloatTensor (B,) with values 0/1

    Uses BCEWithLogitsLoss (so model should output logits, not probabilities).
    """
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.BCEWithLogitsLoss()

    for epoch in range(1, epochs + 1):
        # ---- Train ----
        model.train()
        total_loss = 0.0
        correct = 0
        total = 0

        for batch in train_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)  # float 0/1

            optimizer.zero_grad()

            logits = model(input_ids, attention_mask)  # (B,)
            loss = criterion(logits, labels)

            loss.backward()
            optimizer.step()

            total_loss += loss.item() * labels.size(0)

            # Accuracy
            preds = (torch.sigmoid(logits) &gt;= 0.5).long()
            correct += (preds == labels.long()).sum().item()
            total += labels.size(0)

        train_loss = total_loss / max(total, 1)
        train_acc = correct / max(total, 1)

        print(
          f"Epoch {epoch:02d}/{epochs} | "
          f"train loss {train_loss:.4f} acc {train_acc:.3f}"
        )
    return model
</code></pre></div></div>

<p>Finally, below is the code that ties it all together to load the training data, partition into training and test sets, and train the model:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code># Load data
with open('./training_set3.json', 'r') as f:
  dataset_1 = json.load(f)['data']

# Map tokens to IDs
sentences = [pair[0] for pair in dataset_1]
token_to_id, id_to_token = build_vocab(sentences)
print("Number of tokens: ", len(token_to_id))

# Training and test set indices
train_indices = train_pair_indices(len(dataset_1), 0.8)
test_indices = [
    i for i in range(len(dataset_1))
    if i not in train_indices
]
assert len(set(train_indices) &amp; set(test_indices)) == 0

# Partition dataset into training and test
train_data = [dataset_1[i] for i in train_indices]
test_data = [dataset_1[i] for i in test_indices]
print("Number of training samples: ", len(train_data))
print("Number of test samples: ", len(test_data))

# Data loaders
train_loader_attention = create_dataloader_attention(
    train_data, token_to_id, batch_size=32, shuffle=True
)
test_loader_attention  = create_dataloader_attention(
    test_data, token_to_id, batch_size=64, shuffle=False
)

# Construct model
attention_model = TinyAttentionBinaryClassifier(
    vocab_size=len(token_to_id),
    num_layers=4,
    d_model=64,
    max_len=128,
    pad_id=token_to_id["&lt;pad&gt;"]
)

# Train model
attention_model = train_attention_classifier(
  attention_model,
  train_loader_attention,
  epochs=50,
  lr=1e-4
)
</code></pre></div></div>

<p>Lastly, below is code to evaluate the model:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>attention_model.eval()
test_correct = 0
test_total = 0
with torch.no_grad():
    for batch in test_loader_attention:
        input_ids = batch["input_ids"].to('cpu')
        attention_mask = batch["attention_mask"].to('cpu')
        labels = batch["labels"].to('cpu')

        # Generate logits from model
        logits = attention_model(input_ids, attention_mask)

        # Convert to predictions (i.e., probability &gt;= 0.5)
        preds = (torch.sigmoid(logits) &gt;= 0.5).long()
        test_correct += (preds == labels.long()).sum().item()
        test_total += labels.size(0)

# Compute accuracy
test_acc = test_correct / max(test_total, 1)
print("Accuracy: ", test_acc)
</code></pre></div></div>]]></content><author><name>Matthew N. Bernstein</name></author><category term="tutorial" /><category term="machine learning" /><category term="deep learning" /><category term="transformers" /><category term="attention" /><summary type="html"><![CDATA[Attention is a type of layer in a neural network that is widely regarded to be one of the most important breakthroughs that enabled the development of modern AI systems and large language models. At its heart, attention is a mechanism for explicitly drawing relationships between items in a set. In natural language processing, the set being processed are words (or tokens) and attention enables the model to relate those words to one another even when those words lie far away from eachother in the body of text. In this blog post, we will step through the attention mechanism both mathematically and intuitively. We then present a minimal example of a neural network that uses attention to perform binary classification in a task that is not solveable using a naïve bag-of-words model.]]></summary></entry><entry><title type="html">Demystifying Euler’s number</title><link href="https://mbernste.github.io/posts/eulers_number/" rel="alternate" type="text/html" title="Demystifying Euler’s number" /><published>2025-01-26T00:00:00-08:00</published><updated>2025-01-26T00:00:00-08:00</updated><id>https://mbernste.github.io/posts/eulers_number</id><content type="html" xml:base="https://mbernste.github.io/posts/eulers_number/"><![CDATA[<p><em>Euler’s number $e := 2.71828\dots$ has, to me, always been a semi-mysterious number. While I understood many facts about $e$, I never felt I ever truly understood what it really was – it’s core essence so to speak. I believe that part of the reason for my confusion is that $e$ is often taught coming from two seemingly different perspectives: Either it is introduced in the context of compound interest or it is introduced in the context of calculus as being the base of the exponential function whose derivative is itself. Thanks to an excellent explanation by Grant Sanderson’s <a href="https://www.youtube.com/watch?v=m2MIpDrF7Es">3Blue1Brown video</a>, I now better understand this constant and how these two perspectives relate to one another.  In this blog post, I will attempt to describe, in my own words, my understanding of Euler’s number and expound on Sanderson’s explanation.</em></p>

<h2 id="introduction">Introduction</h2>

<p>Euler’s number $e := 2.71828\dots$ has, to me, always been a semi-mysterious number. While I understood many facts about $e$, I never felt I ever truly understood what it really was – it’s core essence so to speak. I believe that part of the reason for my confusion is that $e$ is often taught coming from two seemingly different perspectives. Specifically, $e$ is introduced in one of two ways:</p>

<ol>
  <li>As a constant that arises from deriving the formula for <a href="https://en.wikipedia.org/wiki/Compound_interest#Continuous_compounding">continuously compounded interest</a></li>
  <li>As the base of an <a href="https://en.wikipedia.org/wiki/Exponential_function">exponential function</a> whose derivative is the value of the function itself. That is, it is the constant, $e$, such that $\frac{de^x}{dx} = e^x$.</li>
</ol>

<p>For a long time, I had trouble seeing how these two perspectives of $e$ were related to one another. Moreover, I didn’t have an understanding for why $e$ appears so often in equations across science and mathematics. In context of my <a href="https://mbernste.github.io/posts/understanding_3d/">prior blog post on “seeing concepts in 3D”</a>, I didn’t have a “3D image” in my mind for how these two different “2D projections” of $e$ formed a cohesive “3D” concept.</p>

<p>With the help of Grant Sanderson’s excellent <a href="https://www.youtube.com/watch?v=m2MIpDrF7Es">3Blue1Brown video</a>, I feel now that I have a much better understanding for it’s essence. In this blog post, I will attempt to describe, in my own words, my understanding of Euler’s number and expound on his explanation by attempting to tie together the two aforementioned perspectives of $e$.</p>

<p>To spoil the punchline, Euler’s constant, at its most abstract and fundamental level, is a number that describes all <a href="https://en.wikipedia.org/wiki/Exponential_function">exponential functions</a> – that is, functions of the form $f(x) := a^x$. In a very loose analogy, $e$ is to exponential functions as $\pi$ is to circles. They both are constants that describe some fundamental characteristic of a family of foundational mathematical structures (i.e., circles and exponential functions respectively). With this most fundamental understanding, we can tie together the two aforementioned perspectives of $e$.</p>

<h2 id="e-describes-the-essence-of-exponential-functions">$e$ describes the essence of exponential functions</h2>

<p>At its most abstract and fundamental level, $e$ is a number that can be used to naturally represent all exponential functions. To arive at this understanding, let’s first discuss what it means for a function to be an exponential function.</p>

<p>To review, exponential functions are functions of the form:</p>

\[f(x) := a^x\]

<p>for some constant $a$. Exponential functions do not only grow, but their <em>growth</em> also grows. When plotted, they look like the following:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/exponential.png" alt="drawing" width="400" /></center>

<p>The key characteristic of exponential functions – the very characteristic that <em>defines</em> exponential functions – is that their growth grows linearly with the value of the function itself. Stated more rigorously, <strong>an exponential’s derivative <em>is proportional</em> to the value of the function itself.</strong> That is:</p>

\[\frac{da^x}{dx} = Ka^x\]

<p>where $K$ is some constant that is determined by $a$.</p>

<p>That is, for all values of $x$, the derivative of $a^x$ is simply $a^x$ itself multiplied by some constant $K$. One can gain intuition for this fact by simply observing the function’s curve; The bigger is $a^x$ the steeper is the rate of change:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/exponential_proportional_to_value.png" alt="drawing" width="400" /></center>

<p>Let’s prove this fact more rigorously by starting with the definition of the derivative of $a^x$:</p>

\[\begin{align*}\frac{da^x}{dx} &amp;= \lim_{h \rightarrow 0} \frac{a^{x+h} - a^x}{h} \\ &amp;= \lim_{h \rightarrow 0} \frac{a^xa^h - a^x}{h} \\ &amp;= a^x \underbrace{\lim_{h \rightarrow 0} \frac{a^h - 1}{h}}_{K} \end{align*}\]

<p>Note, that the derivative of $a^x$ is simply $a^x$ scaled by a constant:</p>

\[K := \lim_{h \rightarrow 0} \frac{a^{h} - 1}{h}\]

<p>Moreover, we see that this constant is determined by the value of $a$ – that is, it is a function of $a$. We can make the dependence of $K$ on $a$ by write this constant as a function of $a$:</p>

\[k(a) := \lim_{h \rightarrow 0} \frac{a^{h} - 1}{h}\]

<h3 id="defining-eulers-number">Defining Euler’s number</h3>

<p>Given our newfound understanding of exponential functions as functions whose derivative is proportional to themselves, a natural question that follows is: What exponential function, $a^x$, yields a constant of 1? That is, for what value of $a$ do we have $k(a) = 1$? It may not surprise you to learn that it’s Euler’s number!</p>

<p>That is, $e$ is the value for $a$ that satisfies the following equation:</p>

\[1 = k(a) = \lim_{h \rightarrow 0} \frac{a^{h} - 1}{h}\]

<p>Said differently, Euler’s number is the base of the exponential function for which the derivative of that exponential function is the exponential function itself:</p>

\[\frac{de^x}{dx} = e^x\]

<p>In a sense, $e$ defines the “base” exponential function; By “base” I mean the exponential function whose rate of change is itself (i.e., whose constant of proportionality is 1).</p>

<p>Of course, this fact does not actually tell us how to calculate $e$’s value. To calculate $e$’s value, we need to derive an alternative formula for $f(x) := e^x$ that does not contain $e$ and then plug $x := 1$ into this formula. We do so by first noting that $f’(x) = f(x)$ is a first-order differential equation. Coupling this differential equation with the fact that $f(0) = e^0 = 1$, we realize this is an <a href="https://en.wikipedia.org/wiki/Initial_value_problem#:~:text=In%20multivariable%20calculus%2C%20an%20initial,given%20point%20in%20the%20domain.">initial value problem</a> that can be solved using the <a href="https://en.wikipedia.org/wiki/Euler_method">Euler Method</a>. If we do so, we find that</p>

\[f(x) = e^x = \lim_{n \rightarrow \infty} (1 + \frac{x}{n})^n\]

<p>See Theorem 1 in the Appendix to this post for the complete derivation.</p>

<p>Plugging in $x := 1$, we find that $e$ can be expressed as a limit that enables us to compute numerical approximations of its value:</p>

\[e = \lim_{n \rightarrow \infty} \left(1 + \frac{1}{n}\right)^n\]

<p>We can calculate ever closer approximations to $e$ by simply plugging in larger and larger values for $n$. If we do so, we find that $e \approx 2.71828$.</p>

<h3 id="all-exponential-functions-can-be-expressed-in-a-more-intuitive-way-with-e">All exponential functions can be expressed in a more intuitive way with $e$</h3>

<p>This fact about $e$ – that it is the base of the exponential function whose derivative is itself – does not quite explain why $e$ is so ubiquitous. Why do we see so many equations involving $e$?</p>

<p>The answer is that $e$ can be used to describe <em>all</em> exponential functions in a more intuitive and “natural” way. Thus, <em>any</em> equation that relates to exponential functions will likely involve $e$! It is sort of similar to how any equation that involves rotations or circles will usually involve $\pi$.</p>

<p>To see why, say we have some exponential function $a^x$. As we showed above, the derivative of $a^x$ is proportional to $a^x$ with some constant of proportionality given by $k(a)$. This value, $k(a)$, would need to be computed. Is there a way to express $a^x$ using $k(a)$ so that this value is obvious from the expression? Yes! It is simply,</p>

\[a^x = e^{k(a) x}\]

<p>To see why, first note that</p>

\[a^x = e^{\log a^x} = e^{x \log a}\]

<p>Here we see that $a^x$ has been re-written as $e^{x \log a}$ where $\log a$ is simply a constant. It turns out that this constant, $\log a$, is really just $k(a)$. We know this because if we take the derivative of $e^{x \log a}$, we get</p>

\[\begin{align*}\frac{d a^x}{dx} &amp;= \frac{d e^{x \log a}}{dx} \\ &amp;= (\log a) e^{x \log a} &amp;&amp; \text{Chain rule} \\ &amp;= (\log a) a^x   \end{align*}\]

<p>That is, $\log a$ is that very constant of proportionality that we defined $k(a)$ to be!</p>

<p>Because every value for $a$ is associated with a unique constant $k(a)$, we can express all exponential functions using the constant $k(a)$ instead of $a$. That is, by using $e^{k(a) x}$ instead of $a^x$. Arguably, this form makes the exponential easier to interpret: Whenever you come upon an exponential function, $f(x) := e^{Kx}$, the rate of change of $f(x)$ at $x$ is simply the value of this function scaled by the constant $K$.</p>

<p>In a very loose analogy, $e$ is to exponential curves as $\pi$ is to circles. Where $\pi$ is a constant that describes all circles, $e$ describes all exponential functions. Specifically, $e$ describes the sort of “origin exponential function” – that is, the exponential function from which all other exponential functions can be described in reference to.</p>

<h2 id="e-arises-in-a-formula-for-continously-compounded-interest">$e$ arises in a formula for continously compounded interest</h2>

<p>Euler’s number was actually discovered first by <a href="https://en.wikipedia.org/wiki/Jacob_Bernoulli">Jacob Bernoulli</a> as it relates to <a href="https://en.wikipedia.org/wiki/Compound_interest">compound interest</a>. In fact, $e$ is often introduced to students in this way and it’s only discussed in relation to exponential functions in more advanced treatments of the topic. Let us now approach $e$ from the perspective of compound interest. We will later connect it from this perspective back to exponential functions.</p>

<p>Say we have $P$ dollars, as principal, that we lend out at an interest rate $r$ over some unit of time (e.g., one year). After this amount of time, we would have the following amount of money:</p>

\[\begin{align*}\text{Total} &amp;:= P + rP \\ &amp;= P(1+r)  \end{align*}\]

<p>Note that because the interest wasn’t paid out until the end of the perscribed time, the interest itself was not given the opportunity to earn any money. One way to fix this would be to have that interest “compound” – that is, be paid out at set increments and added to the principal.</p>

<p>Let’s now say that interest compounds every half unit of time (e.g., every six months instead of every year). Then at the halfway point, when the interest compounds for the first time, the $P$ dollars would earn $\frac{r}{2}P$ and this $\frac{r}{2}P$ could then earn interest for the remaining half of time. We can derive the total amount of money we have left as follows: Let $P_1$ be the money we have after interest compounds the first time. It compounds at a rate of $r / 2$ because we compounded it at half the time interval (i.e., the first half of time):</p>

\[\begin{align*}P_1 &amp;:= P + \frac{1}{2}P \\ &amp;= P\left(1+\frac{r}{2}\right)  \end{align*}\]

<p>We can calculate the total now by treating $P_1$ as the “principal” starting just after the interest compounds the first time from which it will earn at a rate of $r / 2$ (because it compounds over the second half of time):</p>

\[\begin{align*}P_2 &amp;:= P_1 + \frac{1}{2}P_1 \\ &amp;= P_1\left(1+\frac{r}{2}\right)  \end{align*}\]

<p>Plugging in $P_1$, we get the final total as:</p>

\[\begin{align*} \text{Total} &amp;:= P_1\left( 1+\frac{r}{2} \right) \\ &amp;= P\left(1+\frac{r}{2}\right)\left(1+\frac{r}{2}\right)  \\ &amp;= P\left(1+ \frac{r}{2}\right)^{2} \end{align*}\]

<p>We can play this game again and now instead of compounding twice, it compounds three times. In fact, we can increase the number of times that the money compounds to any arbitrary number, $n$, and we will find that the total we end up with at the end will be:</p>

\[\text{Total} := P\left(1-\frac{r}{n}\right)^n\]

<p>Before we move on, let’s generalize this formula slightly and consider the total we would earn after $t$ units of time and derive a formula that takes into account $t$. The derivation will be similar to how we derived the formula above that takes into account the number of times interest compounds per unit of time, $n$. We’ll start by considering $t = 2$. Let’s let $P_1$ be the interest earned after $t = 1$:</p>

\[\begin{align*}P_1 := P\left(1-\frac{r}{n}\right)^n  \end{align*}\]

<p>Similar to before, we can calculate the total by treating $P_1$ as the “principal” starting after the first unit of time and plugging in $P_1$:</p>

\[\begin{align*}\text{Total} &amp;:= P_1 \left(1-\frac{r}{n}\right)^n \\ &amp;= P\left(1-\frac{r}{n}\right)^n \left(1-\frac{r}{n}\right)^n \\ &amp;= P\left(1-\frac{r}{n}\right)^{2n}\end{align*}\]

<p>We can perform this derivation for any value of $t$ and we would find that</p>

\[\text{Total} := P\left(1-\frac{r}{n}\right)^{nt}\]

<p>Now that we’ve generalized our formula to take into account $t$, let’s now ask the question: What if we compound the interest _every possible instant –? That is, instead of compounding $n$ times, it compounds an infinite number of times: every possible instant. We can derive this formula by taking the limit of the above formula as $n$ approaches infinity:</p>

\[\begin{align*}\text{Total} &amp;:= \lim_{n \rightarrow \infty} P\left(1-\frac{r}{n}\right)^{tn} \\ &amp;= P \lim_{n \rightarrow \infty} \left(1-\frac{r}{n}\right)^{nt} \\ &amp;= Pe^{rt} &amp;&amp; \text{Theorem 1} \end{align*}\]

<p>This is simply an exponential function $e^{rt}$ scaled by $P$!</p>

<h2 id="connecting-the-dots">Connecting the dots</h2>

<p>So far, we have derived $e$ from two alternative perspectives:</p>

<ol>
  <li>As being the base of the exponential function whose derivative is itself</li>
  <li>As a constant used in a formula to compute continuously compounding interest</li>
</ol>

<p>Let’s wrap up this discussion by describing explicitly how these two ideas are connected. Let $F(t)$ be a function that tells us how much money we have during the duration of a loan whose interest is continously compounding. Furthermore, let’s pretend for a moment that we don’t know the equation $F(t) := Pe^{rt}$.</p>

<p>Intuitively, if interest is compounding continuously, then the instantaneous growth of $F(t)$ at $t$ should be given by the interest rate multiplied by however much money we currently have, which is given by $F(t)$. That is, we would naturally expect the instantaneous growth rate of our money to be $rF(t)$. Stated mathematically,</p>

\[\frac{dF(t)}{dt} = r F(t)\]

<p>And indeed we see that if $F(t) := Pe^{rt}$, then we can apply the <a href="https://en.wikipedia.org/wiki/Chain_rule">chain rule</a> to find that</p>

\[\begin{align*} \frac{dF(t)}{dt} &amp;:= \frac{dPe^{rt}}{dt} \\ &amp;=rPe^{rt} \\ &amp;= r F(t) \end{align*}\]

<p>This only works because of the fact that $\frac{de^x}{dx} = e^x$! The full picture is illustrated in the following diagram:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/Eulers_number_concept_map.png" alt="drawing" width="800" /></center>

<h2 id="further-reading">Further Reading</h2>
<ul>
  <li><a href="https://www.youtube.com/watch?v=m2MIpDrF7Es">This YouTube video by 3Blue1Brown</a></li>
  <li><a href="https://betterexplained.com/articles/an-intuitive-guide-to-exponential-functions-e/">This article by Better Explained</a></li>
</ul>

<h2 id="appendix">Appendix</h2>

<p><span style="color:#0060C6"><strong>Theorem 1 (Value of $e$):</strong> The value for $e$ is given by the following limit:</span></p>

<center><span style="color:#0060C6">$$e = \lim_{n \rightarrow \infty} \left(1 + \frac{1}{n}\right)^n$$</span></center>

<p><strong>Proof:</strong></p>

<p>Let</p>

\[f(x) := e^{x}\]

<p>and</p>

\[f'(x) = \frac{df(x)}{dx}\]

<p>We know the following facts about $f(x)$:</p>

\[\begin{align*}f(0) &amp;= e^0 = 1 \\ f'(x) &amp;= f(x) \end{align*}\]

<p>Note that $f’(x) = f(x)$ is a first-order differential equation and $f(0) = 1$ is an initial condition. Thus, given an arbitrary value for $x$, we can solve for $f(x)$ by solving this <a href="https://en.wikipedia.org/wiki/Initial_value_problem#:~:text=In%20multivariable%20calculus%2C%20an%20initial,given%20point%20in%20the%20domain.">initial value problem</a>. We can do so using the <a href="https://en.wikipedia.org/wiki/Euler_method">Euler Method</a>.</p>

<p>In order to solve for $f(x)$, we will use Euler’s method with increments of $\Delta t := x/n$ for some number of increments $n$.</p>

<p>We first note that for an arbitrary value of $t$, we can approximate $f(t + \Delta t)$ via</p>

\[f(t + \Delta t) \approx f(t) + \Delta t f'(t)\]

<p>Because $f’(t) = f(t)$, we have</p>

\[\begin{align*}f(t + \Delta t) &amp;\approx f(t) + \Delta t f'(t)  \\ &amp;= f(t) + \Delta t f(t) \\ &amp;= f(t)(1 + \Delta t) \end{align*}\]

<p>Using this formula, we can solve for $f(x)$ by starting at $t := 0$ and stepping towards $f(x)$ at increments of $\Delta t := \frac{x}{n}$ using the equation above. For the first step, where $t := 0$, we have</p>

\[\begin{align*}f(0 + \Delta t) &amp;\approx f(0) (1 + \Delta t) \\  &amp;= 1 + \frac{x}{n}\end{align*}\]

<p>Taking the second step, we have</p>

\[\begin{align*} f\left(\frac{x}{n} + \Delta t\right) &amp;\approx f\left(\frac{x}{n} \right) (1 + \delta t) \\  &amp;= \left(1 + \frac{x}{n}\right) \left(1 + \frac{x}{n}\right) \\ &amp;= \left(1 + \frac{x}{n}\right)^2 \end{align*}\]

<p>Extrapolating this all the way to $n$ steps, arriving at $f(x)$, we see that</p>

\[\begin{align*}f(x) &amp;= f\left( n\frac{x}{n} \right) \\ &amp;= f\left( (n-1)\frac{x}{n} + \frac{x}{n} \right) \\ &amp;=  \left( 1 + \frac{x}{n} \right)^{n-1} \left( 1 + \frac{x}{n} \right) \\ &amp;=  \left( 1 + \frac{x}{n} \right)^{n}  \end{align*}\]

<p>Finally, to derive $e$ itself, we plug in $x = 1$ and see that</p>

\[f(1) = e^1 = e \approx \left( 1 + \frac{1}{n} \right)^{n}\]

<p>Note that, as $n \rightarrow \infty$, this formula will converge on the true value of $e$ and thus,</p>

\[e = \lim_{n \rightarrow \infty} \left( 1 + \frac{1}{n} \right)^{n}\]

<p>$\square$</p>]]></content><author><name>Matthew N. Bernstein</name></author><category term="tutorial" /><category term="mathematics" /><summary type="html"><![CDATA[Euler’s number $e := 2.71828\dots$ has, to me, always been a semi-mysterious number. While I understood many facts about $e$, I never felt I ever truly understood what it really was – it’s core essence so to speak. I believe that part of the reason for my confusion is that $e$ is often taught coming from two seemingly different perspectives: Either it is introduced in the context of compound interest or it is introduced in the context of calculus as being the base of the exponential function whose derivative is itself. Thanks to an excellent explanation by Grant Sanderson’s 3Blue1Brown video, I now better understand this constant and how these two perspectives relate to one another. In this blog post, I will attempt to describe, in my own words, my understanding of Euler’s number and expound on Sanderson’s explanation.]]></summary></entry><entry><title type="html">Reproducing kernel Hilbert spaces and the kernel trick</title><link href="https://mbernste.github.io/posts/rkhs/" rel="alternate" type="text/html" title="Reproducing kernel Hilbert spaces and the kernel trick" /><published>2024-12-14T00:00:00-08:00</published><updated>2024-12-14T00:00:00-08:00</updated><id>https://mbernste.github.io/posts/rkhs</id><content type="html" xml:base="https://mbernste.github.io/posts/rkhs/"><![CDATA[<p><em>If you’re a practitioner of machine learning, then there is little doubt you have seen or used an algorithm that falls into the general category of kernel methods. The premier example of such methods is the support vector machine. When introduced to these algorithms, one is taught that one must provide the algorithm with a kernel function that, intuitively, computes a degree of “similarity” between the objects you are classifying. In practice, one can get pretty far with only this understanding; however, to understand these methods more deeply, one must understand a mathematical object called a reproducing kernel Hilbert space (RKHS). In this post, I will explain the definition of a RKHS and exactly how they produce the kernels used in kernel methods thereby laying a rigorous foundation for a deeper understanding of these methods.</em></p>

<h2 id="introduction">Introduction</h2>

<p>If you’re a practitioner of machine learning, then there is little doubt you have seen or used an algorithm that falls into the general category of <a href="https://en.wikipedia.org/wiki/Kernel_method">kernel methods</a>. The premier example of such methods is the <a href="https://en.wikipedia.org/wiki/Support-vector_machine">support vector machine</a>. When introduced to these algorithms, one is taught that one must provide the algorithm with a <strong>kernel function</strong> that, intuitively, computes a degree of “similarity” between the objects you are classifying (e.g., images, text documents, or <a href="https://mbernste.github.io/posts/rna_seq_basics/">gene expression profiles</a>).</p>

<p>A slightly deeper introduction will explain that what a kernel is really doing is projecting the given objects into some (possibly infinite dimensional) <a href="https://mbernste.github.io/posts/vector_spaces/">vector space</a> and then performing an <a href="https://en.wikipedia.org/wiki/Inner_product_space">inner product</a> on those vectors. That is, if $\mathcal{X}$ is the set of objects we are classifying, then a kernel, $K$, is a function:</p>

\[K: \mathcal{X} \times \mathcal{X} \rightarrow \mathbb{R}\]

<p>for which</p>

\[K(x_1, x_2) = \langle \phi(x_1), \phi(x_2)\rangle\]

<p>where $\phi(x)$ is a vector associated with object $x$ in some <a href="https://mbernste.github.io/posts/vector_spaces/">vector space</a>. This implicit calculation of an inner product between $\phi(x_1)$ and $\phi(x_2)$ is known as the <strong>kernel trick</strong> and it lies at the core of kernel methods.</p>

<p>To use a kernel method in practice, one can get pretty far with only this understanding; however, it is incomplete and to me, unsatisfying. What space are these objects projected into? How is a kernel function derived?</p>

<p>To answer these questions, one must understand a mathematical object called a <strong>reproducing kernel Hilbert space</strong> (RKHS). This object was a bit challenging for me to intuit, so in this post, I will explain the definition of a RKHS and exactly how they produce the kernels used in kernel methods thereby laying a rigorous foundation for a deeper understanding of the kernel trick.</p>

<h2 id="the-reproducing-kernel-hilbert-space">The reproducing kernel Hilbert space</h2>

<p>So what is a RKHS? First, we note from the name, “reproducing kernel Hilbert space” that a RKHS is, obviously, some kind of Hilbert space. To review, a <a href="https://en.wikipedia.org/wiki/Hilbert_space">Hilbert space</a> is a vector space that is 1) equipped with an inner product, and 2) is a <a href="https://en.wikipedia.org/wiki/Complete_metric_space">complete space</a> (that is, they are “infinitely dense”). Specifically,</p>

<p><span style="color:#0060C6"><strong>Definition 1 (Hilbert space):</strong> A <strong>Hilbert space</strong> is a complete inner-product space $(\mathcal{H}, \mathcal{F}), \langle ., . \rangle)$ where $\mathcal{H}$ is a set of vectors, $\mathcal{F}$ is a field of scalars, and $\langle ., . \rangle$ is an inner-product function.</span></p>

<p>Euclidean vector spaces are an example of a Hilbert space. In the case of a Euclidean space, the vectors are the set of coordinate real-valued vectors, $\mathbb{R}^n$, the scalars are the real numbers $\mathbb{R}$, and an inner product can be defined to be the <a href="https://mbernste.github.io/posts/dot_product/">dot product</a> $\langle \boldsymbol{x}, \boldsymbol{y} \rangle := \boldsymbol{x}^T \boldsymbol{y}$. Thus, Hilbert spaces are <em>generalizations</em> of Euclidean vector spaces. One can think of a Hilbert space as “behaving” like the familiar Euclidean space.</p>

<p>Next, we note that the vectors that comprise an RKHS are <em>functions</em>. Recall that a <a href="https://mbernste.github.io/posts/vector_spaces/">vector space</a> is a very general concept that extends the usual coordinate-vectors of a Euclidean space and can be comprised of sets of functions. As a quick review, just like one can add two arrows representing Euclidean vectors “tip to tail” to form a new Euclidean vector, so too can one add two “function vectors” to form a third:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/RKHS_functions_as_vectors.png" alt="drawing" width="550" /></center>

<p><br /></p>

<p>Similarly, just like one can scale a Euclidean vector by “stretching” it, so too one can scale a function:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/RKHS_functions_as_vectors_scaling.png" alt="drawing" width="550" /></center>

<p><br /></p>

<p>In fact, all of the <a href="https://mbernste.github.io/posts/vector_spaces/">axioms of a vector space</a> can be shown to hold for certain sets of functions. Moreover, one can even form a Hilbert space of functions so long as the functions are “infinitely dense” and one can devise an inner-product between functions. As its name suggests, a RKHS is a particular type of Hilbert space of functions.</p>

<p>Now, a RKHS is not just <em>any</em> Hilbert space of functions, but rather it is a Hilbert space of functions with a particular property: <strong>all evaluation functionals are continuous</strong>:</p>

<p><span style="color:#0060C6"><strong>Definition 2 (Reproducing kernel Hilbert space):</strong> A Hilbert space $(\mathcal{H}, \mathcal{F}, \langle ., . \rangle)$ is a <strong>reproducing kernel Hilbert space</strong> if given any $x \in \mathcal{X}$, the evaluation functional for $x$, $\delta_x(f) := f(x)$ (where $f \in \mathcal{H}$), is continuous.</span></p>

<p>Evaluation functional? What on earth is that? Let’s break this definition down.</p>

<p>First, let’s let $\mathcal{H}$ be the set of vectors in our Hilbert space, which in our case consists of functions. That is, each function $f \in \mathcal{H}$ maps elements in some <a href="https://en.wikipedia.org/wiki/Metric_space">metric space</a>, $\mathcal{X}$, to the real numbers $\mathbb{R}$.  That is,</p>

\[f : \mathcal{X} \rightarrow \mathbb{R}\]

<p>Recall from a <a href="https://mbernste.github.io/posts/functionals/">previous blog post</a> that a <strong>functional</strong> is simply a function that accepts as input another function and returns a scalar. In our case, a functional, $\ell$, on $\mathcal{H}$ would map each function in $\mathcal{H}$ to a scalar:</p>

\[\ell: \mathcal{H} \rightarrow \mathbb{R}\]

<p>Now, what is the evaluation functional? For any given $x \in \mathcal{X}$, we define the <strong>evaluation functional</strong>, $\delta_x$, to be simply,</p>

\[\delta_x(f) := f(x)\]

<p>This is a very simple definition. The evaluational functional $\delta_x$ simply takes as input a function $f$, evaluates it at $x$, and returns the value $f(x)$!</p>

<p>All of these sets can get pretty confusing, so here is a schematic for keeping all of these sets straight. Note, we have a set $\mathcal{H}$ of functions, we have a set $\mathcal{X}$ of “objects”, and we have a set $\mathbb{R}$ of scalars:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/RKHS_sets_schematic.png" alt="drawing" width="450" /></center>

<p><br /></p>

<p>Also note that each function $f \in \mathcal{H}$ defines a unique mapping between $\mathcal{X}$ and $\mathbb{R}$. Two such mappings are illustrated below:</p>

<p><br /></p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/RKHS_schematic_example_functions.png" alt="drawing" width="700" /></center>

<p><br /></p>

<p>In a similar vein, each $x \in \mathcal{X}$ can <em>also</em> define a mapping between $\mathcal{H}$ and $\mathbb{R}$ via the evaluation functionals. Specifically, we pick a value $x \in \mathcal{X}$ and this defines an evaluation functional which maps each function to the scalars:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/RKHS_schematic_example_function_w_functional.png" alt="drawing" width="700" /></center>

<p><br /></p>

<p>Specifically, each value $x \in \mathcal{X}$ defines a unique evaluation functional mapping functions in $\mathcal{H}$ to $\mathbb{R}$. Two such mappings are shown below:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/RKHS_schematic_example_functionals.png" alt="drawing" width="700" /></center>

<p><br /></p>

<p>Finally, let’s revisit the fundametal property of an RKSH that differentiates it from an arbitrary Hilbert space of functions: for any given $x \in \mathcal{X}$ the evaluation functional $\delta_x$ is continuous.  Let’s illustrate this definition schematically. Here we are illustrating the evaluation functional for some fixed $x$. Unlike the previous schematics depicting $\mathcal{H}$ as a set of discrete points (each point representing a function), the schematic below depicts $\mathcal{H}$ as a plane of infinitely dense functions with a smooth surface above it representing $\delta_x$. The smoothness of this surface is meant to emphasizes the continuity of $\delta_x$:</p>

<p><br /></p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/AxiomRKHS.png" alt="drawing" width="500" /></center>

<p><br /></p>

<p>Note, $\mathcal{H}$ does not actually form a plane (it is not $\mathbb{R}^2$), but this analogy of depicting $\mathcal{H}$ as a plane emphasizes that the functions/vectors in $\mathcal{H}$ are infinitely dense and that two similar functions in $\mathcal{H}$ will be mapped to two similar values by $\delta_x$. (If you can think of a better schematic emphasizing these characteristics more accurately, please let me know!).</p>

<p>Here’s another property of RKHS’s that follows from the definition: if we have a sequence of functions $f_1, f_2, f_3, \dots$ that converge on a function $f$ in $\mathcal{H}$, then these functions converge <em>pointwise</em> to $f$!  Let’s state this more rigorously:</p>

\[\lim_{n \rightarrow \infty} \vert\vert f_n - f \vert\vert_{\mathcal{H}} = 0 \implies \forall x, \ \lim_{n \rightarrow \infty} \vert f_n(x) - f(x) \vert = 0\]

<p>This is proven in Theorem 1 in the Appendix to this post. In other words, an RKHS is a space of functions that “vary smoothly” from one to another not only in $\mathcal{H}$ but also in regard to their mappings between $\mathcal{X}$ and $\mathbb{R}$. For example, if $\mathcal{H}$ consists of continuous univariate functions, then this means these functions vary smoothly from one to another as smoothly varying curves. An example of a sequence of univariate continuous functions converging on some function $f$ is illustrated below:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/PointwiseConvergenceRKHS.png" alt="drawing" width="500" /></center>

<p><br /></p>

<p>Here we see a sequence of functions $f_1, f_2, f_3, \dots, f_n$ converging on a function $f$ in the Hilbert space. Consequently, these functions are also converging on $f$ point-by-point for all $x \in \mathcal{X}$.</p>

<h2 id="the-reproducing-kernel">The reproducing kernel</h2>

<p>Strangely, you may have noticed that in our definition of a RKHS there is no mention of any object called a “kernel” let alone a “reproducing kernel”. So what exactly is the “reproducing kernel” in a “reproducing kernel Hilbert space”?</p>

<p>The “kernel” arises from a fundamental property that RKHS’s possess: you can “reproduce” any evaluation functional using inner products in the Hilbert space.  Specifically, for any given $x \in \mathcal{X}$, there exists some function $k_x \in \mathcal{H}$ where the following holds:</p>

\[\delta_x(f) = f(x) = \langle f, k_x \rangle\]

<p>This fact is described by the <a href="https://en.wikipedia.org/wiki/Riesz_representation_theorem">Riesz representation theorem</a>. We will not provide a proof for this theorem here; rather, we’ll state an abbreviated version:</p>

<p><span style="color:#0060C6"><strong>Theorem 1 (Riesz representation theorem - abbreviated):</strong> Given a Hilbert space $(\mathcal{H}, \mathcal{F}, \langle ., . \rangle)$ where $\mathcal{H}$ is the set of vectors, $\mathcal{F}$ are a set of scalars, $\langle ., . \rangle$ is an inner product on $\mathcal{H}$, and $\forall f \in \mathcal{H}$, $f$ is a function $f : \mathcal{X} \rightarrow \mathbb{R}$, where $\mathcal{X}$ is some set. Let $\ell$ be a continous linear functional $\ell: \mathcal{H} \rightarrow \mathcal{X}$, then there exists a unique function $f_{\ell} \in \mathcal{H}$ such that $\forall f \in \mathcal{H}, \ \ell{f} = \langle f_{\ell}, f\rangle$.</span></p>

<p>The Riesz representation theorem is not explicitly a statement about RKHSs, but it does hold for RKHSs. We prove this in Theorem 2 in the appendix to this post.</p>

<p>What this theorem says is that we can reproduce the action of $\delta_x$ by taking an inner product with some fixed function $k_x$ in our Hilbert space. This is called the <strong>reproducing property</strong> of a RKHS!  Here’s a schematic to illustrate this property:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/ReproducingPropertyRKHS.png" alt="drawing" width="500" /></center>

<p>What if we choose a different evaluation functional for another value $y \in \mathcal{X}$? Then, there exists a different function $k_y$ that we can use to reproduce $\delta_y$.  Now, recall the function $k_x$ (used to reproduce $\delta_x$) is an element of $\mathcal{H}$ like any other. We thus see that</p>

\[\delta_y(k_x) = \langle k_x, k_y \rangle\]

<p>This is illustrated below:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/KernelRKHS.png" alt="drawing" width="700" /></center>

<p>If you notice, we have taken two arbitrary values, $x, y \in \mathcal{X}$ and mapped them to two vectors \(k_x, k_y \in \mathcal{H}\) and then computed their inner product. This is the very notion of a kernel that we discussed in the introduction to this blog post!</p>

<p>That is, we can construct a function $K$ that operates on pairs of elements in $x, y \in \mathcal{X}$:</p>

\[K(x, y) := \delta_y(k_x) = \langle k_x, k_y \rangle\]

<p>This is the <strong>reproducing kernel</strong> (or simply <strong>kernel</strong>) of the RKHS! To make this clearer, we can denote \(\phi : \mathcal{X} \rightarrow \mathcal{H}\) to be the function that maps elements in $\mathcal{X}$ to their corresponding element in the Hilbert space that reproduces its corresponding evaluation functional! That is,</p>

\[\phi(x) := k_x, \ \text{where} \ \forall f \in \mathcal{H}, \ \delta_x(f) = \langle f, k_x \rangle\]

<p>and thus,</p>

\[K(x, y) = \langle \phi(x), \phi(y) \rangle\]

<p>This function, $\phi$, is often called the <strong>feature map</strong> as it maps our objects in $\mathcal{X}$ to a new “feature representation” in the RKHS. We’ll dig a bit more into this new feature representation in a later section.</p>

<h2 id="the-kernel-trick">The kernel trick</h2>

<p>So far, all of this has been very abstract; we have assumed we have Hilbert space that satisfies the axioms in the definition for a RKHS and showed that we can derive a kernel function from this RKHS.  Unfortunately, nothing we have discussed mentions how one can actually derive a kernel function.  That is, how does one find an appropriate function $\phi$ that magically maps each $x \in \mathcal{X}$ to its $k_x$ in an RKHS that satisfies the reproducing property for $\delta_x$?</p>

<p>It turns out that one does not actually need to derive the $\phi$ function explicitly. It turns out that one needs only find a function $K: \mathcal{X} \times \mathcal{X} \rightarrow \mathbb{R}$ that satisfies a certain property, <strong>positive-definiteness</strong>, and this will <em>automatically</em> be a kernel function for <em>some</em> unique RKHS! It means that we don’t need to actually define the RKHS itself, rather, we simply need to find a kernel function, $K$, that satisfies a particular property. This is stated in the Moore-Aronszajn theorem:</p>

<p><span style="color:#0060C6"><strong>Theorem 2 (Moore-Aronszajn Theorem):</strong> Let $K$ be a symmetric, positive definite function on a set $\mathcal{X}$. Then there exists a unique RKHS of functions on $\mathcal{X}$  for which $K$ is the reproducing kernel.</span></p>

<p>Thus, so long as we define a positive definite function, $K$, the mapping by $\phi$ happens implicitly! This convenient property is called the <strong>kernel trick</strong>; one does not need to actually map objects in $\mathcal{X}$ to $\mathcal{H}$. Rather, one only needs a positive definite function $K$, often representing “similarities” between objects in $\mathcal{X}$, and this function is implicity mapping the objects into the RKHS with no need to explicitly deal with the RKHS at all!</p>

<p>Before moving forward, let’s define a “symmetric positive-definite function”. First, let’s start with <strong>symmetric function</strong>. A multivariate function is symmetric, if the order of its arguments doesn’t matter:</p>

<p><span style="color:#0060C6"><strong>Definition 3 (positive-definite function):</strong>Let $K$ be a multivariate function, with $n$ arguments, $x_1, \dots, x_n$. $K$ is symmetric if its value is the same no matter the ordering of $x_1, \dots, x_n$. That is, for any two permutations $\sigma_1$ and $\sigma_2$, it holds that $K(x_{\sigma_1(1)}, \dots, x_{\sigma_1(n)}) = K(x_{\sigma_2(1)}, \dots, x_{\sigma_2(n)})$</span></p>

<p>Now, here is the definition for a <strong>positive-definite function</strong>:</p>

<p><span style="color:#0060C6"><strong>Definition 4 (symmetric, positive-definite function):</strong> Let $K$ be a symmetric bivariate function, $K: \mathcal{X} \times \mathcal{X} \rightarrow \mathbb{R}$. $K$ is semi-definite if $\forall \ n \in \mathcal{N}, \ \forall \boldsymbol{x}_1, \dots, \boldsymbol{x}_n \in \mathcal{X} \ \forall c_1, \dots, c_n \in \mathbb{R}$, it holds that,</span></p>

<p><span style="color:#0060C6">\(\sum_{i=1}^n \sum_{j=1}^n c_ic_jK(\boldsymbol{x}_i, \boldsymbol{x}_j) \geq 0\)</span></p>

<p>Notably, a positive definite kernel generalizes the idea of a <a href="https://en.wikipedia.org/wiki/Definite_matrix">positive definite matrix</a>. That is, if one has objects $\boldsymbol{x}_1, \boldsymbol{x}_2, \dots, \boldsymbol{x}_n$, then the matrix formed by computing all pairwise kernel values (often called the <a href="https://en.wikipedia.org/wiki/Gram_matrix">Gram Matrix</a>) is positive semi-definite. That is, the following matrix is positive semi-definite:</p>

\[K := \begin{bmatrix}K(\boldsymbol{x}_1, \boldsymbol{x}_1) &amp; \dots &amp; K(\boldsymbol{x}_1, \boldsymbol{x}_n) \\ \vdots &amp; \ddots &amp; \vdots \\ K(\boldsymbol{x}_n, \boldsymbol{x}_1) &amp; \dots &amp; K(\boldsymbol{x}_n, \boldsymbol{x}_n) \end{bmatrix}\]

<h2 id="what-vector-does-the-feature-map-phi-project-each-object-x-to">What vector does the feature map, $\phi$, project each object, $x$, to?</h2>

<p>Although the kernel implicitly evaluates the feature map $\phi$ on the two arguments to the kernel, one may be curious about the precise form of $\phi(x)$.  We know that $\phi(x)$ is a function because it is a member of the Hilbert space $\mathcal{H}$ that consists of functions, but what is the form of this function?</p>

<p>To answer this, let’s say we have some positive kernel $K$ and let’s fix one of the arguments so that $K(x, .)$ is only a univariate function with respect to the second argument. Then, we see that</p>

\[\begin{align*}K(x, .) &amp;= \langle \phi(x), \phi(.) \rangle \\  &amp;= \delta_{.}(\phi(x)) \\ &amp;= \phi(x)(.)\end{align*}\]

<p>What does this mean? it means that the function $\phi(x)$ is simply $K(x, .)$. Thus, we see that</p>

\[K(x, y) = \langle K(x, .), K(y, .)\rangle\]

<h2 id="why-is-phi-called-a-feature-map">Why is $\phi$ called a “feature map”?</h2>

<p>In machine learning, we generally consider a <a href="https://en.wikipedia.org/wiki/Feature_(machine_learning)">feature</a> (or in statistical parlance, a <a href="https://en.wikipedia.org/wiki/Dependent_and_independent_variables#Statistics_synonyms">covariate</a>) to be a single, quantifiable, property of an object. So far we’ve shown that the feature map, $\phi$, maps an object $x$ to the a function $K(x, .) \in \mathcal{H}$. Notably, there don’t seem to be any “features” associated with this function; so why is $\phi$ called a “feature map”?</p>

<p>It turns out that there is another way to represent $\phi(x)$ that is more in line with the idea of $\phi(x)$ mapping $x$ to a new set of “features”. We’ll simplify our discussion to real-valued vectors $\boldsymbol{x} \in \mathbb{R}^n$ (that is, $\mathcal{X} := \mathbb{R}^n$). This new representation comes by way of Mercer’s Theorem:</p>

<p><span style="color:#0060C6"><strong>Theorem 3 (Mercer’s Theorem):</strong> Let $K$ be a continuous, symmetric, positive-definite function defined over a <a href="https://en.wikipedia.org/wiki/Compact_space">compact set</a> $S \subset \mathbb{R}^n$. Let $(S, \Sigma, \mu)$ be a <a href="https://mbernste.github.io/posts/measure_theory_1/">measure space</a> defined over $S$ and $L^2(S, \mu)$ be the space of <a href="https://en.wikipedia.org/wiki/Lp_space">L2-functions</a> over $S$ with respect to measure function $\mu$. Then, define the <a href="https://mbernste.github.io/posts/matrices_linear_transformations/">linear operator</a></span></p>

<center><span style="color:#0060C6">$$T_{K} : L^2(S, \mu) \rightarrow \mathbb{R}$$</span></center>

<p><span style="color:#0060C6">such that for a given function $\phi \in L^2(S, \mu)$, $T_k(\phi)$ is defined as the function</span></p>

<center><span style="color:#0060C6">$$T_{K}(\phi)(\boldsymbol{x}) := \int_S K(\boldsymbol{x}, \boldsymbol{s}) \phi(\boldsymbol{s}) d\mu$$</span></center>

<p><span style="color:#0060C6">where this integral is a <a href="https://mbernste.github.io/posts/measure_theory_3/">Lebesgue integral</a>. Then, there is a sequence of orthonormal basis functions, $\psi_1, \psi_2, \dots$, that are eigenfunctions of $T_{K}$ and are associated with a sequence of eigenvalues, $\lambda_1, \lambda_2, \dots$. Moreover, the kernel function $K$, can be expressed as</span></p>

<center><span style="color:#0060C6">$$K(\boldsymbol{x}, \boldsymbol{y}) = \sum_{i=1}^\infty \lambda_i \psi_i(\boldsymbol{x}) \psi_i(\boldsymbol{y})$$</span></center>

<p>The statement in this theorem is rather complex. We won’t go into extreme detail in this post; rather, we will emphasize the major point necessary to understand where the “features” are coming from.</p>

<p>Specifically, the big idea is as follows: <strong>One can represent $\phi(x) = K(x, .)$ as a coordinate vector (in possibly infinite dimensions)!</strong> We’ll denote this vector as</p>

\[\begin{align*}\psi(x) &amp;:= [\sqrt{\lambda_1}\psi(x)_1, \sqrt{\lambda_2}\psi(x)_2, \dots] \\ &amp;= [\psi'_1(x), \psi'_2(x), \dots]\end{align*}\]

<p>where, for ease of notation, $\psi’_i(x) := \sqrt{\lambda_i}\psi_i(x)$ absorbs the constant term. Thus, we can execute the inner product performed by the kernel using a dot product on this new “feature representation”:</p>

\[K(x, y) = \langle K(x, .), K(y, .)\rangle = \boldsymbol{\psi}'(x)^T\boldsymbol{\psi}'(y) = \sum_{i=1}^\infty \psi'_i(x)\psi'_i(y)\]

<p>In this scenario, each $\psi’_i(x)$ can be interpreted as a new “feature” of $\boldsymbol{x}$ for which <a href="https://mbernste.github.io/posts/dot_product/">dot products</a> on these feature vectors compute inner products in the RKHS.</p>

<p>Now, what are these features exactly? As Mercer’s Theorem states, each feature can be constructed by passing $\boldsymbol{x}$ through each $\psi_i$ function. These $\psi_i$ functions are eigenfunctions of a specific operator defined using the kernel $K$. However, in my opinion, the details regarding where these features come from is not as essential for understanding the big picture as understanding that there does exist a (possibly infinite) sequence of features.</p>

<p>In conclusion, we have come to two alternative ways of viewing the feature map:</p>

<ol>
  <li>The feature map $\phi$ maps each object $x \in \mathcal{X}$ to the function $K(x, .)$ in the Hilbert space</li>
  <li>The feature map $\phi$ maps each object $x \in \mathcal{X}$ to a (possibly infinite) coordinate vector $[\psi’_1(x), \psi’_2(x), \dots]$</li>
</ol>

<h2 id="further-reading">Further reading</h2>

<ul>
  <li><a href="https://arxiv.org/pdf/2106.08443">This tutorial</a> by Ghojogh <em>et al.</em> (2021)</li>
</ul>

<h2 id="appendix-proofs-of-properties-of-the-rkhs-and-kernels">Appendix: Proofs of properties of the RKHS and kernels</h2>

<p><span style="color:#0060C6"><strong>Theorem 1 (Convergence of functions implies pointwise convergence):</strong> Given a RKHS, $\mathcal{H}$, the following holds: $ \ \lim_{n \rightarrow \infty} \vert\vert f_n - f \vert\vert_{\mathcal{H}} = 0 \implies \forall x, \ \lim_{n \rightarrow \infty} \vert f_n(x) - f(x) \vert = 0$.</span></p>

<p><strong>Proof:</strong></p>

<p>We assume that we have a convergent sequence of functions that converges on some function $f$. That is, $\lim_{n \rightarrow \infty} |f_n - f|_{\mathcal{H}} = 0$. By the definition of the limit of a sequence, this means that</p>

\[\forall \epsilon_2 &gt; 0, \ \exists N \ \text{such that} \ n &gt; N \implies  \vert\vert f_n - f  \vert\vert_{\mathcal{H}} &lt; \epsilon_2\]

<p>Let’s keep this in mind as we look at the axioms of a RKHS. Specifically, we note that the axiom of RKHS holds that the evaluation functional $\delta_x$ is continuous. The definition of continuity of $\delta_x$ at a given $g \in \mathcal{H}$ is the following:</p>

\[\forall x \in \mathbb{R}, \ \forall \epsilon_1 &gt; 0, \ \exists \Delta &gt; 0 \ \text{such that} \ \forall f, \ \vert\vert f-g \vert\vert_{\mathcal{H}} \lt \Delta \implies \vert\delta_x(f) - \delta_x(g)\vert &lt; \epsilon_1\]

<p>So far, we’ve only written out the definitions of limits and continuity. Now for the actual proof!</p>

<p>Let us fix $\epsilon_1$ above to an arbitrary value and let $\epsilon_2 := \Delta$. Furthermore, let $x$ be a fixed arbitrary value. Then, because we have a convergence sequence of functions, $f_n$, we know that $\exists N \ \text{such that} \ n &gt; N \implies \vert\vert f_n - f \vert\vert_{\mathcal{H}} &lt; \Delta$.</p>

<p>Based on the axiom of the RKHS, this also implies that</p>

\[\vert \delta_x(f_n) - \delta_x(f)\vert &lt; \epsilon_1\]

<p>and thus,</p>

\[\vert f_n(x) - f(x)\vert &lt; \epsilon_1\]

<p>Now, since our choices of $\epsilon_1$ and $x$ were arbitrary, we see that</p>

\[\forall x, \ \forall \epsilon_1 &gt; 0, \ \exists N \ \text{such that} \ n &gt; N \implies \vert f_n(x) - f(x) \vert &lt; \epsilon_1\]

<p>This is the very definition of the limit:</p>

\[\forall x, \ \lim_{n \rightarrow \infty} \vert f_n(x) - f(x) \vert = 0\]

<p>$\square$</p>

<p><span style="color:#0060C6"><strong>Theorem 2: The Riesz representation theorem holds for RKSHs</strong></span></p>

<p><strong>Proof:</strong></p>

<p>The Riesz representation theorem makes a statement about continuous linear functionals on Hilbert spaces. To prove that this holds for RKHSs, we must show that each evaluation functional on an RKHS are <a href="https://mbernste.github.io/posts/matrices_linear_transformations/">linear</a> and continuous. We first show linearity: Let $f, g \in \mathcal{H}$. Then,</p>

\[\begin{align*}\delta_x(f + g) &amp;= (f+g)(x) \\ &amp;= f(x) + g(x) = \delta_x(f) + \delta_x(g)\end{align*}\]

<p>Now let $c \in \mathcal{F}$ be a scalar. Then,</p>

\[\begin{align*}\delta_x(cf) &amp;= cf(x) \\ &amp;= c\delta_x(f)\end{align*}\]

<p>The continuity of the evaluation functionals on an RKHS is true by the definition of an RKHS.</p>

<p>$\square$</p>]]></content><author><name>Matthew N. Bernstein</name></author><category term="tutorial" /><category term="mathematics" /><category term="functional analysis" /><summary type="html"><![CDATA[If you’re a practitioner of machine learning, then there is little doubt you have seen or used an algorithm that falls into the general category of kernel methods. The premier example of such methods is the support vector machine. When introduced to these algorithms, one is taught that one must provide the algorithm with a kernel function that, intuitively, computes a degree of “similarity” between the objects you are classifying. In practice, one can get pretty far with only this understanding; however, to understand these methods more deeply, one must understand a mathematical object called a reproducing kernel Hilbert space (RKHS). In this post, I will explain the definition of a RKHS and exactly how they produce the kernels used in kernel methods thereby laying a rigorous foundation for a deeper understanding of these methods.]]></summary></entry><entry><title type="html">Dot product</title><link href="https://mbernste.github.io/posts/dot_product/" rel="alternate" type="text/html" title="Dot product" /><published>2024-12-09T00:00:00-08:00</published><updated>2024-12-09T00:00:00-08:00</updated><id>https://mbernste.github.io/posts/dot_product</id><content type="html" xml:base="https://mbernste.github.io/posts/dot_product/"><![CDATA[<p><em>The dot product is a fundamental operation on two Euclidean vectors that captures a notion of similarity between the vectors. In this post, we’ll define the dot product and offer a number of angles for which to intuit the idea captured by this fundamental operation.</em></p>

<h2 id="introduction">Introduction</h2>

<p>The <strong>dot product</strong> is a fundamental operation on two Euclidean vectors that captures a notion of similarity between the vectors. It is defined as follows:</p>

<p><span style="color:#0060C6"><strong>Definition 1 (dot product):</strong> Given vectors $\boldsymbol{v}, \boldsymbol{u} \in \mathbb{R}^n$, the <strong>dot product</strong> between these vectors is defined as, $\boldsymbol{v} \cdot \boldsymbol{u} := \sum_{i=1}^n v_i u_i$</span></p>

<p>Despite the simplicity of its definition, the dot product can be understood from a number of <a href="https://mbernste.github.io/posts/understanding_3d/">different perspectives</a>.  Here are four perspectives I find useful for thinking about the dot products. Ordered from the least abstract to the most abstract, these perspectives are:</p>

<ol>
  <li>The dot product succinctly describes a weighted sum</li>
  <li>The dot product describes a a geometric relationship between two Euclidean vectors</li>
  <li>The dot product is an analogy to multiplication between scalars (i.e., plain old multiplication between numbers)</li>
  <li>The dot describes a notion of similarity between two Euclidean vectors</li>
</ol>

<p>These perspectives are described in the remaining sections of this post.</p>

<h2 id="the-dot-product-as-a-weighted-sum">The dot product as a weighted sum</h2>

<p>The least abstract way of viewing a dot product is as a weighted sum of variables. Lets say we have a vector of variables storing some kind of data $\boldsymbol{x}$.  Let’s say we have a vector of weights $\boldsymbol{w}$ and we want to sum the variables in $\boldsymbol{x}$ where each element $x_i$ in $\boldsymbol{x}$ is multiplied by its weight $w_i$ in $\boldsymbol{w}$.  This operation is stated succinctly as $\boldsymbol{w} \cdot \boldsymbol{x}$.</p>

<p>Whenever you find a dot product, it often helps to think about the operation as a sum of variables where each variable is first multiplied by a weight before summed. Which vector describes the “weights” and which the “variables” depends on the context. This perspective is often helpful in machine learning contexts where “weights” are often mutable model parameters and “variables” are fixed pieces of data.</p>

<h2 id="the-dot-product-describes-a-geometric-relationship-between-euclidean-vectors">The dot product describes a geometric relationship between Euclidean vectors</h2>

<p>The dot product uses the relationship between the directions in which the two vectors point.  More specifically, if the two vectors point in a similar direction, the magnitude of the dot product increases.  If they point in drastically different directions, the dot product decreases.  Now, the question becomes: what do we mean by “point in a similar direction?” More specifically, what do we mean by “similar”? The dot product asserts that the angle between the two vectors measures how similarly they point.   The smaller the angle, the larger will be the dot product.</p>

<p>To show how this works, we note that the dot product between two vectors $\boldsymbol{a}$ and $\boldsymbol{b}$, can be computed using the angle between them as follows (see Theorem 1 in the Appendix to this post):</p>

\[\boldsymbol{a} \cdot \boldsymbol{b} = \vert\vert \boldsymbol{a} \vert\vert \vert\vert \boldsymbol{b} \vert\vert \cos \theta\]

<p>If $\theta := 0$, then the two vectors point in the same direction.  In this case, $\cos \theta = 1$ and the dot product reduces to simply computing the product of the two vectors’ magnitudes. If $\theta = \pi / 2$, then the two vectors point in perpendicular directions (i.e. maximally different directions).  We see that $\cos \pi/2 = 0$ and the dot product between the two vectors is zero.</p>

<p>Another way to understand how this works is to look at the projection of one vector onto the other.  That is, given two vectors $\boldsymbol{a}$, $\boldsymbol{b}$, the dot product between these vectors computes the product of the magnitudes of $\boldsymbol{a}$ and $\boldsymbol{b}$ along the direction that the two vectors share. Said differently, the dot product $\boldsymbol{a} \cdot \boldsymbol{b}$ can be viewed as the magnitude of the projection of one of the vectors onto the other vector multiplied by the magnitude of the vector being projected upon.  That is,</p>

\[\begin{align*}\boldsymbol{a} \cdot \boldsymbol{b} &amp;= \vert\vert \text{proj}(\boldsymbol{a}, \boldsymbol{b}) \vert\vert  \vert\vert\boldsymbol{b} \vert\vert \\
&amp;= \vert\vert \text{proj}(\boldsymbol{b}, \boldsymbol{a}) \vert\vert \vert\vert \boldsymbol{a} \vert\vert  \end{align*}\]

<p>If the two vectors are orthogonal, then the projection of either vector onto the other will be zero and thus the dot product will be zero.  In contrast, if two vectors point in the same direction, then the projection of the smaller vector onto the larger vector is simply the smaller vector so we multiply the magnitude of the smaller vector by the magnitude of the larger vector (i.e. simply multiply their norms).</p>

<p>Given this geometric interpretation of the dot product, we can see that taking the dot product of some vector $\boldsymbol{a}$ and a <a href="https://mbernste.github.io/posts/normed_vector_space/">unit vector</a> $\boldsymbol{b}$, finds the length of the projection of $\boldsymbol{a}$ along the axis defined by $\boldsymbol{b}$:</p>

\[\begin{align*} \boldsymbol{a} \cdot \boldsymbol{b} &amp;= \vert\vert \boldsymbol{a} \vert\vert \text{proj}(\boldsymbol{b}, \boldsymbol{a}) \vert\vert \\ &amp;= \vert\vert \text{proj}(\boldsymbol{b}, \boldsymbol{a}) \vert \vert &amp;&amp; \text{because $\vert \vert \boldsymbol{a} \vert\vert = 1$} \end{align*}\]

<p>Thus, whenever one of the vectors in a dot product is a unit vector, the operation can always be viewed as the length of the projection along the axis defined by the unit vector.</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/dot_product_projection.png" alt="drawing" width="300" /></center>

<h2 id="the-dot-product-is-analogous-to-the-product-on-scalars">The dot product is analogous to the product on scalars</h2>

<p>One way to understand the dot product is as an operation on vectors that is analogous to multiplication between scalars.  Given two scalars, $x, y \in \mathbb{R}$, it is obvious that the more we increase the magnitude (i.e. absolute value) of either $x$ or $y$, the more that the magnitude of their product will grow.  The dot product on vectors behaves similarly.  Given two vectors, $\boldsymbol{a}$ and $\boldsymbol{b}$, if we increase the norm of either of the vectors, the magnitude of the dot product increases.  We see this clearly expressed in the $\vert\vert \boldsymbol{a} \vert\vert \vert\vert \boldsymbol{b} \vert\vert $ term of the geometric definition of the dot product:</p>

\[\boldsymbol{a} \cdot \boldsymbol{b} =\vert\vert \boldsymbol{a}\vert\vert \vert\vert  \boldsymbol{b}\vert\vert   \cos \theta\]

<p>However, unlike multiplication between scalars, the dot product between vectors also takes into account the direction in which the two vectors point. The dot product asserts that if the two vectors point in a similar direction, the magnitude of the dot product increases.  If they point in drastically different directions, the dot product decreases.</p>

<p>One feature of multiplication between scalars is that if $x$ and $y$ have opposite signs then $xy &lt; 0$ (for example, $-2 \times 3 = -6$).  Is this feature shared with the dot product? In a way, yes! But we first need to express the concept of ``opposite signs” between two vectors. To do that, note that if the angle between the vectors $\boldsymbol{a}$ and $\boldsymbol{b}$ is obtuse, then their dot product will be negative:</p>

\[\begin{align*} -\frac{\pi}{2} &gt; \theta_{\boldsymbol{a}, \boldsymbol{b}} &gt;- \frac{3\pi}{2} \implies &amp; \cos  \theta_{\boldsymbol{a}, \boldsymbol{b}}  &lt; 0 \\ \implies  &amp; \vert\vert \boldsymbol{a}\vert\vert \vert\vert\boldsymbol{b} \vert\vert \cos  \theta_{\boldsymbol{a}, \boldsymbol{b}} &lt; 0 \\ \implies &amp; \boldsymbol{a} \cdot \boldsymbol{b} &lt; 0 \end{align*}\]

<p>This is visualized below:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/dot_product_acute_obtuse.png" alt="drawing" width="700" /></center>

<p>Thus, two vectors ``have opposite signs”, in context of thinking about the dot product, if the angle between them is greater than $\pi / 2$ and less than $3\pi/4$.</p>

<p>Please note, this is just an <em>analogy</em> – that is, a way to think about the dot product as sharing certain familiar characteristics with multiplication between numbers.</p>

<h2 id="the-dot-product-as-a-notion-of-similarity">The dot product as a notion of similarity</h2>

<p>Lastly, the dot product between two vectors can be thought about as a notion of similarity between two vectors. Recall that the dot product between two vectors can be written in terms of the projection of one vector onto another:</p>

\[\begin{align*}\boldsymbol{a} \cdot \boldsymbol{b} &amp;= \vert\vert \text{proj}(\boldsymbol{a}, \boldsymbol{b}) \vert\vert  \vert\vert\boldsymbol{b} \vert\vert \\
&amp;= \vert\vert \text{proj}(\boldsymbol{b}, \boldsymbol{a}) \vert\vert \vert\vert \boldsymbol{a} \vert\vert  \end{align*}\]

<p>If we think of the magnitude of the projection of one vector onto another as the amount of “directionality” that the two vectors share, then the dot product can be understood as performing the following sequence of calculations:</p>

<ol>
  <li><strong>Alignment:</strong> Calculate the amount of “directionality” shared between them – that is, the length of the projection of one vector onto the other.</li>
  <li><strong>Multiplication:</strong> Multiply this shared directionality</li>
</ol>

<p>That is, the dot product is the magnitude of one vector multiplied by its shared directionality with the other. (Note, it doesn’t matter which vector we project onto the other, the final result will be the same either way.) If the vectors are perpendicular to one another, then they share no directionality and thus, the dot product is zero. In essence, these vectors share nothing in common – they are pointing in completely different directions! No matter how large either vector is, because they don’t share any directionality, their dot product is zero. In this context, what is the interpretation of a negative dot product? I like to think about it as follows: a negative dot product indicates that the two vectors <em>do</em> share some directionality, just with opposite trends. That is, they share a projection along the same line, but they point in opposite directions along that line.</p>

<h2 id="appendix">Appendix</h2>

<p><span style="color:#0060C6"><strong>Theorem 1:</strong> Given $\theta$ is the angle between the two vectors, $\boldsymbol{v}$ and $\boldsymbol{w}$, the following definition for the dot product between $\boldsymbol{v}$ and $\boldsymbol{w}$ is equivalent to Definition 1: $\boldsymbol{v} \cdot \boldsymbol{w} = \vert\vert \boldsymbol{v} \vert\vert \vert\vert \boldsymbol{w} \vert\vert \cos \theta$.</span></p>

<p><strong>Proof:</strong></p>

<p>$\square$</p>]]></content><author><name>Matthew N. Bernstein</name></author><category term="tutorial" /><category term="mathematics" /><category term="linear algebra" /><summary type="html"><![CDATA[The dot product is a fundamental operation on two Euclidean vectors that captures a notion of similarity between the vectors. In this post, we’ll define the dot product and offer a number of angles for which to intuit the idea captured by this fundamental operation.]]></summary></entry><entry><title type="html">Intuiting biology (Part 1: Order and chaos in the crowded cell)</title><link href="https://mbernste.github.io/posts/intuit_biology_goodsell/" rel="alternate" type="text/html" title="Intuiting biology (Part 1: Order and chaos in the crowded cell)" /><published>2024-11-24T00:00:00-08:00</published><updated>2024-11-24T00:00:00-08:00</updated><id>https://mbernste.github.io/posts/intuition_biology_goodsell</id><content type="html" xml:base="https://mbernste.github.io/posts/intuit_biology_goodsell/"><![CDATA[<p><em>Cells are crowded spaces packed with biomolecules colliding and interacting with one another. Despite this chaotic environment, biologists routinely describe intracellular functions using the clean mathematical language of networks. In this post I will attempt to reconcile these two seemingly contradictory perspectives of the cell. This post will serve as a first part in a series of blog posts I hope to write where I will collect and connect some of the works that have helped me better “intuit” biology as a person coming to biology from the field of computer science.</em></p>

<h2 id="introduction">Introduction</h2>

<p>As a computer scientist working in biomedical research, I have had to develop my biology knowledge on the fly. Over the course of this effort, there have been certain concepts, articles, and other bodies of work that led to notable step-functions in my ability to “intuit” biological systems. (Given the staggering complexity of biology, my use of the word “intuit” is quite strained here, but I digress). In this series of blog posts, I will collect some of these works and attempt to connect them together conceptually. I hope this series serves others who are on a similar journey as I am!</p>

<p>In the first post of this series, I will attempt to tie two seemingly contradictory ideas together:</p>

<ol>
  <li>Cells are densely packed, chaotic places</li>
  <li>We can describe cellular processes using the clean language of <a href="https://en.wikipedia.org/wiki/Biological_network">graphs/networks</a></li>
</ol>

<p>I will begin by discussing the visual depictions of cells by <a href="https://en.wikipedia.org/wiki/David_Goodsell">David S. Goodsell</a>. Dr. Goodsell is a structural biologist at the Scripps Research Institute and Rutgers University and is well known for his scientifically accurate depictions of cells and the molecules that they are comprised of. <a href="https://ccsb.scripps.edu/goodsell/">His work</a> is both educational and beautiful and helped expand my understanding of biology.</p>

<p>As I will discuss, Dr. Goodsell’s work highlights how densely packed, and seemingly chaotic the environments within cells are. The packed, chaotic nature of these environments seems to contradict the fact that biologists routinely describe cells using the clean, mathematical language of networks and graphs as used in the subfied of <a href="https://en.wikipedia.org/wiki/Biological_network">Network Biology</a>. In this post, I will attempt to connect and renconcile these two perspectives of the cell.</p>

<h2 id="cells-are-absolutely-packed">Cells are absolutely packed</h2>

<p>What struck me most from Dr. Goodsell’s work is how densely packed cells really are. Below is an example illustration that depicts the density of the cell:</p>

<center><img src="https://cdn.rcsb.org/pdb101/goodsell/tif/vascular-endothelial-growth-factor-vegf-signaling.tif" alt="drawing" width="1000" /></center>
<center><sup><span style="color:#b8b4b4">Acknowledgement: Illustration by David S. Goodsell, RCSB Protein Data Bank. doi: 10.2210/rcsb_pdb/goodsell-gallery-041</span></sup></center>

<p>This is a far cry from the image I had previously held of cells as a little, uncrowded “bags of water”. I believe this incorrect mental model was fomented in my mind by images like the following that are meant to teach the organelles found in the cell:</p>

<center><img src="https://upload.wikimedia.org/wikipedia/commons/4/4b/Cell-organelles-labeled.png" alt="drawing" width="400" /></center>
<center><sup><span style="color:#b8b4b4">Acknowledgement: Bingbongboing as found on Wikipedia (https://en.wikipedia.org/wiki/Cellular_compartment#/media/File:Cell-organelles-labeled.png)</span></sup></center>

<p>Though images like this are effective for teaching the different organelles found in cells, they imply that cells are empty (at least that was my impression). This was a misconception that I had held for a good portion of my computational biology journey.</p>

<p>A second thing that struck me was the interplay between order and chaos that exists within cells. As an example, see this illustration by Goodsell depicting the coronovirus lifecycle:</p>

<center><img src="https://cdn.rcsb.org/pdb101/goodsell/png-800/coronavirus-life-cycle.png" alt="drawing" width="600" /></center>
<center><sup><span style="color:#b8b4b4">Acknowledgement: David S. Goodsell, RCSB Protein Data Bank; doi: 10.2210/rcsb_pdb/goodsell-gallery-023. Integrative illustration for coronavirus outreach (2020) PLoS Biol 18: e3000815 doi: 10.1371/journal.pbio.3000815</span></sup></center>

<p>Notice that despite the messy distribution of proteins and other biomolecules, clear structures form (membranes, gradients, etc.). In the above picture we see, emerging from the chaos, the insidious formation of new viruses!</p>

<h2 id="how-order-emerges-from-the-chaos-connections-to-network-biology">How order emerges from the chaos: connections to network biology</h2>

<p>If you work in computational biology, there is little doubt that you have been exposed to concepts found in <a href="https://en.wikipedia.org/wiki/Biological_network">network biology</a>. A common way to describe the biochemical processes that occur in cells is to depict and model these processes mathematically as networks or graphs. For example, protein-protein interaction networks are networks in which proteins form the nodes of the graph and an edge between two proteins indicates that those two proteins interact or bind with one another. Such protein interaction networks form <a href="https://en.wikipedia.org/wiki/List_of_signalling_pathways">signaling pathways</a> in which the information-flow through the cell is mediated by cascading interactions between proteins and other molecules. For example, below is a depiction of the <a href="https://www.wikipathways.org/pathways/WP231.html">TNF-alpha signalling pathway</a>, which is a signalling pathway used to modulate immune cell function:</p>

<center><img src="https://www.wikipathways.org/wikipathways-assets/pathways/WP231/WP231.png" alt="drawing" width="1500" /></center>
<center><sup><span style="color:#b8b4b4">Acknowledgement: Agrawal A, et al. (2024) WikiPathways 2024: next generation pathway database. NAR.</span></sup></center>

<p>Despite the cells being so crowded, the network model has proved to still be a powerful model for describing cellular functions. Why is this exactly?</p>

<p>As far as I understand, there are two “competing” phenomenon that are occuring within the cell that leads to the network model:</p>

<ol>
  <li>Molecules move extremely fast inside cells leading to many opportunities for interaction</li>
  <li>Interactions will rarely occur between two biomolecules</li>
</ol>

<p>First, despite the cell being so crowded, molecules move extremely fast within the cell and this gives them a lot of opportunity to meet one another. Our intuition, based on the macroscale in which we live, breaks down at at the subcellular scale. To try to grow some new intuition, let’s look at the numbers.</p>

<p>According to <a href="https://book.bionumbers.org/what-are-the-time-scales-for-diffusion-in-cells/#:~:text=As%20derived%20in%20Figure%201,show%20another%20order%20of%20magnitude">Ron Miley and Rob Phillips in <em>Cell Biology by the Numbers</em></a>, it takes about <strong>10 seconds for a protein to traverse a <a href="https://en.wikipedia.org/wiki/HeLa">HeLa cell</a></strong>. When considering the fact that every protein is moving in this manner, we realize that every few seconds, the arangement of proteins in the cell has completely changed (subject to compartmenalization within the cell).</p>

<p>This rearrangement is even more extreme for small molecules. In fact, in an <em>e. coli</em> bacterium, one can expect that <a href="https://book.bionumbers.org/how-many-reactions-do-enzymes-carry-out-each-second/"><strong>every small molecule will collide with <em>every single protein</em> once per second</strong></a>! (Pretty mind boggling.) Now, Eukaryotic cells are much larger, but we can still expect a small molecule to meet every single protein within a short span of time.</p>

<p>Thus, if we have two biomolecules in the cell that <em>can</em> bind/interact (including proteins), we can deduce that they inevitably <em>will</em> interact within a fairly short span of time (again, subject to their compartmentalization within the cell). Thus, as far as I understand, the deterministic picture portrayed by a network diagram is actually fairly accurate!</p>

<p>Now, with this in mind, wouldn’t one expect many aberrant reactions and/or electrostatic binding? The answer is “no” for two main reasons: First, when considering small molecules in the cell, most chemical reactions have a very high <a href="https://en.wikipedia.org/wiki/Activation_energy">activation energy</a> and require a <a href="https://en.wikipedia.org/wiki/Catalysis">catalyst</a>, such as an <a href="https://en.wikipedia.org/wiki/Enzyme">enzyme</a>, to occur. Thus, though small molecules are colliding with eachother constantly, they are unlikely to interact. Second, when considering proteins (i.e., larger biomolecules), the targets that proteins bind to are <a href="https://elifesciences.org/articles/60924">hyperspecific</a>. Thus, the specific pairs (or combinations) of proteins that actually interact with one another is an <em>extremely</em> small fraction of all possible pairs.</p>

<p>Putting this all together at a (very) high-level: molecules move extremely fast in the cell and will (roughly) meet every other molecule in the cell within a short span of time; however, the pairs of molecules that actually interact when they meet is extremely small. From these two facts, we gain a sense for how the network model emerges despite the packed and chaotic the environment within cells!</p>

<h2 id="related-reading">Related Reading</h2>
<ul>
  <li><a href="https://ccsb.scripps.edu/goodsell/">David S. Goodsell’s homepage</a></li>
  <li><a href="https://book.bionumbers.org"><em>Cell Biology by the Numbers</em> by Ron Miley and Rob Phillips</a></li>
  <li><a href="http://www.righto.com/2011/07/cells-are-very-fast-and-crowded-places.html">This blog post by Ken Shirriff</a></li>
  <li><a href="https://bionumbers.hms.harvard.edu/aboutus.aspx">The BioNumbers database</a></li>
</ul>]]></content><author><name>Matthew N. Bernstein</name></author><category term="biology" /><summary type="html"><![CDATA[Cells are crowded spaces packed with biomolecules colliding and interacting with one another. Despite this chaotic environment, biologists routinely describe intracellular functions using the clean mathematical language of networks. In this post I will attempt to reconcile these two seemingly contradictory perspectives of the cell. This post will serve as a first part in a series of blog posts I hope to write where I will collect and connect some of the works that have helped me better “intuit” biology as a person coming to biology from the field of computer science.]]></summary></entry><entry><title type="html">Notes on _The Art of War_ by Sun Tzu</title><link href="https://mbernste.github.io/posts/art_of_war_part1/" rel="alternate" type="text/html" title="Notes on _The Art of War_ by Sun Tzu" /><published>2024-11-16T00:00:00-08:00</published><updated>2024-11-16T00:00:00-08:00</updated><id>https://mbernste.github.io/posts/Art_of_War_part1</id><content type="html" xml:base="https://mbernste.github.io/posts/art_of_war_part1/"><![CDATA[<p><em>I am currently reading Sun Tzu’s Art of War and am finding much wisdom in it. I have been taking notes during my reading and I thought I’d share them in this post. Here I cover Books 1 and 2.</em></p>

<h2 id="introduction">Introduction</h2>

<p>I am currently reading Sun Tzu’s <a href="https://en.wikipedia.org/wiki/The_Art_of_War"><em>Art of War</em></a> and am finding much wisdom in it. I have been taking notes during my reading and I thought I’d share them in this post.</p>

<p>As a brief background, <em>The Art of War</em> is a treatis on military conflict and strategic assessment written by a Chinese general somewhere around 500 B.C. It is an enduring text because the wisdom that it contains generalizes far beyond military conflict. A small sample of the themes in the book include treating conflict as a scientific discipline, strategic advantage through information, and the principles of good leadership.</p>

<p>In this post, I’ll cover the first two “chapters” of the book. I’ll write Sun Tzu’s original text in bold (as translated by Thomas Cleary). My commentary on each point will follow.</p>

<h2 id="book-1-strategic-assessments">Book 1: Strategic Assessments</h2>

<p><strong>Military action is important to the nation – it is the ground of death and life, the path of survival and destruction so it is imperative to examine.</strong></p>

<ul>
  <li>Generalizing “military action” to mean any form of “conflict”, Sun Tzu stresses the importance of prioritizing the skill of navigating conflict. It bears stating that the aftermath of conflict can be incredibly consequential whether that be at the nation-level or the individual-level, and so navigating conflict is a skill that is important to study. 
Conflict is a discipline and like any discipline, it can be studied scientifically. Here Sun Tzu sets the stage for the rest of the book where he will provide a rigorous, analytical, almost mathematical exploration into the rules and logic of conflict.</li>
</ul>

<p><strong>Therefore measure in terms of five things, use these assessments to make comparisons, and thus find out what the conditions are. The five things are the way, the weather, the terrain, the leadership, and discipline.</strong></p>

<ul>
  <li>From the outset, Sun Tzu quantifies the core components one needs to gain an advantage in conflict. There are exactly five things: a quantity. 
This very first suggestion addresses the act of “assessment”. This will remain a key theme throughout the book. To gain an advantage, one must acquire information and leverage that information. Gathering intelligence and assessing it objectiveluy to gain an advantage is the key to victory.</li>
</ul>

<p><strong>The Way means inducing the people to have the same aim as the leadership, so that they will share death and life, without fear of danger.</strong></p>

<ul>
  <li>Of the five components one should use to assess themselves and their opponent, it is interesting that the very first involves leadership. According to Sun Tzu, the essence of good leadership is the ability to align the goals of the rank and file with the goals of the organization at large. However, it goes beyond mere “alignment”; the word “share” here feels more intimate. The goals of the organization must be “shared” by the leaders and the rank-and-file alike – that is, they are to be held in common in an essential way.
Another lesson so eloquently stated here is that the leaders and the rank and file must share in the consequences of conflict. You cannot have the rank and file lose while the leaders escape unscathed. The leaders and their followers must share in life and death. Any organization for which the leaders don’t share in the consequences are doomed to ineffectiveness and defeat. 
Lastly, a key aspect to good leadership is that of instilling courage (“without fear of danger”). An organization that is setting out with a bold, risky vision must be courageous if they are to stand any chance at achieving their vision. Good leaders instill that courage into the bones of the organization so that it is shared by all levels within the organization: from the top down.</li>
</ul>

<p><strong>The weather means the seasons.</strong></p>

<ul>
  <li>I interpret this to mean that one should pay attention to the cyclical nature of things and assess whether the current moment gives you or your opponent the upper hand. In military operations, this could literally mean the season; for example, an army ill equipped for winter fighting should wait for summer. It could also mean the time of day. In recent military engagements, the U.S. has carried out operations against less technologically advanced opponents (like Al Qaeda) at night when they could rely on their night vision tech to gain an advantage. In business, one might pay attention to the economic cycle or the political climate to assess whether now is a moment of advantage or disadvantage.</li>
</ul>

<p><strong>The terrain is to be assessed in terms of distance, difficulty or ease of travel, dimension, and safety.</strong></p>

<ul>
  <li>The terrain plays a huge role in modern military operations. Both the Soviet Union and the United States failed to achieve their objectives in their respective occupations of Afghanistan. A huge reason for this is the complex terrain, which was known thoroughly by the locals, but was foreign to the occupying forces.</li>
</ul>

<p><strong>Leadership is a matter of intelligence, trustworthiness, humaneness, courage, and sternness.</strong></p>

<ul>
  <li>These five qualities of good leaders echo some of the major themes of the book. First, a good leader has “intelligence”. A key theme of the book is that of gaining victory by superior use of information. The next two qualities, “trustworthiness” and “humaneness”, preview another major theme of the book: one should not seek violence. Rather, to achieve victory one must also have the upper moral hand. The fourth quality, “courage”, echoes The Way: one must act without fear of death. Only the very last quality, “sternness”, do we have a quality related to dominance or (vaguely) violence.</li>
</ul>

<p><strong>Discipline means organization, chain of command, and logistics.</strong></p>

<ul>
  <li>Prior to this point, Sun Tzu has discussed the ingredients of gaining an advantage – namely good leadership (“the Way” and “leadership”) as well as the external conditions (the “weather” and “terrain”); however, he has not addressed how one can translate these advantages into victory. To translate these advantages into victory, one must master the details. That is, one must be organized, have a clear chain of command, and master their logistical operations. So many organizations lack these skills and even if they have many advantages, they are too disorganized to translate those advantages into victory.</li>
</ul>

<p><strong>Therefore use these assessments for comparison, to find out what the conditions are. That is to say, which political leadership has the Way? Which general has ability? Who has the better climate and terrain? Whose discipline is effective? Whose troops are stronger? Whose officers and soldiers are better trained? Whose system of rewards and punishments is clearer? This is how you can know who will win.</strong></p>

<ul>
  <li>Interestingly, Sun Tzu suggests that one can know the victor before any conflict occurs. The statement, “This is how you can know who will win” is extremely certain. He does not say “who will probably win.” He says, “who will win.” This is another key theme that will pervade throughout the book: one should use information to assess the situation, arrive at a point of clarity, and strike decisively!</li>
</ul>

<p><strong>Assess the advantages in taking advice, then structure your forces accordingly, to supplement extraordinary tactics. Forces are to be structured strategically, based on what is advantageous.</strong></p>

<ul>
  <li>I interpret this point to mean that one should structure their resources (“forces”) based on the intelligence they receive. One should adapt their strategy to the circumstances. Again, Sun Tzu emphasizes the importance of using intelligence to gain an advantage.</li>
</ul>

<p><strong>A military operation involves deception. Even though you are competent, appear to be incompetent. Though effective, appear to be ineffective.</strong></p>

<ul>
  <li>In any conflict, you should not telegraph your capabilities to your opponent. This holds in military conflict, but generalizes to conflict as a whole. There are two reasons for this. First, if you telegraph your capabilities, your opponent can use that information to gain the advantage. Second, if you act incompetent, but are actually competent, and your opponent believes this lie, then they will attack when they have the disadvantage. You can thus bait them into a fight they cannot win.</li>
</ul>

<p><strong>When you are going to attack nearby, make it look as if you are going to go a long way; when you are going to attack far away, make it look as if you are going just a short distance.</strong></p>

<ul>
  <li>This point emphasizes the importance of destabilizing the information that your opponent receives. Because of how critical information assessment is to gaining an advantage in conflict, it follows that an effective strategy for gaining an advantage is to spoil the information received by your opponent.</li>
  <li>In athletic competition, deception is commonly employed. In baseball, the pitcher throws different types of pitches (e.g., fast balls or curve balls) to throw off the batter. In basketball, the pump fake is very literally the act of “making it look as if you are going to go a long way” when actually going “nearby”. In amateur wrestling, there are many moves that require deceiving your opponent into thinking you are making one move, but really are going for another. Examples abound.</li>
</ul>

<p><strong>Draw them in with the prospect of gain, take them by confusion.</strong></p>

<ul>
  <li>Here again, Sun Tzu emphasizes the use of deception, but he goes a bit further with this line of thinking. Rather than simply deceiving one’s opponent with false information, here he emphasizes leveraging one’s opponent’s <em>natural emotional responses</em> against them. Here specifically, that emotion is greed; one should use greed to bait one’s opponent. 
The phrase, “take them by confusion” emphasizes the critical point in time when your opponent realizes that their understanding of reality is incorrect. It is in that critical moment of realization, and confusion, that you possess an advantage and should strike decisively.</li>
</ul>

<p><strong>When they are fulfilled, be prepared against them; when they are strong, avoid them.</strong></p>

<ul>
  <li>Don’t start a fight you can’t finish</li>
  <li>Here Sun Tzu emphasizes the criticality of assessing when you do or don’t have the advantage. If you don’t have the advantage, then you must avoid conflict and work tirelessly to gain the advantage for at that point you are vulnerable.</li>
</ul>

<p><strong>Use anger to throw them into disarray.</strong></p>

<ul>
  <li>In a previous point, Sun Tzu suggested using your opponent’s greed against them (“Draw them in with the prospect of gain”) Here he suggests also using their anger against them. Altogether, he emphasizes a comprehensive strategy of not just deceiving your opponents with false information (“a military operation involves deception”), but also toying with their emotions. 
Our tendency to get carried away by emotion is an inherent vulnerability within all people. Just as we should take advantage of this vulnerability in others, so too must we guard against it within ourselves. That is, we should seek to regulate our emotional responses and assess situations logically. The more you are carried away by emotion, the worse your decision making will be, and a clever, more rational opponent will use this to gain an advantage over you.</li>
</ul>

<p><strong>Use humility to make them haughty.</strong></p>

<ul>
  <li>Again, this follows the theme of using your opponents emotions against them</li>
  <li>A haughty opponent is likely to be more careless and less prepared. This advice pairs well with the previous advice, “Even though you are competent, appear to be incompetent”.
Overall the three emotions that Sun Tzu says can be used against your opponent are greed, anger, and arrogance. Consequently, if you possess these emotions yourself, you are vulnerable!</li>
</ul>

<p><strong>Tire them by flight.</strong></p>

<ul>
  <li>For me, the book title, “The Art of War”, suggested that it will be about fighting and force, but in fact, it is all about gaining an advantage over your opponent before you actually fight the battle. So far in the book, Sun Tzu has taught us that this advantage can be gained by 1) Good leadership and internal organization, 2) Gathering and assessment of intelligence, 3) Information warfare and deception, 4) manipulating your opponent’s emotions and now, 4) wearing your opponent down by being faster than them!</li>
</ul>

<p><strong>Cause division among them.</strong></p>

<ul>
  <li>This plays into the theme of emotional warfare – that is, toying with your enemies emotions. Whereas prior advice was directed at the individual (e.g., anger is felt individually), here the emotional warfare is directed at the interpersonal. That is, to destroy the opponent’s structural unity by breaking down the relationships between people.</li>
  <li>This follows logically from an earlier piece of advice on The Way. That is, Sun Tzu taught that “The Way means inducing the people to have the same aim as the leadership, so that they will share death and life, without fear of danger.” The logical consequence is that one should seek to destroy The Way among one’s opponents.</li>
</ul>

<p><strong>Attack when they are unprepared, make your move when they do not expect it.</strong></p>

<ul>
  <li>Essentially this boils down to using the “element of surprise.” The use of surprise attacks is a somewhat well-known, tired trope, but here is given a much richer context by Sun Tzu’s other advice. The use of surprise is another example of assessing your opponent’s mentality and using it against them.</li>
</ul>

<p><strong>The formation and procedure used by the military should not be divulged beforehand</strong></p>

<ul>
  <li>This harkens to the prior advice of not divulging one’s capabilities, “Even though you are competent, appear to be incompetent. Though effective, appear to be ineffective”. Here, this is extended not just to capabilities, but to tactics as well. Don’t reveal your hand.</li>
</ul>

<p><strong>The one who figures on victory at headquarters before even doing battle is the one who has the most strategic factors on his side. The one who figures on inability to prevail at headquarters before doing battle is the one who has the least strategic factors on his side. The one with many strategic factors in his favor wins, the one with few strategic factors in his favor loses – how much the more so for one with no strategic factors in his favor. Observing the matter this way, I can see who will win and who will lose.</strong></p>

<ul>
  <li>This final concluding piece emphasizes not only the importance of gaining the strategic advantage before the fighting starts, but also on the mere fact that it is actually possible to discern who will win and lose. Conflict is a science, and through mastery of that science, one can learn when to fight and when not to fight.</li>
</ul>

<h2 id="book-2-doing-battle">Book 2: Doing Battle</h2>

<p><strong>When you do battle, even if you are winning, if you continue for a long time it will dull your forces and blunt your edge; if you besiege a citadel, your strength will be exhausted. If you keep your armies out in the field for a long time, your supplies will be insufficient.</strong></p>

<ul>
  <li>A central theme of this second chapter is the notion that one should avoid a long and drawn out conflict. This ties into the themes of the first chapter around using strategic planning prior to conflict in that if you successfully prepare for the conflict, and through that preparation learn that you have the advantage, then the conflict itself should not last long because you have prepared a strategy and can execute it quickly.</li>
</ul>

<p><strong>When your forces are dulled, your edge is blunted, your strength is exhausted, and your supplies are gone, then others will take advantage of your debility and rise up. Then even if you have wise advisers you cannot make things turn out well in the end.</strong></p>

<ul>
  <li>Not only should one consider the current conflict at hand, but should always keep in mind the conflicts yet to come. If you waste your resources on the current problem, you will be ill prepared for the next problem. This also touches the idea of “pacing oneself”. When running a long distance race, one needs to conserve their energy for the later miles. So too, when facing a problem, one should consider the future problems that will demand future resources. Said more succinctly, one should always remain future minded.</li>
  <li>A further generalization of this advice is to not be too narrowly focused on the problem at hand, but rather, one should stay aware of their surroundings and keep an eye on other issues that may crop up. If your full concentration is on the current problem, then other problems can catch you unaware. Said more succinctly, one should always remain aware of their surroundings.</li>
  <li>Sun Tzu mentions “wise advisors” here, which harkens to the prior theme of using information effectively. Here Sun Tzu says that if you have exhausted your resources, no amount of planning will help you! Thus, in order to use information effectively and form a strategy you need to make sure you have the resources to execute that strategy.</li>
</ul>

<p><strong>Therefore I have heard of military operations that were clumsy but swift, but I have never seen one that was skillful and lasted a long time. It is never beneficial to a nation to have a military operation continue for a long time.</strong></p>

<ul>
  <li>One can translate this advice into a mathematical axiom: if conflict is long and drawn out, then it is not effective. This axiom provides a nice rule of thumb for ruling out certain strategies: any strategy that will require a long, drawn out conflict is not a viable option.</li>
  <li>It is interesting how even in this modern day, governments have not learned this lesson! The United States wasted so many precious resources – lives, money, and time – on its operation in Afghanistan. When we invaded Afghanistan, we did not adequately ensure that the conflict would be swift and effective. We had no end state in mind! Because of this, most would consider our decades long war in Afghanistan a failure.</li>
</ul>

<p><strong>Therefore, those who are not thoroughly aware of the disadvantages in the use of arms cannot be thoroughly aware of the advantages in the use of arms.</strong></p>

<ul>
  <li>I interpret this to mean that one should understand that engaging in a conflict risks that conflict becoming a quagmire. Without understanding this risk, one has not adequately prepared and does not truly have a sound strategy to be effective. The United States did not prepare adequately for this outcome when they invaded Afghanistan.</li>
  <li>A more general interpretation of this advice is that if one does not know the disadvantages of the use of force, then fundamentally, one cannot yield force effectively because they do not understand its true nature. The use of force is a dangerous and serious affair. A lack of respect for it shows a general lack of understanding of it and thus, an amateurish mentality.</li>
</ul>

<p><strong>Those who use the military skillfully do not raise troops twice and do not provide food three times.</strong></p>

<ul>
  <li>If you have a problem, deal with it once and deal with it thoroughly.</li>
  <li>This ties into Sun Tzu’s prior advice on planning strategically. You should plan a strategy that completely solves your problem. You don’t want your incompletely solved problem to keep cropping up over and over again because you failed to deal with it the first time.</li>
  <li>More generally, don’t “half ass” your work. Finish it completely and finish it well.</li>
</ul>

<p><strong>By taking equipment from your own country but feeding off the enemy you can be sufficient in both arms and provisions.</strong></p>

<ul>
  <li>I interpret “feeding off the enemy” to mean loot food provisions from the enemy so that you don’t have to bring them yourself. There is an element of ruthlessness to this advice that I believe generalizes beyong just stealing provisions. When it comes to a fight, you should not hold back from taking what you need from your opponent. To win you must be ruthless and take what you need. This relates to the prior advice, “Those who use the military effectively do not raise troops twice.” Deal with the enemy once and thoroughly.</li>
  <li>There is a lesson here in regards to planning; one should not make unnecessary preparations. Here, Sun Tzu is saying that you should not prepare your own food because you can simply take food from your enemy once you defeat them. Preparing food yourself is an unnecessary preparation. Thus, one should discern between resources one needs versus resources that can be acquired later. Energy shouldn’t be expended on unnecessary preparations.</li>
  <li>There is an adage in software engineering attributed to Donal Knuth: “Premature optimization is the root of all evil”. Sun Tzu teaches a similar lesson here: prematurely optimizing (i.e., preparing resources that are not yet needed) is an ineffective allocation of energy and should be avoided at all costs.</li>
  <li>Relying on feeding your troops with your enemy’s provisions implies that you will defeat them. Thus, Sun Tzu is so certain of victory that he is willing to risk starvation. This ties in to prior advice on strategic assessment prior to conflict and the importance of knowing whether you have the advantage, “Therefore use these assessments for comparison, to find out what the conditions are…This is how you can know who will win”. That is, if you prepare adequately and know that you will win, you can strategically depend on that victory to setup your next set of actions (e.g., to “feed off the enemy”).</li>
</ul>

<p><strong>When a country is impoverished by military operations, it is because of transporting supplies to a distant place. Transport supplies to a distant place, and the populace will be impoverished.</strong></p>

<ul>
  <li>Here again, Sun Tzu emphasizes building a proper organizational foundation before starting a conflict. Generalizing this beyond conflict, it is critical to set up the proper logistics and organization foundation before embarking on a challenge.</li>
</ul>

<p><strong>Those who are near the army sell at high prices. Because of high prices the wealth of the common people is exhausted.</strong></p>

<ul>
  <li>A theme highlighted here is that of unintended consequences. Sun Tzu warns of the unintended economic consequences that result from war.</li>
  <li>Another theme that begins to emerge in the book is compassion for common people. He is not only focused on the success of the military campaign, but also on the well being of the country as a whole. We too should always keep in mind the consequences our actions have on others. We should not be too focused on our goals that we neglect the well being of others.</li>
</ul>

<p><strong>When resources are exhausted, then levies are made under pressure. When power and resources are exhausted, then the homeland is drained. The common people are deprived of seventy percent of their budget, while the government’s expenses for equipment amount to sixty percent of its budget.</strong></p>

<ul>
  <li>Two recurring themes are echoed in this point. First, we see Sun Tzu’s analytical and quantitative approach towards strategic assessment. Second, we see again his emphasis on looking out for the well-being of society and considering unintended consequences of one’s actions on others.</li>
</ul>

<p><strong>Therefore a wise general strives to feed off the enemy. Each pound of food taken from the enemy is equivalent to twenty pounds you provide by yourself.</strong></p>

<p>This is a very quantitative  He very literally quantifies the benefit of “feeding off the enemy” versus preparing provisions.</p>

<p><strong>So what kills the enemy is anger, what gets the enemy’s goods is reward.</strong></p>

<ul>
  <li>In prior points (e.g., “Use anger to throw them into disarray”), Sun Tzu had stressed the importance of manipulating the enemy’s emotions and using their impulsive, emotional responses against them. Here, in contrast, he emphasizes utilizing the emotions of one’s own troops and direct reports.</li>
  <li>Emotions are an extremely powerful force and should be leveraged to one’s strategic advantage whether that be leveraging the emotions of the opposition or of one’s own allies.</li>
</ul>

<p><strong>Therefore in a chariot battle, reward the first to capture at least ten chariots.</strong></p>

<ul>
  <li>Sun Tzu emphasizes the importance of using reward to motivate. Interestingly, he does not mention the use of punishment. The emphasis on using reward instead of punishment is, in fact, a fairly recent development in both parenting and animal training. Scientifically, it has been shown that reward is more effective than punishment  in inducing a target behavior. It is interesting to see Sun Tzu identify this truth so long ago.</li>
  <li>Here is another example of the use of quantitative analysis (he subscribes the reward to be exactly ten chariots).</li>
</ul>

<p><strong>Change their colors, use them mixed in with your own. Treat the soldiers well, take care of them.</strong></p>

<ul>
  <li>Turning one’s enemies into allies is a better strategy than defeating them. There is this repeated core idea of achieving victory, but avoiding violence. One should seek persuasion rather than conflict.</li>
  <li>After turning enemies into allies one should must also treat them well. One should avoid holding grudges.</li>
  <li>Sun Tzu’s use of the term “mixed in with your own” seems to suggest the importance of enabling these enemies-turned-allies to retain their original identity. Sun Tzu does not seem to be saying that they should be “converted” or transformed into one’s own identity nor should they be separated from one’s own group.</li>
</ul>

<p><strong>This is called overcoming your opponent and increasing your strength to boot.</strong></p>

<ul>
  <li>Turning one’s enemies into allies is a way to not only claim victory, but also to strengthen one’s position simultaneously. In contrast, a direct conflict requires resources and will inevitably weaken you.</li>
</ul>

<p><strong>So the important thing in a military operation is victory, not persistence.</strong></p>

<ul>
  <li>Victory must be achieved quickly and decisively. A long drawn out campaign will lead to defeat</li>
  <li>An alternative interpretation is that what matters is not effort, but outcome. It doesn’t matter how much effort you’ve expended, only the outcome matters. I believe that a common mistake is to overvalue hard work, when what matters more is the outcome.</li>
</ul>

<p><strong>Hence, we know that the leader of the army is in charge of the lives of the people and the safety of the nation.</strong></p>

<ul>
  <li>This is a re-emphasis on a leader’s responsibility to not only achieve their target goal, but to also tend to the well being of their people. It can be also be interpreted as a re-echo of a common theme I see woven throughout the book: consideration for the wider society.</li>
</ul>

<h2 id="book-3-planning-a-siege">Book 3: Planning a Siege</h2>

<p><strong>The general rule for use of the military is that it is better to keep a nation intact than to destroy it. It is better to keep an army intact than to destroy it, better to keep a division intact than to destroy it, better to keep a battalion intact than to destroy it, better to keep a unit intact than to destroy it.</strong></p>

<ul>
  <li>Sun Tzu lays out a fundamental principal: Avoid destruction. Not only is this an ethical and moral principal, but a strategically advantageous one as well.</li>
</ul>

<p><strong>If you can keep the opponent’s nation intact, then your own nation will also be intact. So this is best.</strong></p>

<ul>
  <li>Destruction requires the expenditure of energy and will weaken one’s self. Destruction will evoke a violent and commensurate response from one’s enemy and thus incurs risk. Destruction also seeds an enemy’s anger and may galvanize the enemy.</li>
</ul>

<p><strong>This means that killing is not the important thing.</strong></p>

<ul>
  <li>In a conflict, the most important thing is the outcome, not necessarily the means by which that outcome is achieved. The act of killing, and destruction more generally, is  an extreme and consequential act; it should be treated with the utmost respect.</li>
</ul>

<p><strong>Therefore those who win every battle are not really skillful—those who render others’ armies helpless without fighting are the best of all.</strong></p>

<ul>
  <li>A core theme of the book is stated explicitly here: Dominance is achieved before a conflict begins. In fact, utter dominance usually leads to an avoidance of direct conflict altogether.</li>
</ul>

<p><strong>Therefore the superior militarist strikes while schemes are being laid.</strong></p>

<ul>
  <li>I see two ways to interpret this point: The first interpretation is that one should strike first and quickly when your opponent least expects it (i.e. during the planning stage of a conflict); however, this seems to contradict the prior point that “those who render others’ armies helpless without fighting are the best of all”. Thus, a second interpretation that is more consistent with the prior point is that victory itself is achieved through superior preparation and scheming – Winning by rendering your opponent helpless and unable to fight before they were aware of what you were up to.</li>
</ul>

<p><strong>The next best is to attack alliances.</strong></p>

<ul>
  <li>If victory cannot be achieved through superior scheming, then the next best path to victory is to disrupt the opponent’s alliances, or more generally, their support system.</li>
  <li>This strategy is inferior to winning through scheming (i.e. the prior point) because it risks revealing information. This runs counter to Sun Tzu’s teachings throughout the book in which he emphasizes the importance of keeping ones cards close to the chest. The act of disrupting alliances is at outward, active act that may reveal portions of ones strategy.</li>
</ul>

<p><strong>The next best is to attach the army.</strong></p>

<ul>
  <li>I interpret this to mean “direct conflict”.</li>
  <li>One should only engage in direct conflict only when the prior strategies cannot be executed on (maneuvering, scheming, and disrupting the opponent’s support system).</li>
  <li>The fact that direct conflict is ranked as only the third best strategy emphasizes the earlier point that, “Those who render others’ armies helpless without fighting are the best of all.”</li>
</ul>

<p><strong>The lowest is to attack a city. Siege of a city is only done as a last resort.</strong></p>]]></content><author><name>Matthew N. Bernstein</name></author><category term="book review" /><summary type="html"><![CDATA[I am currently reading Sun Tzu’s Art of War and am finding much wisdom in it. I have been taking notes during my reading and I thought I’d share them in this post. Here I cover Books 1 and 2.]]></summary></entry><entry><title type="html">Denoising diffusion probabilistic models (Part 2: Theoretical justification)</title><link href="https://mbernste.github.io/posts/diffusion_part2/" rel="alternate" type="text/html" title="Denoising diffusion probabilistic models (Part 2: Theoretical justification)" /><published>2024-10-20T00:00:00-07:00</published><updated>2024-10-20T00:00:00-07:00</updated><id>https://mbernste.github.io/posts/diffusion_part2</id><content type="html" xml:base="https://mbernste.github.io/posts/diffusion_part2/"><![CDATA[<p><em>In <a href="https://mbernste.github.io/posts/diffusion_part1/">Part 1</a> of this series, we introduced the denoising diffusion probabilistic model for modeling and sampling from complex distributions. We described the diffusion model as a model that can generate new samples by learning how to reverse a diffusion process. In this post, we provide more theoretical justification for the objective function used to fit diffusion models and make connections between the diffusion model and other concepts in statistical inference and probabilistic modeling.</em></p>

<h2 id="introduction">Introduction</h2>

<p>In <a href="https://mbernste.github.io/posts/diffusion_part1/">Part 1</a> of this series, we introduced the denoising diffusion probabilistic model for modeling and sampling from complex distributions. As a brief review, diffusion models learn how to reverse a diffusion process. Specifically, given a data object $\boldsymbol{x}$, this diffusion process iteratively adds noise to $\boldsymbol{x}$ until it becomes pure white noise. The goal of a diffusion model is to learn to reverse this diffusion process via a model $p_\theta$ parameterized by parameters $\theta$:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/diffusion_example_korra_forward_reverse_distributions_approximate.png" alt="drawing" width="800" /></center>

<p>Once we have this model in hand, we can generate an object by first sampling white noise $\boldsymbol{x}_T$ from a standard normal distribution $N(\boldsymbol{0}, \boldsymbol{I})$, and then iteratively sampling $\boldsymbol{x}_{t-1}$ from each learned $p_{\theta}(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t})$ distribution. At the end of this process we will have “transformed” the random white noise into an object.</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/diffusion_example_generation_korra.png" alt="drawing" width="800" /></center>

<p>To learn this model, we will fit the joint distribution given by the reverse-diffusion model, $p_{\theta}(\boldsymbol{x}_{0:T})$, to joint distribution given by the forward-diffusion model, $q(\boldsymbol{x}_{0:T})$. Specifically, we will seek to minimize the KL-divergence from $p_\theta(\boldsymbol{x}_{0:T})$ to $q(\boldsymbol{x}_{0:T})$:</p>

\[\hat{\theta} := \text{arg min}_\theta \ KL( q(\boldsymbol{x}_{0:T}) \ \vert\vert \ p_\theta(\boldsymbol{x}_{0:T}))\]

<p>While the core idea of learning a denoising model that reverses a diffusion process and then using that denoising model to produce samples may be intuitive at a high-level, one may be wanting for a more rigorous theoretical motivation for the objective function that entails fitting $p_\theta(\boldsymbol{x}_{0:T})$ to $q(\boldsymbol{x}_{0:T})$ by minimizing their KL-divergence.  That is, recall that the goal in traditional probabilistic modeling is fit a model $p_\theta(\boldsymbol{x}_0)$ that approximates the real-world, unknown distribution $q(\boldsymbol{x}_0)$. How does fitting a model to a reverse a diffusion lead accompish this task? Furthermore, what traditional statistical inference frameworks is this related to?</p>

<p>In this post we will answer these questions by discussing several <a href="https://mbernste.github.io/posts/understanding_3d/">perspectives</a> to motivate and understand the diffusion model objective. Specifically, we will view this objective in the following ways:</p>

<ol>
  <li>As maximum-likelihood estimation</li>
  <li>As implicitly minimizing an upper bound on the KL-divergence between $q(\boldsymbol{x}_0)$ and $p_\theta(\boldsymbol{x}_0)$</li>
  <li>As training a hierarchical variational autoencoder that uses a parameterless inference model</li>
  <li>As breaking up a difficult problem into many easier problems</li>
</ol>

<p>A 5th perspective that motivates the loss function lies in its connection with <a href="https://arxiv.org/abs/1907.05600">score matching models</a>; however, this merits a longer conversation that should be provided its own post. Stay tuned!</p>

<h2 id="1-as-maximum-likelihood-estimation">1. As maximum-likelihood estimation</h2>

<p>Our reverse diffusion process can be thought about as a model, like any other, over noiseless objects $\boldsymbol{x}_0$, which we can access by marginalizing over all of the intermediate objects $\boldsymbol{x}_{1:T}$:</p>

\[p_\theta(\boldsymbol{x}_0) = \int p(\boldsymbol{x}_0, \boldsymbol{x}_{1:T}) \ d\boldsymbol{x}_{1:T}\]

<p>It turns out that minimizing the KL-divergence from $p_\theta(\boldsymbol{x}_{0:T})$ to $q(\boldsymbol{x}_{0:T})$ can be viewed as doing approximate <a href="https://en.wikipedia.org/wiki/Maximum_likelihood_estimation">maximum likelihood esimation</a>.</p>

<p>To see why, remember how we showed that minimizing the KL-divergence from $p_\theta(\boldsymbol{x}_{0:T})$ to $q(\boldsymbol{x}_{0:T})$ could be accomplished implicitly by maximizing the ELBO:</p>

\[\begin{align*}\hat{\theta} &amp;:= \text{arg max}_\theta \ \text{ELBO}(\theta) \\ &amp;= \text{arg max}_\theta \  E_{\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0 \sim q}\left[ \log\frac{p_\theta (\boldsymbol{x}_{0:T}) }{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0) } \right]\end{align*}\]

<p>Notice that if we maximize the ELBO, we are not only minimizing the KL-divergence (our original stated goal), but we are also implicitly maximizing a lower bound of the log-likelihood, $\log p_\theta(\boldsymbol{x})$. That is, we see that</p>

\[\begin{align*} \log p_\theta(\boldsymbol{x}) &amp;= KL( q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0) \ \vert\vert \ p_\theta(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)) + \underbrace{E_{\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0 \sim q} \left[ \log\frac{p_\theta (\boldsymbol{x}_{0:T}) }{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0) } \right]}_{\text{ELBO}} \\ &amp;\geq  \underbrace{E_{\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0 \sim q}\left[ \log\frac{p_\theta (\boldsymbol{x}_{0:T}) }{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0) } \right]}_{\text{ELBO}} \ \ \text{Because KL-divergence is non-negative} \end{align*}\]

<p>This idea is depicted schematically below (this figure is adapted from <a href="https://jmtomczak.github.io/blog/4/4_VAE.html">this blog post by Jakub Tomczak</a>):</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/ELBO_vs_log_likelihood.png" alt="drawing" width="600" /></center>

<p>Here, $\theta^*$ represents the maximum likelihood estimate of $\theta$ and $\hat{\theta}$ represents the value for $\theta$ that maximizes the ELBO. If this lower-bound is tight, $\hat{\theta}$ will be close to $\hat{\theta}$. Although in most cases, it is difficult to know with certainty how tight this lower bound is, in practice, this strategy of maximizing the ELBO leads to good results at estimating $\theta^*$.</p>

<h2 id="2-as-implicitly-minimizing-an-upper-bound-on-the-kl-divergence-between-qboldsymbolx_0-and-p_thetaboldsymbolx_0">2. As implicitly minimizing an upper bound on the KL-divergence between $q(\boldsymbol{x}_0)$ and $p_\theta(\boldsymbol{x}_0)$</h2>

<p>Recall that our ultimate goal in fitting a denoising diffusion model is to fit a model $p_\theta(\boldsymbol{x}_0)$ that approximates the real-world, unknown distribution $q(\boldsymbol{x}_0)$. As we described in the previous section, $p_\theta(\boldsymbol{x}_0)$ can be obtained by marginalizing over all of the intermediate objects $\boldsymbol{x}_{1:T}$.</p>

<p>As explained eloquently by Alexander Alemi in <a href="https://blog.alexalemi.com/diffusion.html#extra-entropy">his blog post on this topic</a>, by minimizing the KL-divergence between the full diffusion process’s joint distributions, $p_\theta(\boldsymbol{x}_{0:T})$ and $q(\boldsymbol{x}_{0:T})$, we will implicitly minimize an upper bound on the KL-divergence from $p_\theta(\boldsymbol{x}_0)$ to $q(\boldsymbol{x}_0)$ (See Derivation 1 in the Appendix to this post):</p>

\[KL(q(\boldsymbol{x}_{0:T}) \ \vert\vert \ p_\theta(\boldsymbol{x}_{0:T})) \geq KL(q(\boldsymbol{x}_0) \ \vert\vert \ p_\theta(\boldsymbol{x}_0)) \geq 0\]

<p>Thus, by minimizing $KL(q(\boldsymbol{x}_{0:T}) \ \vert\vert \ p_\theta(\boldsymbol{x}_{0:T}) )$, we are implicitly learning to fit $p_\theta(\boldsymbol{x}_0)$ to $q(\boldsymbol{x}_0)$!</p>

<h2 id="3-as-training-a-hierarchical-variational-autoencoder-that-uses-a-parameterless-inference-model">3. As training a hierarchical variational autoencoder that uses a parameterless inference model</h2>

<p>In the last section, we showed that fitting $p_\theta(\boldsymbol{x}_{0:T})$ to $q(\boldsymbol{x}_{0:T})$ will implicitly minimize an upper bound on the KL-divergence from $p_\theta(\boldsymbol{x}_0)$ to $q(\boldsymbol{x}_0)$. This begs the question: Why don’t we fit $p_\theta(\boldsymbol{x}_0)$ to $q(\boldsymbol{x}_0)$ directly?</p>

<p>To answer this, let us remind ourselves that it is often a fruitful strategy to posit a latent generative process of the observed data and try to fit a model of that latent, generative process rather than fit the marginal distribution of the observed data directly. That is just what we are doing in the diffusion model framework! The idea of expanding the model from modeling a distribution over only a random variable representing the (noiseless) observed data, $\boldsymbol{x}_0$, to incorporate extra random variables, $\boldsymbol{x}_{1:T}$, defines a generative, latent variable model of the observed data. It turns out, this latent variable model resembles a <a href="https://mbernste.github.io/posts/vae/">variational autoencoder (VAE)</a>!</p>

<p>As a brief review, recall that in the VAE framework, we assume that every data item/object that we wish to model, $\boldsymbol{x}$, is associated with a latent variable $\boldsymbol{z}$. We specify an inference model, $q_\phi(\boldsymbol{z} \mid \boldsymbol{x})$, that approximates the posterior distribution $p_\theta(\boldsymbol{z} \mid \boldsymbol{x})$. Note that $q_\phi(\boldsymbol{z} \mid \boldsymbol{x})$ is parameterized by a set of parameters $\phi$. One can view $q_\phi(\boldsymbol{z} \mid \boldsymbol{x})$ as a sort of <em>encoder</em>; given a data item $\boldsymbol{x}$, we encode it into a lower-dimensional vector $\boldsymbol{z}$. We also specify a generative model, $p_\theta(\boldsymbol{x} \mid \boldsymbol{z})$, that given a lower-dimensional, latent vector $\boldsymbol{z}$, we can generate $\boldsymbol{x}$ by sampling. If we have encoded $\boldsymbol{x}$ into $\boldsymbol{z}$, sampling from the distribution $p_\theta(\boldsymbol{x} \mid \boldsymbol{z})$ will produce objects that resemble the original $\boldsymbol{x}$. Thus $p_\theta(\boldsymbol{x} \mid \boldsymbol{z})$ can be viewed as a <em>decoder</em>. This process is depicted schematically below:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/diffusion_VAE_graphical_model.png" alt="drawing" width="170" /></center>

<p><br /></p>

<p>Now, compare this setup to the setup we have described for the diffusion model:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/diffusion_graphical_model_like_VAE.png" alt="drawing" width="400" /></center>

<p>These figures were adapted from <a href="https://angusturner.github.io/generative_models/2021/06/29/diffusion-probabilistic-models-I.html">this blog post</a> by Angus Turner.</p>

<p>In the case of a diffusion model, we have an observed item $\boldsymbol{x}_0$ that we iteratively corrupt into $\boldsymbol{x}_T$. In a way, we can view $\boldsymbol{x}_T$ as a latent representation associated with $\boldsymbol{x}_T$ in a similar way that $\boldsymbol{z}$ is a latent representation of $\boldsymbol{x}$ in the VAE. Note that this is a “hierarchical” VAE since we do not associate a single latent variable with each $\boldsymbol{x}_0$, but rather a whole sequence of latent variables $\boldsymbol{x}_1, \dots, \boldsymbol{x}_T$.</p>

<p>Moreover, the training objectives between the traditional VAE and this “hierarchical” VAE are identical. In the case of the traditional VAE, our goal is to minimize the expectation of the KL-divergence from $p_\theta(\boldsymbol{z} \mid \boldsymbol{x})$ to $q_\phi(\boldsymbol{z} \mid \boldsymbol{x})$, which we do so by maximizing the expectation of the ELBO:</p>

\[\text{ELBO}_{\text{VAE}}(\phi, \theta) := E_{z \mid x \sim q_\phi} \left[ \log \frac{p_\theta( \boldsymbol{x}, \boldsymbol{z})}{q_\phi(\boldsymbol{z} \mid \boldsymbol{x})} \right]\]

<p>In the case of the diffusion model, we seek to minimize the KL-divergence from $p_\theta(\boldsymbol{x}_0, \dots, \boldsymbol{x}_T)$ to $q_(\boldsymbol{x}_0, \dots, \boldsymbol{x}_T)$, which can be maximized by optimizing the expectation of the following ELBO:</p>

\[\text{ELBO}_{\text{Diffusion}}(\theta) := E_{\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0 \sim q}\left[ \log\frac{p_\theta (\boldsymbol{x}_{0:T}) }{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0) } \right]\]

<p>While the analogy between VAEs and denoising diffusion models is helpful for drawing connections between ideas and gaining intuition, the analogy breaks down in a few key areas:</p>

<ol>
  <li>The “encoder” model in the diffusion model, $q_(\boldsymbol{x}_0, \dots, \boldsymbol{x}_T)$, has no parameters. Rather, it is simple diffusion process that adds noise progressively to $\boldsymbol{x}$ until it turns into pure noise. Thus, the “latent representation” output by this “encoder” is rather meaningless – it’s just noise! Because it is just noise, it cannot be reconstructed (“decoded”) back to $\boldsymbol{x}$. In contrast, the encoder model in a VAE, $q_\phi(\boldsymbol{z} \mid \boldsymbol{x})$, is a true “encoder” in the sense that it outputs a meaningful latent variable $\boldsymbol{z}$ that stores the necessary information needed to reconstruct $\boldsymbol{x}$ (which can be performed by the decoder).</li>
  <li>In a VAE, we optimize the ELBO with respect to, $\phi$, the parameters of $q$, which can be interpreted as doing <a href="https://mbernste.github.io/posts/variational_inference/">variational inference</a> to approximate the posterior over $\boldsymbol{z}$ conditioned on $\boldsymbol{x}$. In a diffusion model, $q$ has no parameters and thus, our optimization of the ELBO cannot be truly interpreted as doing variational inference. Rather, because we only optimize with respect to $\theta$, we can only interpret the optimization procedure as performing approximate maximum likelihood that maximizes a lower bound of $\log p_\theta(\boldsymbol{x})$.</li>
</ol>

<h2 id="4-as-breaking-up-a-difficult-problem-into-many-easier-problems">4. As breaking up a difficult problem into many easier problems</h2>

<p>Another, more high-level, reason why diffusion models tend to perform better than other methods, such as <a href="https://mbernste.github.io/posts/vae/">variational autoencoders</a>, is that diffusion models break up a difficult problem into a series of easier problems. That is, unlike variational autoencoders, where we train a model to produce an object all at once, in diffusion models, we train the model to produce the object step-by-step. Intuitively, we train a model to “sculpt” an object out of noise in a step-wise fashion rather than generate the object in one fell-swoop.</p>

<p>This step-wise approach is advantageous because it enables the model to learn features of objects at different levels of resolution. At the end of the reverse diffusion process (i.e., the sampling process), the model identifies broad, vague features of an object within the noise. At later steps of the reverse diffusion process, it fills in smaller details of the object by removing the last remaining noise.</p>

<h2 id="appendix">Appendix</h2>

<h3 id="derivation-1-deriving-an-upper-bound-over-klqboldsymbolx_0--vertvert--p_thetaboldsymbolx_0">Derivation 1 (Deriving an upper bound over $KL(q(\boldsymbol{x}_0) \ \vert\vert \ p_\theta(\boldsymbol{x}_0)))$:</h3>

\[\begin{align*}KL(q(\boldsymbol{x}_{0:T}) \ \vert\vert \ p_\theta(\boldsymbol{x}_{0:T})) &amp;= E_{\boldsymbol{x}_{0:T} \sim q} \left[ \log \frac{q(\boldsymbol{x}_{0:T})}{p_\theta(\boldsymbol{x}_{0:T})} \right] \\ &amp;= E_{\boldsymbol{x}_{0:T} \sim q} \left[ \log \frac{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)q(\boldsymbol{x}_0)}{p_\theta(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)p_\theta(\boldsymbol{x}_0)} \right] \\ &amp;= E_{\boldsymbol{x}_{0:T} \sim q} \left[ \log \frac{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)}{p_\theta(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)} \right] + E_{\boldsymbol{x}_{0} \sim q} \left[ \log \frac{q(\boldsymbol{x}_0)}{p_\theta(\boldsymbol{x}_0)} \right] \\ &amp;= E_{\boldsymbol{x}_0 \mid q} \left[ E_{\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0 \sim q} \left[ \log \frac{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)}{p_\theta(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)} \right]\right] + E_{\boldsymbol{x}_{0} \sim q} \left[ \log \frac{q(\boldsymbol{x}_0)}{p_\theta(\boldsymbol{x}_0)} \right] \\ &amp;= E_{\boldsymbol{x}_0} \left[ KL(q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0) \ \vert\vert \ p_\theta(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)) \right] + KL(q(\boldsymbol{x}_0) \ \vert\vert \ p_\theta(\boldsymbol{x}_0)) \\ &amp;\geq KL(q(\boldsymbol{x}_0) \ \vert\vert \ p_\theta(\boldsymbol{x}_0)) \end{align*}\]

<p>The inequality follows from the fact that KL-divergence is always positive.</p>]]></content><author><name>Matthew N. Bernstein</name></author><category term="tutorial" /><category term="deep learning" /><category term="machine learning" /><category term="probabilistic models" /><summary type="html"><![CDATA[In Part 1 of this series, we introduced the denoising diffusion probabilistic model for modeling and sampling from complex distributions. We described the diffusion model as a model that can generate new samples by learning how to reverse a diffusion process. In this post, we provide more theoretical justification for the objective function used to fit diffusion models and make connections between the diffusion model and other concepts in statistical inference and probabilistic modeling.]]></summary></entry><entry><title type="html">Denoising diffusion probabilistic models (Part 1: Definition and derivation)</title><link href="https://mbernste.github.io/posts/diffusion_part1/" rel="alternate" type="text/html" title="Denoising diffusion probabilistic models (Part 1: Definition and derivation)" /><published>2024-06-28T00:00:00-07:00</published><updated>2024-06-28T00:00:00-07:00</updated><id>https://mbernste.github.io/posts/diffusion_part1</id><content type="html" xml:base="https://mbernste.github.io/posts/diffusion_part1/"><![CDATA[<p><em>Diffusion models are a family of state-of-the-art probabilistic generative models that have achieved ground breaking results in a number of fields ranging from image generation to protein structure design. In Part 1 of this two-part series, I will walk through the denoising diffusion probabilistic model (DDPM) as presented by Ho, Jain, and Abbeel (2020). Specifically, we will walk through the model definition, the derivation of the objective function, and the training and sampling algorithms. We will conclude by walking through an implementation of a simple diffusion model in PyTorch and apply it to the MNIST dataset of hand-written digits.</em></p>

<h2 id="introduction">Introduction</h2>

<p>In a <a href="https://mbernste.github.io/posts/vae/">previous post</a>, we walked through the theory and implementation of the variational autoencoder, which is a probabilistic generative model that combines variational inference and neural networks to model and sample from complex distributions. In this post, we will walk through another such model: the <strong>denoising diffusion probabilistic model</strong>. Diffusion models were originally proposed by <a href="https://arxiv.org/abs/1503.03585">Sohl-Dickstein et al. (2015)</a> and later extended by <a href="https://arxiv.org/pdf/2006.11239.pdf">Ho, Jain, and Abbeel (2020)</a>.</p>

<p>At the time of this writing, diffusion models are state-of-the-art models used for image generation and have achieved what are, in my opinion, breathtaking results in generating incredibly detailed, realistic images. Below, is an example image generated by <a href="https://openai.com/dall-e-3">DALL·E 3</a> (via <a href="https://openai.com/">OpenAI</a>’s <a href="https://openai.com/gpt-4">ChatGPT</a>), which as far as I understand, uses diffusion models as part of its image-generation machinery.</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/dalle3_example.png" alt="drawing" width="500" /></center>

<p><br /></p>

<p>Diffusion models are also being explored in biomedical research. For example, <a href="https://www.nature.com/articles/s41586-023-06415-8">RFDiffusion</a> and <a href="https://www.nature.com/articles/s41586-023-06728-8">Chroma</a> are two methods that use diffusion models to generate novel protein structures. Diffusion models are <a href="https://www.nature.com/articles/s41551-024-01193-8">also being explored</a> for synthetic biomedical data generation.</p>

<p>Because of these models’ incredible performance in image generation, and their burgeoning use-cases in computational biology, I was curious to understand how they work. While I have a relatively good understanding into the theory behind the variational autoencoder, diffusion models presented a bigger challenge as the mathematics is more involved. In this two-part series of posts, I will step through my newfound understanding of diffusion models regarding both their mathematical theory and practical implementation.</p>

<p>In Part 1 of the series, I will walk through the denoising diffusion probabilistic model (DDPM) as presented by <a href="https://arxiv.org/pdf/2006.11239.pdf">Ho, Jain, and Abbeel (2020)</a>. The mathematical derivations are somewhat lengthy and I present them in the Appendix to the post so that they do not distract from the core ideas behind the model. We will conclude by walking through an implementation of a simple diffusion model in <a href="https://pytorch.org/">PyTorch</a> and apply it to the <a href="https://en.wikipedia.org/wiki/MNIST_database">MNIST dataset</a> of hand-written digits. In <a href="https://mbernste.github.io/posts/diffusion_part2/">Part 2 of this series</a>, we will dig deeper into the justification and intuition behind the objective function used to train diffusion models. Hopefully, these posts will serve others who are learning this material like myself. Please let me know if you find anything wrong!</p>

<h2 id="diffusion-models-as-learning-to-reverse-a-diffusion-process">Diffusion models as learning to reverse a diffusion process</h2>

<p>Like all probabilistic generative models, diffusion models can be understood as models that specify a probability distribution, $p(\boldsymbol{x})$, over some set of objects of interest where $\boldsymbol{x}$ is a vector representation of one such object. For example, these objects might be images, text documents, or protein sequences. Generating an image via a diffusion model can be viewed as <em>sampling</em> from $p(\boldsymbol{x})$:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/diffusion_sampling_images.png" alt="drawing" width="750" /></center>

<p>In training a diffusion model, we fit $p(\boldsymbol{x})$ by fitting a <a href="https://en.wikipedia.org/wiki/Diffusion_process">diffusion process</a>. This diffusion process goes as follows: Given a vector $\boldsymbol{x}$ representing an object (e.g., an image), we iteratively add Gaussian noise to $\boldsymbol{x}$ over a series of $T$ timesteps. Let $\boldsymbol{x}_t$ be the object at time step $t$ and let $\boldsymbol{x}_0$ be the original object before noise was added to it.  If $\boldsymbol{x}_0$ is an image of my dog Korra, this diffusion process would look like the following:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/diffusion_example_korra_forward.png" alt="drawing" width="800" /></center>

<p><br /></p>

<p>Here, $q(\boldsymbol{x})$ represents the hypothetical “real world distribution” of objects (which is distinct from the model’s distribution $p(\boldsymbol{x})$, though our goal is to train the model so that $p(\boldsymbol{x})$ resembles $q(\boldsymbol{x})$). Furthermore, if the total number of timesteps $T$ is large enough, then the corrupted object approaches a sample from a standard normal distribution $N(\boldsymbol{0}, \boldsymbol{I})$ – that is, it approaches pure white noise.</p>

<p>Now, the goal of training a diffusion model is to learn how to reverse this diffusion process by iteratively removing noise in the reverse order it was added:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/diffusion_example_korra_forward_reverse.png" alt="drawing" width="800" /></center>

<p><br /></p>

<p>The main idea behind diffusion models is that if our model can remove noise succesfully, then we have a ready-made method for generating new objects. Specifically, we can generate a new object by first sampling noise from $N(\boldsymbol{0}, \boldsymbol{I})$, and then applying our model iteratively, removing noise step-by-step until a new object is formed:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/diffusion_example_generation_korra_high_level.png" alt="drawing" width="800" /></center>

<p>In a sense, the model is “sculpting” an object out of noise bit by bit. It is like a sculptor who starts from an amorphous block of granite and slowly chips away at the rock until a form appears!</p>

<p>Now that we have some high-level intuition, let’s make this more mathematically rigorous. First, the forward diffusion process works as follows: For each timestep, $t$, we will sample noise, $\epsilon$, from a standard normal distribution, and then add it to  $\boldsymbol{x}_t$ in order to form the next, noisier object $\boldsymbol{x}_{t+1}$:</p>

\[\begin{align*}\epsilon &amp;\sim N(\boldsymbol{0}, \boldsymbol{1}) \\ \boldsymbol{x}_{t+1} &amp;:= c_1\boldsymbol{x}_t + c_2\epsilon\end{align*}\]

<p>where $c_1$ and $c_2$ are two constants (to be defined in more detail later in the post). Note that the above process can also be described as sampling from a normal distribution with a mean specified by $\boldsymbol{x}_t$:</p>

\[\boldsymbol{x}_{t+1} \sim N\left(c_1\boldsymbol{x}_t, c_2^2 \boldsymbol{I}\right)\]

<p>Thus, we can view the formation of $\boldsymbol{x}_{t+1}$ as the act of <em>sampling</em> from a normal distribution that is conditioned on $\boldsymbol{x}_t$. We will use the notation $q(\boldsymbol{x}_{t+1} \mid \boldsymbol{x}_t)$ to refer to this conditional distribution.</p>

<p>In a similar manner, we can also view the process of removing noise (i.e., reversing a diffusion step) as sampling. Specifically, we can view it as sampling from the <em>posterior</em> distribution, $q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)$. To reverse the diffusion process, we start from pure noise and iteratively sample from these posteriors:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/diffusion_example_korra_forward_reverse_distributions_exact.png" alt="drawing" width="800" /></center>

<p>Now, how do we derive these posterior distributions? One idea is to use <a href="https://en.wikipedia.org/wiki/Bayes%27_theorem">Bayes Theorem</a>:</p>

\[q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t+1}) = \frac{q(\boldsymbol{x}_{t+1} \mid \boldsymbol{x}_t)q(\boldsymbol{x}_{t})}{q(\boldsymbol{x}_{t+1})}\]

<p>Unfortunately, this posterior is intractable to compute. Why? First note that in order to compute $q(\boldsymbol{x}_t)$, we have to marginalize over all of the time steps prior to $t$:</p>

\[\begin{align*} q(\boldsymbol{x}_t) &amp;= \int_{\boldsymbol{x}_{t-1},\dots,\boldsymbol{x}_0} q(\boldsymbol{x}_t, \boldsymbol{x}_{t-1}, \dots, \boldsymbol{x}_0) \ d\boldsymbol{x}_{t-1}\dots \boldsymbol{x}_{0} \\ &amp;= \int_{\boldsymbol{x}_{t-1},\dots,\boldsymbol{x}_0} q(\boldsymbol{x}_0)\prod_{i=1}^{t} q(\boldsymbol{x}_i \mid \boldsymbol{x}_{i-1}) \ d\boldsymbol{x}_{t-1}\dots \boldsymbol{x}_{0} \end{align*}\]

<p>Notice that this marginalization requires that we define a distribution $q(\boldsymbol{x}_0)$, which is a distribution over noiseless objects (e.g., a distribution over noiseless images). Unfortunately, we don’t know what this is – that is our whole purpose of developing a diffusion model!</p>

<p>To get around this problem, we will employ a similar strategy as used in <a href="https://mbernste.github.io/posts/variational_inference/">variational inference</a>: We will <em>approximate</em> the forward diffusion process $q(\boldsymbol{x}_{0:T})$, which is given by,</p>

\[q(\boldsymbol{x}_{0:T}) = q(\boldsymbol{x}_0)\prod_{t=1}^T q(\boldsymbol{x}_{t} \mid \boldsymbol{x}_{t-1})\]

<p>using a surrogate distribution $p_\theta(\boldsymbol{x}_{0:T})$ that is instead factored by the posterior distributions (i.e., the reverse diffusion steps):</p>

\[p_\theta(\boldsymbol{x}_{0:T}) = p_\theta(\boldsymbol{x}_T)\prod_{t=1}^T p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)\]

<p>Here, $\theta$ represent a set of learnable parameters that we will be use to fit $p_{\theta}(\boldsymbol{x}_{0:T})$ as close to $q(\boldsymbol{x}_{0:T})$ as possible. As we will see later in the post, these $p_{\theta}(\boldsymbol{x}_t \mid \boldsymbol{x}_{t+1})$ distributions can incorporate a neural network so that they can represent a distribution complex enough to sucessfully remove noise.</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/diffusion_example_korra_forward_reverse_distributions_approximate.png" alt="drawing" width="800" /></center>

<p><br /></p>

<p>From our approximate joint distribution $p_\theta(\boldsymbol{x}_{0:T})$, we will obtain a sequence of approximate posterior distributions, each given by $p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)$. With these approximate posterior distributions in hand, we can generate an object by first sampling white noise $\boldsymbol{x}_T$ from a standard normal distribution $N(\boldsymbol{0}, \boldsymbol{I})$, and then iteratively sampling $\boldsymbol{x}_{t-1}$ from each learned $p_{\theta}(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t})$ distribution. At the end of this process we will have “transformed” the random white noise into an object. More specifically, we will have “sampled” an object!</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/diffusion_example_generation_korra.png" alt="drawing" width="800" /></center>

<p>In the next sections, we will more rigorously define and discuss the forward diffusion model and reverse diffusion model.</p>

<h2 id="the-forward-model">The forward model</h2>

<p>As stated previously, the forward model is defined as</p>

\[q(\boldsymbol{x}_{t+1} \mid \boldsymbol{x}_t) := \sim N\left(\boldsymbol{x}_{t+1} ; c_1\boldsymbol{x}_t, c_2^2 \boldsymbol{I}\right)\]

<p>where $c_1$ and $c_2$ are constants. Let us now define these constants. First, let us define values $\beta_1, \beta_2, \dots, \beta_T \in [0, 1]$. These are $T$ values between zero and one, each corresponding to a timestep. The constants $c_1$ and $c_2$ are simply:</p>

\[\begin{align*}c_1 &amp;:= \sqrt{1-\beta_t} \\ c_2 &amp;:= \beta_t\end{align*}\]

<p>Then, the fully-defined forward model at timestep $t$ is:</p>

\[q(\boldsymbol{x}_{t+1} \mid \boldsymbol{x}_t) := N\left(\boldsymbol{x}_{t+1}; \sqrt{1-\beta_t}\boldsymbol{x}_t, \beta_t \boldsymbol{I}\right)\]

<p>Here we see that $c_2 := \beta_t$ sets the variance of the noise at timestep $t$. In diffusion models, it is common to predefine a function that returns $\beta_t$ at each timestep. This function is called the <strong>variance schedule</strong>. For example, one might use a linear variance schedule defined as:</p>

\[\beta_t := (\text{max} - \text{min})(t/T) + \text{min}\]

<p>where $\text{max}, \text{min} \in [0,1]$ and $\text{min} &lt; \text{max}$ are two small constants. The function above will compute a sequence of $\beta_1, \dots, \beta_T$ that interpolate linearly between $\text{min}$ and $\text{max}$. Note, the specific variance schedule that one uses is a modeling design choice. Instead of a linear variance schedule, such as the one shown above, one may opt for another one. For example, <a href="https://arxiv.org/pdf/2102.09672.pdf">Nichol and Dhariwal (2021)</a> suggest replacing a linear variance schedule with a cosine variance schedule (which we won’t discuss here).</p>

<p>This begs the question: Why use a different value of $\beta_t$ at each time step? Why not set $\beta_t$ constant across timesteps? The answer is that, empirically, if $\beta_t$ is large, then the object will turn to noise too quickly and wash away the structure of the object too early in the process thereby making it challenging to learn $p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)$. In contrast, if $\beta$ is very small, then each step only removes a very small amount of noise, and thus, to turn an object $\boldsymbol{x}_0$ into white noise (and back via reverse diffusion), we would require many timesteps (which as we will see, would lead to inefficient training of the model).</p>

<p>A solution that balances the need to maintain the object’s structure while keeping the number of timesteps relatively short, is to increase the variance at each timestep according to a set schedule so that at the beginning of the diffusion process, only a little bit of noise is added at a time, but towards the end of the process, more noise is added at a time to ensure that $\boldsymbol{x}_T$ approaches a sample from $N(\boldsymbol{0}, \boldsymbol{I})$ (i.e., it becomes pure noise). This is illustrated in the figure below:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/diffusion_variance_schedule_example_korra.png" alt="drawing" width="800" /></center>
<p><br /></p>

<p>Now that we have a better understanding of the second constant (i.e., $c_2 := \beta_t$), which scales the variance, let’s turn our attention to the first constant, $c_1 := \sqrt{1-\beta_t}$, which scales the mean. Why are we scaling the mean with this constant? Doesn’t it make more sense to simply center the mean of the forward noise distribution at $\boldsymbol{x}_t$?</p>

<p>The reason for this term is that it makes sure that the variance of the noise does not increase, but rather equals one. That is, $\sqrt{1-\beta}$, is precisely the value required to scale the mean of the forward diffusion process distribution, $q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1})$ such that $\text{Var}(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}) = 1$. See Derivation 4 in the Appendix to this post for a proof. Recall, our goal is to transform $\boldsymbol{x}_0$ into white noise distributed by a standard normal distribution (which has a variance of 1), and thus, we cannot have the various grow at each timestep.  Below we depict a forward diffusion process on 1-dimensional data using two strategies: the first does not scale the mean and the second does. Notice that the variance continues to grow when we don’t scale the mean, but it remains fixed when we scale the mean by $\sqrt{1-\beta}$:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/diffusion_forward_process_mean_scaling_term_1D.png" alt="drawing" width="800" /></center>

<p><br /></p>

<p>Before we conclude this section, we will also prove a few convenient properties of the forward model that will be useful for deriving the final objective function used to train diffusion models:</p>

<p>1. <strong>$q(\boldsymbol{x}_t \mid \boldsymbol{x}_0)$ has a closed form.</strong> That is, the distribution over a noisy object at timestep $t$ of the diffusion process has a closed form solution. That solution is specifically the following normal distribution (See Derivation 5 in the Appendix to this post):</p>

\[q(\boldsymbol{x}_t \mid \boldsymbol{x}_0) := N\left(\boldsymbol{x}_t; \sqrt{\bar{\alpha}_t} \boldsymbol{x}_0, (1-\bar{\alpha}_t)\boldsymbol{I} \right)\]

<p>where $\alpha_t := 1-\beta$ and $\bar{\alpha}_t := \prod_{i=1}^t \alpha_t$ (this notation is used in the original paper by <a href="https://arxiv.org/pdf/2006.11239.pdf">Ho, Jain, and Abbeel (2020)</a> and makes the equations going forward easier to read). This is depicted schematically below:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/diffusion_forward_t_cond_0_korra.png" alt="drawing" width="500" /></center>

<p>Note that because $q(\boldsymbol{x}_t \mid \boldsymbol{x}_0)$ is simply a normal distribution, this enables us to sample noisy images at any arbitrary timestep $t$ without having to run the full diffusion process for $t$ timesteps. That is, instead of having to sample from $t$ normal distributions, which is what would be required to run the forward diffusion process to timestep $t$, we can instead sample from one distribution. As we will show, this will enable us to speed up the training of the model.</p>

<p>2. <strong>$q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)$ has a closed form.</strong> Note that we previously discussed how the conditional distribution, $q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)$ was intractible to compute.</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/diffusion_posterior_intractible_korra.png" alt="drawing" width="275" /></center>

<p>However, it turns out that if instead of only conditioning $\boldsymbol{x}_t$, we also condition on the original, noiseless object, $\boldsymbol{x}_0$, we <em>can</em> derive a closed form for this posterior distribution. That distribution is a normal distribution (See Derivations 6 and 7 in the Appendix to this post):</p>

\[\begin{align*}q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0) &amp;= N\left(\boldsymbol{x}_{t-1}; \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}\boldsymbol{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1 - \bar{\alpha}_t}\boldsymbol{x}_0, \frac{\beta_t \left(1 - \bar{\alpha}_{1-t}\right)}{1- \bar{\alpha}_t}\boldsymbol{I}\right) &amp;&amp; \text{Derivation 6} \\ &amp;= N\left(\boldsymbol{x}_{t-1}; \frac{1}{\sqrt{\alpha_t}} \left(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t) - \frac{\beta_t}{\sqrt{1- \bar{\alpha}_{t}}}\epsilon_t \right) , \frac{\beta_t \left(1 - \bar{\alpha}_{1-t}\right)}{1- \bar{\alpha}_t}\boldsymbol{I}\right) &amp;&amp; \text{Derivation 7}\end{align*}\]

<p>The second line follow from a reparameterization of $\boldsymbol{x}_t$ by noting that the previously described closed form of $q(\boldsymbol{x}_t \mid \boldsymbol{x}_0)$ implies that $\boldsymbol{x}_t$ can be generated by first sampling $\epsilon_t \sim N(\boldsymbol{0}, \boldsymbol{I})$ and then passing $\boldsymbol{x}_0$ and $\epsilon_t$ into the function,</p>

\[\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t) := \sqrt{\bar{\alpha}_t}\boldsymbol{x}_0 + \sqrt{1 - \bar{\alpha}_t}\epsilon_t\]

<p>Sampling $\boldsymbol{x}_{t-1}$ conditioned on $\boldsymbol{x}_t$ and $\boldsymbol{x}_0$ is depected schematically below:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/diffusion_posterior_tractible_korra.png" alt="drawing" width="700" /></center>

<p><br /></p>

<p>The fact that this posterior has a closed form when conditioning on $\boldsymbol{x}_0$ makes intuitive sense: as we talked about previously, the posterior distribution $q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)$ requires knowing $q(\boldsymbol{x}_0)$ – that is, in order to turn noise into an object, we need to know what real, noiseless objects look like. However, if we condition on $\boldsymbol{x}_0$, this means we are assuming we <em>know</em> what $\boldsymbol{x}_0$ looks like and the modified posterior, $q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)$, needs only to take into account subtraction of noise towards this noiseless object.</p>

<h2 id="the-reverse-model">The reverse model</h2>

<p>Let’s now describe the model that we will use to approximate the reverse diffusion steps, $p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)$. In the most general form, we will define this distribution to be a normal distribution where the mean and variance are defined by two functions of $\boldsymbol{x}_t$ and $t$:</p>

\[p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) := N(\boldsymbol{x}_{t-1}; \boldsymbol{\mu}_\theta(\boldsymbol{x}_t, t), \boldsymbol{\Sigma}_\theta(\boldsymbol{x}_t, t))\]

<p>where $\boldsymbol{\mu}_\theta(\boldsymbol{x}_t, t)$ and $\boldsymbol{\Sigma}_\theta(\boldsymbol{x}_t, t)$ are two functions that take $\boldsymbol{x}_t$ and $t$ as input and output the mean and variance respectively. These functions are parameterized by $\theta$.</p>

<p><a href="https://arxiv.org/pdf/2006.11239.pdf">Ho, Jain, and Abbeel (2020)</a> simplified this model such that the variance is constant at each time step $t$ rather than output by a function (i.e., model). Specifically, they define</p>

\[\boldsymbol{\Sigma}_\theta(\boldsymbol{x}_t, t) := \sigma_t^2 \boldsymbol{I}\]

<p>Thus, the reverse model becomes</p>

\[p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) := N(\boldsymbol{x}_{t-1}; \boldsymbol{\mu}_\theta(\boldsymbol{x}_t, t), \sigma_t^2\boldsymbol{I})\]

<p>The authors found that setting $\sigma_t := \beta_t$ worked well in practice.</p>

<h2 id="fitting-p_thetaboldsymbolx_0t-to-qboldsymbolx_0t">Fitting $p_\theta(\boldsymbol{x}_{0:T})$ to $q(\boldsymbol{x}_{0:T})$</h2>

<p>To fit $p_\theta(\boldsymbol{x}_{0:T})$ to $q(\boldsymbol{x}_{0:T})$, diffusion models will seek to minimize the KL-divergence from $p_\theta(\boldsymbol{x}_{0:T})$ to $q(\boldsymbol{x}_{0:T})$:</p>

\[\hat{\theta} := \text{arg min}_\theta \ KL( q(\boldsymbol{x}_{0:T}) \ \vert\vert \ p_\theta(\boldsymbol{x}_{0:T}))\]

<p>Let’s attempt to derive a closed form for this objective function. Following Derivation 1 in the Appendix to this post, we can write this KL-divergence as:</p>

\[KL( q(\boldsymbol{x}_{0:T}) \ \vert\vert \ p_\theta(\boldsymbol{x}_{0:T}) ) = E_{\boldsymbol{x}_0 \sim q}\left[ \log q(\boldsymbol{x}_0) \right] - E_{\boldsymbol{x}_{0:T} \sim q}\left[ \log\frac{p_\theta (\boldsymbol{x}_{0:T}) }{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0) } \right]\]

<p>Notice, the first term, $E_{\boldsymbol{x}_0 \sim q}\left[ \log q(\boldsymbol{x}_0) \right]$, does not depend on our parameters $\theta$. The second term can be viewed as a function of the parameters $\theta$. Because there is a negative sign in front of this second term, we see that in order to minimize the KL-divergence, we must maximize it. Thus, we seek:</p>

\[\hat{\theta} = \text{arg max}_\theta \  E_{\boldsymbol{x}_{0:T} \sim q}\left[ \log\frac{p_\theta (\boldsymbol{x}_{0:T}) }{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0) } \right]\]

<p>We note that this second term is the expectation of an <a href="https://mbernste.github.io/posts/elbo/">evidence lower bound (ELBO)</a> with respect to $\boldsymbol{x}_0 \sim q$. Why is this term called an “evidence lower bound”? As shown in Derivation 2 in the Appendix to this post, we see that the second term shown above is a lower bound for the expected log-likelihood, otherwise known as the “evidence”:</p>

\[\begin{align*} \log p_\theta(\boldsymbol{x}) &amp;\geq E_{\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0 \sim q}\left[ \log\frac{p_\theta (\boldsymbol{x}_{0:T}) }{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0) } \right] \\ \implies E_{\boldsymbol{x}_0 \sim q}\left[ \log p_\theta(\boldsymbol{x}) \right] &amp;\geq E_{\boldsymbol{x}_{0:T} \sim q}\left[ \log\frac{p_\theta (\boldsymbol{x}_{0:T}) }{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0) } \right] \\ &amp;\geq E_{\boldsymbol{x}_0 \sim q}\left[ \text{ELBO}(\theta) \right] \end{align*}\]

<p>Thus, by maximizing the expectation of the ELBO with respect to $\theta$, we are implicitly maximizing a lower bound of $\log p_\theta(\boldsymbol{x}_0)$, so we can view this procedure as doing approximate maximum likelihood estimation! We will discuss this idea further in <a href="https://mbernste.github.io/posts/diffusion_part2/">Part 2 of this series</a>.</p>

<p>Let’s now examine the expected ELBO more closely. It turns out that it can be further manipulated into a form that has a term for each step of the diffusion process (See Derivation 3 in the Appendix to this post):</p>

\[\begin{align*}E_{\boldsymbol{x}_0 \sim q} \left[ \text{ELBO}(\theta) \right] &amp;= E_{\boldsymbol{x}_{0:T} \sim q}\left[ \log \frac{ p_\theta (\boldsymbol{x}_{0:T}) }{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)} \right] \\ &amp;= \underbrace{E_{\boldsymbol{x}_1, \boldsymbol{x}_0 \sim q} \left[ p_\theta(\boldsymbol{x}_0 \mid \boldsymbol{x}_1) \right]}_{L_0} + \underbrace{\sum_{t=2}^T  E_{\boldsymbol{x}_t, \boldsymbol{x}_0 \sim q} \left[ KL \left( q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0) \ \vert\vert \ p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) \right) \right]}_{L_1, L_2, \dots, L_{T-1}} + E_{\boldsymbol{x}_0} \left[ \underbrace{KL\left( q(\boldsymbol{x}_T \mid \boldsymbol{x}_0) \ \vert\vert \  p_\theta(\boldsymbol{x}_T) \right)}_{L_T} \right]\end{align*}\]

<p>These terms are broken into three cagegories:</p>

<ol>
  <li>$L_0$ is the probability the model gives the data conditioned on the very first diffusion step. In the reverse diffusion process, this is the last step required to transform the noise into the original image. This term is called the <strong>reconstruction term</strong> because it provides high probility if the model can succesfully predict the original noiseless image $\boldsymbol{x}_0$ from $\boldsymbol{x}_1$, which is the result of the first iteration of the diffusion process.</li>
  <li>$L_1, \dots, L_{T-1}$ are terms that measure how well the model is performing reverse diffusion. That is, it asking how well the posterior probabilities specified by the model, $p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)$, are pushing the object closer to $\boldsymbol{x}_0$ according to the probabilities $q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)$. As we will see later in this section, this term can be viewed as a “noise prediction” term where the model must learn how to remove noise from $\boldsymbol{x}_t$ to bring it one step closer to $\boldsymbol{x}_0$.</li>
  <li>$L_T$ simply measures how well the result of the forward diffusion process, $q(\boldsymbol{x}_0)$, which theoretically approaches a standard normal distribution, matches the distribution from which we seed the reverse diffusion process, $p_\theta(\boldsymbol{x}_0)$. This objective is explicitly minimized from the outset because we explicitly define $p_\theta(\boldsymbol{x}_0)$ to be a standard normal distribution.</li>
</ol>

<p>Now we show that by breaking up the expected ELBO into these discrete terms, we can simplify the whole thing into a closed form expression. Let’s start with the last term $L_T$. Recall that we define $p_\theta(\boldsymbol{x}_T)$ to be a standard normal distribution that does not incorporate any model parameters. That is,</p>

\[p_\theta(\boldsymbol{x}_T) := N(\boldsymbol{x}_T; \boldsymbol{0}, \boldsymbol{I})\]

<p>Because it does not incorporate any parameters, we can ignore this term when maximizing the expected ELBO with respect to $\theta$. Thus, our task will be to find:</p>

\[\hat{\theta} := \text{arg max}_\theta \ \underbrace{E_{\boldsymbol{x}_1 \sim q} \left[ p_\theta(\boldsymbol{x}_0 \mid \boldsymbol{x}_1) \right]}_{L_0} +  \underbrace{\sum_{t=2}^T E_{\boldsymbol{x}_t, \boldsymbol{x}_0 \sim q} \left[ KL \left( q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0) \ \vert\vert \ p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) \right) \right]}_{L_1, L_2, \dots, L_{T-1}}\]

<p>Now, let’s turn to the middle terms $L_1, \dots, L_{T-1}$. Here we see that these terms require calculating KL-divergences from  $p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)$ to $q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)$. Recall from the previous sections that both of these distributions are normal distributions. That is,</p>

\[\begin{align*}L_t &amp;:= E_{\boldsymbol{x}_t, \boldsymbol{x}_0 \sim q} \left[ KL\left(q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0) \ \vert\vert \ p_\theta(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1})\right) \right] \\ &amp;= E_{\epsilon_t, \boldsymbol{x}_0} \left[ KL\left(N\left(\boldsymbol{x}_{t-1}; \frac{1}{\sqrt{\alpha_t}} \left(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t) - \frac{\beta_t}{\sqrt{1- \bar{\alpha}_{t}}}\epsilon_t \right) , \frac{\beta_t \left(1 - \bar{\alpha}_{1-t}\right)}{1- \bar{\alpha}_t}\boldsymbol{I}\right) \ \vert\vert \ N(\boldsymbol{x}_{t-1}; \boldsymbol{\mu}_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t), \sigma_t^2\boldsymbol{I})\right) \right] \\ &amp;= E_{\epsilon_t, \boldsymbol{x}_0} \left[KL\left(N(\boldsymbol{x}_{t-1}; \tilde{\boldsymbol{\mu}}(\boldsymbol{x}_{0}, \epsilon_t), \tilde{\sigma}_t^2\boldsymbol{I}) \ \vert\vert \ N(\boldsymbol{x}_{t-1}; \boldsymbol{\mu}_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t), \sigma_t^2\boldsymbol{I})  \right) \right]\end{align*}\]

<p>where for ease of notation we define,</p>

\[\begin{align*}\tilde{\boldsymbol{\mu}}(\boldsymbol{x}_{0}, \epsilon_t) &amp;:= \frac{1}{\sqrt{\alpha_t}} \left(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t) - \frac{\beta_t}{\sqrt{1- \bar{\alpha}_{t}}}\epsilon_t \right) \\ \tilde{\sigma}_t^2 &amp;:= \frac{\beta_t \left(1 - \bar{\alpha}_{1-t}\right)}{1- \bar{\alpha}_t}\end{align*}\]

<p>Note that the above formulation of $L_t$ is now an expectation of over $\epsilon_t \sim N(\boldsymbol{0}, \boldsymbol{I})$ due to the fact that we have reparameterized $L_t$ based on observing that the previously described closed form of $q(\boldsymbol{x}_t \mid \boldsymbol{x}_0)$ implies that $\boldsymbol{x}_t$ can be sampled by first sampling $\epsilon_t \sim N(\boldsymbol{0}, \boldsymbol{I})$ and then generating $\boldsymbol{x}_t$ by passing $\boldsymbol{x}_0$ and $\epsilon_t$ into the function,</p>

\[\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon) := \sqrt{\bar{\alpha}_t}\boldsymbol{x}_0 + \sqrt{1 - \bar{\alpha}_t}\epsilon_t\]

<p>Said differently, we now assume that the object $\boldsymbol{x}_t$ is generated via the function $\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t)$ and the expectation described by $L_t$ is now an expectation of a random value whose stochasticity comes from the stochasticity of $\epsilon_t$.</p>

<p>We now use the following fact: Given two multivariate normal distributions</p>

\[\begin{align*}P_1 \:= N(\boldsymbol{\mu}_1, \boldsymbol{\Sigma}_1) \\ P_2 \:= N(\boldsymbol{\mu}_2, \boldsymbol{\Sigma}_2) \end{align*}\]

<p>it follows that</p>

\[KL(P_1 \ \vert\vert P_2) = \frac{1}{2}\left( \left(\boldsymbol{\mu}_2 - \boldsymbol{\boldsymbol{\mu}_1}\right)^T \boldsymbol{\Sigma}_2^{-1} \left(\boldsymbol{\mu_2} - \boldsymbol{\mu_1}\right) + \text{Trace}\left(\boldsymbol{\Sigma}_2^{-1} \boldsymbol{\Sigma}_1\right) + \log \frac{ \text{Det} \left(\boldsymbol{\Sigma}_2\right) }{\text{Det}\left(\boldsymbol{\Sigma}_1\right)} - d\right)\]

<p>where $d$ is the dimensionality of each multivariate Gaussian. We won’t prove this fact here (see <a href="https://statproofbook.github.io/P/mvn-kl.html">this link</a> for a formal proof).</p>

<p>Applying this fact to $L_t$, we see that,</p>

\[\begin{align*}L_t &amp;:= E_{\epsilon_t, \boldsymbol{x}_0} \left[KL\left(N(\boldsymbol{x}_{t-1}; \tilde{\boldsymbol{\mu}}(\boldsymbol{x}_{0}, \epsilon_t), \tilde{\sigma}_t^2\boldsymbol{I}) \ \vert\vert \ N(\boldsymbol{x}_{t-1}; \boldsymbol{\mu}_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t), \sigma_t^2\boldsymbol{I})  \right) \right] \\ &amp;= E_{\epsilon_t, \boldsymbol{x}_0} \left[ \frac{1}{2}\left( \left(\boldsymbol{\mu}_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t) - \tilde{\boldsymbol{\boldsymbol{\mu}}}(\boldsymbol{x}_0, \epsilon_t)\right)^T \left( \sigma_t^2 \boldsymbol{I} \right)^{-1} \left(\boldsymbol{\mu}_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t) - \tilde{\boldsymbol{\boldsymbol{\mu}}}(\boldsymbol{x}_0, \epsilon_t) \right) \right) \right] + \text{Trace}\left( \left( \sigma^2_t\boldsymbol{I} \right)^{-1} \left(\tilde{\sigma}_t^2 \boldsymbol{I} \right) \right) + \log \frac{ \text{Det} \left( \sigma^2_t\boldsymbol{I} \right) }{\text{Det}\left(\tilde{\sigma}_t^2 \boldsymbol{I}\right)} - d \\ &amp;= E_{\epsilon_t, \boldsymbol{x}_0} \left[ \frac{1}{2\sigma_t^2} \vert\vert\boldsymbol{\mu}_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t) - \tilde{\boldsymbol{\boldsymbol{\mu}}}(\boldsymbol{x}_0, \epsilon_t)\vert\vert^2 \right] +  \underbrace{\text{Trace}\left( \left( \sigma^2_t\boldsymbol{I} \right)^{-1} \left(\tilde{\sigma}_t^2 \boldsymbol{I} \right) \right) + \log \frac{ \text{Det} \left( \sigma^2_t\boldsymbol{I} \right) }{\text{Det}\left(\tilde{\sigma}_t^2 \boldsymbol{I}\right)} - d}_{\text{Constant with respect to } \ \theta} \end{align*}\]

<p>Because the last set of terms are constant with respect to $\theta$, we don’t need to consider these terms when optimizing our objective function. Thus, we can remove these terms from each $L_t$. Each $L_t$ can thus be defined as,</p>

\[L_t := E_{\epsilon_t, \boldsymbol{x}_0} \left[ \frac{1}{2\sigma_t^2} \vert\vert\boldsymbol{\mu}_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t) - \tilde{\boldsymbol{\boldsymbol{\mu}}}(\boldsymbol{x}_0, \epsilon_t)\vert\vert^2 \right]\]

<p>Here we see that to optimize the objective function, we simply must minimize the mean squared error between the mean of $q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)$, given by $\tilde{\boldsymbol{\mu}}(\boldsymbol{x}_0, \epsilon_t)$, and the mean of $p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)$, given by $\boldsymbol{\mu}_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t)$. This mean squared error is taken over the Guassian noise $\epsilon_t$ and noiseless items $\boldsymbol{x}_0$.</p>

<p>While this equation could be used to train the model, <a href="https://arxiv.org/pdf/2006.11239.pdf">Ho, Jain, and Abbeel (2020)</a> found that a further modification to the objective function improved the stability of training and performance of the model. Specifically, the authors proposed reparameterizing the function $\boldsymbol{\mu}(\boldsymbol{x}_t, t)$ (i.e., the trainable model) as follows:</p>

\[\boldsymbol{\mu}_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t) := \frac{1}{\sqrt{\alpha_t}}\left( \boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t) - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t) \right)\]

<p>where $\epsilon_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t)$ is now the trainable function parameterized by $\theta$ that takes as input the object $\boldsymbol{x}_t$ (which is provided by $\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t)$) and the timestep $t$. By performing this reparameterization, $L_t$ simplifies to</p>

\[L_t = E_{\epsilon_t, \boldsymbol{x}_0} \left[ \frac{\beta_t^2}{2\sigma_t^2 \alpha_t (1-\bar{\alpha}_t)} \vert\vert \epsilon_t - \epsilon_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t) \vert\vert^2 \right]\]

<p>See Derivation 8 in the Appendix to this post.</p>

<p>Here we see that to minimize the objective function, our model, $\epsilon_\theta$, must accurately predict the Gaussian noise $\epsilon_t$ that was added to $\boldsymbol{x}_{t-1}$ to produce $\boldsymbol{x}_t$! In other words, the model is a <em>noise predictor</em>. Given a noisy object $\boldsymbol{x}_t$, the goal is to predict what parts of $\boldsymbol{x}_t$ is recently added noise (i.e., noise added within the last timestep) and which parts of $\boldsymbol{x}_t$ is the less noisy object $\boldsymbol{x}_{t-1}$. In the next section, we will discuss how the model $\epsilon_\theta$ will be used to <em>remove noise</em> from each $\boldsymbol{x}_t$ in an iterative fashion to generate a new object $\boldsymbol{x}_0$. However, before we get there, let’s finish our discussion of the objective function.</p>

<p>If you are still following along, we are now nearing the final form of the denoising objective function proposed by <a href="https://arxiv.org/pdf/2006.11239.pdf">Ho, Jain, and Abbeel (2020)</a>. The authors simplified $L_t$ by simply removing the constant term in front of the squared error term and found that removing this term did not greatly affect the performance of the model:</p>

\[L_t := E_{\epsilon_t, \boldsymbol{x}_0} \left[ \vert\vert \epsilon_t - \epsilon_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t) \vert\vert^2 \right]\]

<p>With this $L_t$ in hand, the full objective function becomes:</p>

\[\hat{\theta} := \text{arg max}_\theta \ \underbrace{E_{\boldsymbol{x}_1, \boldsymbol{x}_0 \sim q} \left[ p_\theta(\boldsymbol{x}_0 \mid \boldsymbol{x}_1) \right] }_{L_0} +  \underbrace{\sum_{t=2}^T E_{\epsilon_t, \boldsymbol{x}_0} \left[ \vert\vert \epsilon_t - \epsilon_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t) \vert\vert^2\right]}_{L_1, L_2, \dots, L_{T-1}}\]

<p>Finally, let’s turn our attention to the first term $L_0$. While <a href="https://arxiv.org/pdf/2006.11239.pdf">Ho, Jain, and Abbeel (2020)</a> propose a model for $p_\theta(\boldsymbol{x}_0 \mid \boldsymbol{x}_1)$, my understanding is that in practice this term is simply removed from the objective function due to the fact that given enough timesteps (i.e., a large enough value for $T$) this first term will not greatly effect the objective function and for simplicity can be removed. The final objective function is thus simply,</p>

\[\hat{\theta} := \text{arg max}_\theta \ \sum_{t=2}^T E_{\epsilon_t, \boldsymbol{x}_0} \left[ \vert\vert \epsilon_t - \epsilon_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t) \vert\vert^2 \right]\]

<p>At the end of all of this, we come to a framework in which we simply train a model $\epsilon_\theta$ that will seek to predict the added noise at each timestep $t$! Hence the term “denoising” in the name “Denoising diffusion probabilistic models”.</p>

<h2 id="the-training-algorithm">The training algorithm</h2>

<p>Note the objective function we derived in the previous section is simply a sum of discrete terms and thus, to maximize this function, we can maximize each term independently. The final training algorithm proposed by <a href="https://arxiv.org/pdf/2006.11239.pdf">Ho, Jain, and Abbeel (2020)</a> simply  timesteps at random and updates $\theta$ according to a single step of <a href="https://en.wikipedia.org/wiki/Gradient_descent">gradient ascent</a>.</p>

<p>More specifically, the full training algorithm is as follows: Until training converges (i.e., the objective function no longer improves), we repeat the following steps:</p>

<p>1. Sample a random timestep, $t’$, uniformly at random:</p>

\[t' \sim \text{Uniform}(1, \dots, T)\]

<p>2. Sample Gaussian noise, $\epsilon’_t$:</p>

\[\epsilon'_t \sim N(\boldsymbol{0}, \boldsymbol{I})\]

<p>3. Sample an item:</p>

\[\boldsymbol{x}'_0 \sim q(\boldsymbol{x}_0)\]

<p>In practice, this would entail sampling an item randomly from our training set.</p>

<p>4. Compute the gradient,</p>

\[\nabla_\theta \left[ \vert\vert \epsilon'_t - \epsilon_\theta(\boldsymbol{x}'_t(\boldsymbol{x}_0, \epsilon'_t), t') \vert\vert^2 \right]\]

<p>Note, that because we randomly sampled an item, $\boldsymbol{x}’_0$, and Gaussian noise, $\epsilon_{t’}$, we are performing stochastic gradient ascent, since in expectation, the above gradient would be equal to the gradient of the objective function we derived in the previous section:</p>

\[\nabla_\theta E_{\epsilon_{t'}, \boldsymbol{x}_0} \left[ \vert\vert \epsilon_{t'} - \epsilon_\theta(\boldsymbol{x}_{t'}(\boldsymbol{x}_0, \epsilon_{t'}), {t'}) \vert\vert^2 \right]\]

<p>5. Update the parameters according to the gradient. This can be done by taking a “vanilla” gradient step or by using a more advanced variant of gradent ascent such as <a href="https://en.wikipedia.org/wiki/Stochastic_gradient_descent#Adam">Adam</a>.</p>

<h2 id="the-sampling-algorithm">The sampling algorithm</h2>

<p>Once we’ve trained the error-prediction model, $\epsilon_\theta(\boldsymbol{x}_0, t)$, we can use it to execute the reverse diffusion process that will generate a new sample from noise. We start by sampling pure noise $\boldsymbol{x}_T$,</p>

\[\boldsymbol{x}_T \sim N(\boldsymbol{0}, \boldsymbol{I})\]

<p>Then, from steps $t = T-1, \dots, 1$, we iteratively sample each $\boldsymbol{x}_t$ from $p_\theta(\boldsymbol{x}_t \mid \boldsymbol{x}_{t+1})$. To do so, recall that $p_\theta(\boldsymbol{x}_t \mid \boldsymbol{x}_{t+1})$ is a normal distribution whose mean is given by:</p>

\[\boldsymbol{\mu}_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t) := \frac{1}{\sqrt{\alpha_t}}\left( \boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t) - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t) \right)\]

<p>The variance of this normal distribution is defined to be the constant, $\sigma_t^2\boldsymbol{I}$ (where in practice, $\sigma_t$ is set to $\beta_t$). Thus, to sample from this distribution, we can perform the following steps:</p>

<p>1. Sample $\boldsymbol{z}$ from a standard normal distribution:</p>

\[\boldsymbol{z} \sim N(\boldsymbol{0}, \boldsymbol{I})\]

<p>2. Transform $\boldsymbol{z}$ into $\boldsymbol{x}_t$:</p>

\[\boldsymbol{x}_t := \frac{1}{\sqrt{\alpha_t}}\left( \boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t) - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t) \right) + \sigma_t\boldsymbol{z}\]

<p>The final step, to sample $\boldsymbol{x}_0$ from $\boldsymbol{x}_1$, entails removing the predicted noise from $\boldsymbol{x}_1$ without adding any addition noise. That is,</p>

\[\boldsymbol{x}_t := \frac{1}{\sqrt{\alpha_t}}\left( \boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t) - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t) \right)\]

<p>In essence, in this final step, we simply output the mean of $p_\theta(\boldsymbol{x}_0 \mid \boldsymbol{x}_1)$, which is also the mode (since it is a Gaussian distribution).</p>

<h2 id="applying-a-diffusion-model-on-mnist">Applying a diffusion model on MNIST</h2>

<p>In this section, we will walk through a relatively simple implementation of a diffusion model in <a href="https://pytorch.org/">PyTorch</a> and apply it to the <a href="https://en.wikipedia.org/wiki/MNIST_database">MNIST dataset</a>  of hand-written digits. I used the following GitHub repositories as guides:</p>

<ul>
  <li><a href="https://github.com/hojonathanho/diffusion">https://github.com/hojonathanho/diffusion</a></li>
  <li><a href="https://github.com/cloneofsimo/minDiffusion">https://github.com/cloneofsimo/minDiffusion</a></li>
  <li><a href="https://github.com/cloneofsimo/minDiffusion">https://github.com/bot66/MNISTDiffusion/tree/main</a></li>
  <li><a href="https://github.com/usuyama/pytorch-unet">https://github.com/usuyama/pytorch-unet</a></li>
</ul>

<p>My goal was to implement a small model (both small in complexity and size) that would generate realistic digits. In the following sections, I will detail each component and show some of the model’s outputs. All code implementing the model can be found on <a href="https://colab.research.google.com/drive/14ue6jpN7yEM9c11qERpXra8G88ss__99?usp=sharing">Google Colab</a>.</p>

<p><strong>Using a U-Net with ResNet blocks to predict the noise</strong></p>

<p>For the noise-model, I used a <a href="https://en.wikipedia.org/wiki/U-Net">U-Net</a> with <a href="https://en.wikipedia.org/wiki/Residual_neural_network">ResNet</a>-like <a href="https://en.wikipedia.org/wiki/Convolutional_neural_network">convolutional</a> blocks – that is, convolutional layers with skip-connection between them. This architecture is depicted below:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/diffusion_unet_for_MNIST.png" alt="drawing" width="800" /></center>

<p><br /></p>

<p>Code for my U-Net implementation are found in the Appendix to this blog post as well as on <a href="https://colab.research.google.com/drive/14ue6jpN7yEM9c11qERpXra8G88ss__99?usp=sharing">Google Colab</a>.</p>

<p><strong>Representing the timestep using a time-embedding</strong></p>

<p>As we discussed, the noise model conditions on the timestep, $t$. Thus, we need a way for the neural network to 
take as input, and utilize, the timestep. To do this, <a href="https://arxiv.org/pdf/2006.11239.pdf">Ho, Jain, and Abbeel (2020)</a> borrowed an idea from the <a href="https://en.wikipedia.org/wiki/Transformer_(deep_learning_architecture)">transformer model</a> original conceived by <a href="https://arxiv.org/pdf/1706.03762.pdf">Vaswani <em>et al.</em> (2017)</a>. Specifically, each timestep is mapped to a specific, sinusoidal <em>embedding</em> vector and this vector is added, element-wise to certain layers of the neural network. The code for generating these embeddings is presented in the Appendix to this post. A heatmap depicting these embeddings is shown below:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/diffusion_time_embedding_example.png" alt="drawing" width="600" /></center>

<p><br /></p>

<p>Recall that at every iteration of the training loop, we sample some objects in the training set (a minibatch) and sample a timestep for each object. Below, we depict a single timestep embedding for a given timestep $t$. The U-Net implementation takes this time embedding, passes it through a feed-forward neural network, re-shapes the vector into a tensor, and then adds it to the input of the up-sampling blocks. This process is depicted below:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/diffusion_unet_for_MNIST_w_timeembedding.png" alt="drawing" width="800" /></center>

<p><br /></p>

<p><strong>Example outputs from the model</strong></p>

<p>Once we’ve trained the model and implemented the sampling algorithm, we can generate new MNIST digits! (See Appendix for the code used to generate new images). Below, is an example of the model generating a “3”. As we examine the image across timesteps of the reverse diffusion process, we see it being sucessfully transformed from noise into a clear image!</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/diffusion_example_MNIST_reverse_diffusion_5.png" alt="drawing" width="650" /></center>

<p><br /></p>

<p>Here is a sample of hand-selected images of digits output by the model:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/diffusion_example_MNIST_all_digits.png" alt="drawing" width="800" /></center>

<p><br /></p>

<p>The model also output many nonsensical images. While this may not be desirable, I find it interesting that the model honed in on patterns that are “digit-like”. These not-quite digits look like symbols from an alien language:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/diffusion_MNIST_examples_weird_symbols.png" alt="drawing" width="450" /></center>

<p><br /></p>

<p>A better model may output fewer of these nonsensical “digits”; however, I think this demonstrates how these generative models can be used for creative tasks. That is, the model succesfully modeled “digit-like patterns”, which in some cases led it to producing nonsensical digits that still look visually interesting (well, interesting to me at least). It did this by assembling these digit-like patterns in new, interesting ways.</p>

<h2 id="further-reading">Further Reading</h2>

<p>Much of my understanding of this material came from the following resources:</p>

<ul>
  <li><a href="https://www.davidinouye.com/course/ece57000-fall-2022/lectures/diffusion-models.pdf">These lecture notes</a> by David I. Inouye</li>
  <li><a href="https://lilianweng.github.io/posts/2021-07-11-diffusion-models/">This blog post</a> by Lilian Weng</li>
  <li><a href="https://paramhanji.github.io/posts/2021/06/ddpm/">This blog post</a> by Param Hanji</li>
  <li><a href="https://angusturner.github.io/generative_models/2021/06/29/diffusion-probabilistic-models-I.html">This blog post</a> by Angus Turner</li>
  <li><a href="https://yang-song.net/blog/2021/score/">This blog post</a> by Yang Song</li>
  <li><a href="https://www.youtube.com/watch?v=687zEGODmHA&amp;t=1212s&amp;ab_channel=MachineLearningatBerkeley">This YouTube lecture</a> at UC, Berkeley</li>
  <li><a href="https://www.youtube.com/watch?v=HoKDTa5jHvg&amp;ab_channel=Outlier">This YouTube lecture</a> by Dominic Rampas</li>
  <li><a href="https://jaketae.github.io/study/elbo/">This blog post</a> by Jake Tae</li>
  <li><a href="https://blog.alexalemi.com/diffusion.html#extra-entropy">This blog post</a> by Alexander A. Alemi</li>
</ul>

<h2 id="appendix">Appendix</h2>

<h3 id="derivation-1-re-writing-the-kl-divergence-objective">Derivation 1 (Re-writing the KL-divergence objective)</h3>

\[\begin{align*} KL( q(\boldsymbol{x}_{0:T}) \ \vert\vert \ p_\theta(\boldsymbol{x}_{0:T})) &amp;= E_{\boldsymbol{x}_{0:T} \sim q} \left[ \log \frac{q(\boldsymbol{x}_{0:T})}{p_\theta(\boldsymbol{x}_{0:T})} \right] \\ &amp;= E_{\boldsymbol{x}_{0:T} \sim q} \left[ \log \frac{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)q(\boldsymbol{x}_0)}{p_\theta(\boldsymbol{x}_{0:T})} \right] \\ &amp;= E_{\boldsymbol{x}_{0:T} \sim q} \left[ \log \frac{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)}{p_\theta(\boldsymbol{x}_{0:T})} \right] + E_{\boldsymbol{x}_0} \left[ q(\boldsymbol{x}_0) \right] \\ &amp;= -E_{\boldsymbol{x}_{0:T} \sim q} \left[ \log \frac{p_\theta(\boldsymbol{x}_{0:T})}{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)} \right] + E_{\boldsymbol{x}_0} \left[ q(\boldsymbol{x}_0) \right]\end{align*}\]

<h3 id="derivation-2-deriving-the-elbo">Derivation 2 (Deriving the ELBO)</h3>

\[\begin{align*} \log p_\theta(\boldsymbol{x}_0) &amp;= \log \int p_\theta(\boldsymbol{x}_{0:T}) \ d\boldsymbol{x}_{1:t} \\ &amp;= \log \int p_\theta(\boldsymbol{x}_{0:T}) \frac{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)}{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)} \ d\boldsymbol{x}_{1:t} \\ &amp;= \log E_{\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0 \sim q} \left[ \frac{p_\theta(\boldsymbol{x}_{0:T})}{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)} \right] \\ &amp;\geq E_{\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0 \sim q} \log \left[ \frac{p_\theta(\boldsymbol{x}_{0:T})}{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0)} \right] &amp;&amp; \text{by Jensen's inequality}\end{align*}\]

<h3 id="derivation-3-reformulating-the-expected-elbo">Derivation 3 (Reformulating the expected ELBO)</h3>

\[\begin{align*}E_{\boldsymbol{x}_0 \sim q} \left[ \text{ELBO}(\theta) \right] &amp;:= E_{\boldsymbol{x}_{0:T} \sim q}\left[ \log\frac{p_\theta (\boldsymbol{x}_{0:T}) }{q(\boldsymbol{x}_{1:T} \mid \boldsymbol{x}_0) } \right] \\ &amp;= E_{\boldsymbol{x}_{0:T} \sim q}\left[\log \frac{ p_\theta(\boldsymbol{x}_T) \prod_{t=1}^{T} p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) }{ \prod_{t=1}^{T} q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1})} \right] \\ &amp;= E_{\boldsymbol{x}_{0:T} \sim q} \left[\log \frac{ p_\theta(\boldsymbol{x}_T)p_\theta(\boldsymbol{x}_0 \mid \boldsymbol{x}_1) \prod_{t=2}^{T} p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) }{q(\boldsymbol{x}_1 \mid \boldsymbol{x}_0) \prod_{t=2}^{T} q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}} \right] \\ &amp;= E_{\boldsymbol{x}_{0:T} \sim q} \left[\log \frac{ p_\theta(\boldsymbol{x}_T)p_\theta(\boldsymbol{x}_0 \mid \boldsymbol{x}_1)}{q(\boldsymbol{x}_1 \mid \boldsymbol{x}_0)} + \sum_{t=2}^{T} \log \frac{p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) }{ q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1})} \right] \\ &amp;= E_{\boldsymbol{x}_{0:T} \sim q} \left[\log \frac{ p_\theta(\boldsymbol{x}_T)p_\theta(\boldsymbol{x}_0 \mid \boldsymbol{x}_1)}{q(\boldsymbol{x}_1 \mid \boldsymbol{x}_0)} + \sum_{t=2}^{T} \log \frac{p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) }{ q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}, \boldsymbol{x}_0)} \right] &amp;&amp; \text{Note 1} \\ &amp;= E_{\boldsymbol{x}_{0:T} \sim q} \left[\log \frac{ p_\theta(\boldsymbol{x}_T)p_\theta(\boldsymbol{x}_0 \mid \boldsymbol{x}_1)}{q(\boldsymbol{x}_1 \mid \boldsymbol{x}_0)} + \sum_{t=2}^{T} \log \frac{p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) }{ \frac{q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0) q(\boldsymbol{x}_t \mid \boldsymbol{x}_0) }{q(\boldsymbol{x}_{t-1})} } \right] &amp;&amp; \text{Note 2} \\ &amp;= E_{\boldsymbol{x}_{0:T} \sim q} \left[\log \frac{ p_\theta(\boldsymbol{x}_T)p_\theta(\boldsymbol{x}_0 \mid \boldsymbol{x}_1)}{q(\boldsymbol{x}_1 \mid \boldsymbol{x}_0)} + \sum_{t=2}^{T} \log \frac{p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_0) }{ q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0) q(\boldsymbol{x}_t \mid \boldsymbol{x}_0) } \right] \\ &amp;= E_{\boldsymbol{x}_{0:T} \sim q} \left[\log \frac{ p_\theta(\boldsymbol{x}_T)p_\theta(\boldsymbol{x}_0 \mid \boldsymbol{x}_1)}{q(\boldsymbol{x}_1 \mid \boldsymbol{x}_0)} + \sum_{t=2}^T \log \frac{q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_0)}{q(\boldsymbol{x}_t \mid \boldsymbol{x}_0)} + \sum_{t=2}^{T} \log \frac{p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) }{ q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)} \right] \\ &amp;= E_{\boldsymbol{x}_{0:T} \sim q} \left[\log \frac{ p_\theta(\boldsymbol{x}_T)p_\theta(\boldsymbol{x}_0 \mid \boldsymbol{x}_1)}{q(\boldsymbol{x}_1 \mid \boldsymbol{x}_0)} + \sum_{t=2}^T \log q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_0) - \sum_{t=2}^T \log q(\boldsymbol{x}_t \mid \boldsymbol{x}_0) + \sum_{t=2}^{T} \log \frac{p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) }{ q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)} \right] \\ &amp;= E_{\boldsymbol{x}_{0:T} \sim q} \left[\log \frac{ p_\theta(\boldsymbol{x}_T)p_\theta(\boldsymbol{x}_0 \mid \boldsymbol{x}_1)}{q(\boldsymbol{x}_1 \mid \boldsymbol{x}_0)} + \log q(\boldsymbol{x}_1 \mid \boldsymbol{x}_0) - \log q(\boldsymbol{x}_T \mid \boldsymbol{x}_0) + \sum_{t=2}^{T} \log \frac{p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) }{ q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)} \right] \\ &amp;= E_{\boldsymbol{x}_{0:T} \sim q} \left[\log p_\theta(\boldsymbol{x}_0 \mid \boldsymbol{x}_1) + \log \frac{p_\theta(\boldsymbol{x}_T)}{q(\boldsymbol{x}_T \mid \boldsymbol{x}_0)} + \sum_{t=2}^{T} \log \frac{p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) }{ q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)} \right] \\ &amp;= E_{\boldsymbol{x}_{0:T} \sim q} \left[\log p_\theta(\boldsymbol{x}_0 \mid \boldsymbol{x}_1) \right] + E_{\boldsymbol{x}_{0:T} \sim q} \left[\log \frac{p_\theta(\boldsymbol{x}_T)}{q(\boldsymbol{x}_T \mid \boldsymbol{x}_0)}\right] + \sum_{t=2}^{T} E_{\boldsymbol{x}_{0:T} \sim q}\left[\log \frac{p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) }{ q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)} \right] \\ &amp;= E_{\boldsymbol{x}_1 \sim q} \left[ p_\theta(\boldsymbol{x}_0 \mid \boldsymbol{x}_1) \right] + \sum_{t=2}^T E_{\boldsymbol{x}_t, \boldsymbol{x}_0 \sim q} \left[ KL \left( q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0) \ \vert\vert \ p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) \right) \right] + E_{\boldsymbol{x}_0} \left[ KL\left( q(\boldsymbol{x}_T \mid \boldsymbol{x}_0) \ \vert\vert \  p_\theta(\boldsymbol{x}_T) \right)\right] &amp;&amp; \text{Notes 3 and 4}\end{align*}\]

<p><strong>Note 1:</strong> By the <a href="https://en.wikipedia.org/wiki/Markov_property#:~:text=In%20probability%20theory%20and%20statistics,is%20independent%20of%20its%20history.">Markov property</a> of the forward diffusion process, it holds that $q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}) =  q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}, \boldsymbol{x}_0)$.</p>

<p><strong>Note 2:</strong> Apply Bayes theorem:</p>

\[q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}, \boldsymbol{x}_0) = \frac{q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0) q(\boldsymbol{x}_t \mid \boldsymbol{x}_0) }{q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_0)}\]

<p><strong>Note 3:</strong></p>

\[\begin{align*}E_{\boldsymbol{x}_{0:T} \sim q} \left[\frac{p_\theta(\boldsymbol{x}_T)}{q(\boldsymbol{x}_T \mid \boldsymbol{x}_0)} \right] &amp;= E_{\boldsymbol{x}_{T}, \boldsymbol{x}_0 \sim q} \left[\frac{p_\theta(\boldsymbol{x}_T)}{q(\boldsymbol{x}_T \mid \boldsymbol{x}_0)} \right] \\ &amp;= E_{\boldsymbol{x}_0} \left[ KL\left( q(\boldsymbol{x}_T \mid \boldsymbol{x}_0) \ \vert\vert \  p_\theta(\boldsymbol{x}_T) \right) \right]\end{align*}\]

<p><strong>Note 4:</strong>
\(\begin{align*}E_{\boldsymbol{x}_{0:T} \sim q}\left[\log \frac{p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) }{ q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)} \right] &amp;= E_{\boldsymbol{x}_{t}, \boldsymbol{x}_{t-1}, \boldsymbol{x}_0 \sim q}\left[\log \frac{p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) }{ q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)} \right] \\ &amp;= E_{\boldsymbol{x}_0 \sim q}\left[ \iint q(\boldsymbol{x}_{t-1}, \boldsymbol{x}_t \mid \boldsymbol{x}_0) \log \frac{p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) }{ q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)} \ d\boldsymbol{x}_{t-1} d\boldsymbol{x}_t \right] \\ &amp;=  E_{\boldsymbol{x}_0 \sim q} \left[ \iint q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0) q(\boldsymbol{x}_t \mid \boldsymbol{x}_0) \log \frac{p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) }{ q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)} \ d\boldsymbol{x}_{t-1} d\boldsymbol{x}_t \right] \\ &amp;= E_{\boldsymbol{x}_0 \sim q}\left[ \int q(\boldsymbol{x}_t \mid \boldsymbol{x}_0) \left[ \int q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)  \log \frac{p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t) }{ q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)} \ d\boldsymbol{x}_{t-1} \right] d\boldsymbol{x}_t \right] \\ &amp;= E_{\boldsymbol{x}_0 \sim q} \left[ \int q(\boldsymbol{x}_t \mid \boldsymbol{x}_0) KL(q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0) \ \vert\vert \ p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)) d\boldsymbol{x}_t \right] \\ &amp;= E_{\boldsymbol{x}_t, \boldsymbol{x}_0 \sim q} \left[ KL(q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0) \ \vert\vert \ p_\theta(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t)) \right]\end{align*}\)</p>

<h3 id="derivation-4-scaling-the-mean-by-sqrt1-beta-constrains-the-variance-of-the-forward-process">Derivation 4 (Scaling the mean by $\sqrt{1-\beta}$ constrains the variance of the forward process)</h3>

<p>Let,</p>

\[\boldsymbol{x}_t \sim N(\boldsymbol{\mu}, \boldsymbol{I})\]

<p>for some mean $\boldsymbol{\mu}$. For the next timestep, we have</p>

\[\boldsymbol{x}_{t+1} \sim N(a\boldsymbol{x}_t, \beta \boldsymbol{I})\]

<p>where $a$ is some constant that scales the mean given by $\boldsymbol{x}_t$. We seek a value of $a$ such that $\text{Var}(\boldsymbol{x}_{t+1}) = 1$. To find this value, we use the <a href="https://en.wikipedia.org/wiki/Law_of_total_variance">law of total variance</a>:</p>

\[\begin{align*}\text{Var}(\boldsymbol{x}_{t+1}) &amp;= E\left[\text{Var}(\boldsymbol{x}_{t+1} \mid \boldsymbol{x}_t ) \right] + \text{Var}\left( E\left[\boldsymbol{x}_{t+1} \mid \boldsymbol{x}_t \right]\right) \\ &amp;= E[\beta] + \text{Var}(a\boldsymbol{x}_t) \\ &amp;= \beta + a^2\text{Var}(\boldsymbol{x}_t) \\ &amp;= \beta + a^2\text{Var}(\boldsymbol{x}_t) \\ &amp;= \beta + a^2\end{align*}\]

<p>Now, if we fix $\text{Var}(\boldsymbol{x}_{t+1}) = 1$, it follows that:</p>

\[\begin{align*}&amp;1 = \beta + a^2 \\ \implies &amp;a = \sqrt{1-\beta}\end{align*}\]

<h3 id="derivation-5-closed-form-of-qboldsymbolx_t-mid-boldsymbolx_0">Derivation 5 (Closed form of $q(\boldsymbol{x}_t \mid \boldsymbol{x}_0)$)</h3>

<p>We start with $q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1})$. Recall it is given by,</p>

\[q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}) := N(\boldsymbol{x}_t;  \sqrt{1-\beta_t}\boldsymbol{x}_{t-1}, \beta \boldsymbol{I})\]

<p>Because this is a normal distribution, we can generate a sample</p>

\[\boldsymbol{x}_t \sim q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1})\]

<p>by first sampling $\epsilon_{t-1}$ from a standard normal, $N(\boldsymbol{0}, \boldsymbol{I})$, and then transforming it into $\boldsymbol{x}_t$ via,</p>

\[\begin{align*}\boldsymbol{x}_t &amp;= \sqrt{1-\beta_t}\boldsymbol{x}_{t-1} + \sqrt{\beta_t}\epsilon_{t-1} \\ &amp;= \sqrt{\alpha_t}\boldsymbol{x}_{t-1} + \sqrt{1-\alpha_t}\epsilon_{t-1} \end{align*}\]

<p>where $\alpha_t := 1 - \beta_t$ (which will make the notation easier going forward).</p>

<p>Notice that this transformation relies on $\boldsymbol{x}_{t-1}$, which is a sample from  $q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t-2})$. From this observation, we realize there is a way to sample $\boldsymbol{x}_t$ not from $q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1})$, but rather from $q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-2})$. Specifically, we can generate <em>two</em> samples from a standard normal distribution,</p>

\[\epsilon_{t-1}, \epsilon_{t-2} \sim N(\boldsymbol{0}, \boldsymbol{I})\]

<p>Then, we can generate a sample</p>

\[\boldsymbol{x}_t \sim q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-2})\]

<p>via the following transformation of $\epsilon_{t-1}$ and $\epsilon_{t-2}$:</p>

\[\begin{align*}\boldsymbol{x}_t &amp;= \sqrt{\alpha_t}\boldsymbol{x}_{t-1} + \sqrt{1-\alpha_t}\epsilon_{t-1} \\ &amp;= \sqrt{\alpha_t}\left(\sqrt{\alpha_{t-1}}\boldsymbol{x}_{t-2} + \sqrt{1-\alpha_{t-1}}\epsilon_{t-2}\right) + \sqrt{1-\alpha_t}\epsilon_{t-1} \\ &amp;= \sqrt{\alpha_t}\sqrt{\alpha_{t-1}}\boldsymbol{x}_{t-2} + \sqrt{\alpha_t}\sqrt{1-\alpha_{t-1}} \epsilon_{t-1} + \sqrt{1-\alpha_{t}}\epsilon_t \\ &amp;=\sqrt{\alpha_t}\sqrt{\alpha_{t-1}}\boldsymbol{x}_{t-2} + (\sqrt{\alpha_t}\sqrt{1-\alpha_{t-1}} + \sqrt{1-\alpha_{t}})\epsilon_{t, t-1} \\ &amp;= \sqrt{\alpha_t\alpha_{t-1}}\boldsymbol{x}_{t-2} + (\sqrt{1-\alpha_t \alpha_{t-1}})\epsilon_{t, t-1}\end{align*}\]

<p>where $\epsilon_{t, t-1}$ is a sample of $N(\boldsymbol{0}, \boldsymbol{I})$. Here, we used the fact that if we have two random variables $X$ and $Y$ such that,</p>

\[\begin{align*}X &amp;\sim N(0, \sigma_X^2) \\ Y &amp;\sim N(0, \sigma_Y^2) \end{align*}\]

<p>Then it follows that,</p>

\[X + Y \sim N(0,  \sigma_X^2 +  \sigma_Y^2)\]

<p>though we won’t prove this fact here.</p>

<p>Now, following the same logic above, we can generate a sample,</p>

\[\boldsymbol{x}_t \sim q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-3})\]

<p>via</p>

\[\boldsymbol{x}_t = \sqrt{\alpha_t\alpha_{t-1}\alpha_{t-2}}\boldsymbol{x}_{t-3} + (\sqrt{1-\alpha_t \alpha_{t-1} \alpha_{t-2}})\epsilon_{t, t-1, t-2}\]

<p>where $\epsilon_{t, t-1, t-2} \sim N(\boldsymbol{0}, \boldsymbol{I})$. If we follow this pattern all the way down to $t=0$, we see that we can generate a sample,</p>

\[\boldsymbol{x}_t \sim q(\boldsymbol{x}_t \mid \boldsymbol{x}_0)\]

<p>via</p>

\[\begin{align*}\boldsymbol{x}_t &amp;= \sqrt{\prod_{i=1}^t \alpha_i}\boldsymbol{x}_0 + \sqrt{1-\prod_{i=1}^t \alpha_i}\epsilon_{t, t-1, \dots, 0} \\ &amp;= \sqrt{\bar{\alpha_t}}\boldsymbol{x}_0 + \sqrt{\bar{\alpha_t}}\epsilon_{t, t-1, \dots, 0}\end{align*}\]

<p>where $\bar{\alpha_t} := \prod_{i=1}^t \alpha_i$. Thus, we see that,</p>

\[q(\boldsymbol{x}_t \mid \boldsymbol{x}_0) = N\left(\boldsymbol{x}_t; \sqrt{\bar{\alpha}_t}\boldsymbol{x}_0, \left(1-\bar{\alpha}_t \right) \boldsymbol{I}\right)\]

<h3 id="derivation-6-closed-form-of-qboldsymbolx_t-1-mid-boldsymbolx_t-boldsymbolx_0">Derivation 6 (Closed form of $q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)$)</h3>

\[\begin{align*}q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0) &amp;= \frac{q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}, \boldsymbol{x}_0) q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_0)}{q(\boldsymbol{x}_t \mid \boldsymbol{x}_0)} &amp;&amp; \text{Note 1} \\ &amp;\propto q(\boldsymbol{x}_t \mid \boldsymbol{x}_{t-1}, \boldsymbol{x}_0) q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_0) &amp;&amp; \text{Note 2} \\ &amp;= N\left(\boldsymbol{x}_t; \sqrt{\alpha_t}\boldsymbol{x}_{t-1}, (1-\alpha_t)\boldsymbol{I} \right) N\left(\boldsymbol{x}_{t-1}; \sqrt{\bar{\alpha}_{t-1}}\boldsymbol{x}_0, (1-\bar{\alpha}_{t-1}) \boldsymbol{I} \right) \\ &amp;\propto \exp \left(-\frac{1}{2} \frac{\left(\boldsymbol{x}_t - \sqrt{\alpha_t}\boldsymbol{x}_{t-1} \right)^2}{1-\alpha_t} \right) \exp \left(-\frac{1}{2} \frac{\left( \boldsymbol{x}_{t-1} - \sqrt{\bar{\alpha}_{t-1}}\boldsymbol{x}_0 \right)^2 }{1-\bar{\alpha}_{t-1}} \right) \\ &amp;= \exp \left(-\frac{1}{2} \left(\frac{\left(\boldsymbol{x}_t - \sqrt{\alpha_t}\boldsymbol{x}_{t-1} \right)^2}{1-\alpha_t} + \frac{ \left(\boldsymbol{x}_{t-1} - \sqrt{\bar{\alpha}_{t-1}}\boldsymbol{x}_0 \right)^2}{1-\bar{\alpha}_{t-1}} \right) \right) \\ &amp;= \exp\left(-\frac{1}{2} \left( \frac{\boldsymbol{x}^2_t - 2\sqrt{\alpha_t}\boldsymbol{x}_t\boldsymbol{x}_{t-1} + \alpha_t \boldsymbol{x}^2_{t-1} }{\beta_t} + \frac{\boldsymbol{x}^2_{t-1} -2\sqrt{\bar{\alpha}_{t-1}}\boldsymbol{x}_0\boldsymbol{x}_{t-1} + \bar{\alpha}_{t-1}\boldsymbol{x}^2_0}{1-\bar{\alpha}_{t-1}} \right) \right) \\ &amp;\propto \exp\left( -\frac{1}{2} \left( \left(\frac{\alpha_t}{\beta_t} + \frac{1}{1-\bar{\alpha}_{t-1}} \right)\boldsymbol{x}^2_{t-1} - \left(\frac{2 \sqrt{\alpha_t} \boldsymbol{x}_t }{\beta_t} + \frac{2 \sqrt{\bar{\alpha}_{t-1}} \boldsymbol{x}_0 }{1 - \bar{\alpha}_{t-1}}  \right)\boldsymbol{x}_{t-1} \right)\right) &amp;&amp; \text{Note 3} \\ &amp;\propto \exp\left(-\frac{1}{2} \left(\underbrace{ \left(\frac{\alpha_t}{\beta_t} + \frac{1}{1-\bar{\alpha}_{t-1}}\right)}_{1 / \sigma^2}\left(\boldsymbol{x}_{t-1} - \underbrace{\frac{ \frac{2\sqrt{\alpha_t}\boldsymbol{x}_t }{\beta_t} + \frac{2 \sqrt{\bar{\alpha}_{t-1}}\boldsymbol{x}_0 }{1 - \bar{\alpha}_{t-1}}  }{2\left(\frac{\alpha_t}{\beta_t} + \frac{1}{1-\bar{\alpha}_{t-1}}  \right)} }_{\mu} \right)^2  \right) \right) &amp;&amp; \text{Note 4} \\ &amp;= \exp \left(-\frac{1}{2} \frac{\left( \boldsymbol{x}_{t-1} - \underbrace{\left( \frac{\sqrt{\alpha_t} \left( 1 - \bar{\alpha}_{t-1} \right) }{1-\bar{\alpha}_t} \boldsymbol{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_{1-t}}\boldsymbol{x}_0 \right)}_{\mu} \right)^2}{\underbrace{ \frac{\beta_t\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t}}_{\sigma^2} } \right) &amp;&amp; \text{Notes 5 and 6}\end{align*}\]

<p>This is the functional form of the density function of a normal distribution. Thus, we see that,</p>

\[q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)  = N\left(\boldsymbol{x}_{t-1}; \frac{\sqrt{\alpha_t} \left( 1 - \bar{\alpha}_{t-1} \right) }{\beta_t} \boldsymbol{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_{1-t}}\boldsymbol{x}_0, \frac{\beta_t\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t}\boldsymbol{I} \right)\]

<p><strong>Note 1:</strong> Apply Bayes Theorem</p>

<p><strong>Note 2:</strong> Throughout this derivation, we will only consider terms that contain $\boldsymbol{x}_{t-1}$.</p>

<p><strong>Note 3:</strong> Here we remove terms that do not involve $\boldsymbol{x}_{t-1}$ by using the following fact: given a term, $f(\boldsymbol{x}_{t-1})$, and a constant term, $C$, it follows that:</p>

\[\begin{align*}\exp\left(f(\boldsymbol{x}_{t-1}) + C\right) &amp;= \exp\left(f(\boldsymbol{x}_{t-1})\right)\exp\left(C\right) \\ &amp; \propto \exp\left(f(\boldsymbol{x}_{t-1})\right)\end{align*}\]

<p><strong>Note 4:</strong> Here we <a href="https://en.wikipedia.org/wiki/Completing_the_square">complete the square</a> and use the fact that:</p>

\[\begin{align*}&amp; ax^2 + bx + c = 0 \\ \implies &amp; a\left(x+\frac{b}{2a}\right)^2 + \left(c - \frac{b^2}{4a}\right) = 0 \end{align*}\]

<p>In our case,</p>

\[\begin{align*}a &amp;:= \frac{\alpha_t}{\beta_t} + \frac{1}{1-\bar{\alpha}_{t-1}} \\ b &amp;:= \frac{2\sqrt{\alpha_t}\boldsymbol{x}_t }{\beta_t} + \frac{2 \sqrt{\bar{\alpha}_{t-1}}\boldsymbol{x}_0 }{1 - \bar{\alpha}_{t-1}}  \end{align*}\]

<p>Note that we can disgregard the term, $\left(c - \frac{b^2}{4a}\right)$, since this is a constant with respect to $\boldsymbol{x}_{t-1}$ and it gets “swallowed” by the $\propto$ as described in Note 3.</p>

<p>Moreover, after completing the square, we see that this is the functional form of a normal distribution where we have annotated the mean, $\mu$, and reciprocal of the variance, $1 / \sigma^2$.</p>

<p><strong>Note 5:</strong></p>

\[\begin{align*} \sigma^2 &amp;:= \frac{1}{\frac{\alpha_t}{\beta_t} + \frac{1}{1-\bar{\alpha}_{t-1}} } \\ &amp;= \frac{1}{ \frac{\alpha_t (1-\bar{\alpha}_{t-1}) + \beta_t }{\beta_t (1-\bar{\alpha}_{t-1})} } \\ &amp;= \frac{\beta_t (1- \bar{\alpha}_{t-1})}{\alpha_t(1-\bar{\alpha}_{t-1}) + (1-\alpha_t)} \\ &amp;= \frac{\beta_t(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t}\end{align*}\]

<p><strong>Note 6:</strong></p>

\[\begin{align*}\mu &amp;:= \frac{\frac{2 \sqrt{\alpha_t}}{\beta_t}\boldsymbol{x}_t + \frac{2 \sqrt{\bar{\alpha}_{t-1}} }{1- \bar{\alpha}_{t-1}}\boldsymbol{x}_0 }{2\left( \frac{\alpha_t}{\beta_t} + \frac{1}{1-\bar{\alpha}_{t-1}} \right)} \\ &amp;= \left( \frac{\sqrt{\alpha_t} }{\beta_t}\boldsymbol{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1}} }{1 - \bar{\alpha}_{t-1} }\boldsymbol{x}_0 \right) \frac{\beta_t (1-\bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} \\ &amp;= \frac{\sqrt{\alpha_t} \left( 1 - \bar{\alpha}_{t-1} \right) }{\beta_t} \boldsymbol{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}\boldsymbol{x}_0 \end{align*}\]

<h3 id="derivation-7-simplification-of-closed-form-of-qboldsymbolx_t-1-mid-boldsymbolx_t-boldsymbolx_0">Derivation 7 (Simplification of closed form of $q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0)$)</h3>

\[\begin{align*}q(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0) &amp;= \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}
\boldsymbol{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t} \boldsymbol{x}_0 \\ &amp;= \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}
\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon) + \frac{\sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_t}\beta_t \left( \frac{\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon) - \sqrt{1 - \bar{\alpha}_t}\epsilon}{\sqrt{\bar{\alpha}_t}} \right) &amp;&amp; \text{Note 1} \\ &amp;= \frac{\sqrt{\alpha_t} (1-\bar{\alpha}_{t-1}) }{1 - \bar{\alpha}_t} \boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon) + \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{(1-\bar{\alpha}_t) \sqrt{\bar{\alpha}_t}}\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon) - \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t \sqrt{1-\bar{\alpha}_t}}{(1-\bar{\alpha}_t)\sqrt{\bar{\alpha}_t}} \epsilon \\ &amp;= \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon) + \frac{\beta_t}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon) - \frac{\sqrt{1-\bar{\alpha}_t} \beta_t}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}} \epsilon &amp;&amp; \text{Note 2} \\ &amp;= \frac{1}{\sqrt{\alpha_t}}\left( \frac{\alpha_t (1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon) + \frac{\beta_t}{1-\bar{\alpha}_{t}}\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon) - \frac{\sqrt{1 - \bar{\alpha}_t} \beta_t}{1-\bar{\alpha}_t}\epsilon \right) \\ &amp;= \frac{1}{\sqrt{\alpha_t}}\left( \frac{\alpha_t (1-\bar{\alpha}_{t-1}) + \beta_t}{1-\bar{\alpha}_t}\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon) - \frac{\sqrt{1 - \bar{\alpha}_t} \beta_t}{1-\bar{\alpha}_t}\epsilon \right) \\ &amp;= \frac{1}{\sqrt{\alpha_t}}\left( \frac{\alpha_t (1-\bar{\alpha}_{t-1}) + \beta_t}{1-\bar{\alpha}_t}\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon) - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon \right) &amp;&amp; \text{Note 3} \\ &amp;= \frac{1}{\sqrt{\alpha_t}}\left(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon) - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon\right) &amp;&amp; \text{Note 4} \end{align*}\]

<p><strong>Note 1:</strong> Recall from Derivation 4 that,</p>

\[q(\boldsymbol{x}_t \mid \boldsymbol{x}_0) = N\left(\boldsymbol{x}_t; \sqrt{\bar{\alpha}_t}\boldsymbol{x}_0, \left(1-\bar{\alpha}_t \right) \boldsymbol{I}\right)\]

<p>Thus, we can sample $\boldsymbol{x}_t$ by first sampling $\epsilon \sim N(\boldsymbol{0}, \boldsymbol{I})$ and then passing $\boldsymbol{x}_0$ and $\epsilon$ into the function</p>

\[\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon) := \sqrt{\bar{\alpha}_t}\boldsymbol{x}_0 + \sqrt{1-\bar{\alpha}_t}\epsilon\]

<p>Which implies that \(\boldsymbol{x}_0 = \frac{\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon) - \sqrt{1 - \bar{\alpha}_t}\epsilon}{\sqrt{\bar{\alpha}_t}}\)</p>

<p><strong>Note 2:</strong></p>

\[\frac{\sqrt{\bar{\alpha}_{t-1}}}{\sqrt{\bar{\alpha}_t}} = \sqrt{\frac{\prod_{i=1}^{t-1} \alpha_i }{\prod_{i=1}^{t} \alpha_i}} = \frac{1}{\sqrt{\alpha_t}}\]

<p><strong>Note 3:</strong></p>

\[\frac{\sqrt{1-\bar{\alpha}_t}}{1-\bar{\alpha}_t} = \frac{\sqrt{1-\bar{\alpha}_t}}{\sqrt{1-\bar{\alpha}_t}} \frac{\sqrt{1-\bar{\alpha}_t}}{1-\bar{\alpha}_t} = \frac{1-\bar{\alpha}_t}{\sqrt{1-\bar{\alpha}_t}(1-\bar{\alpha}_t)} = \frac{1}{\sqrt{1-\bar{\alpha}_t}}\]

<p><strong>Note 4:</strong></p>

\[\begin{align*}\frac{\alpha_t(1-\bar{\alpha}_{t-1}) + \beta_t}{1-\bar{\alpha}_t} &amp;= \frac{\alpha_t - \alpha_t\bar{\alpha}_{t-1} + \beta_t}{1 - \bar{\alpha}_t} \\ &amp;= \frac{1 - \beta_t - \bar{\alpha}_t + \beta_t}{1 - \bar{\alpha}_t} \\ &amp;= \frac{1 - \bar{\alpha}_t}{1 - \bar{\alpha}_t} \\ &amp;= 1 \end{align*}\]

<p><strong>Derivation 8 (Reparameterizing $L_t$ as a noise-predictition term)</strong></p>

\[\begin{align*}L_t &amp;= E_{\epsilon_t \boldsymbol{x}_0} \left[ \frac{1}{\sigma_t^2} \vert\vert \boldsymbol{\mu}_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t ) - \tilde{\boldsymbol{\mu}}(\boldsymbol{x}_0, \epsilon_t) \vert\vert^2 \right] \\ &amp;= E_{\epsilon_t \boldsymbol{x}_0} \left[ \frac{1}{2\sigma_t^2} \lvert\lvert \frac{1}{\sqrt{\alpha_t}}\left(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t) - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t) \right) - \frac{1}{\sqrt{\alpha_t}}\left(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t) - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_t \right) \rvert\rvert^2 \right] \\ &amp;= E_{\epsilon_t \boldsymbol{x}_0} \left[ \frac{1}{2\sigma_t^2} \lvert\lvert \frac{1}{\sqrt{\alpha_t}} \left( \boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t) - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t) - \boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t) + \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_t \right) \vert\vert^2 \right] \\ &amp;= E_{\epsilon_t \boldsymbol{x}_0} \left[ \frac{1}{2\sigma_t^2} \lvert\lvert \frac{1}{\sqrt{\alpha_t}} \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \left( \epsilon_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t) - \epsilon_t \right)  \vert\vert^2 \right] \\ &amp;= E_{\epsilon_t \boldsymbol{x}_0} \left[ \frac{\beta_t}{2\sigma_t^2 \alpha_t (1-\bar{\alpha}_t)} \lvert\lvert \left( \epsilon_\theta(\boldsymbol{x}_t(\boldsymbol{x}_0, \epsilon_t), t) - \epsilon_t \right)  \vert\vert^2 \right]\end{align*}\]

<h3 id="implementation-of-a-diffusion-model-for-generating-mnist-digits">Implementation of a diffusion model for generating MNIST digits:</h3>

<p>In this section, we will walk through all of the code used to implement a diffusion model. The full code can be run on <a href="https://colab.research.google.com/drive/14ue6jpN7yEM9c11qERpXra8G88ss__99?usp=sharing">Google Colab</a>. We will start with importing the required packages:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
</code></pre></div></div>

<p>Next, we implement the U-Net neural network, which implements the noise model – that is, the model used to predict the noise in an image that has undergone diffusion. To implement the U-Net, we define three subclasses: a <code class="language-plaintext highlighter-rouge">UNetDownBlock</code> class, which represents a set of layers on the downward portion of the U-Net, a <code class="language-plaintext highlighter-rouge">UNetUpBlock</code> class, which represents a set of layers on the upward portion of the U-Net, and a <code class="language-plaintext highlighter-rouge">UNet</code> class, which represents the full neural network:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>class UNetDownBlock(nn.Module):
  def __init__(
      self,
      in_channels,
      out_channels,
      kernel_size=3,
      pad_maxpool=0,
      normgroups=4
    ):
    super(UNetDownBlock, self).__init__()
    self.conv1 = nn.Conv2d(
      in_channels,
      out_channels,
      kernel_size=3,
      padding=1
    )
    self.groupnorm1 = nn.GroupNorm(
      normgroups,
      out_channels
    )
    self.relu1 = nn.ReLU()

    self.conv2 = nn.Conv2d(
      out_channels,
      out_channels,
      kernel_size=3,
      padding=1
    )
    self.groupnorm2 = nn.GroupNorm(
      normgroups,
      out_channels
    )
    self.relu2 = nn.ReLU()

    self.conv3 = nn.Conv2d(
      out_channels,
      out_channels,
      kernel_size=3,
      padding=1
    )
    self.groupnorm3 = nn.GroupNorm(
      normgroups,
      out_channels
    )
    self.relu3 = nn.ReLU()
    self.maxpool = nn.MaxPool2d(
      2, padding=pad_maxpool
    )

  def forward(self, x):
    # First convolution
    x = self.conv1(x)
    x = self.groupnorm1(x)
    x_for_skip = self.relu1(x)

    # Second convolution
    x = self.conv2(x_for_skip)
    x = self.groupnorm2(x)
    x = self.relu2(x)

    x = self.conv3(x)
    x = self.groupnorm3(x)

    # Skip connection
    x = x + x_for_skip
    x = self.relu3(x)

    x = self.maxpool(x)
    return x


class UNetUpBlock(nn.Module):
  def __init__(self, in_channels, out_channels, time_dim, normgroups=4):
    super(UNetUpBlock, self).__init__()
    self.upsample =  nn.Upsample(scale_factor=2, mode='nearest')

    # Convolution 1
    self.conv1 = nn.Conv2d(
        in_channels, out_channels, kernel_size=3, padding=1
    )
    self.groupnorm1 = nn.GroupNorm(normgroups, out_channels)
    self.relu1 = nn.ReLU()

    # Convolution 2
    self.conv2 = nn.Conv2d(
        out_channels, out_channels, kernel_size=3, padding=1
    )
    self.groupnorm2 = nn.GroupNorm(normgroups, out_channels)
    self.relu2 = nn.ReLU()

    # Convolution 3
    self.conv3 = nn.Conv2d(
        out_channels, out_channels, kernel_size=3, padding=1
    )
    self.groupnorm3 = nn.GroupNorm(normgroups, out_channels)
    self.relu3 = nn.ReLU()

    # Parameters to scale and shift the time embedding
    self.time_mlp = nn.Linear(time_dim, time_dim)
    self.time_relu = nn.ReLU()


  def forward(self, x, x_down, t_embed):
    x_up = self.upsample(x)
    #print("x_up: ", x_up.shape)
    x = torch.cat([x_down, x_up], dim=1)

    # Cut embedding to be the size of the current channels
    t_embed = t_embed[:,:x.shape[1]]

    # Enable the neural network to modify the time-embedding
    # as it needs to
    t_embed = self.time_mlp(t_embed)
    t_embed = self.time_relu(t_embed)
    t_embed = t_embed[:,:,None,None].expand(x.shape)

    # Add time-embedding to input.
    x = x + t_embed

    # Convolution 1
    x = self.conv1(x)
    x = self.groupnorm1(x)
    x_for_skip = self.relu1(x)

    # Convolution 2
    x = self.conv2(x_for_skip)
    x = self.groupnorm2(x)
    x = self.relu2(x)

    # Convolution 3
    x = self.conv3(x)
    x = self.groupnorm3(x)

    # Skip connection
    x = x + x_for_skip
    x = self.relu3(x)

    return x


class UNet(nn.Module):
  def __init__(self):
    super(UNet, self).__init__()

    # Down blocks
    self.down1 = UNetDownBlock(1, 4, normgroups=1)
    self.down2 = UNetDownBlock(4, 10, normgroups=1)
    self.down3 = UNetDownBlock(10, 20, normgroups=2)
    self.down4 = UNetDownBlock(20, 40, normgroups=4)

    # Convolutional layer at the bottom of the U-Net
    self.bottom_conv = nn.Conv2d(
        40, 40, kernel_size=3, padding=1
    )
    self.bottom_groupnorm = nn.GroupNorm(4, 40)
    self.bottom_relu = nn.ReLU()

    # Up blocks
    self.up1 = UNetUpBlock(60, 20, 60, normgroups=2) # down4 channels + down3 channels
    self.up2 = UNetUpBlock(30, 10, 30, normgroups=1) # down2 channels + up1 channels
    self.up3 = UNetUpBlock(14, 5, 14, normgroups=1) # down1 channels + up2 channels
    self.up4 = UNetUpBlock(6, 5, 6, normgroups=1) # input channels + up3 channels

    # Final convolution to produce output. This layer injects negative
    # values into the output.
    self.final_conv = nn.Conv2d(
        5, 1, kernel_size=3, padding=1
    )

  def forward(self, x, t_emb):
    """
    Parameters
    ----------
      x: Input tensor representing an image
      t_embed: The time-embedding vector for the current timestep
    """
    # Pad the input so that it is 32x32. This enables downsampling to
    # 16x16, then to 8x8, and finally to 4x4 at the bottom of the "U"
    x = F.pad(x, (2,2,2,2), 'constant', 0)

    # Down-blocks of the U-Net compress the image down to a smaller
    # representation
    x_d1 = self.down1(x)
    x_d2 = self.down2(x_d1)
    x_d3 = self.down3(x_d2)
    x_d4 = self.down4(x_d3)

    # Bottom layer perform final transformation on compressed representation
    # before re-inflation
    x_bottom = self.bottom_conv(x_d4)
    x_bottom = self.bottom_groupnorm(x_d4)
    x_bottom = self.bottom_relu(x_d4)

    # Up-blocks re-inflate the compressed representation back to the original
    # image size while taking as input various representations produced in the
    # down-sampling steps
    x_u1 = self.up1(x_bottom, x_d3, t_emb)
    x_u2 = self.up2(x_u1, x_d2, t_emb)
    x_u3 = self.up3(x_u2, x_d1, t_emb)
    x_u4 = self.up4(x_u3, x, t_emb)

    # Final convolutional layer. Introduces negative values.
    x_u4 = self.final_conv(x_u4)

    # Remove initial pads to produce a 28x28 MNIST digit
    x_u4 = x_u4[:,:,2:-2,2:-2]

    return x_u4
</code></pre></div></div>
<p>Next, we will implement a function that will generate the timestep embeddings. Below is an adaptation of the time embedding function by Ho, Jain, and Abbel from their GitHub repository,
<a href="https://github.com/hojonathanho/diffusion">https://github.com/hojonathanho/diffusion</a>. This code was adapted from TensorFlow to PyTorch:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>def get_timestep_embedding(timesteps, embedding_dim):
  """
  Translated from Tensorflow to PyTorch by the original Diffusion implementation
  by Ho et al. in https://github.com/hojonathanho/diffusion
  """
  assert len(timesteps.shape) == 1  # and timesteps.dtype == torch.int32

  half_dim = embedding_dim // 2
  emb = np.log(10000) / (half_dim - 1)
  emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
  emb = timesteps[:, None].to(torch.float32) * emb[None, :]
  emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
  if embedding_dim % 2 == 1:  # zero pad
    emb = torch.nn.functional.pad(emb, (0, 1))
  assert emb.shape == (timesteps.shape[0], embedding_dim)
  return emb
</code></pre></div></div>

<p>This code is adapted from TensorFlow to PyTorch. The function accepts two integers: the number of timesteps (i.e., $T$) and the embedding dimension. Similar to Ho, Jain, and Abbeel, I used 1,000 timesteps (as we will see in the code that follows). In my model, the largest feature vector associated with each pixel (corresponding to the number of channels in the convolutional layer at the very bottom of the U-Net) is 60, so the embedding dimension would be 60. This function returns a matrix with number of rows equal to the $T$ and number of columns equal to the number of dimensions in the embedding.</p>

<p>Next, we will write a function that will produce a linear variance schedule. Given a minimum variance, maximum variance, and number of timesteps, it will create a linear interpolation between the max and min over the given number of timesteps:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>def linear_variance_schedule(min: float, max: float, T: int):
  """
  min: minimum value for beta
  max: maximum value for beta
  T: number of timesteps
  """
  betas = torch.arange(0, T) / T
  betas *= max - min
  betas += min
  return betas
</code></pre></div></div>

<p>Now that we have defined our UNet model and functions for generating timestep embeddings and the variance schedule, let’s begin to construct and train the model. We will start by setting our parameters for the training process. We train the model for 300 epochs using a minibatch size of 128. We use a linear variance schedule starting spanning from a minimal variance of 1e-4 to a maximum variance of 0.02 as per <a href="https://github.com/cloneofsimo/minDiffusion">https://github.com/cloneofsimo/minDiffusion</a>. Specifically, the variables for storing these parameters are shown below:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code># Parameters
EPOCHS = 300
T = 1000
LEARNING_RATE = 1e-4
BATCH_SIZE = 128
MIN_VARIANCE = 1e-4
MAX_VARIANCE = 0.02
DEVICE = 'cuda'
</code></pre></div></div>

<p>Next, let’s load the data. We will use PyTorch’s built-in functionality for loading the MNIST digits data. Note, this implementation <em>centers</em> the pixel values around zero (via the provided <code class="language-plaintext highlighter-rouge">transforms.Normalize((0.5), (0.5))</code> transformation to the <code class="language-plaintext highlighter-rouge">DataLoader</code>). That is, the raw MNIST data provides pixel values spanning from 0 to 1; however, this code centers the data so that it spans -1 to 1 and is centered at zero. This follows the implementation provided by <a href="https://github.com/cloneofsimo/minDiffusion">https://github.com/cloneofsimo/minDiffusion</a>.</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code># Load dataset
dataset = MNIST(
  "./data",
  train=True,
  download=True,
  transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5)) # By subtracting 0.5, we center the data
  ])
)
dataloader = DataLoader(
  dataset,
  batch_size=BATCH_SIZE,
  shuffle=True,
  num_workers=1
)
</code></pre></div></div>

<p>Finally, let’s put this all together and train a mdoel. The code below instantiates the variance schedule, time embeddings, and UNet model and then implements the training loop. The code is heavily commented for pedagogical purposes:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code># Compute variance schedule
betas = linear_variance_schedule(MIN_VARIANCE, MAX_VARIANCE, T).to(device)

# Compute constants based on variance schedule
alphas = 1 - betas
onemalphas = 1 - alphas
alpha_bar = torch.exp(torch.cumsum(torch.log(alphas), dim=0))
sqrt_alphabar = torch.sqrt(alpha_bar)
onemalphabar = 1-alpha_bar
sqrt_1malphabar = torch.sqrt(1-alpha_bar)

# Instantiate the noise model, loss function, and optimizer 
noise_model = UNet().to(device)
optimizer = optim.Adam(noise_model.parameters(), lr=LEARNING_RATE)
mse_loss = nn.MSELoss().to(device)

# Generate timestep embeddings. Note, the embedding dimension is hardcoded
# and based on the number of channels at the bottom layer of the U-Net
# noise model
time_embeddings = get_timestep_embedding(
  torch.arange(0,T),
  embedding_dim=60 
).to(device)

# The training loop
epoch_losses = []
for epoch in range(EPOCHS):
  loss_sum = 0
  n_batchs = 0
  for b_i, (X_batch, _) in enumerate(dataloader):
    n_batchs += 1

    # Move batch to device
    X_batch = X_batch.to(device)

    # Sample noise for each pixel and image in this batch
    # B x M x N matrix where B is minibatch size, M is number
    # of rows in each image and N is number of columns in the
    # each image
    eps = torch.randn_like(X_batch).to(device)

    # Get a random timepoint for each item in this batch
    # B x 1 matrix
    ts = torch.randint(
        1, T+1, size=(X_batch.shape[0],)
    ).to(device)

    # Grab the time-embeddings for each of these sampled timesteps
    # B x D matrix where B is minibatch size and D is time embedding
    # dimension
    t_embs = time_embeddings[ts-1].to(device)

    # Compute X_batch after adding noise via the diffusion process for each of
    # the items in the batch (at the sampled per-item timepoints, `ts`)
    # B x M x N matrix
    sqrt_alphabar_ts = sqrt_alphabar[ts-1]
    sqrt_1malphabar_ts = sqrt_1malphabar[ts-1]
    X_t = sqrt_alphabar_ts[:, None, None, None] * X_batch \
      +  sqrt_1malphabar_ts[:, None, None, None] * eps

    # Predict the noise from our sample using the UNet
    # B x M x N matrix
    pred_eps = noise_model(X_t, t_embs)

    # Compute the loss between the real noise and predicted noise
    loss = mse_loss(eps, pred_eps)
    loss_sum += float(loss)

    # Update the weights in the U-Net via a step of gradient descent
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  print(f"Epoch: {epoch}. Mean loss: {loss_sum/n_batchs}")
  epoch_losses.append(loss_sum/n_batchs)
</code></pre></div></div>

<p>After this process finishes (it took a couple of hours to train in Google Colab running on an NVIDIA T4 GPU), we will have a trained model that we can use to generate new MNIST digits. To generate a new MNIST digit, we first sample white noise and then run the reverse diffusion process by iteratively applying our trained model. A function for generating images in this manner is shown below:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>def sample_from_model(T=999, show_img_mod=None, cmap='viridis'):
  # Initialize the image to white noise
  X_t = torch.randn(1, 1, 28, 28).to(DEVICE)

  # This samples accordingly to Algorithm 2. It is exactly the same logic.
  for t in range(T, -1, -1):
    
    # Sample noise
    if t &gt; 1:
      Z = torch.randn(1, 1, 28, 28).to(DEVICE)
    else:
      Z = torch.zeros(1, 1, 28, 28).to(DEVICE)

    # Get current time embedding
    t_emb = time_embeddings[t][None,:]

    # Predict the noise from the current image
    pred_eps = noise_model(X_t, t_emb)

    # Compute constants
    one_over_sqrt_alpha_t = 1 / torch.sqrt(alphas[t])
    pred_noise_scale = betas[t] / sqrt_1malphabar[t]
    sqrt_beta_t = torch.sqrt(betas[t])

    # Generate next image in the Markov chain
    X_t = (one_over_sqrt_alpha_t * (X_t - (pred_eps * pred_noise_scale))) \
      + (sqrt_beta_t * Z)

    # Show current image
    if show_img_mod is not None:
      if t % show_img_mod == 0:
        print(f"t = {t}")
        plt.imshow(
          (X_t.detach().cpu().numpy().squeeze() + 1.) / 2.,
          cmap=cmap
        )
        plt.xticks([])
        plt.yticks([])
        plt.show()
    if t ==0:
      print(f"t = {t}")
      plt.imshow(
        (X_t.detach().cpu().numpy().squeeze() + 1.) / 2.,
        cmap=cmap
      )
      plt.xticks([])
      plt.yticks([])
      plt.show()

  return X_t
</code></pre></div></div>]]></content><author><name>Matthew N. Bernstein</name></author><category term="tutorial" /><category term="deep learning" /><category term="machine learning" /><category term="probabilistic models" /><summary type="html"><![CDATA[Diffusion models are a family of state-of-the-art probabilistic generative models that have achieved ground breaking results in a number of fields ranging from image generation to protein structure design. In Part 1 of this two-part series, I will walk through the denoising diffusion probabilistic model (DDPM) as presented by Ho, Jain, and Abbeel (2020). Specifically, we will walk through the model definition, the derivation of the objective function, and the training and sampling algorithms. We will conclude by walking through an implementation of a simple diffusion model in PyTorch and apply it to the MNIST dataset of hand-written digits.]]></summary></entry><entry><title type="html">Assessing the utility of data visualizations based on dimensionality reduction</title><link href="https://mbernste.github.io/posts/dim_reduc/" rel="alternate" type="text/html" title="Assessing the utility of data visualizations based on dimensionality reduction" /><published>2024-03-02T00:00:00-08:00</published><updated>2024-03-02T00:00:00-08:00</updated><id>https://mbernste.github.io/posts/dim_reduc</id><content type="html" xml:base="https://mbernste.github.io/posts/dim_reduc/"><![CDATA[<p><em>We human beings use our vision as our chief sense for understanding the world, and thus when we are confronted with data, we try to understand that data through visualization. Dimensionality reduction methods, such as PCA, t-SNE, and UMAP, are approaches designed to enable the visualization of high-dimensional data. Unfortunately, because these methods inevitably distort aspects of the data, these methods are receiving new scrutiny. In this post, I propose that dimensionality reduction requires a “probabilistic” framework of interpretation rather than a “deterministic” one wherein conclusions one draws from a dimensionality reduction plot have some probability of not actually being true of the data. I will propose that this does not mean these plots are not useful. Rather, to evaluate their utility, I will argue that empirical user studies of these methods will shed light on whether these methods provide more benefit or more harm in practice.</em></p>

<h2 id="introduction">Introduction</h2>

<p>The advancement of technology has brought with it the ability to generate ever larger and more complex collections of data. This is especially true in biomedical research, where new technologies can produce thousands, or even millions, of biomolecular measurements at a time. Because we human beings use our vision as our chief sense for understanding the world, when we are confronted with data, we try to understand that data through visualization. Moreover, because we evolved in a three-dimensional world, we can only ever visualize up to three dimensions of an object at a time. This limitation poses a fundamental problem when it comes to high-dimensional data; high-dimensional data cannot, without loss of information, be visualized in their totality at once. But this does not mean we have not tried! The field of <a href="https://en.wikipedia.org/wiki/Dimensionality_reduction#:~:text=Dimensionality%20reduction%2C%20or%20dimension%20reduction,close%20to%20its%20intrinsic%20dimension.">dimensionality reduction algorithms</a> studies and develops algorithms that map high dimensional data to two or three dimensions where we can visualize it with minimal loss of information. For example, the classical <a href="https://en.wikipedia.org/wiki/Principal_component_analysis">principal components analysis (PCA)</a> uses a linear mapping to project data down to a space that preserves as much variance as possible. More recently, the <a href="https://www.jmlr.org/papers/volume9/vandermaaten08a/vandermaaten08a.pdf">t-SNE</a> and <a href="https://arxiv.org/pdf/1802.03426.pdf">UMAP</a> algorithms use <a href="https://en.wikipedia.org/wiki/Nonlinear_dimensionality_reduction">nonlinear mappings</a> that attempts to preserve the “topology” of the data – that is, that attempts to preserve neighborhoods of nearby data points while preventing overlapping dense regions of data in the output figure. An example of  <a href="https://en.wikipedia.org/wiki/Single-cell_sequencing">single-cell RNA-seq</a> data from <a href="https://en.wikipedia.org/wiki/Peripheral_blood_mononuclear_cell">peripheral blood mononuclear cells (PBMCs)</a> visualized by PCA, t-SNE, and UMAP are shown below (Data was downloaded via <a href="https://doi.org/10.1186/s13059-017-1382-0">Scanpy’s</a> <a href="https://scanpy.readthedocs.io/en/latest/generated/scanpy.datasets.pbmc3k.html#scanpy.datasets.pbmc3k">pbmc3k function</a>. Code to reproduce this figure can be found on <a href="https://colab.research.google.com/drive/1g4bt9S0aE6qu8BZNvVblKNqIZf-UK62L?usp=sharing">Google Colab</a>. Note, t-SNE and UMAP are being used to visualize the top 50 principal components from PCA.):</p>

<p> </p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/dim_reduc_PBMC_3k_example.png" alt="drawing" width="800" /></center>

<p> </p>

<p>Unfortunately, because it is mathematically impossible to avoid losing information when mapping data from high to low dimensions, these algorithms inevitably lose some aspect of the data, either by distortion or ommision, when plotting it in lower dimensions. This limitation makes the figures generated by these methods easy to misinterpret. Because of this, dimensionality reduction algorithms, especially t-SNE and UMAP, are facing <a href="https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1011288">new scrutiny</a> by those who argue that nonlinear dimension reduction algorithms distort the data so heavily that their output is at best useless and at worst harmful. On the other hand, proponents of these methods argue that although distortion is inevitable, these methods can and do reveal aspects of the data’s structure that would be difficult to convey by other means.</p>

<p>In this blog post, I will attempt to provide my views on the matter which lie somewhere between those held by the critics and proponents. I will start with a review of dimensionality reduction and describe how it inevitably entails a loss of information. I will then argue that dimensionality reduction methods require a different kind of mentality to use them correctly than traditional data visualizations (i.e., those that do not compress high dimensional data into few dimensions). As a brief preview, I will argue that dimensionality reduction requires a “probabilistic” framework of interpretation rather than a “deterministic” one wherein conclusions one draws from a dimensionality reduction plot have some probability of not actually being true of the data. I will propose that this does not mean these plots are not useful! To evaluate their utility, I will argue that empirical <a href="https://en.wikipedia.org/wiki/User_research">user studies</a> of these methods are required to evaluate them. That is, we must empirically assess whether or not the conclusions practitioners draw from these figures are more often true than not, and when not true, how consequential are they.</p>

<p>For much of this blog, I will use data generated by single-cell <a href="https://mbernste.github.io/posts/rna_seq_basics/">RNA-sequencing</a> (scRNA-seq) as the primary example of high-dimensional data which I will use in a case study addressing the risks and merits of using dimension reduction for data visualization. As a brief review, scRNA-seq data is structured as a data table/<a href="https://mbernste.github.io/posts/matrices/">matrix</a> where rows represent individual cells and columns represent genes. Each entry of the matrix stores a measurement of the relative abundance of mRNA molecules transcribed from a given gene in a given cell. scRNA-seq studies routinely generate data for <a href="https://doi.org/10.1126/science.abl4896">hundreds of thousands of cells</a> and provide gene expression measurements for tens of thousands of genomic features such as genes or <a href="https://en.wikipedia.org/wiki/Gene_isoform">isoforms</a>. Thus these data are very high-dimensional. For a comprehensive review on RNA-seq, please see <a href="https://mbernste.github.io/posts/rna_seq_basics/">my previous blog post</a>.</p>

<h2 id="dimensionality-reduction-almost-always-entails-a-loss-of-information">Dimensionality reduction almost always entails a loss of information</h2>

<p>In this section we will review the task of dimensionality reduction and describe why it inevitably entails a loss of information. Before moving forward, let’s formalize what we mean by the “dimensionality” of data. For the purposes of our discussion, we will refer to data as being $d$-dimensional if that data can be represented as a set of coordinate <a href="https://mbernste.github.io/posts/vector_spaces/">vectors</a> in $\mathbb{R}^d$. That is, the dataset can be represented as $\boldsymbol{x}_1, \dots \boldsymbol{x}_n \in \mathbb{R}^d$. Collectively, we can represent the data as a <a href="https://mbernste.github.io/posts/matrices/">matrix</a> $\boldsymbol{X}^{n \times d}$ where each row represents a datapoint. This description thus covers all tabular data. (For a more philsophical treatment on the notion of “dimensionality”, see my <a href="https://mbernste.github.io/posts/intrinsic_dimensionality/">previous blog post</a>).</p>

<p>The task of dimensionality reduction is to find a new set of vectors $\boldsymbol{x}’_1, \dots, \boldsymbol{x}’_n$ in a $d’$ dimensional space where $d’ &lt; d$ such that these lower dimensional points preserve some aspect of the original data’s structure. Said more succintly, the task is to convert the high dimensional data $\boldsymbol{X} \in \mathbb{R}^{n \times d}$ to $\boldsymbol{X}’ \in \mathbb{R}^{n \times d’}$ where $d’ &lt; d$. This is often cast as an optimization problem of the form:</p>

\[\max_{\boldsymbol{X}' \in \mathbb{R}^{n \times d'}} \text{Similarity}(\boldsymbol{X}, \boldsymbol{X}')\]

<p>where the function $\text{Similarity}(\boldsymbol{X}, \boldsymbol{X}’)$ outputs a value that tells us “how well” the pairwise relationships between data points in $\boldsymbol{X}’$ reflect those in $\boldsymbol{X}$. The exact form of $\text{Similarity}(\boldsymbol{X}, \boldsymbol{X}’)$ depends on the dimensionality reduction method.</p>

<p>Note that if $d &gt; 3$ then we cannot easily visualize our data as a scatterplot to see the global structure between datapoints. Thus, to visualize data it is common to set $d’$ to either 2 or 3 thereby mapping each datapoint $\boldsymbol{x}_i$ to a new, 2 or 3 dimensional data point $\boldsymbol{x}’_i$ that can be visualized in a scatterplot.</p>

<p>However, there is a crucial problem to visualizing data in this manner: it is not possible (in general) to compress data down to a lower dimensional space and preserve all of the relative pairwise distances between data points. As an illustrative example, consider three points in 2-dimensional space that are equidistant from one another. If we were to compress these down into one dimension then <em>inevitably</em> at least one pair of data points will have a larger distance from one another than the other two pairs. This is shown below:</p>

<p><br /></p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/dim_reduc_toy_example_equidistant.png" alt="drawing" width="600" /></center>

<p><br /></p>

<p>Notice how the distance between the blue and green data point is neither equal to the distance between the blue and red data points nor to the distance between the red and green data points as was the case in the original two dimensional space. Thus, the distances between this set of three data points have been distorted!</p>

<p>This problem presents itself in the more familiar territory of creating maps. It is mathematically impossible to project the globe onto a 2D surface without distorting distances and/or shapes of continents. Different <a href="https://en.wikipedia.org/wiki/Map_projection">map projections</a> have been devised to reflect certain aspects of the 3D configuration of the Earth’s surface (e.g., shapes of continents), but that comes at the expense of some other aspect of that 3D configuration (e.g., distances between continents). A few examples of map projections are illustrated below (These images were created by <a href="https://en.wikipedia.org/wiki/List_of_map_projections">Daniel R. Strebe and were pulled from Wikipedia</a>):</p>

<p> </p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/globe_projections_by_Daniel_R_Strebe.png" alt="drawing" width="600" /></center>

<p> </p>

<p>In fact, you can demonstrate this problem for yourself in your kitchen! Just peel an orange and attempt to lay the pieces of the orange peel on the table to reconstruct the original surface of the orange… it’s impossible!</p>

<p>It is almost always the case that some information is lost following dimensionality reduction. To visualize high dimensional data in two or three dimensions, one must either throw away dimensions and plot the remaining two/three or devise a more sophisticated approach that maps high dimensional data points to low dimensional data points to preserve <em>some aspect</em> of the high dimensional data’s structure (with respect to the Euclidean distances between data points) at the expense of other aspects. Exactly which aspect of the data’s structure you wish to preserve depends on how you define your $\text{Similarity}(\boldsymbol{X}, \boldsymbol{X}’)$ function described above! (Note, there are scenarios where it is possible to preserve pairwise distances between data points following dimensionality reduction, but those scenarios tend to be uninteresting. An uninteresting example would be a case in which all of your data points lie in a 2-dimensional plane embedded in a higher-dimensional space).</p>

<h2 id="a-probabistic-mindset-for-thinking-about-inferences-drawn-from-data-visualizations">A probabistic mindset for thinking about inferences drawn from data visualizations</h2>

<p>Now, I am going to build a framework for thinking about data visualization that will draw what I view as an important distinction between “traditional” data visualizations, such as heatmaps or barcharts, and data visualizations based on dimensionality reduction such as UMAP scatterplots. As a very brief preview, I will argue that traditional data visualizations enable one to make claims about the underlying data with 100% certainty whereas dimensionality reduction visualizations do not provide the same certainty.</p>

<p>Before we get going, I am going to make a statement that may appear obvious, but is important for laying the foundation for my views on dimensionality reduction: <strong>The primary outputs of a data visualization are a set of statements about the data</strong>. Take the following table and associated bar chart as an example to describe what I mean by this:</p>

<p><br /></p>
<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/dim_reduc_traditional_data_viz.png" alt="drawing" width="600" /></center>

<p><br /></p>

<p>One statement about the data being conveyed by this plot is that Label A is associated with the value 9. Another statement about the data is that Label A is associated with a larger value than Label B. Below are a set of example statements about the data being conveyed by this plot:</p>

<p><br /></p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/dim_reduc_statements_on_barchart.png" alt="drawing" width="550" /></center>

<p><br /></p>

<p>Note, I am describing statements <em>about the data</em>, not statements about the <em>world</em>. That is, the barchart enables one to make claims about the literal values stored in the data table from which this figure was generated. (The task of drawing conclusions about the world based on data and/or data visualizations is the task of science and statistics). Data visualizations make facts about the data easier to understand than the raw data by itself (i.e., large tables of numbers) because we human beings are visual animals.</p>

<p>For traditional data visualizations, statements about the data being described by the visualization are 100% certain to be true. For example, when one looks at the barchart above they <em>know</em> that Label A is associated with a larger value than Label B’s value (unless of course, there was an error in the generation of the visualization, but we will assume no errors were made here). That is because in traditional data visualizations, there is either a one-to-one or linear mapping between some aspect of the data and some visual or spatial element to the visualization. In the bar chart above, the mappings are as follows:</p>

<ul>
  <li>Magnitude of Value $\rightarrow$ Height of bar (linear mapping)</li>
  <li>Distinct Label $\rightarrow$ Distinct bar (one-to-one mapping)</li>
</ul>

<p>Because these mappings are <a href="https://en.wikipedia.org/wiki/Inverse_function">invertible</a> and <a href="https://en.wikipedia.org/wiki/Deterministic_system">deterministic</a>, we can draw conclusions about the raw data with 100% certainty based on the visual elements in the figure.</p>

<p>Now, let’s turn our attention to data visualizations produced by dimensionality reduction methods. Let’s use a UMAP plot of the PBMCs shown above, but now let’s not color the cells by their respective cell type and let’s pretend we don’t know much of anything about these cells.</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/dim_reduc_PBMC_uncolored.png" alt="drawing" width="450" /></center>

<p>What are some (possibly incorrect) statements we might make about the data from this figure? Below are a few examples:</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/dim_reduc_statements_on_umap.png" alt="drawing" width="450" /></center>

<p>The problem is that many of the above statements are probably not true! For UMAP in particular, statements regarding distances are <a href="https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1011288">especially unlikely to be true</a>. This is because, as we discussed before, dimensionality reduction methods distort the data. Statements we make about the data from this figure do not have the same gaurantee of truth as statements we made from the bar chart!</p>

<p>This brings us to a probablistic framework for thinking about data visualizations. Any statement about a dataset, $S$, drawn from a traditional data visualization, like the bar chart, has the following property:</p>

\[P(S = \text{True}) = 1\]

<p>On the other hand, this is not the case for dimensionality reduction plots. Rather,</p>

\[P(S = \text{True}) &lt; 1\]

<p>Of course, the probability $P(S = \text{True})$ cannot really be defined in a <a href="https://en.wikipedia.org/wiki/Frequentist_probability">frequentist sense</a>, since $S$ either is or is not true. Rather, this is more of a <a href="https://en.wikipedia.org/wiki/Bayesian_probability">Bayesian probability</a> (i.e., a degree of certainty) that can be informed by looking across many plots with a similar feature as the feature being described by $S$ and asking: for what fraction of the datasets described by those plots is the statement $S$ true?</p>

<p>I concede that this “probabilistic framework” is a bit hand-wavey and not very rigorous, but to me illustrates an important distinction between plots based on dimensionality reduction and “traditional” data visualizations. Specifically, for dimensionality reduction plots, the statements one might draw from them about the data may be wrong. Said differently, dimensionality reduction plots help users draw <em>uncertain inferences</em> about the structure of the data whereas there is no uncertainty in a traditional data visualization.</p>

<p>Expert users of dimensionality reduction plots know this, of course, but use them anyway. Why? Well, because they claim that certain “classes” of statements are more often true than not and because of this fact, these methods are useful. For example, when it comes to UMAP, it is <em>often</em> true (but not always) that when you see distinct clusters in the scatterplot, those clusters are real characteristics of the data. Thus, one might say that a statement on clusters, $S_{\text{cluster}}$, is associated with a proability, $P(S_{\text{cluster}} = \text{True})$, that is high enough to be useful. On the other hand, a statement on distances between data points (especially long distances), $S_{\text{distance}}$, is associated with a probability, $P(S_{\text{distance}} = \text{True})$, that is far too low to be useful. The problem then is to determine which statements are more likely to be true than others for a given dimensionality reduction method.</p>

<h2 id="addressing-the-criticisms-of-dimensionality-reduction-methods">Addressing the criticisms of dimensionality reduction methods</h2>

<p>As I mentioned in the introduction, dimensionality reduction methods are receiving new scrutiny. Given my “probabilistic mindset” for interpreting data visualizations, I will attempt to summarize my understanding of certain criticisms of dimensionality reduction methods, especially non-linear methods like t-SNE and UMAP:</p>

<p><strong>Concern 1: Distortion caused by popular methods like t-SNE and UMAP are too severe to be useful:</strong> That is, distortion of the data is so severe that practically any interesting statement $S$ that one might make from such a visualization is associated with too low of a probability of actually being true about the data to be useful for anything.</p>

<p>To this concern, I am undecided and I will address it more thoroughly in the following section. To preview, I believe that the best way to assess dimensionality reduction methods is to employ <a href="https://en.wikipedia.org/wiki/User_research">user studies</a>. Do these plots lead to new insights <em>in practice</em> despite the fact that they distort the data? Do they lead to more good than harm?</p>

<p><strong>Concern 2: We don’t know what inferences can be drawn reliably from these plots:</strong> That is, we do not have a deep enough understanding into the classes of statements that have high or low probability of being true. Without this understanding, we cannot use these plots effectively.</p>

<p>I mostly agree with this statement. The objective function of t-SNE, for example, is mostly built upon heuristics designed to generate a figure with certain properties. For example, the use of a t-distribution in the underlying model is motivated by the fact that it pushes data points apart in the resultant low-dimensional space and avoids overcrowding of data points. In my opinion, this is not really based on sound statistical theory. Similarly, UMAP assumes certain characteristics of the high-dimensional data that, as far as I know, are difficult or impossible to test. For example, UMAP assumes that “<a href="https://umap-learn.readthedocs.io/en/latest/">The data is uniformly distributed on Riemannian manifold</a>”. I’m not sure how, in general, one can know this without a very sophisticated understanding of the underlying data-generating process.</p>

<p>All of that said, there is ongoing research to either develop new dimensionality reduction methods that are easier to interpret or to help users more accurately interpret plots generated by existing methods. For example, the recent method, <a href="https://doi.org/10.1038/s41587-019-0336-3">PHATE</a>, claims to better preserve continuums of data points in the high-dimensional space. <a href="https://doi.org/10.1038/s41587-020-00801-7">DensMAP</a> claims to better preserve regions of high or low density. <a href="https://doi.org/10.1186/s13059-023-02998-7">Suprisal Components Analysis (SCA)</a> claims to better preserve small clusters. <a href="https://doi.org/10.1038/s41467-024-45891-y">scDEED</a> identifies features in dimensionality reduction plots that are misleading.</p>

<p><strong>Concern 3: Any visualization in which there is uncertainty around what it says about the data should be avoided:</strong> The argument here is that it is too easy to misinterpret and misuse <em>any</em> data visualization technique in which one can make a reasonable statement about the data, $S$, but that $P(S = \text{True}) &lt; 1$.</p>

<p>The concern here is that if a plot does not provide certainty into the data that it describes, then it is too easy to fall victim to confirmation bias when interpreting that figure. I admit I fell prey to this myself. In <a href="https://doi.org/10.1016/j.isci.2020.101913">a paper I led</a> presenting a cell type classification algorithm, called <a href="https://doi.org/10.1016/j.xpro.2021.100705">CellO</a>, I made the following statement based on a UMAP plot (referencing Figure 7): “CellO annotated many of these cells as pancreatic A cells (a.k.a. pancreatic alpha cells), which is plausible owing to both their close position to annotated A cells according to UMAP, which is known to preserve some level of global structure in high dimensional data (Becht et al., 2018)…” Granted, this statement is not very strong, I nonetheless ask myself whether what I saw in that UMAP plot is what I wanted to see? Indeed, because these figures may make us more prone to confirmation bias, I am sympathetic to the argument that we should avoid them altogether. At the very least, one should use extreme caution when using them and make sure to confirm any hypotheses generated from these figures using orthogonal techniques. I know I will pay more attention going forward.</p>

<h2 id="user-studies-may-be-the-optimal-way-to-assess-the-utility-of-dimensionality-reduction-plots">User studies may be the optimal way to assess the utility of dimensionality reduction plots</h2>

<p>While recent studies, such as studies by <a href="https://doi.org/10.1371/journal.pcbi.1011288">Chari and Pachter (2023)</a> and <a href="https://doi.org/10.1038/s42003-022-03628-x">Huang <em>et al.</em> (2022)</a> evaluate dimensionality reduction algorithms quantitatively (and are very valuable studies), I argue that these studies don’t directly address the fundamental question regarding whether these methods lead to more harm or benefit. Because <em>statements</em> about data are the primary output of a data visualization, it is those statements that we should be evaluating. That is, even though it is established that dimensionality reduction methods distort the data, the question remains (in my mind) whether or not the statements that practicioners in the field draw from these plots have a high or low probability of being true. Do these plots lead to new insights <em>in practice</em> despite the fact that they distort the data? What alternative visualizations would provide the same insights with more certainty?</p>

<p>I am not an expert in how to conduct these kinds of studies and I am not sure what the best strategy would be, but I envision something like the following: Gather a group of scientist volunteers in some specific field and present them with a dimensionality reduction plot (e.g., a UMAP plot) for a dataset from a domain that they are unfamiliar with.  Next, ask each volunteer to list statements/hypotheses they have about the data from that figure. Finally, evaluate how many of those statements were actually true within the data or were not true? What alternative visualization methods would have led the user to the same correct hypotheses, but avoided the incorrect ones? Were there certain categories of statements (e.g., related to clusters) that tended to be true and others (e.g., related to distances) that tended not to be true? Of course, this would be a fairly qualitative study, but perhaps it would shed light on how these plots are being used in the field.</p>

<h2 id="some-thoughts-on-best-practices-for-using-dimensionality-reduction-plots">Some thoughts on best practices for using dimensionality reduction plots</h2>

<p>I propose that if one seeks to visualize their data with dimensionality reduction, they should use multiple methods in parallel. Because any statement, $S$, that one draws from these figures has a probability of not being true, it helps to assess whether other dimensionality reduction methods lead to the same statement. If many different methods all support $S$, then perhaps it is more likely to be true than if only one method supports $S$. That is because, as long as the methods are “orthogonal” to one another (i.e., are grounded in different theory or approach), then it would be quite a coincidence that $S$ is supported by multiple methods, but not actually true. Viewing these plots requires one to have a “probabilistic mindset” that is not needed for traditional data visualizations.</p>

<p>As an example, let’s look at another single-cell dataset from <a href="https://en.wikipedia.org/wiki/Cellular_differentiation">differentiating</a> <a href="https://en.wikipedia.org/wiki/Myeloid_tissue">myeloid cells</a> published by <a href="https://doi.org/10.1016/j.cell.2015.11.013">Paul <em>et al.</em> (2015)</a>. Below, I visualize these cells using six different dimensionality reduction methods: PCA, t-SNE, UMAP, <a href="https://doi.org/10.1371/journal.pone.0098679">Force-directed layout</a> of the <a href="https://en.wikipedia.org/wiki/Nearest_neighbor_graph">k-nearest neighbors graph</a> , PHATE, and Surprisal Components Analysis (SCA) (Data was downloaded via <a href="https://doi.org/10.1186/s13059-017-1382-0">Scanpy’s</a> <a href="https://scanpy.readthedocs.io/en/stable/generated/scanpy.datasets.paul15.html">paul15 function</a>. Code to reproduce this figure can be found on <a href="https://colab.research.google.com/drive/1g4bt9S0aE6qu8BZNvVblKNqIZf-UK62L?usp=sharing">Google Colab</a>. Note, t-SNE, UMAP, and force-directed layout are being used to visualize the top 50 principal components from PCA.):</p>

<center><img src="https://raw.githubusercontent.com/mbernste/mbernste.github.io/master/images/dim_reduc_Paul15_multiple_methods.png" alt="drawing" width="800" /></center>

<p>Note that all of the figures here present a continuum of cells originating at megakaryocyte/erythrocyte progenitors (MEP) and extend outward along two “branches”. Because this is featured by <em>all</em> of the plots, I think it is a reasonable hypothesis that there is indeed a continuum of cells starting from this cell type in the high-dimensional gene expression space. But of course, this may not be true. In my analysis, t-SNE, UMAP, and force-directed layout are all operating on the top 50 principal components from PCA, so they are not perfectly orthogonal. Similarly, UMAP, PHATE, and force-directed layout are all operating on a <a href="https://en.wikipedia.org/wiki/Nearest_neighbor_graph">k-nearest neighbors graph</a>. While t-SNE does not explicitly operate on a k-nearest neighbors graph, its use of centering a unimodal distribution around each point to capture a certain density of neighbors is effectively operating on a k-nearest neighbors graph. Thus, these methods in particular are even more similar to one another.</p>

<p>In conclusion, no statement can be made with absolute certainty from dimensionality reduction plots. We must be dilligent in confirming any hypotheses generated by these methods using alternative, statistically grounded approaches. Lastly, when using these methods, we must remain self-aware enough to avoid the confirmation bias that these methods may promote.</p>

<p><strong>Final note:</strong> Please let me know if I mischaracterized any work cited above.</p>]]></content><author><name>Matthew N. Bernstein</name></author><category term="tutorial" /><category term="data science" /><category term="visualization" /><summary type="html"><![CDATA[We human beings use our vision as our chief sense for understanding the world, and thus when we are confronted with data, we try to understand that data through visualization. Dimensionality reduction methods, such as PCA, t-SNE, and UMAP, are approaches designed to enable the visualization of high-dimensional data. Unfortunately, because these methods inevitably distort aspects of the data, these methods are receiving new scrutiny. In this post, I propose that dimensionality reduction requires a “probabilistic” framework of interpretation rather than a “deterministic” one wherein conclusions one draws from a dimensionality reduction plot have some probability of not actually being true of the data. I will propose that this does not mean these plots are not useful. Rather, to evaluate their utility, I will argue that empirical user studies of these methods will shed light on whether these methods provide more benefit or more harm in practice.]]></summary></entry></feed>