// TREE.CPP - implementation of balanced binary trees.

#include <stream.hpp>
#include <stddef.h>
#include <stdlib.h>
#include <string.h>
#include "tree.hpp"

// constants which tell which way the subtree is out of balance

#define BALANCED     0
#define LEFT        -1
#define RIGHT        1

#define DOUBLE_LEFT  (LEFT*2)
#define DOUBLE_RIGHT (RIGHT*2)

#define RETURN_CURRENT  0     // messages used by getprev() and getnext()
#define RETURN_PARENT   1

// maximum allowable imbalance in a subtree

#define ALLOWABLE_IMBALANCE   1

// messages passed from a child to a parent node

#define DELETE_THIS_NODE   2
#define TREE_SHRANK        1
#define NO_CHANGE          0
#define NODE_NOT_FOUND    -2

TREE::TREE(void)     // construct a tree
   {
   head = NULL;
   current = NULL;
   errval = OK;
   }

void TREE::delete_subtree(NODE *nodeptr)
   {
   if(nodeptr->right != NULL)
      {
      delete_subtree(nodeptr->right);
      delete nodeptr->right;
      }

   if(nodeptr->left != NULL)
      {
      delete_subtree(nodeptr->left);
      delete nodeptr->left;
      }

   delete_data(nodeptr);
   }

TREE::~TREE(void)    // destroy a tree
   {
   if(head != NULL)
      {
      delete_subtree(head);
      delete head;
      head = NULL;
      }
   }

void TREE::rebalance(NODE *parent)
   {
   NODE *p1, *p2;
   NODE swapnode;

   if(parent->balance == DOUBLE_LEFT)
      {
      if(parent->left->balance != RIGHT)      // LL imbalance
         {
         p1           = parent->left;
         parent->left = p1->right;
         p1->right    = p1;   // will resolve properly after node swapping

         if(p1->balance == LEFT)
            parent->balance = p1->balance = BALANCED;
         else
            {
            parent->balance = LEFT;
            p1->balance = RIGHT;
            }

         swapnode = *parent;
         *parent  = *p1;
         *p1      = swapnode;
         }
      else                                   // LR imbalance
         {
         p1           = parent->left;
         p2           = p1->right;
         p1->right    = p2->left;
         p2->left     = p1;
         parent->left = p2->right;
         p2->right    = p2;   // will resolve properly after node swapping

         if(p2->balance == RIGHT)
            {
            parent->balance = BALANCED;
            p1->balance     = LEFT;
            p2->balance     = RIGHT;
            }
         else if(p2->balance == LEFT)
            {
            parent->balance = RIGHT;
            p1->balance     = BALANCED;
            p2->balance     = BALANCED;
            }
         else
            parent->balance = p1->balance = p2->balance = BALANCED;

         swapnode = *parent;
         *parent  = *p2;
         *p2      = swapnode;
         }
      }
   else     // parent->balance == DOUBLE_RIGHT
      {
      if(parent->right->balance != LEFT)    // RR imbalance
         {
         p1            = parent->right;
         parent->right = p1->left;
         p1->left      = p1;  // will resolve properly after node swapping

         if(p1->balance == RIGHT)
            parent->balance = p1->balance = BALANCED;
         else
            {
            parent->balance = RIGHT;
            p1->balance = LEFT;
            }

         swapnode = *parent;
         *parent  = *p1;
         *p1      = swapnode;
         }
      else                                   // RL imbalance
         {
         p1 = parent->right;
         p2 = p1->left;
         p1->left = p2->right;
         p2->right = p1;
         parent->right = p2->left;
         p2->left = p2;       // will resolve properly after node swapping

         if(p2->balance == RIGHT)
            {
            parent->balance = LEFT;
            p1->balance     = BALANCED;
            p2->balance     = BALANCED;
            }
         else if(p2->balance == LEFT)
            {
            parent->balance = BALANCED;
            p1->balance     = RIGHT;
            p2->balance     = BALANCED;
            }
         else
            parent->balance = p1->balance = p2->balance = BALANCED;

         swapnode = *parent;
         *parent  = *p2;
         *p2      = swapnode;
         }
      }
   }

int TREE::addnode(NODE *parent, NODE *newnode)
   {
   int retval;
   int addside;

   if((retval = compare(parent->data,newnode->data,parent->dsize)) != 0)
      {
      if(retval < 0)         // parent < newnode - add on RIGHT side of parent
         {
         if(parent->right != NULL)
            retval = addnode(parent->right,newnode);
         else
            {
            parent->right = newnode;
            retval = (parent->balance == LEFT) ? 0 : 1;
            }
         addside = RIGHT;
         }
      else                    // parent > newnode - add on LEFT side of parent
         {
         if(parent->left != NULL)
            retval = addnode(parent->left,newnode);
         else
            {
            parent->left = newnode;
            retval = (parent->balance == RIGHT) ? 0 : 1;
            }
         addside = LEFT;
         }
      }
   else     // duplicate items not allowed
      {
      errval = NO_DUPES;
      return 0;
      }

   // Part B - adjust the parent balance if necessary

   if(retval != 0)      // tree changed depth
      {
      parent->balance += addside;
      if(abs(parent->balance) > ALLOWABLE_IMBALANCE)
         rebalance(parent);
      }

   return (parent->balance == BALANCED) ? 0 : retval;
   }

int TREE::add(void *data, size_t size)
   {
   // build the node to be added

   errval = OK;

   NODE *nptr = new NODE;
   if(nptr == NULL)     // unable to allocate new node
      {
      errval = MEM_ALLOC_FAIL;
      return 0;
      }

   nptr->left    = NULL;
   nptr->right   = NULL;
   nptr->balance = BALANCED;
   nptr->dsize   = size;
   alloc_data(nptr);
   if(errval != OK)
      return 0;

   copy_data(nptr, data);

   // Now determine where to add the new node

   if(head == NULL)
      head = nptr;
   else
      addnode(head, nptr);

   if(errval == OK)
      return 1;
   else
      return 0;
   }

static NODE *walk_tree(NODE *nptr, int walk_side)
   {
   if(walk_side == RIGHT)
      {
      if(nptr->right == NULL)
         return nptr;
      else
         return walk_tree(nptr->right, walk_side);
      }
   else     // walk_side == LEFT
      {
      if(nptr->left == NULL)
         return nptr;
      else
         return walk_tree(nptr->left, walk_side);
      }
   }

int TREE::remove_node(NODE *currnode, void *key)
   {
   int retval,crval,remove_side;
   NODE *new_head_ptr;
   NODE worknode;

   if((crval = compare(currnode->data,key,currnode->dsize)) < 0)
      {
      if(currnode->right != NULL)
         {
         remove_side = RIGHT;
         retval = remove_node(currnode->right, key);
         }
      else
         return NODE_NOT_FOUND;
      }
   else if(crval > 0)
      {
      if(currnode->left != NULL)
         {
         remove_side = LEFT;
         retval = remove_node(currnode->left, key);
         }
      else
         return NODE_NOT_FOUND;
      }
   else     // currnode is the node to remove
      {
      if(currnode->left == NULL && currnode->right == NULL)    // no descendants
         return DELETE_THIS_NODE;
      else if( (currnode->left == NULL && currnode->right != NULL) ||
               (currnode->left != NULL && currnode->right == NULL))
         {     // one descendant
         NODE *saveptr;

         if(currnode->left != NULL)
            {
            saveptr = currnode->left;
            *currnode = *(currnode->left);
            delete saveptr;
            }
         else
            {
            saveptr = currnode->right;
            *currnode = *(currnode->right);
            delete saveptr;
            }

         return TREE_SHRANK;
         }
      else     // two descendants
         {
         if(currnode->balance == LEFT)
            new_head_ptr = walk_tree(currnode->left,RIGHT);
         else
            new_head_ptr = walk_tree(currnode->right,LEFT);

         worknode.dsize = new_head_ptr->dsize;

         alloc_data(&worknode);
         if(errval != OK)
            return NO_CHANGE;

         copy_data(&worknode, new_head_ptr->data);
         remove(new_head_ptr->data);
         delete_data(currnode);
         currnode->data = worknode.data;

         return NO_CHANGE;
         }
      }

   // we're on the way back up the recursion stack

   if(retval == DELETE_THIS_NODE)
      {
      if(remove_side == RIGHT)
         {
         delete_data(currnode->right);
         delete currnode->right;
         currnode->right = NULL;
         currnode->balance += LEFT;
         }
      else
         {
         delete_data(currnode->left);
         delete currnode->left;
         currnode->left = NULL;
         currnode->balance += RIGHT;
         }

      if(abs(currnode->balance) > ALLOWABLE_IMBALANCE)
         rebalance(currnode);

      if(currnode->balance == BALANCED)
         return TREE_SHRANK;
      else
         return NO_CHANGE;
      }
   else if(retval == TREE_SHRANK)
      {
      if(remove_side == LEFT)
         currnode->balance += RIGHT;
      else
         currnode->balance += LEFT;

      if(abs(currnode->balance) > ALLOWABLE_IMBALANCE)
         rebalance(currnode);

      if(currnode->balance == BALANCED)
         return TREE_SHRANK;
      else
         return NO_CHANGE;
      }
   else
      return NO_CHANGE;
   }

int TREE::remove(void *key)     // delete a node
   {
   int retval;

   errval = OK;

   if(head == NULL)
      {
      errval = TREE_EMPTY;
      return 0;
      }
   else
      {
      retval = remove_node(head,key);
      switch(retval)
         {
         case NODE_NOT_FOUND:
            errval = NO_SUCH_NODE;
            return 0;

         case DELETE_THIS_NODE:
            delete_data(head);
            delete head;
            head = NULL;
            return 1;

         default:
            return 1;
         }
      }
   }

NODE *TREE::findnode(NODE *currnode, void *key)
   {
   int cresult = compare(currnode->data, key, currnode->dsize);
   if(cresult < 0)         // node < data - go right
      {
      if(currnode->right == NULL)
         return NULL;      // node not found
      else
         return findnode(currnode->right, key);
      }
   else if(cresult > 0)    // node > data - go left
      {
      if(currnode->left == NULL)
         return NULL;      // node not found
      else
         return findnode(currnode->left, key);
      }
   else                    // this must be the place!
      return currnode;
   }

void *TREE::find(void *key)     // locate a node
   {
   NODE *nodeptr;
   void *retptr;

   errval = OK;

   if(head == NULL)
      {
      errval = TREE_EMPTY;
      return NULL;
      }
   else
      {
      nodeptr = findnode(head,key);
      if(nodeptr == NULL)
         errval = NO_SUCH_NODE;
      else
         {
         current = nodeptr;
         retptr = nodeptr->data;
         }

      return retptr;
      }
   }

void *TREE::findfirst(void)   // locate the first entry in a tree
   {
   errval = OK;

   if(head != NULL)
      {
      current = walk_tree(head, LEFT);
      return current->data;
      }
   else
      {
      errval = NO_SUCH_NODE;
      return NULL;
      }
   }

void *TREE::findlast(void)    // locate the last entry in a tree
   {
   errval = OK;

   if(head != NULL)
      {
      current = walk_tree(head, RIGHT);
      return current->data;
      }
   else
      {
      errval = NO_SUCH_NODE;
      return NULL;
      }
   }

static int TREE::getprev(NODE *currnode)
   {
   int side;
   int retval;
   int cresult = compare(currnode->data, current->data, currnode->dsize);
   if(cresult < 0)         // node < data - go right
      {
      side = RIGHT;
      retval = getprev(currnode->right);
      }
   else if(cresult > 0)    // node > data - go left
      {
      side = LEFT;
      retval = getprev(currnode->left);
      }
   else                    // this must be the place!
      {
      if(current->left != NULL)
         {
         current = current->left;         // step to the left
         while(current->right != NULL)    // then walk the right edge of the
            current = current->right;     // subtree
         return RETURN_CURRENT;
         }
      else
         return RETURN_PARENT;
      }

   // unwind the recursion stack

   if(retval == RETURN_PARENT)
      {
      if(side != LEFT)
         {
         current = currnode;
         return RETURN_CURRENT;
         }
      else
         return retval;
      }
   }

void *TREE::findprev(void)
   {
   int retval;

   errval = OK;

   retval = getprev(head);
   if(retval == RETURN_PARENT)
      {
      errval = NO_SUCH_NODE;  // 'current' is pointing to first node in tree
      return NULL;
      }
   else
      return current->data;
   }

static int TREE::getnext(NODE *currnode)
   {
   int side;
   int retval;
   int cresult = compare(currnode->data, current->data, currnode->dsize);
   if(cresult < 0)         // node < data - go right
      {
      side = RIGHT;
      retval = getnext(currnode->right);
      }
   else if(cresult > 0)    // node > data - go left
      {
      side = LEFT;
      retval = getnext(currnode->left);
      }
   else                    // this must be the place!
      {
      if(current->right != NULL)
         {
         current = current->right;        // step to the right
         while(current->left != NULL)     // then walk the left edge of the
            current = current->left;      // subtree
         return RETURN_CURRENT;
         }
      else
         return RETURN_PARENT;
      }

   // unwind the recursion stack

   if(retval == RETURN_PARENT)
      {
      if(side != RIGHT)
         {
         current = currnode;
         return RETURN_CURRENT;
         }
      else
         return retval;
      }
   }

void *TREE::findnext(void)    // locate the 'next' entry in the tree
   {
   int retval;

   errval = OK;

   retval = getnext(head);
   if(retval == RETURN_PARENT)
      {
      errval = NO_SUCH_NODE;  // 'current' is pointing to last node in tree
      return NULL;
      }
   else
      return current->data;
   }

void TREE::alloc_data(NODE *nptr)
   {
   nptr->data = new char[nptr->dsize];
   if(nptr->data == NULL)
      errval = MEM_ALLOC_FAIL;
   }

void TREE::copy_data(NODE *to, void *from)
   {
   memcpy(to->data, from, to->dsize);
   }

void TREE::delete_data(NODE *nptr)
   {
   delete nptr->data;
   nptr->data = NULL;
   }

int TREE::compare(void *data1, void *data2, size_t size)
   {
   return memcmp(data1,data2,size);
   }

#define VERTBAR      "|      "
#define BLANK        "       "
#define RIGHT_DOWN   "--+    "
#define LEFT_DOWN    "+------"

static unsigned char *mapptr;

static void print_bars(int level, int isright)
   {
   for(int i = 0 ; i < level ; ++i)
      {
      unsigned char bit = 0x80 >> (i % 8);
      int ndx = i / 8;

      if(i < level-1)
         {
         if(mapptr[ndx] & bit)
            printf(VERTBAR);
         else
            printf(BLANK);
         }
      else     // i == level-1
         {
         if(mapptr[ndx] & bit)
            {
            if(!isright)
               {
               printf(LEFT_DOWN);
               mapptr[ndx] ^= bit;     // turn off the appropriate line
               }
            else
               printf(VERTBAR);
            }
         else
            printf(BLANK);
         }
      }
   }

static void print_node(NODE *nptr, int level, int isright)
   {
   // print bars and leaders

   print_bars(level, isright);

   // print the actual node

   char balchar = (nptr->balance == BALANCED) ? 'B' :
                     ((nptr->balance == LEFT) ? 'L' : 'R');

   printf("%02s(%c)",nptr->data,balchar);

   if(nptr->right != NULL)    // add trailing right extension if appropriate
      printf(RIGHT_DOWN);

   printf("\n");

   // set the appropriate bits

   if(nptr->left != NULL)
      {
      unsigned char bit = 0x80 >> (level % 8);
      int ndx = level / 8;

      mapptr[ndx] |= bit;
      }

   if(nptr->right != NULL)
      print_node(nptr->right, level+1, 1);

   if(nptr->left != NULL)
      print_node(nptr->left, level+1, 0);

   if(!isright && nptr->left == NULL && nptr->right == NULL)
      {
      print_bars(level-1,1);
      printf("\n");
      }
   }

void TREE::print_tree(void)
   {
   // figure out the maximum possible tree depth

   NODE *nptr = head;

   for(int maxdepth = 0 ; nptr != NULL ; ++maxdepth, nptr = nptr->right)
      ;  // placed here to avoid warning message

   maxdepth += 2;    // max difference between left & right tree is 2

   // allocate bit array used for printing vertical bars

   mapptr = new unsigned char[(maxdepth/8)+1];
   if(mapptr == NULL)
      {
      errval = MEM_ALLOC_FAIL;
      return;
      }

   for(int i = 0 ; i < (maxdepth/8)+1 ; ++i)
      mapptr[i] = 0;

   print_node(head,0,0);

   delete mapptr;
   }

