// ==================================================================
// os.h
//	Header for order statistics tree (OSTree) class.
// ==================================================================

#include "btree.h"

template<class T> class OSNode : public BinaryNode<T> {
public:
    int size;
public:
    OSNode(const T& x, int Size);
    OSNode(const T& x, OSNode<T> *p = 0, 
	OSNode<T> *l = 0, OSNode<T> *r = 0);


    BinaryNode<T> *LeftRotate(BinaryNode<T> *root);
    BinaryNode<T> *RightRotate(BinaryNode<T> *root);
    BinaryNode<T> *DeleteNode(BinaryNode<T> *z);
    BinaryNode<T> *Insert(const T& AddMe);

    T *Select(int i);
    void PrintNodes();
    int NumNodes();
    int Rank();
    friend void CheckNumNodes(BinaryNode<T> *x);
};

template<class T>
OSNode<T>::OSNode(const T& X, int Size): BinaryNode<T>(X)
{
    size = Size;
}

template<class T>
OSNode<T>::OSNode(const T& x, OSNode<T> *p = 0, 
    OSNode<T> *l = 0, OSNode<T> *r = 0):
	BinaryNode<T>(x, p, l, r)
{
    size = 0;
}

template<class T>
BinaryNode<T> *
OSNode<T>::LeftRotate(BinaryNode<T> *root) 
{
    OSNode<T> *ret = (OSNode<T> *) BinaryNode<T>::LeftRotate(root);

    ((OSNode<T> *) p)->size = size;
    size = (l) ? ((OSNode<T> *) l)->size + 1 : 0;
    size += (r) ? ((OSNode<T> *) r)->size + 1 : 0;

    return ret;
}

template<class T>
BinaryNode<T> *
OSNode<T>::RightRotate(BinaryNode<T> *root) 
{
    OSNode<T> *ret = (OSNode<T> *) BinaryNode<T>::RightRotate(root);

    ((OSNode<T> *) p)->size = size;
    size = (l) ? ((OSNode<T> *) l)->size + 1 : 0;
    size += (r) ? ((OSNode<T> *) r)->size + 1 : 0;

    return ret;
}

template<class T>
T *
OSNode<T>::Select(int i) 
{
    OSNode<T> *trav = this;
    while (trav) {
	int rank = (trav->l) ? ((OSNode<T> *) trav->l)->size + 1 : 0;
	if (i == rank)
	    return &trav->x;
	if (i < rank)
	    trav = (OSNode<T> *) trav->l;
	else {
	    trav = (OSNode<T> *) trav->r;
	    i -= rank + 1;
	}
    }
    return 0;
}

template<class T>
int
OSNode<T>::NumNodes()
{
    int ret = (l) ? ((OSNode<T> *) l)->NumNodes() : 0;
    ret += (r) ? ((OSNode<T> *) r)->NumNodes() : 0;
    return ret + 1;
}

template<class T>
int 
OSNode<T>::Rank() 
{
    int ret = (l) ? ((OSNode<T> *) l)->size : 0;
    return ret + 1;
}




template<class T> void
CheckNumNodes(BinaryNode<T> *x)
{
    if (((OSNode<T> *) x)->NumNodes() !=
	((OSNode<T> *) x)->size)
	cerr << "Problems\n";
}

template<class T> 
BinaryNode<T> *
OSNode<T>::DeleteNode(BinaryNode<T> *z)
{
    OSNode<T> *trav;
    if (z && z->l && z->r)
	trav = (OSNode<T> *) z->Succ();
    else
	trav = (OSNode<T> *) z;

    while (trav) {
	trav->size--;
	trav = (OSNode<T> *) trav->p;
    }
    return BinaryNode<T>::DeleteNode(z);
}

template<class T> BinaryNode<T> *
OSNode<T>::Insert(const T& AddMe)
{
    OSNode<T> *x = this;
    OSNode<T> *y = 0;
    while (x) {
	x->size++;
	y = x;
	x = (AddMe < x->x) ? (OSNode<T> *) x->l : (OSNode<T> *) x->r;
    }
    OSNode<T> *addme = new OSNode<T>(AddMe, y);
    return BinaryNode<T>::InsertNode(addme);
}

template<class T> void
OSNode<T>::PrintNodes()
{
    OSNode<T> *trav = (OSNode<T> *) Min();
    while (trav) {
	cout << trav->x << "(" << trav->size << ") ";
	trav = (OSNode<T> *) trav->Succ();
    }
}

template<class T> class OSTree : public BinaryTree<T> {
public:
    OSTree();
    T *Select(int i);
    void Insert(const T& AddMe);
    void PrintNodes();
    void CheckNodes();

};

template<class T>
OSTree<T>::OSTree()
{
    root = 0;
}

template<class T>
T *
OSTree<T>::Select(int i)
{
    return (root) ? ((OSNode<T> *) root)->Select(i) : 0;
}

template<class T>
void
OSTree<T>::Insert(const T& AddMe)
{
    if (root)
        root = ((OSNode<T> *) root)->Insert(AddMe);
    else
        root = new OSNode<T>(AddMe);
}

template<class T>
void
OSTree<T>::PrintNodes()
{
    if (root) 
	((OSNode<T> *) root)->PrintNodes(); 
}

template<class T>
void
OSTree<T>::CheckNodes()
{
    if (root) {
        BinaryNode<T> *trav = root->Min();
        while (trav) {
    	if (((OSNode<T> *) trav)->NumNodes() != ((OSNode<T> *) trav)->size + 1)
    	    cerr << "Problems\n";
    	trav = trav->Succ();
        }
    }
}




