#! /usr/bin/python3 # =================================================================== # The balance algorithem tries to balance a new node's # child nodes everytime a new node is inserted into the tree # =================================================================== # create a sort of balanced tree # # The node class in this code has only one data item. It is the # node's data and also is the node's unique key. In the # "real world" a node would contain much more data and perhaps # have a unique key that is seperate from the node's data. # # Note: A node's unique key is used to position a node's place # in the tree. # =================================================================== # Tree Functions... # tree.debug set debug flag # tree.balance set balance processing flag # tree.count_nodes return the number of nodes in the tree # tree.max_height return the max depth (levels) in the tree # tree.avg_height return average search depth in the tree # tree.insert insert node into tree # tree.search search for a value (node) in the tree # tree.remove_leaf_node remove a leaf node from the tree # tree.print_tree_nodes print the tree node information # tree.print_tree_values print the tree node values # tree.Print_tree_structure print the tree structure # =================================================================== # Note about the code: # a. curnode is the current working node # b. newnode is the node being added to the tree # c. parnode is the parent node of current working node # d. the tree root node is modified but not replaced # e. the BinaryTree class has two special methods. They control # BinaryTree operations. they are: # tree.set_debug() print messages used for debugging # tree.set_balance() use my balance "stuff" when creating the tree # =================================================================== # # Testing: # # from random import randint # from random import sample # import btree_9 as bt9 # # # ---- function - create random node values ----------------------- # # ---- with potential duplicates ----------------------- # # def create_random_node_values_dups(elms=20,maxint=100): # values = [] # for _ in range(elms): # values.append(randint(1,maxint)) # return values # # # ---- function - create random node values ----------------------- # # ---- with no duplicates ------------------------------ # # def create_random_node_values(elms=20,maxint=100): # values = sample(range(1,maxint),elms) # return values # # #---- function - fill a tree with nodes (values) ------------------ # # def fill_tree(tree,values): # for i in values: # tree.insert(bt9.Node(i)) # # # ---- test values ------------------------------------------------ # # ##values = [30,20,10,5,1,40,50,60,70] # ##values = [30,40,50,60,70] # ##values = [30,20,10,5,1] # ##values = [10,9,8,7,6,5,4,3,2,1,11,12,13,14,15,16,17,18,19] # # values = create_random_node_values(100,10000) # # print('\n---- create un-balanced tree --------------------------') # # tree = bt9.BinaryTree() # tree.set_balance(False) # tree.set_debug(False) # fill_tree(tree,values) # ##tree.print_tree_values() # ##tree.print_tree_nodes() # tree.print_tree_structure() # print("Maximum height of tree is {}".format(tree.max_height())) # print("Average height of tree is {}".format(tree.avg_height())) # # print('\n---- create a sort of balanced tree -------------------') # # tree = bt9.BinaryTree() # tree.set_balance(True) # tree.set_debug(False) # fill_tree(tree,values) # ##tree.print_tree_values() # ##tree.print_tree_nodes() # tree.print_tree_structure() # print("Maximum height of tree is {}".format(tree.max_height())) # print("Average height of tree is {}".format(tree.avg_height())) # # =================================================================== import numpy as np from random import randint from random import sample # ------------------------------------------------------------------- # class definitions # ------------------------------------------------------------------- class Node: """ This class defines nodes to be inserted into a btree. This class is a container for user define data. The only required field is 'data' which is a unique key used to position a node in the btree and contain user data. """ def __init__(self,data,print_flag=False): ''' Constructor for btree Node class. Attributes: data (int): node data and node unique key used to position the node in the btree. left (node): link to the root node of a subtree containing nodes with with data (keys) less than this node. right (node): link to the root node of a subtree containing nodes with with data (keys) greater than this node. parent (node): link to this node's parent node in the tree. If there is no parent (tree's root node) the link is set to None. Parameters: data: A node's data (key) value. flag: A Flag to print debug information. True will print information. False will not. If this parameter is not provided the default is False. ''' self.data = data self.left = None self.right = None self.parent = None if print_flag: print('\n>>>> create node {} <<<<'.format(data)) def print_node(self,title='',single_line = True): ''' Print node information in one of two ways. single line or multiline. Parameters: title (str): A string to print before printing the node information. single_line (bool): If True a single line is print if not True multiple lines are printed. If title is None, an empty string, or not provided no title is printed. ''' d = 'None' if self.data == None else self.data l = 'None' if self.left == None else self.left.data r = 'None' if self.right == None else self.right.data p = 'None' if self.parent == None else self.parent.data if title: print(title) if single_line: print('(d={},l={},r={},p={})'.format(d,l,r,p)) else: print('data = {}'.format(d)) print('left = {}'.format(l)) print('right = {}'.format(r)) print('parent = {}'.format(p)) class BinaryTree: # ---- init btree ---------------------------------------------- def __init__(self): ''' Constructor for the BinaryTree class. Attributes: root The root node of the btree. Initially it is set to None. The first node inserted into the btree becomes the root node. balance If True, try to ballance the current node. debug If True, print debug messages. ''' self.root = None self.balance = False self.debug = False # ---- set internal state -------------------------------------- def set_balance(self,flag=False): ''' Set "create balanced tree" flag. The default is False. ''' self.balance = flag def set_debug(self,flag=False): ''' Set debug flag. The default is False. ''' self.debug = flag ################################################################ # ---- debug support functions --------------------------------- def _get_user_input(self,prompt): ''' Prompt the user and return their input Return the users input. ''' return input(prompt) def _pause_program(self): ''' Pause the program until the user is ready. ''' self._get_user_input('\nPress enter to continue: ') def _debug_info(self,node,msg='',pause=False): ''' Print node debug information. ''' if msg: print(msg) ##print(type(node)) node.print_node() if (pause): self._pause_program() ################################################################ # ---- find the maximum tree height ---------------------------- def max_height(self): ''' Return the maximum number of levels in a tree. ''' return self._max_height(self.root,0) def _max_height(self,curnode,curheight): ''' Max height helper function. ''' if curnode == None: return curheight left_height = self._max_height(curnode.left,curheight+1) right_height = self._max_height(curnode.right,curheight+1) return max(left_height,right_height) # ---- find the average tree height ---------------------------- def avg_height(self): ''' Return the average number of levels in a tree. Note: only paths to leaf nodes are used in this calculation. That is because we are looking at worst case (not found) searches. Note: Because the Python scoping rules suck, I am using the list x where: x[0] = total leaf node path lengths x[1] = number of leaf node paths ''' if self.root == None: return 0.0 x = [0,0] self._avg_height(self.root,0,x) return float(x[0]/x[1]) def _avg_height(self,curnode,curheight,x): ''' Average height helper function. ''' ##print("---------------------------------------------") ##print("curnode = {}".format(curnode.data)) ##print("path lengths = {}".format(x[0])) ##print("path count = {}".format(x[1])) ##print(" X[0],x[1] = {},{}".format(x[0],x[1])) if curnode.left == None: x[0] += curheight x[1] += 1 else: self._avg_height(curnode.left,curheight+1,x) if curnode.right == None: x[0] += curheight x[1] += 1 else: self._avg_height(curnode.right,curheight+1,x) return # ---- insert a node into tree --------------------------------- def insert(self,newnode): ''' Insert a new node into the tree. ''' if self.root == None: self.root = newnode ##if self.debug: ## self._debug_info(newnode, ## '---- insert root node',True) return True else: return self._insert(newnode,self.root) def _insert(self,newnode,curnode): ''' Insert helper function. ''' if newnode.data < curnode.data: if curnode.left == None: newnode.parent = curnode curnode.left = newnode ##if self.debug: ## self._debug_info(newnode, ## '---- insert node <',True) if self.balance: return self._balance(curnode) return True else: return self._insert(newnode,curnode.left) elif newnode.data > curnode.data: if curnode.right == None: newnode.parent = curnode curnode.right = newnode ##if self.debug: ## self._debug_info(newnode, ## '---- insert node >',True) if self.balance: return self._balance(curnode) return True else: return self._insert(newnode,curnode.right) else: print("Tree node {} already in tree".format(curnode.data)) return False # ---- balance tree helper function ----------------------------- def _balance(self,curnode): ''' Balance inserted node helper function ''' if self.debug: print('---- begin balance ---------------------') self.print_tree_structure() print("Maximum height of tree is {}".format(self.max_height())) ## ---- test current node ----------------------------------- if curnode == None: if self.debug: print('curnode {} is None'.format(curnode.data)) return False if curnode == self.root: if self.debug: print('curnode {} is tree root'.format(curnode.data)) return True # ---- test current node's parent node ---------------------- parnode = curnode.parent if parnode == self.root: if self.debug: print('parnode {} is tree root'.format(parnode.data)) return True if parnode.left != curnode and parnode.right != curnode: print('Oops, parnode {} is not the parent of curnode {}'. format(parnode.data,curnode.data)) parnode.print_node('node: parent') curnode.print_node('node: curnode') print('node: curnode.left {}'. format(curnode.left.data if curnode.left else 'None')) print('node: curnode.right {}'. format(curnode.right.data if curnode.right else 'None')) return False # ---- test current node's grand parent node ---------------- pparnode = curnode.parent.parent if pparnode.left != parnode and pparnode.right != parnode: print('Oops, pparnode {} is not the parent of parnode {}'. format(pparnode.data,parnode.data)) pparnode.print_node('node: pparnode') parnode.print_node('node: parent') curnode.print_node('node: curnode') print('node: curnode.left {}'. format(curnode.left.data if curnode.left else 'None')) print('node: curnode.right {}'. format(curnode.right.data if curnode.right else 'None')) return False # ---- balance the current node ----------------------------- if curnode.data < parnode.data: if curnode.right != None: return True curnode.right = parnode elif curnode.data > parnode.data: if curnode.left != None: return True curnode.left = parnode else: return True if pparnode.left == parnode: pparnode.left = curnode else: pparnode.right = curnode curnode.parent = pparnode if parnode.left == curnode: parnode.left = None elif parnode.right == curnode: parnode.right = None parnode.parent = curnode if self.debug: print('---- after balance ----------------------') self.print_tree_structure() print("Maximum height of tree is {}".format(self.max_height())) return True # ---- search tree for a value ---------------------------------- # ---- if found, return the node it is in ----------------------- def search(self,value): ''' Search a tree for a node (value) and return the node value. Return None if not found. ''' if self.root == None: return None return self._search(self.root,value) def _search(self,curnode,value): ''' Search helper function. ''' if curnode == None: return None if value == curnode.data: return curnode if value < curnode.data: return self._search(curnode.left,value) return self._search(curnode.right,value) # ---- remove node ---------------------------------------------- # ---- only leaf nodes can be removed --------------------------- # ---- return the node removed ---------------------------------- def remove_leaf_node(self,value): ''' Remove a node from the tree. ''' if self.root == None: return None return self._remove_leaf_node(self.root,value) def _remove_leaf_node(self,curnode,value): ''' Remove helper function ''' if curnode == None: return None if curnode.data == value: if curnode.left != None or curnode.right != None: print('can not remove node {}'.format(value)) print('it is not a leaf node') return None else: cnode = curnode if curnode.parent == None: self.root = None else: if curnode.parent.left == curnode: curnode.parent.left = None elif curnode.parent.right == curnode: curnode.parent.right = None else: print('Oops, node {} is not the parent of node {}'. \ format(curnode.parent.data,value)) return None cnode.parent = None cnode.left = None cnode.right = None return cnode elif value < curnode.data: return self._remove_leaf_node(curnode.left,value) else: return self._remove_leaf_node(curnode.right,value) # ---- print tree values ---------------------------------------- def print_tree_values(self): ''' print a tree's values - one per line ''' if self.root == None: print("tree empty") else: self._print_tree_values(self.root) def _print_tree_values(self,curnode): ''' print tree helper function ''' if curnode.left: self._print_tree_values(curnode.left) print("{}".format(curnode.data)) if curnode.right: self._print_tree_values(curnode.right) # ---- print tree structure ------------------------------------- def print_tree_structure(self): ''' print a tree's structure ''' if self.root == None: print("tree empty") else: self._print_tree_structure(self.root,1) def _print_tree_structure(self,curnode,n): ''' print tree structure helper function ''' if curnode == None: return n if curnode.right: self._print_tree_structure(curnode.right,n+1) pd = 'None' if curnode.parent == None else curnode.parent.data print('{} {} (p={})'.format('---'*n,curnode.data,pd)) if curnode.left: self._print_tree_structure(curnode.left,n+1) # ---- print tree nodes ----------------------------------------- def print_tree_nodes(self,sep1='',sep2=''): ''' print a tree nodes - one per line ''' if sep1: print(sep1) if self.root == None: print("tree empty") else: self._print_tree_nodes(self.root) if sep2: print(sep2) def _print_tree_nodes(self,curnode): ''' print tree nodes helper function ''' if curnode == None: return if curnode.right: self._print_tree_nodes(curnode.right) curnode.print_node() if curnode.left: self._print_tree_nodes(curnode.left) # ---- count all tree nodes ------------------------------------- def count_nodes(self): ''' return the number of nodes in a tree. ''' return self._count_nodes(self.root) def _count_nodes(self,curnode): ''' return the number of nodes in a tree helper function ''' if curnode == None: return 0 return 1 + self._count_nodes(curnode.left) + \ self._count_nodes(curnode.right) # ------------------------------------------------------------------- # main testing # ------------------------------------------------------------------- if __name__ == '__main__': # --------------------------------------------------------------- # create various trees for testing # --------------------------------------------------------------- # ---- create a generic tree def create_generic_tree(tree,values): for i in values: tree.insert(Node(i)) # ---- create a worst case tree def worst_case_tree(tree): values = [10,9,8,7,6,5,4,3,2,1,11,12,13,14,15,16,17,18,19] print('create worst case tree') print(values) create_generic_tree(values) # ---- create the tree from my documentation def my_doc_tree(tree): values = [30,40,50,55] print('create tree for my documentation') print(values) create_generic_tree([30,40,50,55]) # ---- create a tree with random nodes # ---- random with replacement - potential duplicates def random_nodes(tree,elms=20,maxint=99): print('create tree with random nodes - potential duplicates') for _ in range(elms): i = randint(1,maxint) ##print('RandomInt = {}'.format(i)) tree.insert(Node(i)) # ---- create a tree with random nodes # ---- random without replacement - no dupicates def random_nodes_unique(tree,elms=20,maxint=99): print('create tree with random nodes - no duplicates') values = sample(range(1,maxint),elms) print(values) create_generic_tree(value) # ---- create a tree with no nodes def no_root_node(tree): pass # ---- create a tree with only 1 node def only_root_node(tree): values = [100] print('create tree with 1 node') print(values) create_generic_tree(values) # ---- create a tree wtih 2 nodes def root_node_plus_one(tree): values = [100,200] print('create tree with 2 nodes') create_generic_tree(values) # --------------------------------------------------------------- # test support functions # --------------------------------------------------------------- # ---- ask the user for input def get_user_input(prompt): return input(prompt) # ---- pause the program until the user is ready def pause_program(): get_user_input('\nPress enter to continue: ') # --------------------------------------------------------------- # basic debugging # --------------------------------------------------------------- ################################################################# ##print("\n----debug----skip-other-testing----------------------") ##values = [30,20,10,5,1,40,50,60,70] ##values = [30,40,50,60,70] ##values = [30,20,10,5,1] values = [10,9,8,7,6,5,4,3,2,1,11,12,13,14,15,16,17,18,19] print('\nnode values =') print(values) print('\n---- create un-balanced tree --------------------------') tree = BinaryTree() tree.set_debug(False) tree.set_balance(False) create_generic_tree(tree,values) ##tree.print_tree_values() ##tree.print_tree_nodes() tree.print_tree_structure() print("Maximum height of tree is {}".format(tree.max_height())) print("Average height of tree is {}".format(tree.avg_height())) print('\n---- create balanced tree -----------------------------') tree = BinaryTree() tree.set_debug(False) tree.set_balance(True) create_generic_tree(tree,values) ##tree.print_tree_values() ##tree.print_tree_nodes() tree.print_tree_structure() print("Maximum height of tree is {}".format(tree.max_height())) print("Average height of tree is {}".format(tree.avg_height())) ##pause_program() quit() ################################################################# # --------------------------------------------------------------- # tests # --------------------------------------------------------------- print("\n----create-tree--------------------------------------") tree = BinaryTree() ##no_root_node(tree) ##only_root_node(tree) ##root_node_plus_one(tree) ##my_doc_tree(tree) ##worst_case_tree(tree) ##random_nodes(tree,20,99) ##random_nodes_unique(tree,20,99) print("\n----print-tree-values-(in-order)---------------------") tree.print_tree_values() print("\n----max-tree-height----------------------------------") print("Maximum height of tree is {}".format(tree.max_height())) print("\n----print-tree-structure-(view-it-sideways)----------") tree.print_tree_structure() print("\n----print-nodes-values-------------------------------") tree.print_nodes_values() print("\n----count-nodes--------------------------------------") print('number of tree nodes = {}'.format(tree.count_nodes())) print("\n----search-the-tree----------------------------------") while True: val = get_user_input('\nEnter search value: ') sval = val.strip() if sval == '': break if sval.isdigit() != True: print('\nIllegal value entered ({})'.format(sval)) pause_program() continue rnode = tree.search(int(sval)) if rnode == None: print('Value {} not found in tree'.format(sval)) else: print('Value {} found in tree'.format(rnode.data)) print("\n----remove-leaf-node-from-tree-----------------------") while True: val = get_user_input('\nEnter value (leaf node) to remove: ') rval = val.strip() if rval == '': break if rval.isdigit() != True: print('\nIllegal value entered ({})'.format(rval)) pause_program() continue rnode = tree.remove_leaf_node(int(rval)) if rnode == None: print('node {} not removed from tree'.format(rval)) continue print('node {} removed from tree'.format(rnode.data)) tree.print_tree_structure()