// ==================================================================
// btree.h
//	Template-based balanced binary tree class.
//	Copyright (C) 1992 by Nicholas Wilt.  All rights reserved.
// ==================================================================

#ifndef __BTREE__

#define __BTREE__

// BinaryNode is the class template that does all the work.
// All the binary tree primitives are implemented in here.
template<class T>
class BinaryNode {
protected:
    // For node colors
    enum RedBlack { Black, Red };
public:
    T x;				// Node contents
    enum RedBlack clr;			// Color of node (red or black)
    BinaryNode<T> *l, *r, *p;		// Left, right and parent pointers

protected:
    // Tree manipulations used during insertion and deletion
    virtual BinaryNode<T> *LeftRotate(BinaryNode<T> *root);
    virtual BinaryNode<T> *RightRotate(BinaryNode<T> *root);
    virtual BinaryNode<T> *DeleteFixup(BinaryNode<T> *x, BinaryNode<T> *p);

public:
    // Constructors.  Node always starts out red.
    BinaryNode();
    BinaryNode(const T& X, BinaryNode<T> *P = 0, 
	BinaryNode<T> *L = 0, BinaryNode<T> *R = 0);
    virtual ~BinaryNode() { }
    static void PostOrderDeletion(BinaryNode<T> *r);
    virtual BinaryNode<T> *Dup(BinaryNode<T> *P) const;

    // Operations defined on binary trees.  All run in O(lgN) time.
    virtual BinaryNode<T> *Min();
    virtual BinaryNode<T> *Max();
    virtual BinaryNode<T> *Pred();
    virtual BinaryNode<T> *Succ();
    virtual BinaryNode<T> *Query(const T& q);
    virtual BinaryNode<T> *InsertNode(BinaryNode<T> *root);
    virtual BinaryNode<T> *Insert(const T& AddMe);
    virtual BinaryNode<T> *DeleteNode(BinaryNode<T> *z);
    virtual BinaryNode<T> *DeleteItem(const T& q);
    virtual BinaryNode<T> *DeletePassbk(T q, T *passbk);

    // Returns 0 if the red-black invariant holds.
    virtual int CheckInvariant(int i, int num);

    // Returns number of black nodes from root to leftmost node.
    int BlackToMin();
};


template<class T>
class BinaryTreeIter {
public:
    // Create iterator for tree, initially pointing at root.
    BinaryTreeIter(BinaryTree<T>& tree);

    // Create iterator for tree, initially pointing at the node
    // queried by q.
    BinaryTreeIter(BinaryTree<T>& tree, const T& q);

    // Reset iterator to point to root of given tree.
    void Reset(BinaryTree<T>& tree);

    // Returns pointer to the contents of the current node, or
    // 0 if the current node is 0.
    T *Contents() const;

    // Sets iterator to point to minimum node in the subtree.
    void Min();

    // Sets iterator to point to maximum node in the subtree.
    void Max();

    // Sets iterator to point to the current node's predecessor.
    void Pred();

    // Sets iterator to point to the current node's successor.
    void Succ();

    // Queries subtree for the given key.
    int Query(const T&);

protected:
    BinaryTree<T> *tree;	// Pointer to the tree being scanned

    BinaryNode<T> *subtree;	// Subtree currently being considered
};

// BinaryTree class template.
template<class T> 
class BinaryTree {

protected:
    BinaryNode<T> *root;

public:
    // Default constructor.
    BinaryTree();

    // Copy constructor.
    BinaryTree(const BinaryTree<T>& x);

    // Assignment operator.
    BinaryTree<T>& operator= (const BinaryTree<T>& x);

    // Destructor.
    ~BinaryTree();

    virtual T *Min() const;
    virtual T *Max() const;
    virtual T *Pred(const T& q) const;
    virtual T *Succ(const T& q) const;
    virtual T *Query(const T& q) const;
    virtual void Insert(const T& addme);
    virtual void DeleteItem(const T& q);
    virtual void DeletePassbk(const T& q, T *passbk);
    virtual int IsEmpty() const;
    virtual int CheckInvariant();

// The following are accessible only to classes that inherit
// from BinaryTree, since they deal directly with BinaryNodes.
protected:
    virtual BinaryNode<T> *InsertPassbk(const T& addme);
    virtual BinaryNode<T> *QueryNode(const T& q) const;
    virtual void DeleteNode(BinaryNode<T> *delme);

    friend BinaryTreeIter<T>;
};

template<class T>
BinaryTreeIter<T>::BinaryTreeIter(BinaryTree<T>& Tree)
{
    tree = &Tree;
    subtree = tree->root;
}

template<class T>
BinaryTreeIter<T>::BinaryTreeIter(BinaryTree<T>& Tree, const T& q)
{
    tree = &Tree;
    subtree = tree->root;
    if (subtree)
	subtree = subtree->Query(q);
}

template<class T>
void
BinaryTreeIter<T>::Reset(BinaryTree<T>& Tree)
{
    tree = &Tree;
    subtree = tree->root;
}

template<class T>
T *
BinaryTreeIter<T>::Contents() const
{
    return (subtree) ? &subtree->x : 0;
}

template<class T>
void
BinaryTreeIter<T>::Min()
{
    if (subtree)
	subtree = subtree->Min();
}

template<class T>
void
BinaryTreeIter<T>::Max()
{
    if (subtree)
	subtree = subtree->Max();
}

template<class T>
void
BinaryTreeIter<T>::Pred()
{
    if (subtree)
	subtree = subtree->Pred();
}

template<class T>
void
BinaryTreeIter<T>::Succ()
{
    if (subtree)
	subtree = subtree->Succ();
}

template<class T>
int
BinaryTreeIter<T>::Query(const T& x)
{
    subtree = (subtree) ? subtree->Query(x) : 0;
    return subtree != 0;
}

    // Private enum to implement the red-black tree.
//    enum RedBlack { Black, Red };

template<class T>
BinaryTree<T>::BinaryTree()
{
    root = 0;
}

template<class T>
BinaryTree<T>::BinaryTree(const BinaryTree<T>& x)
{
    root = x.root->Dup(0);
}

template<class T>
BinaryTree<T>&
BinaryTree<T>::operator=(const BinaryTree<T>& x)
{
    BinaryNode<T>::PostOrderDeletion(root);
    root = x.root->Dup(0);
    return *this;
}

template<class T>
BinaryTree<T>::~BinaryTree()
{
    BinaryNode<T>::PostOrderDeletion(root);
}

template<class T>
T *
BinaryTree<T>::Min() const
{
    return (root) ? &root->Min()->x : 0;
}

template<class T>
T *
BinaryTree<T>::Max() const
{
    return (root) ? &root->Max()->x : 0;
}

template<class T>
T *
BinaryTree<T>::Pred(const T& q) const
{
    BinaryNode<T> *p = (root) ? root->Query(q) : 0;
    if (p) {
        BinaryNode<T> *r = p->Pred();
        return (r) ? &r->x : 0;
    }
    else return 0;
}

template<class T>
T *
BinaryTree<T>::Succ(const T& q) const
{
    BinaryNode<T> *p = (root) ? root->Query(q) : 0;
    if (p) {
        BinaryNode<T> *r = p->Succ();
        return (r) ? &r->x : 0;
    }
    else return 0;
}

template<class T>
T *
BinaryTree<T>::Query(const T& q) const
{
    BinaryNode<T> *p = (root) ? root->Query(q) : 0;
    return (p) ? &p->x : 0;
}

template<class T>
void
BinaryTree<T>::Insert(const T& addme)
{
    if (root)
	root = root->Insert(addme);
    else
	root = new BinaryNode<T>(addme);
}

template<class T>
BinaryNode<T> *
BinaryTree<T>::InsertPassbk(const T& addme)
{
    if (root)
        root = root->Insert(addme);
    else
        root = new BinaryNode<T>(addme);
    return root->Query(addme);
}

template<class T>
void
BinaryTree<T>::DeleteItem(const T& q)
{
    if (root)
	root = root->DeleteItem(q);
}

template<class T>
void
BinaryTree<T>::DeletePassbk(const T& q, T *passbk)
{
    if (root)
        root = root->DeletePassbk(q, passbk);
}

template<class T>
int
BinaryTree<T>::IsEmpty() const
{
    return root == 0;
}

template<class T>
int
BinaryTree<T>::CheckInvariant()
{
    return (root) ? root->CheckInvariant(0, root->BlackToMin()) : 0;
}

template<class T>
BinaryNode<T> *
BinaryTree<T>::QueryNode(const T& q) const
{
    return (root) ? root->Query(q) : 0;
}

template<class T>
void
BinaryTree<T>::DeleteNode(BinaryNode<T> *delme)
{
    if (root)
	root = root->DeleteNode(delme);
}

template<class T>
BinaryNode<T>::BinaryNode()
{
    clr = Red;
    l = r = p = 0;
}

template<class T>
BinaryNode<T>::BinaryNode(const T& X, BinaryNode<T> *P = 0,
    BinaryNode<T> *L = 0, BinaryNode<T> *R = 0): x(X)
{
    clr = Red;
    p = P;
    l = L;
    r = R;
}

template<class T>
void
BinaryNode<T>::PostOrderDeletion(BinaryNode<T> *r)
{
    if (r) {
        PostOrderDeletion(r->l);
        PostOrderDeletion(r->r);
        delete r;
    }
}

template<class T>
BinaryNode<T> *
BinaryNode<T>::Dup(BinaryNode<T> *P) const
{
    BinaryNode<T> *ret = new BinaryNode<T>(x);
    ret->clr = clr;
    ret->l = l->Dup(ret);
    ret->r = r->Dup(ret);
    ret->p = P;
    return ret;
}

template<class T>
BinaryNode<T> *
BinaryNode<T>::Min()
{
    BinaryNode<T> *x = this;
    while (x && x->l)
        x = x->l;
    return x;
}

template<class T>
BinaryNode<T> *
BinaryNode<T>::Max()
{
    BinaryNode<T> *x = this;
    while (x && x->r)
        x = x->r;
    return x;
}

template<class T> 
BinaryNode<T> *
BinaryNode<T>::Min()
{
    BinaryNode<T> *x = this;
    while (x && x->l)
	x = x->l;
    return x;
}

template<class T> 
BinaryNode<T> *
BinaryNode<T>::Max()
{
    BinaryNode<T> *x = this;
    while (x && x->r)
	x = x->r;
    return x;
}

template<class T> 
BinaryNode<T> *
BinaryNode<T>::Pred()
{
    BinaryNode<T> *trav = this;
    if (trav->l)
	return trav->l->Max();
    BinaryNode<T> *y = trav->p;
    while (y && trav == y->l) {
	trav = y;
	y = y->p;
    }
    return y;
}

template<class T> 
BinaryNode<T> *
BinaryNode<T>::Succ()
{
    BinaryNode<T> *trav = this;
    if (trav->r)
	return trav->r->Min();
    BinaryNode<T> *y = trav->p;
    while (y && trav == y->r) {
	trav = y;
	y = y->p;
    }
    return y;
}

template<class T> 
BinaryNode<T> *
BinaryNode<T>::Query(const T& q)
{
    BinaryNode<T> *trav = this;
    while (trav) {
	if (q < trav->x)
	    trav = trav->l;
	else if (trav->x < q)
	    trav = trav->r;
	else
	    return trav;
    }
    return 0;
}

template<class T> 
BinaryNode<T> *
BinaryNode<T>::LeftRotate(BinaryNode<T> *root)
{
    BinaryNode<T> *ret = root;
    BinaryNode<T> *y = r;
    r = y->l;
    if (r)
	r->p = this;
    y->p = p;
    if (p) {
	if (this == p->l)
	    p->l = y;
	else
	    p->r = y;
    }
    else
	ret = y;
    y->l = this;
    p = y;
    return ret;
}

template<class T> 
BinaryNode<T> *
BinaryNode<T>::RightRotate(BinaryNode<T> *root)
{
    BinaryNode<T> *ret = root;
    BinaryNode<T> *x = l;
    l = x->r;
    if (l)
	l->p = this;
    x->p = p;
    if (p) {
	if (this == p->l)
	    p->l = x;
	else
	    p->r = x;
    }
    else
	ret = x;
    x->r = this;
    p = x;
    return ret;
}

template<class T>
BinaryNode<T> *
BinaryNode<T>::InsertNode(BinaryNode<T> *addme)
{
    BinaryNode<T> *root = this;
    if (! addme->p)
	root = addme;
    else {
	if (addme->x < addme->p->x)
	    addme->p->l = addme;
	else
	    addme->p->r = addme;
    }
    clr = Red;
    while (addme != root && addme->p->clr == Red) {
	BinaryNode<T> *y;

	if (! addme->p->p)
	    break;
	if (addme->p == addme->p->p->l) {
	    y = addme->p->p->r;
	    if (y && y->clr == Red) {
		addme->p->clr = Black;
		y->clr = Black;
		addme->p->p->clr = Red;
		addme = addme->p->p;
	    }
	    else {
		if (addme == addme->p->r) {
		    addme = addme->p;
		    root = addme->LeftRotate(root);
		}
	        addme->p->clr = Black;
	        if (addme->p->p) {
		    addme->p->p->clr = Red;
		    root = addme->p->p->RightRotate(root);
		}
	    }
	}
	else {
	    y = addme->p->p->l;
	    if (y && y->clr == Red) {
		addme->p->clr = Black;
		y->clr = Black;
		addme->p->p->clr = Red;
		addme = addme->p->p;
	    }
	    else {
		if (addme == addme->p->l) {
		    addme = addme->p;
		    root = addme->RightRotate(root);
		}
		addme->p->clr = Black;
		if (addme->p->p) {
		    addme->p->p->clr = Red;
		    root = addme->p->p->LeftRotate(root);
		}
	    }
	}
    }
    root->clr = Black;
    return root;
}

template<class T> 
BinaryNode<T> *
BinaryNode<T>::Insert(const T& AddMe)
{
    BinaryNode<T> *x = this;
    BinaryNode<T> *y = 0;
    while (x) {
	y = x;
	x = (AddMe < x->x) ? x->l : x->r;
    }
    BinaryNode<T> *addme = new BinaryNode<T>(AddMe, y);
    return InsertNode(addme);
}

template<class T> 
BinaryNode<T> *
BinaryNode<T>::DeleteFixup(BinaryNode<T> *x, BinaryNode<T> *p)
{
    BinaryNode<T> *root = this;
    while (x != root && (! x || x->clr == Black)) {
	BinaryNode<T> *w;
	if (x == p->l) {
	    if (! p)
		return root;
	    w = p->r;
	    if (! w)
		return root;
	    if (w->clr == Red) {
		w->clr = Black;
		p->clr = Red;
		root = p->LeftRotate(root);
		w = p->r;
		if (! p || ! w)
		    return root;
	    }
	    if ( ((! w->l) || w->l->clr == Black) &&
		 ((! w->r) || w->r->clr == Black)) {
		w->clr = Red;
		x = p;
		p = p->p;
		continue;
	    }
	    else if ((! w->r) || w->r->clr == Black) {
		w->l->clr = Black;
		w->clr = Red;
		root = w->RightRotate(root);
		w = p->r;
		if (! p || ! w)
		    return root;
	    }
	    w->clr = p->clr;
	    if (p)
		p->clr = Black;
	    w->r->clr = Black;
	    if (p)
		root = p->LeftRotate(root);
	    x = root;
  	}
	else {
	    if (! p)
		return root;
	    w = p->l;
	    if (! p || ! w)
		return root;
	    if (w->clr == Red) {
		w->clr = Black;
		p->clr = Red;
		root = p->RightRotate(root);
		w = p->l;
		if (! p || ! w)
		    return root;
	    }
	    if ( ((! w->r) || w->r->clr == Black) &&
	         ((! w->l) || w->l->clr == Black)) {
		w->clr = Red;
		x = p;
		p = p->p;
		continue;
	    }
	    else if ((! w->l) || w->l->clr == Black) {
		w->r->clr = Black;
		w->clr = Red;
		root = w->LeftRotate(root);
		w = p->l;
		if (! p || ! w)
		    return root;
	    }
	    w->clr = p->clr;
	    if (p)
		p->clr = Black;
	    w->l->clr = Black;
	    if (p)
		root = p->RightRotate(root);
	    x = root;
	}
    }
    if (x)
	x->clr = Black;
    return root;
}

template<class T> 
BinaryNode<T> *
BinaryNode<T>::DeleteNode(BinaryNode<T> *z)
{
    BinaryNode<T> *root = this;
    BinaryNode<T> *x, *y;

    if (! z)
	return root;
    y = (! z->l || ! z->r) ? z : z->Succ();
    x = (y->l) ? y->l : y->r;

    if (x)
	x->p = y->p;

    if (y->p) {
	if (y == y->p->l)
	    y->p->l = x;
	else
	    y->p->r = x;
    }
    else
	root = x;
    if (y != z)
	z->x = y->x;
    if (y->clr == Black) {
	if (root)
	    root = root->DeleteFixup(x, y->p);
    }
    delete y;
    return root;
}

template<class T> 
BinaryNode<T> *
BinaryNode<T>::DeleteItem(const T& q)
{
    return DeleteNode(Query(q));
}

template<class T> 
BinaryNode<T> *
BinaryNode<T>::DeletePassbk(T q, T *passbk)
{
    BinaryNode<T> *z = Query(q);
    if (z)
	*passbk = z->x;
    return DeleteNode(z);
}

template<class T> 
int
BinaryNode<T>::CheckInvariant(int i, int num)
{
    int ret;
    if (! l && (i != num + clr == Black))
	return -1;
    if (! r && (i != num + clr == Black))
	return -1;
    ret = (l) ? l->CheckInvariant(i + (clr == Black), num) : 0;
    if (ret)
	return ret;
    return r && r->CheckInvariant(i + (clr == Black), num);
}

template<class T> 
int
BinaryNode<T>::BlackToMin()
{
    BinaryNode<T> *trav = this;
    int ret = 0;
    while (trav) {
	ret += (trav->clr == Black);
	trav = trav->l;
    }
    return ret;
}

#endif

