Photo URL is broken

Never having taken any computer science courses beyond the freshman level, there are some glaring holes in my computer science knowledge. According to Customs and Border Protection, every self-respecting software engineer should be able to balance a binary search tree.

I used a library-provided red-black tree in Policy-Based Data Structures in C++, so I've been aware of the existence of these trees for quite some time. However, I've never rolled my own. I set out to fix this deficiency in my knowledge a few weeks ago.

I've created an augmented AVL tree to re-solve ORDERSET.

The Problem

The issue with vanilla binary search trees is that the many of the operations have worst-case $O(N)$ complexity. To see this, consider the case where your data is sorted, so you insert your data in ascending order. Then, the binary search tree degenerates into a linked list.

Definition of an AVL Tree

AVL trees solve this issue by maintaining the invariant that the height of the left and right subtrees differ by at most 1. The height of a tree is the maximum number of edges between the root and a leaf node. For any node $x$, we denote its height $h_x$. Now, consider a node $x$ with left child $l$ and right child $r$. We define the balance factor of $x$ to be $$b_x = h_l - h_r.$$ In an AVL tree, $b_x \in \{-1,0,1\}$ for all $x$.

Maintaining the Invariant

Insertion

As we insert new values into the tree, our tree may become unbalanced. Since we are only creating one node, if the tree becomes unbalanced some node has a balance factor of $\pm 2$. Without loss of generality, let us assume that it's $-2$ since the two cases are symmetric.

Now, the only nodes whose heights are affected are along the path between the root and the new node. So, only the parents of these nodes, who are also along this path, will have their balance factor altered. Thus, we should start at the deepest node with an incorrect balance factor. If we can find a a way to rebalance this subtree, we can recursively balance the whole tree by going up to the root and rebalancing on the way.

So, let us assume we have tree whose root has a balance factor of $-2$. Thus, the height of the right subtree exceeds the height of the left subtree by $2$. We have $2$ cases.

Right-left Rotation: The Right Child Has a Balance Factor of $1$

In these diagrams, circles denote nodes, and rectangles denote subtrees.

In this example, by making $5$ the new right child, we still have and unbalanced tree, where the root has balance factor $-2$, but now, the right child has a right subtree of greater height.

Right-right Rotation: The Right Child Has a Balance Factor of $-2$, $-1$, or $0$

To fix this situation, the right child becomes the root.

We see that after this rotation is finished, we have a balanced tree, so our invariant is restored!

Deletion

When we remove a node, we need to replace it with the next largest node in the tree. Here's how to find this node:

  1. Go right. Now all nodes in this subtree are greater.
  2. Find the smallest node in this subtree by going left until you can't.

We replace the deleted node, say $x$, with the smallest node we found in the right subtree, say $y$. Remember this path. The right subtree of $y$ becomes the new left subtree of the parent of $y$. Starting from the parent of $y$, the balance factor may have been altered, so we start here, go up to the root, and do any rotations necessary to correct the balance factors along the way.

Implementation

Here is the code for these rotations with templates for reusability and a partial implementation of the STL interface.

#include <cassert>
#include <algorithm>
#include <iostream>
#include <functional>
#include <ostream>
#include <stack>
#include <string>
#include <utility>

using namespace std;

namespace phillypham {
  namespace avl {
    template <typename T>
    struct node {
      T key;
      node<T>* left;
      node<T>* right;
      size_t subtreeSize;
      size_t height;
    };

    template <typename T>
    int height(node<T>* root) {
      if (root == nullptr) return -1;
      return root -> height;
    }

    template <typename T>
    void recalculateHeight(node<T> *root) {
      if (root == nullptr) return;
      root -> height = max(height(root -> left), height(root -> right)) + 1;
    }

    template <typename T>
    size_t subtreeSize(node<T>* root) {
      if (root == nullptr) return 0;
      return root -> subtreeSize;
    }

    template <typename T>
    void recalculateSubtreeSize(node<T> *root) {
      if (root == nullptr) return;
      root -> subtreeSize = subtreeSize(root -> left) + subtreeSize(root -> right) + 1;
    }

    template <typename T>
    void recalculate(node<T> *root) {
      recalculateHeight(root);
      recalculateSubtreeSize(root);
    }

    template <typename T>
    int balanceFactor(node<T>* root) {
      if (root == nullptr) return 0;
      return height(root -> left) - height(root -> right);
    }

    template <typename T>
    node<T>*& getLeftRef(node<T> *root) {
      return root -> left;    
    }

    template <typename T>
    node<T>*& getRightRef(node<T> *root) {
      return root -> right;
    }

    template <typename T>
    node<T>* rotateSimple(node<T> *root,
                          node<T>*& (*newRootGetter)(node<T>*),
                          node<T>*& (*replacedChildGetter)(node<T>*)) {
      node<T>* newRoot = newRootGetter(root);
      newRootGetter(root) = replacedChildGetter(newRoot);
      replacedChildGetter(newRoot) = root;
      recalculate(replacedChildGetter(newRoot));
      recalculate(newRoot);
      return newRoot;
    }

    template <typename T>
    void swapChildren(node<T> *root,
                      node<T>*& (*childGetter)(node<T>*),
                      node<T>*& (*grandChildGetter)(node<T>*)) {
      node<T>* newChild = grandChildGetter(childGetter(root));
      grandChildGetter(childGetter(root)) = childGetter(newChild);
      childGetter(newChild) = childGetter(root);
      childGetter(root) = newChild;
      recalculate(childGetter(newChild));
      recalculate(newChild);
    }

    template <typename T>
    node<T>* rotateRightRight(node<T>* root) {
      return rotateSimple(root, getRightRef, getLeftRef);
    }

    template <typename T>
    node<T>* rotateLeftLeft(node<T>* root) {
      return rotateSimple(root, getLeftRef, getRightRef);    
    }

    template <typename T>
    node<T>* rotate(node<T>* root)  {
      int bF = balanceFactor(root);
      if (-1 <= bF && bF <= 1) return root;
      if (bF < -1) { // right side is too heavy
        assert(root -> right != nullptr);
        if (balanceFactor(root -> right) != 1) {
          return rotateRightRight(root);
        } else { // right left case
          swapChildren(root, getRightRef, getLeftRef);
          return rotate(root);
        }
      } else { // left side is too heavy
        assert(root -> left != nullptr);
        // left left case
        if (balanceFactor(root -> left) != -1) {
          return rotateLeftLeft(root);
        } else { // left right case
          swapChildren(root, getLeftRef, getRightRef);
          return rotate(root);
        }
      }
    }

    template <typename T, typename cmp_fn = less<T>>
    node<T>* insert(node<T>* root, T key, const cmp_fn &comparator = cmp_fn()) {
      if (root == nullptr) {
        node<T>* newRoot = new node<T>();
        newRoot -> key = key;
        newRoot -> left = nullptr;
        newRoot -> right = nullptr;
        newRoot -> height = 0;
        newRoot -> subtreeSize = 1;
        return newRoot;
      }
      if (comparator(key, root -> key)) { 
        root -> left = insert(root -> left, key, comparator);
      } else if (comparator(root -> key, key)) {
        root -> right = insert(root -> right, key, comparator);
      }
      recalculate(root);
      return rotate(root);
    }

    template <typename T, typename cmp_fn = less<T>>
    node<T>* erase(node<T>* root, T key, const cmp_fn &comparator = cmp_fn()) {
      if (root == nullptr) return root; // nothing to delete
      if (comparator(key, root -> key)) { 
        root -> left = erase(root -> left, key, comparator);
      } else if (comparator(root -> key, key)) {
        root -> right = erase(root -> right, key, comparator);
      } else { // actual work when key == root -> key
        if (root -> right == nullptr) {
          node<T>* newRoot = root -> left;
          delete root;
          return newRoot;
        } else if (root -> left == nullptr) {
          node<T>* newRoot = root -> right;
          delete root;
          return newRoot;
        } else {
          stack<node<T>*> path;
          path.push(root -> right);
          while (path.top() -> left != nullptr) path.push(path.top() -> left);
          // swap with root
          node<T>* newRoot = path.top(); path.pop();
          newRoot -> left = root -> left;
          delete root;

          node<T>* currentNode = newRoot -> right;
          while (!path.empty()) {
            path.top() -> left = currentNode;
            currentNode = path.top(); path.pop();
            recalculate(currentNode);
            currentNode = rotate(currentNode);
          }
          newRoot -> right = currentNode;
          recalculate(newRoot);
          return rotate(newRoot);
        }        
      }
      recalculate(root);
      return rotate(root);
    }

    template <typename T>
    stack<node<T>*> find_by_order(node<T>* root, size_t idx) {
      assert(0 <= idx && idx < subtreeSize(root));
      stack<node<T>*> path;
      path.push(root);
      while (idx != subtreeSize(path.top() -> left)) {
        if (idx < subtreeSize(path.top() -> left)) {
          path.push(path.top() -> left);
        } else {
          idx -= subtreeSize(path.top() -> left) + 1;
          path.push(path.top() -> right);
        }
      }
      return path;
    }

    template <typename T>
    size_t order_of_key(node<T>* root, T key) {
      if (root == nullptr) return 0ULL;
      if (key == root -> key) return subtreeSize(root -> left);
      if (key < root -> key) return order_of_key(root -> left, key);
      return subtreeSize(root -> left) + 1ULL + order_of_key(root -> right, key);
    }

    template <typename T>
    void delete_recursive(node<T>* root) {
      if (root == nullptr) return;
      delete_recursive(root -> left);
      delete_recursive(root -> right);
      delete root;
    }
  }

  template <typename T, typename cmp_fn = less<T>>
  class order_statistic_tree {
  private:
    cmp_fn comparator;
    avl::node<T>* root;
  public:
    class const_iterator: public std::iterator<std::bidirectional_iterator_tag, T> {
      friend const_iterator order_statistic_tree<T, cmp_fn>::cbegin() const;
      friend const_iterator order_statistic_tree<T, cmp_fn>::cend() const;
      friend const_iterator order_statistic_tree<T, cmp_fn>::find_by_order(size_t) const;
    private:
      cmp_fn comparator;
      stack<avl::node<T>*> path;
      avl::node<T>* beginNode;
      avl::node<T>* endNode;
      bool isEnded;

      const_iterator(avl::node<T>* root) {
        setBeginAndEnd(root);
        if (root != nullptr) {
          path.push(root);
        }
      }

      const_iterator(avl::node<T>* root, stack<avl::node<T>*>&& path): path(move(path)) {
        setBeginAndEnd(root);
      }

      void setBeginAndEnd(avl::node<T>* root) {
        beginNode = root;
        while (beginNode != nullptr && beginNode -> left != nullptr)
          beginNode = beginNode -> left;
        endNode = root;        
        while (endNode != nullptr && endNode -> right != nullptr)
          endNode = endNode -> right;
        if (root == nullptr) isEnded = true;
      }

    public:
      bool isBegin() const {
        return path.top() == beginNode;
      }

      bool isEnd() const {
        return isEnded;
      }

      const T& operator*() const {
        return path.top() -> key;
      }

      const bool operator==(const const_iterator &other) const {
        if (path.top() == other.path.top()) {
          return path.top() != endNode || isEnded == other.isEnded;
        }
        return false;
      }

      const bool operator!=(const const_iterator &other) const {
        return !((*this) == other);
      }

      const_iterator& operator--() {
        if (path.empty()) return *this;
        if (path.top() == beginNode) return *this;
        if (path.top() == endNode && isEnded) {
          isEnded = false;
          return *this;
        }
        if (path.top() -> left == nullptr) {
          T& key = path.top() -> key;
          do {
            path.pop();
          } while (comparator(key, path.top() -> key));
        } else {
          path.push(path.top() -> left);
          while (path.top() -> right != nullptr) {
            path.push(path.top() -> right);
          }
        }
        return *this;
      }

      const_iterator& operator++() {
        if (path.empty()) return *this;
        if (path.top() == endNode) {
          isEnded = true;
          return *this;
        }
        if (path.top() -> right == nullptr) {
          T& key = path.top() -> key;
          do {
            path.pop();
          } while (comparator(path.top() -> key, key));
        } else {
          path.push(path.top() -> right);
          while (path.top() -> left != nullptr) {
            path.push(path.top() -> left);
          }
        }
        return *this;
      }
    };

    order_statistic_tree(): root(nullptr) {}

    void insert(T newKey) {
      root = avl::insert(root, newKey, comparator);
    }

    void erase(T key) {
      root = avl::erase(root, key, comparator);
    }

    size_t size() const {
      return subtreeSize(root);
    }

    // 0-based indexing
    const_iterator find_by_order(size_t idx) const {
      return const_iterator(root, move(avl::find_by_order(root, idx)));
    }

    // returns the number of keys strictly less than the given key
    size_t order_of_key(T key) const {
      return avl::order_of_key(root, key);
    }

    ~order_statistic_tree() {
      avl::delete_recursive(root);
    }

    const_iterator cbegin() const{
      const_iterator it = const_iterator(root);
      while (!it.isBegin()) --it;
      return it;
    }

    const_iterator cend() const {
      const_iterator it = const_iterator(root);
      while (!it.isEnd()) ++it;
      return it;
    }
  };
}

Solution

With this data structure, the solution is fairly short, and is in fact, identical to my previous solution in Policy-Based Data Structures in C++. Performance-wise, my implementation is very efficient. It differs by only a few hundredths of a second with the policy-based tree set according to SPOJ.

int main(int argc, char *argv[]) {
  ios::sync_with_stdio(false); cin.tie(NULL);

  phillypham::order_statistic_tree<int> orderStatisticTree;
  int Q; cin >> Q; // number of queries
  for (int q = 0; q < Q; ++q) {
    char operation; 
    int parameter;
    cin >> operation >> parameter;
    switch (operation) {
      case 'I':
        orderStatisticTree.insert(parameter);
        break;
      case 'D':
        orderStatisticTree.erase(parameter);
        break;
      case 'K':
        if (1 <= parameter && parameter <= orderStatisticTree.size()) {
          cout << *orderStatisticTree.find_by_order(parameter - 1) << '\n';
        } else {
          cout << "invalid\n";
        }
        break;
      case 'C':
        cout << orderStatisticTree.order_of_key(parameter) << '\n';
        break;
    }
  }
  cout << flush;
  return 0;
}

New Comment


Comments

No comments have been posted yet. You can be the first!