btree_a.py

#! /usr/bin/python3
# ===================================================================
# The balance algorithem start a the bottom of the tree (leaf nodes)
# and proceeds up the tree towards the root node
# ===================================================================
# create a sort of balanced tree
#
# The node class in this code has only one data item. It is the
# node's data and also is the node's unique key. In the
# "real world" a node would contain much more data and perhaps
# have a unique key that is seperate from the node's data.
#
# Note: A node's unique key is used to position a node's place
#       in the tree.
# ===================================================================
# Tree Functions...
# tree.debug                  set debug flag
# tree.blance                 balance the tree (sort of)
# tree.count_nodes            return the number of nodes in the tree
# tree.max_height             return the max depth (levels) in the tree
# tree.avg_height             return average search depth in the tree
# tree.insert                 insert node into tree
# tree.search                 search for a value (node) in the tree
# tree.remove_leaf_node       remove a leaf node from the tree
#                             (a leaf nodes have no right or left links)
# tree.print_tree_nodes       print the tree node information
# tree.print_tree_values      print the tree node values 
# tree.Print_tree_structure   print the tree structure
# ===================================================================
# Note about the code:
# a. curnode is the current working node
# b. newnode is the node being added to the tree
# c. parnode is the parent node of current working node
# d. chdnode is a child node of the current working node 
# e. the tree root node may be modified but not replaced
# ===================================================================
#
# Testing:
#
# from random import randint
# from random import sample
# import btree_9 as bt9
#
# # ---- function - create random node values -----------------------
# # ----            with potential duplicates -----------------------
#
# def create_random_node_values_dups(elms=20,maxint=100):
#     values  = []
#     for _ in range(elms):
#         values.append(randint(1,maxint))
#     return values
#
# # ---- function - create random node values -----------------------
# # ----            with no duplicates ------------------------------
#
# def create_random_node_values(elms=20,maxint=100):
#     values = sample(range(1,maxint),elms)
#     return values
#
# #---- function - fill a tree with nodes (values) ------------------
#
# def fill_tree(tree,values):
#     for i in values:
#         tree.insert(bt9.Node(i))
#
# # ---- test values ------------------------------------------------
#
# ##values = [30,20,10,5,1,40,50,60,70]
# ##values = [30,40,50,60,70]
# ##values = [30,20,10,5,1]
# ##values = [10,9,8,7,6,5,4,3,2,1,11,12,13,14,15,16,17,18,19]
#
# values = create_random_node_values(100,10000)
#
# print('\n---- create tree (un-balanced) ------------------------')
#
# tree = bt9.BinaryTree()
# tree.set_debug(False)
# fill_tree(tree,values)
# ##tree.print_tree_values()
# ##tree.print_tree_nodes()
# tree.print_tree_structure()
# print("Maximum height of tree is {}".format(tree.max_height()))
# print("Average height of tree is {:.2f}".format(tree.avg_height()))
#
# print('\n---- create a sort of balanced tree -------------------')
#
# tree.balance()
# ##tree.print_tree_values()
# ##tree.print_tree_nodes()
# tree.print_tree_structure()
# print("Maximum height of tree is {}".format(tree.max_height()))
# print("Average height of tree is {:.2f}".format(tree.avg_height()))
#
# ===================================================================

import numpy as np
from random import randint
from random import sample

# -------------------------------------------------------------------
# class definitions
# -------------------------------------------------------------------

class Node:
    """
    This class defines nodes to be inserted into a btree.

    This class is a container for user define data. The
    only required field is 'data' which is a unique key
    used to position a node in the btree and contain
    user data.
    """

    def __init__(self,data,print_flag=False):
        '''
        Constructor for btree Node class.

        Attributes:
          data (int):    node data and node unique key used to
                         position the node in the btree.
          left (node):   link to the root node of a subtree
                         containing nodes with with data (keys)
                         less than this node.
          right (node):  link to the root node of a subtree
                         containing nodes with with data (keys)
                         greater than this node.
          parent (node): link to this node's parent node in the
                         tree. If there is no parent (tree's
                         root node) the link is set to None.

        Parameters:
          data:  A node's data (key) value.
          flag:  A Flag to print debug information. True will
                 print information. False will not. If this
                 parameter is not provided the default is False.
        '''
        self.data   = data
        self.left   = None
        self.right  = None
        self.parent = None
        if print_flag:
            print('\n>>>> create node {} <<<<'.format(data))

    def print_node(self,title='',single_line = True):
        '''
        Print node information in one of two ways.
        single line or multiline.

        Parameters:
           title (str):        A string to print before printing the
                               node information.
           single_line (bool): If True a single line is print if
                               not True multiple lines are printed.
                               If title is None, an empty string, or
                               not provided no title is printed.
        '''
        d = 'None' if self.data   == None else self.data
        l = 'None' if self.left   == None else self.left.data
        r = 'None' if self.right  == None else self.right.data
        p = 'None' if self.parent == None else self.parent.data
        if title: print(title)
        if single_line:
            print('(d={},l={},r={},p={})'.format(d,l,r,p))
        else:
            print('data   = {}'.format(d))
            print('left   = {}'.format(l))
            print('right  = {}'.format(r))
            print('parent = {}'.format(p))

class BinaryTree:

    # ---- init btree ----------------------------------------------

    def __init__(self):
        '''
        Constructor for the BinaryTree class.

        Attributes:
          root      The root node of the btree.
                    Initially it is set to None. The first node
                    inserted into the btree becomes the root node.
          debug     If True, print debug messages.
        ''' 
        self.root    = None
        self.debug   = False 

    # ---- set internal state --------------------------------------

    def set_debug(self,flag=False):
        '''
        Set debug flag. The default is False.
        '''
        self.debug = flag
        ##if self.debug:
        ##    print('debug flags is True')

    ################################################################
    # ---- debug support functions ---------------------------------

    def _get_user_input(self,prompt):
        '''
        Prompt the user and return their input
        Return the users input.
        '''
        return input(prompt)

    def _pause_program(self):
        '''
        Pause the program until the user is ready.
        '''
        self._get_user_input('\nPress enter to continue: ')

    def _debug_info(self,node,msg='',pause=False):
        '''
        Print node debug information.
        '''
        if msg: print(msg)
        ##print(type(node))
        node.print_node()
        if (pause): self._pause_program()
    ################################################################

    # ---- find the maximum tree height ----------------------------

    def max_height(self):
        '''
        Return the maximum number of levels in a tree.
        '''
        return self._max_height(self.root,0)

    def _max_height(self,curnode,curheight):
        ''' 
        Max height helper function.
        '''
        if curnode == None: return curheight
        left_height = self._max_height(curnode.left,curheight+1)
        right_height = self._max_height(curnode.right,curheight+1)
        return max(left_height,right_height)

    # ---- find the average tree height ----------------------------

    def avg_height(self):
        '''
        Return the average number of levels in a tree.

        Note: only paths to end nodes are used in this
              calculation. That is because we are looking
              at worst case (not found) searches.

        Note: An end node may be in the middle of the
              tree for a particular path. If a node's right or
              left link is None, it is an end node for a
              right or left path.

        Note: Because the Python scoping rules suck,
              I am using the list x where:
              x[0] = total leaf node path lengths
              x[1] = number of leaf node paths
        '''
        if self.root == None: return 0.0
        x = [0,0]
        self._avg_height(self.root,0,x)
        return float(x[0]/x[1])

    def _avg_height(self,curnode,curheight,x):
        '''
        Average height helper function.
        '''
        ##print("---------------------------------------------")
        ##print("curnode      = {}".format(curnode.data))
        ##print("path lengths = {}".format(x[0]))
        ##print("path count   = {}".format(x[1]))
        ##print(" X[0],x[1]   = {},{}".format(x[0],x[1]))
        if curnode.left == None:
            x[0] += curheight
            x[1] += 1 
        else:
            self._avg_height(curnode.left,curheight+1,x)
        if curnode.right == None:
            x[0] += curheight
            x[1] += 1
        else:
            self._avg_height(curnode.right,curheight+1,x)
        return

    # ---- insert a node into tree ---------------------------------

    def insert(self,newnode):
        '''
        Insert a new node into the tree.
        '''
        if self.root == None:
            self.root = newnode
            ##if self.debug:
            ##    self._debug_info(newnode,
            ##        '---- insert root node',True)
            return True
        else:
            return self._insert(newnode,self.root)

    def _insert(self,newnode,curnode):
        '''
        Insert helper function.
        '''
        if newnode.data < curnode.data:
            if curnode.left   == None:
                newnode.parent = curnode
                curnode.left   = newnode
                ##if self.debug:
                ##    self._debug_info(newnode,
                ##        '---- insert node <',True)
                return True
            else:
                return self._insert(newnode,curnode.left)
        elif newnode.data > curnode.data:
            if curnode.right  == None:
                newnode.parent = curnode
                curnode.right  = newnode
                ##if self.debug:                      ##############
                ##    self._debug_info(newnode,
                ##        '---- insert node >',True)
                return True
            else:
                return self._insert(newnode,curnode.right)
        else:
            print("Tree node {} already in tree".format(curnode.data))
            return False

    # ---- balance the tree ----------------------------------------

    def balance(self):
        '''
        Balance the tree.
        '''

        # ---- is there a tree root ---------------------------------
        # ---- interesting question - is this an error or not -------

        if self.root == None:
            return True

        # ---- skip over the root node and the level just below it --
        # ---- if you want to know why, see the ---------------------
        # ---- documentation/diagrams -------------------------------

        if self.root.left:
            if self.root.left.left:
                r = self._balance(self.root.left.left)
                if r != True:
                    return False

            if self.root.left.right:
                r = self._balance(self.root.left.right)
                if r != True:
                   return False

        if self.root.right:
            if self.root.right.left:
                r = self._balance(self.root.right.left)
                if r != True:
                    return False

            if self.root.right.right:
                r = self._balance(self.root.right.right)
                if r != True:
                   return False

        return True


    def _balance(self,curnode):
        '''
        Balance helper function. (This does most of the work.)
        '''

        # ---- is the current node None? ---------------------------

        if curnode == None:
           print('Oops, curnode is None')
           return False

        if self.debug:
            print('balance curnode = {}'.format(curnode.data))

        # ---- process the current node's left and right subtrees ---

        if curnode.left:
            r = self._balance(curnode.left)
            if r != True:
                return False

        if curnode.right:
            r = self._balance(curnode.right)
            if r != True:
                return False

        # ---- is the current node a leaf node? ---------------------

        if curnode.left == None and curnode.right == None:
            return True

        # ---- test current node's parent node ----------------------

        parnode = curnode.parent

        if parnode == self.root:
            if self.debug:
                print('parnode {} is tree root'.format(parnode.data))
            return True 
        if parnode.left != curnode and parnode.right != curnode:
            print('Oops, parnode {} is not the parent of curnode {}'.
                format(parnode.data,curnode.data))
            parnode.print_node('node: parent')
            curnode.print_node('node: curnode')
            print('node: curnode.left {}'.
                format(curnode.left.data if curnode.left else 'None'))
            print('node: curnode.right {}'.
                format(curnode.right.data if curnode.right else 'None'))
            return False

        # ---- test current node's grand parent node ----------------

        pparnode = curnode.parent.parent

        if pparnode.left != parnode and pparnode.right != parnode:
            print('Oops, pparnode {} is not the parent of parnode {}'.
                format(pparnode.data,parnode.data))
            pparnode.print_node('node: pparnode')
            parnode.print_node('node: parent')
            curnode.print_node('node: curnode')
            print('node: curnode.left {}'.
                format(curnode.left.data if curnode.left else 'None'))
            print('node: curnode.right {}'.
                format(curnode.right.data if curnode.right else 'None'))
            return False

        # ---- balance the current node -----------------------------

        if self.debug:
            print('---- begin balance ---------------------')
            print('curnode = {}'.format(curnode.data))
            print('----------------------------------------')
            self.print_tree_structure()
            print("Maximum height of tree is {}".format(tree.max_height()))

        if curnode.data < parnode.data:
            if curnode.right != None:
                return True
            curnode.right = parnode
        elif curnode.data > parnode.data:
            if curnode.left != None:
                return True
            curnode.left = parnode
        else:
            print('Oops, a duplicate node in the tree {}'.
                format(curnode.data))
            return False

        if pparnode.left == parnode:
            pparnode.left = curnode
        else:
            pparnode.right = curnode

        curnode.parent = pparnode

        if parnode.left == curnode:
            parnode.left = None
        elif parnode.right == curnode:
            parnode.right = None

        parnode.parent = curnode

        if self.debug:
            print('---- after balance ----------------------')
            self.print_tree_structure()
            print("Maximum height of tree is {}".format(tree.max_height()))
            print("Average height of tree is {:.2f}".format(tree.avg_height()))

        return True

    # ---- search tree for a value ----------------------------------
    # ---- if found, return the node it is in -----------------------

    def search(self,value):
        '''
        Search a tree for a node (value) and return the
        node value. Return None if not found.
        '''
        if self.root == None: return None
        return self._search(self.root,value)

    def _search(self,curnode,value):
        '''
        Search helper function.
        '''
        if curnode == None: return None
        if value == curnode.data: return curnode
        if value < curnode.data:
            return self._search(curnode.left,value)
        return self._search(curnode.right,value)

    # ---- remove node ----------------------------------------------
    # ---- only leaf nodes can be removed ---------------------------
    # ---- return the node removed ----------------------------------

    def remove_leaf_node(self,value):
        '''
        Remove a node from the tree.
        '''
        if self.root == None: return None
        return self._remove_leaf_node(self.root,value)

    def _remove_leaf_node(self,curnode,value):
        '''
        Remove helper function
        '''
        if curnode == None: return None
        if curnode.data == value:
            if curnode.left != None or curnode.right != None:
                print('can not remove node {}'.format(value))
                print('it is not a leaf node')
                return None
            else:
                cnode = curnode
                if curnode.parent == None:
                    self.root = None
                else:
                    if curnode.parent.left == curnode:
                        curnode.parent.left = None
                    elif curnode.parent.right == curnode:
                        curnode.parent.right = None
                    else:
                        print('Oops, node {} is not the parent of node {}'. \
                            format(curnode.parent.data,value))
                        return None
                cnode.parent = None
                cnode.left   = None
                cnode.right  = None
                return cnode
        elif value < curnode.data:
             return self._remove_leaf_node(curnode.left,value)
        else:
             return self._remove_leaf_node(curnode.right,value)

    # ---- print tree values ----------------------------------------

    def print_tree_values(self):
        '''
        print a tree's values - one per line
        '''
        if self.root == None:
            print("tree empty")
        else:
            self._print_tree_values(self.root)

    def _print_tree_values(self,curnode):
        '''
        print tree helper function
        '''
        if curnode.left:
            self._print_tree_values(curnode.left)
        print("{}".format(curnode.data))
        if curnode.right:
            self._print_tree_values(curnode.right)

    # ---- print tree structure -------------------------------------

    def print_tree_structure(self):
        '''
        print a tree's structure
        '''
        if self.root == None:
            print("tree empty")
        else:
            self._print_tree_structure(self.root,1)

    def _print_tree_structure(self,curnode,n):
        '''
        print tree structure helper function
        '''
        if curnode == None: return n
        if curnode.right:
            self._print_tree_structure(curnode.right,n+1)
        pd = 'None' if curnode.parent == None else curnode.parent.data
        print('{} {}  (p={})'.format('---'*n,curnode.data,pd))
        if curnode.left:
            self._print_tree_structure(curnode.left,n+1)

    # ---- print tree nodes -----------------------------------------

    def print_tree_nodes(self,sep1='',sep2=''):
        '''
        print a tree nodes - one per line
        '''
        if sep1: print(sep1)
        if self.root == None:
            print("tree empty")
        else:
            self._print_tree_nodes(self.root)
        if sep2: print(sep2)

    def _print_tree_nodes(self,curnode):
        '''
        print tree nodes helper function
        ''' 
        if curnode == None: return
        if curnode.right: self._print_tree_nodes(curnode.right)
        curnode.print_node()
        if curnode.left: self._print_tree_nodes(curnode.left)

    # ---- count all tree nodes -------------------------------------

    def count_nodes(self):
        '''
        return the number of nodes in a tree.
        '''
        return self._count_nodes(self.root)

    def _count_nodes(self,curnode):
        '''
        return the number of nodes in a tree helper function
        '''
        if curnode == None: return 0
        return 1 + self._count_nodes(curnode.left) + \
                   self._count_nodes(curnode.right)


# -------------------------------------------------------------------
# main testing
# -------------------------------------------------------------------

if __name__ == '__main__':

    # ---------------------------------------------------------------
    # create various trees for testing
    # ---------------------------------------------------------------

    # ---- create a generic tree
    def create_generic_tree(tree,values):
        for i in values:
            tree.insert(Node(i))

    # ---- create a worst case tree
    def worst_case_tree(tree):
        values = [10,9,8,7,6,5,4,3,2,1,11,12,13,14,15,16,17,18,19]
        print('create worst case tree')
        print(values)
        create_generic_tree(tree,values)

    # ---- create the tree from my documentation
    def my_doc_tree(tree):
        values = [30,40,50,55]
        print('create tree for my documentation')
        print(values)
        create_generic_tree(tree,[30,40,50,55])

    # ---- create a tree with random nodes
    # ---- random with replacement - potential duplicates
    def random_nodes(tree,elms=20,maxint=99):
        print('create tree with random nodes - potential duplicates')
        for _ in range(elms):
            i = randint(1,maxint)
            ##print('RandomInt = {}'.format(i))
            tree.insert(Node(i))

    # ---- create a tree with random nodes
    # ---- random without replacement - no dupicates
    def random_nodes_unique(tree,elms=20,maxint=99):
        print('create tree with random nodes - no duplicates')
        values = sample(range(1,maxint),elms)
        print(values)
        create_generic_tree(tree,value)

    # ---- create a tree with no nodes
    def no_root_node(tree):
        pass

    # ---- create a tree with only 1 node
    def only_root_node(tree):
        values = [100]
        print('create tree with 1 node')
        print(values)
        create_generic_tree(tree,values)

    # ---- create a tree wtih 2 nodes
    def root_node_plus_one(tree):
        values = [100,200]
        print('create tree with 2 nodes')
        create_generic_tree(values)

    # ---------------------------------------------------------------
    # test support functions
    # ---------------------------------------------------------------

    # ---- ask the user for input
    def get_user_input(prompt):
        return input(prompt)

    # ---- pause the program until the user is ready
    def pause_program():
        get_user_input('\nPress enter to continue: ')

    # ---------------------------------------------------------------
    # basic debugging tests - skip other tests
    # ---------------------------------------------------------------

    #################################################################

    def create_random_node_values(elms=20,maxint=100):
        values = sample(range(1,maxint),elms)
        return values

    ##values = [30,20,10,5,1,40,50,60,70]
    ##values = [30,40,50,60,70]
    ##values = [30,20,10,5,1]
    values = [10,9,8,7,6,5,4,3,2,1,11,12,13,14,15,16,17,18,19]

    ##values = create_random_node_values()

    print('\nnode values =')
    print(values)

    print('\n==== create tree (un-balanced) ========================')

    tree = BinaryTree()
    tree.set_debug(False)
    create_generic_tree(tree,values)
    ##tree.print_tree_values()
    ##tree.print_tree_nodes()
    print('\n---------- initial tree structure --------------------')
    tree.print_tree_structure()
    print("Maximum height of tree is {}".format(tree.max_height()))
    print("Average height of tree is {:.2f}".format(tree.avg_height()))

    print('\n==== balanced tree ====================================')

    tree.set_debug(False)
    tree.balance()
    ##tree.print_tree_values()
    ##tree.print_tree_nodes()
    print('\n---------- final tree structure -----------------------')
    tree.print_tree_structure()
    print("Maximum height of tree is {}".format(tree.max_height()))
    print("Average height of tree is {:.2f}".format(tree.avg_height()))

    ##pause_program()

    quit()
    #################################################################


    # ---------------------------------------------------------------
    # tests
    # ---------------------------------------------------------------

    print("\n----create-tree--------------------------------------")

    tree = BinaryTree()

    ##no_root_node(tree)
    ##only_root_node(tree)
    ##root_node_plus_one(tree)
    ##my_doc_tree(tree)
    ##worst_case_tree(tree)
    ##random_nodes(tree,20,99)
    ##random_nodes_unique(tree,20,99)

    print("\n----print-tree-values-(in-order)---------------------")

    tree.print_tree_values()

    print("\n----max-tree-height----------------------------------")

    print("Maximum height of tree is {}".format(tree.max_height()))

    print("\n----print-tree-structure-(view-it-sideways)----------")

    tree.print_tree_structure()

    print("\n----print-nodes-values-------------------------------")

    tree.print_nodes_values()

    print("\n----count-nodes--------------------------------------")

    print('number of tree nodes = {}'.format(tree.count_nodes()))

    print("\n----search-the-tree----------------------------------")

    while True:

        val = get_user_input('\nEnter search value: ')

        sval = val.strip()

        if sval == '': break

        if sval.isdigit() != True:
            print('\nIllegal value entered ({})'.format(sval))
            pause_program()
            continue

        rnode = tree.search(int(sval))

        if rnode == None:
            print('Value {} not found in tree'.format(sval))
        else:
            print('Value {} found in tree'.format(rnode.data))

    print("\n----remove-leaf-node-from-tree-----------------------")

    while True:

        val = get_user_input('\nEnter value (leaf node) to remove: ')

        rval = val.strip()

        if rval == '': break

        if rval.isdigit() != True:
            print('\nIllegal value entered ({})'.format(rval))
            pause_program()
            continue

        rnode = tree.remove_leaf_node(int(rval))

        if rnode == None:
            print('node {} not removed from tree'.format(rval))
            continue

        print('node {} removed from tree'.format(rnode.data))


        tree.print_tree_structure()