# Very Basic Diffusion

Credit for the title image is from https://yang-song.net/blog/2021/score/

After being at Waymo for a few years, I feel like I've fallen behind on the current state-of-the-art in machine learning. I've decided to brush up on some fundamentals. I wanted to start with diffusion models.

There was a time when I believed the only way to learn something was prove theorems from first principles. But a more practical approach is to just look at the code. 🤗 Diffusers has done a great job of creating the right abstractions for this. This is something that Google has never been good at in my opinion, but that's another topic.

Some presentations of diffusion models can be very mathematical with stochastic differential equations and variational inference. You can get a PhD in math and not be familiar with those concepts.

My mini project was to create something like Figure 2 from Score-Based Generative Modeling through Stochastic Differential Equations .

Here's my multimodal one-dimensional normal distribution.

!pip install -q diffusers
import numpy as np
import jax
from jax import sharding
from jax.experimental import mesh_utils

np.set_printoptions(suppress=True)
jax.config.update('jax_threefry_partitionable', True)
mesh = sharding.Mesh(mesh_utils.create_device_mesh((jax.device_count(),)), ('data',))

from jax import numpy as jnp
from matplotlib import pyplot as plt

P = np.array((0.3, 0.1, 0.6), np.float32)
MU = np.array((-5, 0, 6), np.float32)
SIGMA = np.array((1, 0.5, 2), np.float32)

def make_data(key: jax.Array, shape=()):
ps, mus, sigmas = jnp.array(P), jnp.array(MU), jnp.array(SIGMA)
i_key, noise_key = jax.random.split(key, 2)
i = jax.random.choice(i_key, np.arange(ps.shape[0]), shape, p=ps)
xs = mus[i]
ys = jax.random.normal(noise_key, xs.shape)
return xs + ys * sigmas[i]

plt.hist(make_data(jax.random.key(100), shape=(10000,)), bins=100);


## Forward Process

Pretty much the only thing that makes diffusion different from a normal neural network training loop are Schedulers. We can use one to recreate the forward process where we add noise to our data.

import diffusers

scheduler = diffusers.schedulers.FlaxDDPMScheduler(
1_000, clip_sample=False,
prediction_type='v_prediction',
variance_type='fixed_small')
scheduler_state = scheduler.create_state()
print(f'{scheduler_state.init_noise_sigma=}')

xs = make_data(jax.random.key(100), shape=(10000,))

fig = plt.figure(figsize=(10, 11))
axes = fig.subplots(nrows=5, sharex=True)
axes[0].hist(xs, bins=100)
axes[0].set_title('timestep=0')
for ax, timestep_idx in zip(axes[1:], reversed(range(0, scheduler_state.timesteps.shape[0], scheduler_state.timesteps.shape[0] // (len(axes) - 1)))):
timestep = int(scheduler_state.timesteps[timestep_idx])
noisy_xs = scheduler.add_noise(
scheduler_state, xs, jax.random.normal(jax.random.key(0), xs.shape),
jnp.ones(xs.shape[:1], jnp.int32) * timestep,)
ax.hist(noisy_xs, bins=100);
ax.set_title(f'{timestep=}')


So after we add noise it just looks like a $\mathrm{Normal}(0, \sigma^2)$, where $\sigma$ is $\texttt{init_noise_sigma}$.

## Reverse Process

The reverse process is about training a neural network using backpropagation to predict the noise at the various timesteps. People have decided that we can just choose timesteps randomly.

### Training

What follows is a basic JAX training loop with some flourish for handling dropout and non-trainable state for batch normalization. I also make basic use of data parallelism since Google Colab gives me 8 TPUv2 cores.

import flax
from flax.training import train_state as train_state_lib
import optax

class TrainState(train_state_lib.TrainState):
batch_stats: ...
ema_decay: jax.Array
ema_params: jax.Array
scheduler_state: ...

def apply_gradients(self, *, grads, **kwargs) -> 'TrainState':
train_state = super().apply_gradients(grads=grads, **kwargs)
ema_params = jax.tree_map(
lambda ema, p: ema * self.ema_decay + (1 - self.ema_decay) * p,
self.ema_params, train_state.params,
)
return train_state.replace(ema_params=ema_params)

@classmethod
def create(cls, *, params, tx, **kwargs):
ema_params = jax.tree_map(lambda x: x, params)
return super().create(apply_fn=None, params=params, tx=tx, ema_params=ema_params, **kwargs)

from diffusers.models import embeddings_flax
from flax import linen as nn

class ResnetBlock(nn.Module):
dropout_rate: bool
use_running_average: bool

@nn.compact
def __call__(self, xs):
xs += nn.Dropout(deterministic=False, rate=self.dropout_rate)(
nn.gelu(nn.Dense(128)(nn.BatchNorm(self.use_running_average)(xs))))
return xs, ()

class VelocityModel(nn.Module):
dropout_rate: float
use_running_average: bool

@nn.compact
def __call__(self, xs, ts):
ts = embeddings_flax.FlaxTimesteps()(ts)
ts = embeddings_flax.FlaxTimestepEmbedding()(ts)
xs = jnp.concatenate((xs, ts), -1)
xs = nn.BatchNorm(self.use_running_average)(xs)
xs = nn.Dropout(deterministic=False, rate=self.dropout_rate)(nn.gelu(nn.Dense(128)(xs)))
xs, _ = nn.scan(
ResnetBlock,
length=3,
variable_axes={"params": 0, "batch_stats": 0},
split_rngs={"params": True, "dropout": True},
metadata_params={nn.PARTITION_NAME: None})(
self.dropout_rate, self.use_running_average)(xs)
xs = nn.BatchNorm(self.use_running_average)(xs)
return nn.Dense(1)(xs)

model = VelocityModel(dropout_rate=0.1, use_running_average=False)
model_vars = model.init(jax.random.key(0),
make_data(jax.random.key(0), shape=(2, 1)),
jnp.ones((2,), jnp.int32))

import optax

make_train_state = lambda model_vars: TrainState.create(
params=model_vars['params'],
tx=optax.chain(
optax.clip_by_global_norm(1.0),
optax.adamw(optax.linear_schedule(3e-4, 0, 5_000, 500))),
batch_stats=model_vars['batch_stats'],
ema_decay=jnp.array(0.9),
scheduler_state=scheduler_state,
)

import functools

def apply_fn(model, scheduler, key, state, xs):
dropout_key, noise_key, timesteps_key = jax.random.split(key, 3)
noise = jax.random.normal(noise_key, xs.shape)
timesteps = jax.random.choice(
timesteps_key, state.scheduler_state.timesteps, (xs.shape[0],))
predicted_velocity, mutable_vars = model.apply(
dict(params=state.params, batch_stats=state.batch_stats),
scheduler.add_noise(
state.scheduler_state, xs, noise, timesteps), timesteps,
rngs=dict(dropout=dropout_key),
mutable=['batch_stats'])
return (predicted_velocity, mutable_vars), noise, timesteps

def loss_fn(scheduler, noise, predicted_velocity, scheduler_state, xs, ts):
velocity = scheduler.get_velocity(scheduler_state, xs, noise, ts)
loss = jnp.square(predicted_velocity - velocity)
return jnp.mean(loss), loss

def _update_fn(model, scheduler, key, state, xs):

@functools.partial(jax.value_and_grad, has_aux=True)
def _loss_fn(params):
(predicted_velocity, mutable_vars), noise, timesteps = apply_fn(
model, scheduler, key, state.replace(params=params), xs
)
loss, _ = loss_fn(scheduler, noise, predicted_velocity, state.scheduler_state, xs, timesteps)
return loss, mutable_vars

(loss, mutable_vars), grads = _loss_fn(state.params)
state = state.apply_gradients(grads=grads, batch_stats=mutable_vars['batch_stats'])
return state, loss

update_fn = jax.jit(
functools.partial(_update_fn, model, scheduler),
in_shardings=(None, None, sharding.NamedSharding(mesh, sharding.PartitionSpec('data'))),
donate_argnums=1)

train_state = jax.jit(
make_train_state,
out_shardings=sharding.NamedSharding(mesh, sharding.PartitionSpec()))(
model_vars)

make_sharded_data = jax.jit(
make_data,
static_argnums=1,
out_shardings=sharding.NamedSharding(mesh, sharding.PartitionSpec('data')),
)
key = jax.random.key(1)
for i in range(5_000):
key = jax.random.fold_in(key, i)
xs = make_sharded_data(key, (1 << 18, 1))
train_state, loss = update_fn(key, train_state, xs)
if i % 500 == 0:
print(f'{i=}, {loss=}')
loss


which gives us

i=0, loss=Array(24.140306, dtype=float32)
i=500, loss=Array(9.757293, dtype=float32)
i=1000, loss=Array(9.72407, dtype=float32)
i=1500, loss=Array(9.700769, dtype=float32)
i=2000, loss=Array(9.7025585, dtype=float32)
i=2500, loss=Array(9.676901, dtype=float32)
i=3000, loss=Array(9.736925, dtype=float32)
i=3500, loss=Array(9.661432, dtype=float32)
i=4000, loss=Array(9.6644335, dtype=float32)
i=4500, loss=Array(9.635541, dtype=float32)
Array(9.662968, dtype=float32)


It's always good when loss goes down.

### Sampling

Now for sampling we use the scheduler's step function.

inference_model = VelocityModel(dropout_rate=0, use_running_average=True)
inference_scheduler_state = scheduler.set_timesteps(train_state.scheduler_state, 1000)

fig = plt.figure(figsize=(10, 11))
axes = fig.subplots(nrows=5, sharex=True)

@jax.jit
def step_fn(model_vars, state, t, xs, key):
velocity = inference_model.apply(model_vars, xs, jnp.broadcast_to(t, xs.shape[:1]))
return scheduler.step(state, velocity, t, xs, key=key).prev_sample

key = jax.random.key(100)
xs = jax.random.normal(key, (10000, 1))
for i, t in enumerate(inference_scheduler_state.timesteps):
if i % (inference_scheduler_state.num_inference_steps // 4) == 0:
ax = axes[i // (inference_scheduler_state.num_inference_steps // 4)]
ax.hist(jnp.squeeze(xs, -1), bins=100)
timestep = int(t)
ax.set_title(f'{timestep=}')
key = jax.random.fold_in(key, t)
xs = step_fn(dict(params=train_state.ema_params, batch_stats=train_state.batch_stats),
inference_scheduler_state, t, xs, key)
timestep = int(t)
axes[-1].hist(jnp.squeeze(xs, -1), bins=100);
axes[-1].set_title(f'{timestep=}');


We found some success? The modes and the mean and variance of the modes are correct. But the distribution among the modes seems off. I found it surprisingly hard to train this model correctly even for a simple distribution. There were some bugs and I had to tweak the model and optimization schedule.

### Score Matching and Langevin dynamics

To be honest, I don't understand it yet and I hope to explore it more in another blog post, but here's the basic code for the Langevin dynamics alluded to in https://yang-song.net/blog/2021/score/#langevin-dynamics.

$$x_{i+1} = x_i + \epsilon\nabla_x\log{p(x)} + \sqrt{2\epsilon}z_i,$$

where $z_i \sim \mathrm{Normal}(0, 1)$.

from jax import scipy as jscipy

def log_pdf(xs):
xs = jnp.asarray(xs)
pdfs = jscipy.stats.norm.pdf(xs[..., jnp.newaxis], loc=MU, scale=SIGMA)
p = jnp.sum(P * pdfs, axis=-1)
return jnp.log(p)

def score(xs):
primals_out, vjp = jax.vjp(log_pdf, xs)
return vjp(jnp.ones_like(primals_out))[0]

@jax.jit
def step_fn(state, t, xs, key):
del state, t
eps = 1e-2
noise = jax.random.normal(key, xs.shape)
return xs + eps * score(xs) + jnp.sqrt(2 * eps) * noise

key = jax.random.key(100)
xs = jax.random.normal(key, (10000, 1))
for t in jnp.arange(5_000 - 1, -1, -1, dtype=jnp.int32):
key = jax.random.fold_in(key, t)
xs = step_fn(inference_scheduler_state, t, xs, key);
plt.hist(jnp.squeeze(xs, -1), bins=100);


It seems to work way better and does actually get the weights of the modes correct. This could be because the score matching is a better objective or I cheated and derived the exact score.

## Conclusion

Despite the complex math behind it, it's remarkable how little the code differs from a basic neural network training loop. It's one of the remarkable (bitter?) lessons of engineering that simple stuff scales better and eventually beats complex stuff. It's not surprising that diffusion is powering mind-blowing things like https://openai.com/sora.

## Colab

You can find the full colab here: https://colab.research.google.com/drive/1wBv3L-emUAu-4Ml2KLyDCVJVPLVsb9WA?usp=sharing. There's a nice example of how profile JAX code there.

# Arithmetic Progressions

Recently, a problem from the USACO training pages has been bothering me. I had solved it years ago in Java, but my friend Robert Won challenged me to do in Python. Since Python is many times slower, this means my code has to be much smarter.

## Problem

An arithmetic progression is a sequence of the form $a$, $a+b$, $a+2b$, $\ldots$, $a+nb$ where $n=0, 1, 2, 3, \ldots$. For this problem, $a$ is a non-negative integer and $b$ is a positive integer.

Write a program that finds all arithmetic progressions of length $n$ in the set $S$ of bisquares. The set of bisquares is defined as the set of all integers of the form $p^2 + q^2$ (where $p$ and $q$ are non-negative integers).

TIME LIMIT: 5 secs

PROGRAM NAME: ariprog

INPUT FORMAT

• Line 1: $N$ ($3 \leq N \leq 25$), the length of progressions for which to search
• Line 2: $M$ ($1 \leq M \leq 250$), an upper bound to limit the search to the bisquares with $0 \leq p,q \leq M$.

SAMPLE INPUT (file ariprog.in)

5
7


OUTPUT FORMAT

If no sequence is found, a single line reading NONE. Otherwise, output one or more lines, each with two integers: the first element in a found sequence and the difference between consecutive elements in the same sequence. The lines should be ordered with smallest-difference sequences first and smallest starting number within those sequences first.

There will be no more than 10,000 sequences.

SAMPLE OUTPUT (file ariprog.out)

1 4
37 4
2 8
29 8
1 12
5 12
13 12
17 12
5 20
2 24


## Dynamic Programming Solution

My initial solution that I translated from C++ to Python was not fast enough. I wrote a new solution that I thought was clever. We iterate over all possible deltas, and for each delta, we use dynamic programming to find the longest sequence with that delta.

def find_arithmetic_progressions(N, M):
is_bisquare = [False] * (M * M + M * M + 1)
bisquare_indices = [-1] * (M * M + M * M + 1)
bisquares = []
for p in range(0, M + 1):
for q in range(p, M + 1):
x = p * p + q * q
if is_bisquare[x]: continue
is_bisquare[x] = True
bisquares.append(x)
bisquares.sort()
for i, bisquare in enumerate(bisquares):
bisquare_indices[bisquare] = i

sequences, i = [], 0
for delta in range(1, bisquares[-1] // (N - 1) + 1):
sequence_lengths = [1] * len(bisquares)
while bisquares[i] < delta: i += 1
for x in bisquares[i:]:
previous_idx = bisquare_indices[x - delta]
if previous_idx == -1: continue
idx, sequence_length = bisquare_indices[x], sequence_lengths[previous_idx] + 1
sequence_lengths[idx] = sequence_length
if sequence_length >= N:
sequences.append((delta, x - (N - 1) * delta))

return sequences


Too slow!

Executing...
Test 1: TEST OK [0.011 secs, 9352 KB]
Test 2: TEST OK [0.011 secs, 9352 KB]
Test 3: TEST OK [0.011 secs, 9168 KB]
Test 4: TEST OK [0.011 secs, 9304 KB]
Test 5: TEST OK [0.031 secs, 9480 KB]
Test 6: TEST OK [0.215 secs, 9516 KB]
Test 7: TEST OK [2.382 secs, 9676 KB]
> Run 8: Execution error: Your program (ariprog') used more than
the allotted runtime of 5 seconds (it ended or was stopped at
5.242 seconds) when presented with test case 8. It used 12948 KB
of memory.

------ Data for Run 8 [length=7 bytes] ------
22
250
----------------------------
Test 8: RUNTIME 5.242>5 (12948 KB)


I managed to put my mathematics background to good use here: $p^2 + q^2 \not\equiv 3 \pmod 4$ and $p^2 + q^2 \not\equiv 6 \pmod 8$. This means that a bisquare arithmetic progression with more than 3 elements must have delta divisible by 4. If $b \equiv 1 \pmod 4$ or $b \equiv 3 \pmod 4$, there would have to be a bisquare $p^2 + q^2 \equiv 3 \pmod 4$, which is impossible. If $b \equiv 2 \pmod 4$, there would be have to be $p^2 + q^2 \equiv 6 \pmod 8$, which is also impossible.

This optimization makes it fast, enough.

def find_arithmetic_progressions(N, M):
is_bisquare = [False] * (M * M + M * M + 1)
bisquare_indices = [-1] * (M * M + M * M + 1)
bisquares = []
for p in range(0, M + 1):
for q in range(p, M + 1):
x = p * p + q * q
if is_bisquare[x]: continue
is_bisquare[x] = True
bisquares.append(x)
bisquares.sort()
for i, bisquare in enumerate(bisquares):
bisquare_indices[bisquare] = i

sequences, i = [], 0
for delta in (range(1, bisquares[-1] // (N - 1) + 1) if N == 3 else
range(4, bisquares[-1] // (N - 1) + 1, 4)):
sequence_lengths = [1] * len(bisquares)
while bisquares[i] < delta: i += 1
for x in bisquares[i:]:
previous_idx = bisquare_indices[x - delta]
if previous_idx == -1: continue
idx, sequence_length = bisquare_indices[x], sequence_lengths[previous_idx] + 1
sequence_lengths[idx] = sequence_length
if sequence_length >= N:
sequences.append((delta, x - (N - 1) * delta))

return sequences

Executing...
Test 1: TEST OK [0.010 secs, 9300 KB]
Test 2: TEST OK [0.011 secs, 9368 KB]
Test 3: TEST OK [0.015 secs, 9248 KB]
Test 4: TEST OK [0.014 secs, 9352 KB]
Test 5: TEST OK [0.045 secs, 9340 KB]
Test 6: TEST OK [0.078 secs, 9464 KB]
Test 7: TEST OK [0.662 secs, 9756 KB]
Test 8: TEST OK [1.473 secs, 9728 KB]
Test 9: TEST OK [1.313 secs, 9740 KB]

All tests OK.


## Even Faster!

Not content to merely pass, I wanted to see if we could pass all test cases with less than 1 second (time limit was 5 seconds). Indeed, we can. The solution in the official analysis take advantage of the fact that the sequence length is short. The dynamic programming optimization is not that helpful. It's better to optimize for traversing the bisquares less. Instead, we take pairs of bisquares carefully: we break out when the delta is too big. The official solution has some inefficiencies like using a hash map. If we instead use indexed array lookups, we can be very fast.

def find_arithmetic_progressions(N, M):
is_bisquare = [False] * (M * M + M * M + 1)
bisquares = []
for p in range(0, M + 1):
for q in range(p, M + 1):
x = p * p + q * q
if is_bisquare[x]: continue
is_bisquare[x] = True
bisquares.append(x)
bisquares.sort()

sequences = []
for i in reversed(range(len(bisquares))):
x = bisquares[i]
max_delta = x // (N - 1)
for j in reversed(range(i)):
y = bisquares[j]
delta = x - y
if delta > max_delta: break
if N > 3 and delta % 4 != 0: continue
z = x - (N - 1) * delta
while y > z and is_bisquare[y - delta]: y -= delta
if z == y: sequences.append((delta, z))
sequences.sort()
return sequences

Executing...
Test 1: TEST OK [0.013 secs, 9280 KB]
Test 2: TEST OK [0.012 secs, 9284 KB]
Test 3: TEST OK [0.013 secs, 9288 KB]
Test 4: TEST OK [0.012 secs, 9208 KB]
Test 5: TEST OK [0.018 secs, 9460 KB]
Test 6: TEST OK [0.051 secs, 9292 KB]
Test 7: TEST OK [0.421 secs, 9552 KB]
Test 8: TEST OK [0.896 secs, 9588 KB]
Test 9: TEST OK [0.786 secs, 9484 KB]

All tests OK.


Yay!

# The Skyline Problem

The only question that really stumped me during my Google interviews was The Skyline Problem. I remember only being able to write some up a solution in pseudocode after being given many hints before my time was up.

It's been banned for some time now, so I thought I'd dump the solution here. Maybe, I will elaborate and clean up the code some other time. It's one of the cleverest uses of an ordered map (usually implemented as a tree map) that I've seen.

#include <algorithm>
#include <iostream>
#include <map>
#include <sstream>
#include <utility>
#include <vector>

using namespace std;

namespace {
struct Wall {
enum Type : int {
LEFT = 1,
RIGHT = 0
};

int position;
int height;
Type type;

Wall(int position, int height, Wall::Type type) :
position(position), height(height), type(type) {}

bool operator<(const Wall &other) {
return position < other.position;
}
};

ostream& operator<<(ostream& stream, const Wall &w) {
return stream << "Position: " << to_string(w.position) << ';'
<< " Height: " << to_string(w.height) << ';'
<< " Type: " << (w.type == Wall::Type::LEFT ? "Left" : "Right");
}
}  // namespace

class Solution {
public:
vector<vector<int>> getSkyline(vector<vector<int>>& buildings) {
vector<Wall> walls;
for (const vector<int>& building : buildings) {
walls.emplace_back(building[0], building[2], Wall::Type::LEFT);
walls.emplace_back(building[1], building[2], Wall::Type::RIGHT);
}
sort(walls.begin(), walls.end());
vector<vector<int>> skyline;
map<int, int> heightCount;
for (vector<Wall>::const_iterator wallPtr = walls.cbegin(); wallPtr != walls.cend();) {
int currentPosition = wallPtr -> position;
do {
if (wallPtr -> type == Wall::Type::LEFT) {
++heightCount[wallPtr -> height];
} else if (wallPtr -> type == Wall::Type::RIGHT) {
if (--heightCount[wallPtr -> height] == 0) {
heightCount.erase(wallPtr -> height);
}
}
++wallPtr;
} while (wallPtr != walls.cend() && wallPtr -> position == currentPosition);
if (skyline.empty() || heightCount.empty() ||
heightCount.crbegin() -> first != skyline.back()[1]) {
skyline.emplace_back(vector<int>{
currentPosition, heightCount.empty() ? 0 : heightCount.crbegin() -> first});
}
}
return skyline;
}
};


# Cycle Detection with $O(1)$ Memory

The easiest way to detect cycles in a linked list is to put all the seen nodes into a set and check that you don't have a repeat as you traverse the list. This unfortunately can blow up in memory for large lists.

Floyd's Tortoise and Hare algorithm gets around this by using two points that iterate through the list at different speeds. It's not immediately obvious why this should work.

/*
* For your reference:
*
* SinglyLinkedListNode {
*     int data;
*     SinglyLinkedListNode* next;
* };
*
*/
namespace {
template <typename Node>
bool has_cycle(const Node* const tortoise, const Node* const hare) {
if (tortoise == hare) return true;
if (hare->next == nullptr || hare->next->next == nullptr) return false;
return has_cycle(tortoise->next, hare->next->next);
}
}  // namespace

bool has_cycle(SinglyLinkedListNode* head) {
if (head == nullptr ||
head->next == nullptr ||
head->next->next == nullptr) return false;
return has_cycle(head, head->next->next);
}


The above algorithm solves HackerRank's Cycle Detection.

To see why this work, consider a cycle that starts at index $\mu$ and has length $l$. If there is a cycle, we should have $x_i = x_j$ for some $i,j \geq \mu$ and $i \neq j$. This should occur when $$i - \mu \equiv j - \mu \pmod{l}. \label{eqn:cond}$$

In the tortoise and hare algorithm, the tortoise moves with speed 1, and the hare moves with speed 2. Let $i$ be the location of the tortoise. Let $j$ be the location of the hare.

The cycle starts at $\mu$, so the earliest that we could see a cycle is when $i = \mu$. Then, $j = 2\mu$. Let $k$ be the number of steps we take after $i = \mu$. We'll satisfy Equation \ref{eqn:cond} when \begin{align*} i - \mu \equiv j - \mu \pmod{l} &\Leftrightarrow \left(\mu + k\right) - \mu \equiv \left(2\mu + 2k\right) - \mu \pmod{l} \\ &\Leftrightarrow k \equiv \mu + 2k \pmod{l} \\ &\Leftrightarrow 0 \equiv \mu + k \pmod{l}. \end{align*}

This will happen for some $k \leq l$, so the algorithm terminates within $\mu + k$ steps if there is a cycle. Otherwise, if there is no cycle the algorithm terminates when it reaches the end of the list.

# Overrandomized: Approximations with Integrals

Consider the problem Overrandomized. Intuitively, one can see something like Benford's law. Indeed, counting the leading digit works:

#include <algorithm>
#include <iostream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

using namespace std;

string Decode() {
unordered_map<char, int> char_counts; unordered_set<char> chars;
for (int i = 0; i < 10000; ++i) {
long long Q; string R; cin >> Q >> R;
char_counts[R[0]]++;
for (char c : R) chars.insert(c);
}
vector<pair<int, char>> count_chars;
for (const pair<char, int>& char_count : char_counts) {
count_chars.emplace_back(char_count.second, char_count.first);
}
sort(count_chars.begin(), count_chars.end());
string code;
for (const pair<int, char>& count_char : count_chars) {
code += count_char.second;
chars.erase(count_char.second);
}
code += *chars.begin();
reverse(code.begin(), code.end());
return code;
}

int main(int argc, char *argv[]) {
ios::sync_with_stdio(false); cin.tie(NULL);
int T; cin >> T;
for (int t = 1; t <= T; ++t) {
int U; cin >> U;
cout << "Case #" << t << ": " << Decode() << '\n';
}
cout << flush;
return 0;
}


Take care to read Q as a long long because it can be large.

It occurred to me that there's no reason the logarithms of the randomly generated numbers should be uniformly distributed, so I decided to look into this probability distribution closer. Let $R$ be the random variable representing the return value of a query.

\begin{align*} P(R = r) &= \sum_{m = r}^{10^U - 1} P(M = m, R = r) \\ &= \sum_{m = r}^{10^U - 1} P(R = r \mid M = m)P(M = m) \\ &= \frac{1}{10^U - 1}\sum_{m = r}^{10^U - 1} \frac{1}{m}. \end{align*} since $P(M = m) = 1/(10^U - 1)$ for all $m$.

The probability that we get a $k$ digit number that starts with a digit $d$ is then \begin{align*} P(d \times 10^{k-1} \leq R < (d + 1) \times 10^{k-1}) &= \frac{1}{10^U - 1} \sum_{r = d \times 10^{k-1}}^{(d + 1) \times 10^{k-1} - 1} \sum_{m = r}^{10^U - 1} \frac{1}{m}. \end{align*}

Here, you can already see that for a fixed $k$, smaller $d$s will have more terms, so they should occur as leading digits with higher probability. It's interesting to try to figure out how much more frequently this should happen, though. To get rid of the summation, we can use integrals! This will make the computation tractable for large $k$ and $U$. Here, I start dropping the $-1$s in the approximations.

\begin{align*} P\left(d \times 10^{k-1} \leq R < (d + 1) \times 10^{k-1}\right) &= \frac{1}{10^U - 1} \sum_{r = d \times 10^{k-1}}^{(d + 1) \times 10^{k-1} - 1} \sum_{m = r}^{10^U - 1} \frac{1}{m} \\ &\approx \frac{1}{10^U} \sum_{r = d \times 10^{k-1}}^{(d + 1) \times 10^{k-1} - 1} \left[\log 10^U - \log r \right] \\ &=\frac{10^{k - 1}}{10^{U}}\left[ U\log 10 - \frac{1}{10^{k - 1}}\sum_{r = d \times 10^{k-1}}^{(d + 1) \times 10^{k-1} - 1} \log r \right]. \end{align*}

Again, we can apply integration. Using integration by parts, we have $\int_a^b x \log x \,dx = b\log b - b - \left(a\log a - a\right)$, so \begin{align*} \sum_{r = d \times 10^{k-1}}^{(d + 1) \times 10^{k-1} - 1} \log r &\approx 10^{k-1}\left[ (k - 1)\log 10 + (d + 1) \log (d + 1) - d \log d - 1 \right]. \end{align*}

Substituting, we end up with \begin{align*} P&\left(d \times 10^{k-1} \leq R < (d + 1) \times 10^{k-1}\right) \approx \\ &\frac{1}{10^{U - k + 1}}\left[ 1 + (U - k + 1)\log 10 - \left[(d + 1) \log(d+1) - d\log d\right] \right]. \end{align*}

We can make a few observations. Numbers with lots of digits are more likely to occur since for larger $k$, the denominator is much smaller. This makes sense: there are many more large numbers than small numbers. Independent of $k$, if $d$ is larger, the quantity inside the inner brackets is larger since $x \log x$ is convex, so the probability decreases with $d$. Thus, smaller digits occur more frequently. While the formula follows the spirit of Benford's law, the formula is not quite the same.

This was the first time I had to use integrals for a competitive programming problem!

# Induction on Prefix Trees

One of my favorite things about personal coding projects is that you're free to over-engineer and prematurely optimize your code to your heart's content. Production code written in a shared code based needs to be maintained, and hence, should favor simplicity and readability. For personal projects, I optimize for fun, and what could be more fun than elaborate abstractions, unnecessary optimizations, and abusing recursion?

To that end, I present my solution to the Google Code Jam 2019 Round 1A problem, Alien Rhyme.

In this problem, we maximize the number of pairs of words that could possibly rhyme. I guess this problem has some element of realism as it's similar in spirit to using frequency analysis to decode or identify a language.

After reversing the strings, this problem reduces to greedily taking pairs of words with the longest common prefix. Each time we select a prefix, we update the sizes of the remaining prefixes. If where are $N$ words, this algorithm is $O\left(N^2\right)$ and can be implemented with a linked list in C++:

// Reverses and sorts suffixes to make finding common longest common suffix easier.
vector<string> NormalizeSuffixes(const vector<string>& words) {
vector<string> suffixes; suffixes.reserve(words.size());
for (const string& word : words) {
suffixes.push_back(word);
reverse(suffixes.back().begin(), suffixes.back().end());
}
sort(suffixes.begin(), suffixes.end());
return suffixes;
}

int CountPrefix(const string &a, const string &b) {
int size = 0;
for (int i = 0; i < min(a.length(), b.length()); ++i)
if (a[i] == b[i]) { ++size; } else { break; }
return size;
}

int MaximizePairs(const vector<string>& words) {
const vector<string> suffixes = NormalizeSuffixes(words);
// Pad with zeros: pretend there are empty strings at the beginning and end.
list<int> prefix_sizes{0};
for (int i = 1; i < suffixes.size(); ++i)
prefix_sizes.push_back(CountPrefix(suffixes[i - 1], suffixes[i]));
prefix_sizes.push_back(0);
// Count the pairs by continually finding the longest common prefix.
list<int>::iterator max_prefix_size;
while ((max_prefix_size = max_element(prefix_sizes.begin(), prefix_sizes.end())) !=
prefix_sizes.begin()) {
// Claim this prefix and shorten the other matches.
while (*next(max_prefix_size) == *max_prefix_size) {
--(*max_prefix_size);
++max_prefix_size;
}
// Use transitivity to update the common prefix size.
*next(max_prefix_size) = min(*prev(max_prefix_size), *next(max_prefix_size));
prefix_sizes.erase(prefix_sizes.erase(prev(max_prefix_size)));
}
return suffixes.size() - (prefix_sizes.size() - 1);
}


A single file example can be found on GitHub. Since $N \leq 1000$ in this problem, this solution is more than adequate.

## Asymptotically Optimal Solution

We can use the fact that the number of characters in each word $W$ is at most 50 and obtain a $O\left(N\max\left(\log N, W\right)\right)$ solution.

### Induction

Suppose we have a tree where each node is a prefix (sometimes called a trie). In the worst case, each prefix will have a single character. The title image shows such a tree for the words: PREFIX, PRELIM, PROF, SUFFER, SUFFIX, SUM, SWIFT, SWIFTER, SWOLE.

Associated with each node is a count of how many words have that prefix as a maximal prefix. The depth of each node is the sum of the traversed prefix sizes.

The core observation is that at any given node, any words in the subtree can have a common prefix with length at least the depth of the node. Greedily selecting the longest common prefixes corresponds to pairing all possible prefixes in a subtree with length greater than the depth of the parent. The unused words can then be used higher up in the tree to make additional prefixes. Tree algorithms are best expressed recursively. Here's the Swift code.

func maximizeTreePairs<T: Collection>(
root: Node<T>, depth: Int, minPairWordCount: Int) -> (used: Int, unused: Int)
where T.Element: Hashable {
let (used, unused) = root.children.reduce(
(used: 0, unused: root.count),
{
(state: (used: Int, unused: Int), child) -> (used: Int, unused: Int) in
let childState = maximizeTreePairs(
root: child.value, depth: child.key.count + depth, minPairWordCount: depth)
return (state.used + childState.used, state.unused + childState.unused)
})
let shortPairUsed = min(2 * (depth - minPairWordCount), (unused / 2) * 2)
return (used + shortPairUsed, unused - shortPairUsed)
}

func maximizePairs(_ words: [String]) -> Int {
let suffixes = normalizeSuffixes(words)
let prefixTree = compress(makePrefixTree(suffixes))
return prefixTree.children.reduce(
0, { $0 + maximizeTreePairs( root:$1.value, depth: $1.key.count, minPairWordCount: 0).used }) }  Since the tree has maximum depth$W$and there are$N$words, recursing through the tree is$O\left(NW\right)$. ### Making the Prefix Tree The simplest way to construct a prefix tree is to start at the root for each word and character-by-character descend into the tree, creating any nodes necessary. Update the count of the node when reaching the end of the word. This is$O\left(NW\right)$. As far as I know, the wost case will always be$O\left(NW\right)$. In practice, though, if there are many words with lengthy shared common prefixes we can avoid retracing paths through the tree. In our example, consider SWIFT and SWIFTER. If we naively construct a tree, we will need to traverse through$5 + 7 = 12$nodes. But if we insert our words in lexographic order, we don't need to retrace the first 5 characters and simply only need to traverse 7 nodes. Swift has somewhat tricky value semantics. structs are always copied, so we need to construct this tree recursively. func makePrefixTree<T: StringProtocol>(_ words: [T]) -> Node<T.Element> { let prefixCounts = words.reduce( into: (counts: [0], word: "" as T), {$0.counts.append(countPrefix($0.word,$1))
$0.word =$1
}).counts
let minimumPrefixCount = MinimumRange(prefixCounts)
let words = [""] + words
/// Inserts words[i] into a rooted tree.
///
/// - Parameters:
///  - root: The root node of the tree.
///  - state: The index of the word for the current path and depth of root.
///  - i: The index of the word to be inserted.
/// - Returns: The index of the next word to be inserted.
func insert(_ root: inout Node<T.Element>,
_ state: (node: Int, depth: Int),
_ i: Int) -> Int {
// Start inserting only for valid indices and at the right depth.
if i >= words.count { return i }
// Max number of nodes that can be reused for words[i].
let prefixCount = state.node == i ?
prefixCounts[i] : minimumPrefixCount.query(from: state.node + 1, through: i)
// Either (a) inserting can be done more efficiently at a deeper node;
// or (b) we're too deep in the wrong state.
if prefixCount > state.depth || (prefixCount < state.depth && state.node != i) { return i }
// Start insertion process! If we're at the right depth, insert and move on.
if state.depth == words[i].count {
root.count += 1
return insert(&root, (i, state.depth), i + 1)
}
// Otherwise, possibly create a node and traverse deeper.
let key = words[i][words[i].index(words[i].startIndex, offsetBy: state.depth)]
if root.children[key] == nil {
root.children[key] = Node<T.Element>(children: [:], count: 0)
}
// After finishing traversal insert the next word.
return insert(
&root, state, insert(&root.children[key]!, (i, state.depth + 1), i))
}
var root = Node<T.Element>(children: [:], count: 0)
let _ = insert(&root, (0, 0), 1)
return root
}


While the naive implementation of constructing a trie would involve $48$ visits to a node (the sum over the lengths of each word), this algorithm does it in $28$ visits as seen in the title page. Each word insertion has its edges colored separately in the title image.

Now, for this algorithm to work efficiently, it's necessary to start inserting the next word at the right depth, which is the size of longest prefix that the words share.

### Minimum Range Query

Computing the longest common prefix of any two words reduces to a minimum range query. If we order the words lexographically, we can compute the longest common prefix size between adjacent words. The longest common prefix size of two words $i$ and $j$, where $i < j$ is then:

$$\textrm{LCP}(i, j) = \min\left\{\textrm{LCP}(i, i + 1), \textrm{LCP}(i + 1, i + 2), \ldots, \textrm{LCP}(j - 1, j)\right\}.$$

A nice dynamic programming $O\left(N\log N\right)$ algorithm exists to precompute such queries that makes each query $O\left(1\right)$.

We'll $0$-index to make the math easier to translate into code. Given an array $A$ of size $N$, let

$$P_{i,j} = \min\left\{A_k : i \leq k < i + 2^{j} - 1\right\}.$$

Then, we can write $\mathrm{LCP}\left(i, j - 1\right) = \min\left(P_{i, l}, P_{j - 2^l, l}\right)$, where $l = \max\left\{l : l \in \mathbb{Z}, 2^l \leq j - i\right\}$ since $\left([i, i + 2^l) \cup [j - 2^l, j)\right) \cap \mathbb{Z} = \left\{i, i + 1, \ldots , j - 1\right\}$.

$P_{i,0}$ can be initialized $P_{i,0} = A_{i}$, and for $j > 0$, we can have

$$P_{i,j} = \begin{cases} \min\left(P_{i, j - 1}, P_{i + 2^{j - 1}, j - 1}\right) & i + 2^{j -1} < N; \\ P_{i, j - 1} & \text{otherwise}. \\ \end{cases}$$

See a Swift implementation.

struct MinimumRange<T: Collection> where T.Element: Comparable {
private let memo: [[T.Element]]
private let reduce: (T.Element, T.Element) -> T.Element

init(_ collection: T,
reducer reduce: @escaping (T.Element, T.Element) -> T.Element = min) {
let k = collection.count
var memo: [[T.Element]] = Array(repeating: [], count: k)
for (i, element) in collection.enumerated() { memo[i].append(element) }
for j in 1..<(k.bitWidth - k.leadingZeroBitCount) {
let offset = 1 << (j - 1)
for i in 0..<memo.count {
memo[i].append(
i + offset < k ?
reduce(memo[i][j - 1], memo[i + offset][j - 1]) : memo[i][j - 1])
}
}
self.memo = memo
self.reduce = reduce
}

func query(from: Int, to: Int) -> T.Element {
let (from, to) = (max(from, 0), min(to, memo.count))
let rangeCount = to - from
let bitShift = rangeCount.bitWidth - rangeCount.leadingZeroBitCount - 1
let offset = 1 << bitShift
return self.reduce(self.memo[from][bitShift], self.memo[to - offset][bitShift])
}

func query(from: Int, through: Int) -> T.Element {
return query(from: from, to: through + 1)
}
}


### Path Compression

Perhaps my most unnecessary optimization is path compression, especially since we only traverse the tree once in the induction step. If we were to traverse the tree multiple times, it might be worth it, however. This optimization collapses count $0$ nodes with only $1$ child into its parent.

/// Use path compression. Not necessary, but it's fun!
func compress(_ uncompressedRoot: Node<Character>) -> Node<String> {
var root = Node<String>(
children: [:], count: uncompressedRoot.count)
for (key, node) in uncompressedRoot.children {
let newChild = compress(node)
if newChild.children.count == 1, newChild.count == 0,
let (childKey, grandChild) = newChild.children.first {
root.children[String(key) + childKey] = grandChild
} else {
root.children[String(key)] = newChild
}
}
return root
}


### Full Code Example

A full example with everything wired together can be found on GitHub.

## GraphViz

Also, if you're interested in the graph in the title image, I used GraphViz. It's pretty neat. About a year ago, I made a trivial commit to the project: https://gitlab.com/graphviz/graphviz/commit/1cc99f32bb1317995fb36f215fb1e69f96ce9fed.

digraph {
rankdir=LR;
// SUFFER
S1 [label="S"];
F3 [label="F"];
F4 [label="F"];
E2 [label="E"];
R2 [label="R"];
S1 -> U [label=12, color="#984ea3"];
U -> F3 [label=13, color="#984ea3"];
F3 -> F4 [label=14, color="#984ea3"];
F4 -> E2 [label=15, color="#984ea3"];
E2 -> R2 [label=16, color="#984ea3"];
// PREFIX
E1 [label="E"];
R1 [label="R"];
F1 [label="F"];
I1 [label="I"];
X1 [label="X"];
P -> R1 [label=1, color="#e41a1c"];
R1 -> E1 [label=2, color="#e41a1c"];
E1 -> F1 [label=3, color="#e41a1c"];
F1 -> I1 [label=4, color="#e41a1c"];
I1 -> X1 [label=5, color="#e41a1c"];
// PRELIM
L1 [label="L"];
I2 [label="I"];
M1 [label="M"];
E1 -> L1 [label=6, color="#377eb8"];
L1 -> I2 [label=7, color="#377eb8"];
I2 -> M1 [label=9, color="#377eb8"];
// PROF
O1 [label="O"];
F2 [label="F"];
R1 -> O1 [label=10, color="#4daf4a"];
O1 -> F2 [label=11, color="#4daf4a"];
// SUFFIX
I3 [label="I"];
X2 [label="X"];
F4 -> I3 [label=17, color="#ff7f00"];
I3 -> X2 [label=18, color="#ff7f00"];
// SUM
M2 [label="M"];
U -> M2 [label=19, color="#ffff33"];
// SWIFT
I4 [label="I"];
F5 [label="F"];
T1 [label="T"];
S1 -> W [label=20, color="#a65628"];
W -> I4 [label=21, color="#a65628"];
I4 -> F5 [label=22, color="#a65628"];
F5 -> T1 [label=23, color="#a65628"];
// SWIFTER
E3 [label="E"];
R3 [label="R"];
T1 -> E3 [label=24, color="#f781bf"];
E3 -> R3 [label=25, color="#f781bf"];
// SWOLE
O2 [label="O"];
L2 [label="L"];
E4 [label="E"];
W -> O2 [label=26, color="#999999"];
O2 -> L2 [label=27, color="#999999"];
L2 -> E4 [label=28, color="#999999"];
}


# Hamiltonian Paths in Nearly Complete Graphs

In general, the Hamiltonian path problem is NP-complete. But in some special cases, polynomial-time algorithms exists.

One such case is in Pylons, the Google Code Jam 2019 Round 1A problem. In this problem, we are presented with a grid graph. Each cell is a node, and a node is connected to every other node except those along its diagonals, in the same column, or in the same row. In the example, from the blue cell, we can move to any other cell except the red cells.

If there are $N$ cells, an $O\left(N^2\right)$ is to visit the next available cell with the most unavailable, unvisited cells. Why? We should visit those cells early because if we wait too long, we will become stuck at those cells when we inevitably need to visit them.

I've implemented the solution in findSequence (see full Swift solution on GitHub):

func findSequence(N: Int, M: Int) -> [(row: Int, column: Int)]? {
// Other cells to which we are not allowed to jump.
var badNeighbors: [Set<Int>] = Array(repeating: Set(), count: N * M)
for i in 0..<(N * M) {
let (ri, ci) = (i / M, i % M)
for j in 0..<(N * M) {
let (rj, cj) = (j / M, j % M)
if ri == rj || ci == cj || ri - ci == rj - cj || ri + ci == rj + cj {
badNeighbors[i].insert(j)
badNeighbors[j].insert(i)
}
}
}
// Greedily select the cell which has the most unallowable cells.
var sequence: [(row: Int, column: Int)] = []
var visited: Set<Int> = Set()
while sequence.count < N * M {
guard let i = (badNeighbors.enumerated().filter {

Yay, 2017!