# Policy-Based Data Structures in C++

Certain problems in competitive programming call for more advanced data structure than our built into Java's or C++'s standard libraries. Two examples are an order statistic tree and a priority queue that lets you modify priorities. It's questionable whether these implementations are useful outside of competitive programming since you could just use Boost.

## Order Statistic Tree

Consider the problem ORDERSET. An order statistic tree trivially solves this problem. And actually, implementing an order statistic tree is not so difficult. You can find the implementation here. Basically, you have a node invariant

operator()(node_iterator node_it, node_const_iterator end_nd_it) const {
node_iterator l_it = node_it.get_l_child();
const size_type l_rank = (l_it == end_nd_it) ? 0 : l_it.get_metadata();

node_iterator r_it = node_it.get_r_child();
const size_type r_rank = (r_it == end_nd_it) ? 0 : r_it.get_metadata();

}


where each node contains a count of nodes in its subtree. Every time you insert a new node or delete a node, you can maintain the invariant in $O(\log N)$ time by bubbling up to the root.

With this extra data in each node, we can implement two new methods, (1) find_by_order and (2) order_of_key. find_by_order takes a nonnegative integer as an argument and returns the node corresponding to that index, where are data is sorted and we use $0$-based indexing.

find_by_order(size_type order) {
node_iterator it = node_begin();
node_iterator end_it = node_end();

while (it != end_it) {
node_iterator l_it = it.get_l_child();
const size_type o = (l_it == end_it)? 0 : l_it.get_metadata();

if (order == o) {
return *it;
} else if (order < o) {
it = l_it;
} else {
order -= o + 1;
it = it.get_r_child();
}
}

return base_type::end_iterator();
}


It works recursively like this. Call the index we're trying to find $k$. Let $l$ be the number of nodes in the left subtree.

• $k = l$: If you're trying to find the $k$th-indexed element, then there will be $k$ nodes to your left, so if the left child has $k$ elements in its subtree, you're done.
• $k < l$: The $k$-indexed element is in the left subtree, so replace the root with the left child.
• $k > l$: The $k$ indexed element is in the right subtree. It's equivalent to looking for the $k - l - 1$ element in the right subtree. We subtract away all the nodes in the left subtree and the root and replace the root with the right child.

order_of_key takes whatever type is stored in the nodes as an argument. These types are comparable, so it will return the index of the smallest element that is greater or equal to the argument, that is, the least upper bound.

order_of_key(key_const_reference r_key) const {
node_const_iterator it = node_begin();
node_const_iterator end_it = node_end();

const cmp_fn& r_cmp_fn = const_cast<PB_DS_CLASS_C_DEC*>(this)->get_cmp_fn();
size_type ord = 0;
while (it != end_it) {
node_const_iterator l_it = it.get_l_child();

if (r_cmp_fn(r_key, this->extract_key(*(*it)))) {
it = l_it;
} else if (r_cmp_fn(this->extract_key(*(*it)), r_key)) {
ord += (l_it == end_it)? 1 : 1 + l_it.get_metadata();
it = it.get_r_child();
} else {
ord += (l_it == end_it)? 0 : l_it.get_metadata();
it = end_it;
}
}
return ord;
}


This is a simple tree traversal, where we keep track of order as we traverse the tree. Every time we go down the right branch, we add $1$ for every node in the left subtree and the current node. If we find a node that it's equal to our key, we add $1$ for every node in the left subtree.

While not entirely trivial, one could write this code during a contest. But what happens when we need a balanced tree. Both Java implementations of TreeSet and C++ implementations of set use a red-black tree, but their APIs are such that the trees are not easily extensible. Here's where Policy-Based Data Structures come into play. They have a mechanism to create a node update policy, so we can keep track of metadata like the number of nodes in a subtree. Conveniently, tree_order_statistics_node_update has been written for us. Now, our problem can be solved quite easily. I have to make some adjustments for the $0$-indexing. Here's the code.

#include <functional>
#include <iostream>

#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;

namespace phillypham {
template<typename T,
typename cmp_fn = less<T>>
using order_statistic_tree =
__gnu_pbds::tree<T,
__gnu_pbds::null_type,
cmp_fn,
__gnu_pbds::rb_tree_tag,
__gnu_pbds::tree_order_statistics_node_update>;
}

int main(int argc, char *argv[]) {
ios::sync_with_stdio(false); cin.tie(NULL);

int Q; cin >> Q; // number of queries

phillypham::order_statistic_tree<int> orderStatisticTree;
for (int q = 0; q < Q; ++q) {
char operation;
int parameter;
cin >> operation >> parameter;
switch (operation) {
case 'I':
orderStatisticTree.insert(parameter);
break;
case 'D':
orderStatisticTree.erase(parameter);
break;
case 'K':
if (1 <= parameter && parameter <= orderStatisticTree.size()) {
cout << *orderStatisticTree.find_by_order(parameter - 1) << '\n';
} else {
cout << "invalid\n";
}
break;
case 'C':
cout << orderStatisticTree.order_of_key(parameter) << '\n';
break;
}
}
cout << flush;
return 0;
}


## Dijkstra's algorithm and Priority Queues

Consider the problem SHPATH. Shortest path means Dijkstra's algorithm of course. Optimal versions of Dijkstra's algorithm call for exotic data structures like Fibonacci heaps, which lets us achieve a running time of $O(E + V\log V)$, where $E$ is the number of edges, and $V$ is the number of vertices. In even a fairly basic implementation in the classic CLRS, we need more than what the standard priority queues in Java and C++ offer. Either, we implement our own priority queues or use a slow $O(V^2)$ version of Dijkstra's algorithm.

Thanks to policy-based data structures, it's easy to use use a fancy heap for our priority queue.

#include <algorithm>
#include <climits>
#include <exception>
#include <functional>
#include <iostream>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include <ext/pb_ds/priority_queue.hpp>

using namespace std;

namespace phillypham {
template<typename T,
typename cmp_fn = less<T>> // max queue by default
class priority_queue {
private:
struct pq_cmp_fn {
bool operator()(const pair<size_t, T> &a, const pair<size_t, T> &b) const {
return cmp_fn()(a.second, b.second);
}
};
typedef typename __gnu_pbds::priority_queue<pair<size_t, T>,
pq_cmp_fn,
__gnu_pbds::pairing_heap_tag> pq_t;
typedef typename pq_t::point_iterator pq_iterator;
pq_t pq;
vector<pq_iterator> map;

public:
class entry {
private:
size_t _key;
T _value;

public:
entry(size_t key, T value) : _key(key), _value(value) {}

size_t key() const { return _key; }

T value() const { return _value; }
};

priority_queue() {}

priority_queue(int N) : map(N, nullptr) {}

size_t size() const {
return pq.size();
}

size_t capacity() const {
return map.size();
}

bool empty() const {
return pq.empty();
}

/**
* Usually, in C++ this returns an rvalue that you can modify.
* I choose not to allow this because it's dangerous, however.
*/
T operator[](size_t key) const {
return map[key] -> second;
}

T at(size_t key) const {
if (map.at(key) == nullptr) throw out_of_range("Key does not exist!");
return map.at(key) -> second;
}

entry top() const {
return entry(pq.top().first, pq.top().second);
}

int count(size_t key) const {
if (key < 0 || key >= map.size() || map[key] == nullptr) return 0;
return 1;
}

pq_iterator push(size_t key, T value) {
// could be really inefficient if there's a lot of resizing going on
if (key >= map.size()) map.resize(key + 1, nullptr);
if (key < 0) throw out_of_range("The key must be nonnegative!");
if (map[key] != nullptr) throw logic_error("There can only be 1 value per key!");
map[key] = pq.push(make_pair(key, value));
return map[key];
}

void modify(size_t key, T value) {
pq.modify(map[key], make_pair(key, value));
}

void pop() {
if (empty()) throw logic_error("The priority queue is empty!");
map[pq.top().first] = nullptr;
pq.pop();

}

void erase(size_t key) {
if (map[key] == nullptr) throw out_of_range("Key does not exist!");
pq.erase(map[key]);
map[key] = nullptr;
}

void clear() {
pq.clear();
fill(map.begin(), map.end(), nullptr);
}
};
}


By replacing __gnu_pbds::pairing_heap_tag with __gnu_pbds::binomial_heap_tag, __gnu_pbds::rc_binomial_heap_tag, or __gnu_pbds::thin_heap_tag, we can try different types of heaps easily. See the priority_queue interface. Unfortunately, we cannot try the binary heap because modifying elements invalidates iterators. Conveniently enough, the library allows us to check this condition dynamically .

#include <iostream>
#include <functional>
#include <ext/pb_ds/priority_queue.hpp>

using namespace std;

int main(int argc, char *argv[]) {
__gnu_pbds::priority_queue<int, less<int>, __gnu_pbds::binary_heap_tag> pq;
cout << (typeid(__gnu_pbds::container_traits<decltype(pq)>::invalidation_guarantee) == typeid(__gnu_pbds::basic_invalidation_guarantee)) << endl;
// prints 1
cout << (typeid(__gnu_pbds::container_traits<__gnu_pbds::priority_queue<int, less<int>, __gnu_pbds::binary_heap_tag>>::invalidation_guarantee) == typeid(__gnu_pbds::basic_invalidation_guarantee)) << endl;
// prints 1
return 0;
}


See the documentation for basic_invalidation_guarantee. We need at least point_invalidation_guarantee for the below code to work since we keep a vector of iterators in our phillypham::priority_queue.

vector<int> findShortestDistance(const vector<vector<pair<int, int>>> &adjacencyList,
int sourceIdx) {
phillypham::priority_queue<int, greater<int>> minDistancePriorityQueue(N);
for (int i = 0; i < N; ++i) {
minDistancePriorityQueue.push(i, i == sourceIdx ? 0 : INT_MAX);
}
vector<int> distances(N, INT_MAX);
while (!minDistancePriorityQueue.empty()) {
phillypham::priority_queue<int, greater<int>>::entry minDistanceVertex =
minDistancePriorityQueue.top();
minDistancePriorityQueue.pop();
distances[minDistanceVertex.key()] = minDistanceVertex.value();
for (pair<int, int> nextVertex : adjacencyList[minDistanceVertex.key()]) {
int newDistance = minDistanceVertex.value() + nextVertex.second;
if (minDistancePriorityQueue.count(nextVertex.first) &&
minDistancePriorityQueue[nextVertex.first] > newDistance) {
minDistancePriorityQueue.modify(nextVertex.first, newDistance);
}
}
}
return distances;
}


Fear not, I ended up using my own binary heap that wrote from Dijkstra, Paths, Hashing, and the Chinese Remainder Theorem. Now, we can benchmark all these different implementations against each other.

int main(int argc, char *argv[]) {
ios::sync_with_stdio(false); cin.tie(NULL);
int T; cin >> T;              // number of tests
for (int t = 0; t < T; ++t) {
int N; cin >> N;            // number of nodes
unordered_map<string, int> cityIdx;
for (int i = 0; i < N; ++i) {
string city;
cin >> city;
cityIdx[city] = i;
int M; cin >> M;
for (int j = 0; j < M; ++j) {
int neighborIdx, cost;
cin >> neighborIdx >> cost;
--neighborIdx; // convert to 0-based indexing
}
}
// compute output
int R; cin >> R;            // number of subtests
for (int r = 0; r < R; ++r) {
string sourceCity, targetCity;
cin >> sourceCity >> targetCity;
int sourceIdx = cityIdx[sourceCity];
int targetIdx = cityIdx[targetCity];
cout << distances[targetIdx] << '\n';
}
}
cout << flush;
return 0;
}


I find that the policy-based data structures are much faster than my own hand-written priority queue.

Algorithm Time (seconds)
PBDS Pairing Heap, Lazy Push 0.41
PBDS Pairing Heap 0.44
PBDS Binomial Heap 0.48
PBDS Thin Heap 0.54
PBDS RC Binomial Heap 0.60
Personal Binary Heap 0.72

Lazy push is small optimization, where we add vertices to the heap as we encounter them. We save a few hundreths of a second at the expense of increased code complexity.

vector<int> findShortestDistance(const vector<vector<pair<int, int>>> &adjacencyList,
int sourceIdx) {
vector<int> distances(N, INT_MAX);
phillypham::priority_queue<int, greater<int>> minDistancePriorityQueue(N);
minDistancePriorityQueue.push(sourceIdx, 0);
while (!minDistancePriorityQueue.empty()) {
phillypham::priority_queue<int, greater<int>>::entry minDistanceVertex =
minDistancePriorityQueue.top();
minDistancePriorityQueue.pop();
distances[minDistanceVertex.key()] = minDistanceVertex.value();
for (pair<int, int> nextVertex : adjacencyList[minDistanceVertex.key()]) {
int newDistance = minDistanceVertex.value() + nextVertex.second;
if (distances[nextVertex.first] == INT_MAX) {
minDistancePriorityQueue.push(nextVertex.first, newDistance);
distances[nextVertex.first] = newDistance;
} else if (minDistancePriorityQueue.count(nextVertex.first) &&
minDistancePriorityQueue[nextVertex.first] > newDistance) {
minDistancePriorityQueue.modify(nextVertex.first, newDistance);
distances[nextVertex.first] = newDistance;
}
}
}
return distances;
}


All in all, I found learning to use these data structures quite fun. It's nice to have such easy access to powerful data structures. I also learned a lot about C++ templating on the way.

# Segmented Sieve of Eratosthenes

A classic programming problem is to find all the primes up to a certain number. This problem admits a classic solution, the sieve of Eratosthenes. Here it is in Java.

/**
* @param upper bound exclusive
* @return a list of primes strictly less than upper
*/
public static Deque<Integer> findPrimes(int upper) {
Deque<Integer> primes = new ArrayDeque<Integer>();
if (upper <= 2) return primes;
boolean[] isPrime = new boolean[(upper-2)/2]; // index 0 is 3
Arrays.fill(isPrime, true);
for (int p = 3; p < upper; p += 2) {
if (isPrime[p/2 - 1]) {
// only need to start from p^2 since we already checked p*m, where m < p
for (long q = ((long) p)*((long) p); q < upper; q += 2*p) {
isPrime[((int) q)/2 - 1] = false;
}
}
}
return primes;
}


The problem is with the isPrime array. This solution is $O(n)$ in space and computation. We may only be interested in finding large primes from say 999,900,000 to 1,000,000,000 as in this problem PRIME1. It doesn't make sense to check numbers less than 999,900,000 or allocate space for them.

Hence, we use a segmented sieve. The underlying idea is that to check if a number $P$ is prime by trial division, we only need to check that it is not divisible by any prime numbers $q \leq \sqrt{P}$. Thus, if we want to find all the primes between $m$ and $n$, we first generate all the primes that are less than or equal to $\sqrt{n}$ with the traditional sieve. Let $S$ be the set of those primes.

Then, let $L$ be some constant number. We work in segments $[m, m + L)$, $[m + L, m + 2L)$, $\ldots$, $[m + kL, n + 1)$. In each of these segments, we identify of all the multiples of the primes found in $S$ and mark them as not prime. Now, we only need $O(\max(|S|, L))$ space, and computation is $$O\left(|S| \cdot \frac{n-m}{L} + (n-m)\right),$$ and we can set $L$ to be as large or small as we want.

By the prime number theorem, $|S|$ is not typically very large. Asympototically, $$|S| = \pi(\sqrt{n}) \sim \frac{\sqrt{n}}{\log \sqrt{n}}.$$ For $L$, we have a tradeoff. If we have large $L$, we may need a lot of space. If we have $L$ too small, our sieve is very small and may not contain many multiples of the primes in $S$, which results in wasted computation. Here is the code with some tweaks to avoid even numbers.

/**
* Find primes in range
* @param lower bound, inclusive
* @param upper bound exclusive
* @param sieveSize space to use
* @return list of primes in range
*/
public static Deque<Integer> findPrimes(int lower, int upper, int sieveSize) {
if (lower >= upper) throw new IllegalArgumentException("lower must be less than upper");
int sievingPrimesUpper = (int) Math.sqrt(upper);
if (lower <= sievingPrimesUpper || sievingPrimesUpper <= 2) {
Deque<Integer> primes = findPrimes(upper);
if (!primes.isEmpty()) while (primes.peekFirst() < lower) primes.removeFirst();
return primes;
}
if (sieveSize < 5) sieveSize = 10;
Deque<Integer> primes = new ArrayDeque<Integer>();
Deque<Integer> sievingPrimes = findPrimes(sievingPrimesUpper + 1);
sievingPrimes.removeFirst(); // get rid of 2
while (!sievingPrimes.isEmpty() &&
sievingPrimes.getLast()*sievingPrimes.getLast() >= upper) sievingPrimes.removeLast();
if (lower % 2 == 0) lower += 1; // make lower odd
boolean[] isPrime = new boolean[sieveSize]; // isPrime[i] refers to lower + 2*i
/**
* Find first odd multiple for each sieving prime. lower + 2*nextMultipleOffset[i]
* will be the first odd multiple of sievingPrimes[i] that is greater than or
* equal to lower.
*/
int[] nextMultipleOffset = new int[sievingPrimes.size()];
int idx = 0;
for (int p : sievingPrimes) {
int nextMultiple = lower - (lower % p); // make it a multiple of p
if (nextMultiple < lower)  nextMultiple += p; // make it at least lower
if (nextMultiple % 2 == 0) nextMultiple += p; // make sure it's odd
nextMultipleOffset[idx++] = (nextMultiple - lower)/2;
}
while (lower < upper) {
Arrays.fill(isPrime, true);
idx = 0;
for (int p : sievingPrimes) {
int offset = nextMultipleOffset[idx];
for (int j = offset; j < sieveSize; j += p) isPrime[j] = false;
/**
* We want (lower + 2*sieveSize + 2*(nextMultipleOffset[idx] + k)) % p == 0
* and (lower + 2*sieveSize + 2*(nextMultipleOffset[idx] + k)) % 2 == 1,
* where k is the correction term. Second equation is always true.
* First reduces to 2*(sieveSize + k) % p == 0 ==> (sieveSize + k) % p == 0
* since 2 must be invertible in the field F_p. Thus, we have that
* k % p = (-sieveSize) % p. Then, we make sure that the offset is nonnegative.
*/
nextMultipleOffset[idx] = (nextMultipleOffset[idx] - sieveSize) % p;
if (nextMultipleOffset[idx] < 0) nextMultipleOffset[idx] += p;
++idx;
}
for (int i = 0; i < sieveSize; ++i) {
int newPrime = lower + i*2;
if (newPrime >= upper) break;