balance_tree_values.py

5#!/usr/bin/python3
# ====================================================================
# build and balance trees
# ====================================================================

import user_interface as ui

# --------------------------------------------------------------------
# ---- tree node class
# ----
# ---- Traditionally the child nodes of a parent node are named
# ---- left and right. In this code I named them less and more
# ---- (less than and more than) to indicate their relations
# ---- to the parent node.
# --------------------------------------------------------------------

class Node():

    def __init__(self,value):
        self.value  = value     # node value
        self.parent = None      # parent node
        self.less   = None      # left child
        self.more   = None      # right child

# --------------------------------------------------------------------
# ---- display node
# --------------------------------------------------------------------

def display_node(node):
    
    print(f'------ node {node.value:<3}')

    if node.parent is None:
        print('parent None')
    else:
        print(f'parent {node.parent.value}')
        
    if node.less is None:
        print('less   None')
    else:
        print(f'less   {node.less.value}')

    if node.more is None:
        print('more   None')
    else:
        print(f'more   {node.more.value}')

# --------------------------------------------------------------------
# ---- calculate the maximum height of the tree
# ---- (call function recursively)
# --------------------------------------------------------------------

def height(node):

    if node is None: return 0

    # --- find the height of the node.less subtree

    h_less = height(node.less)

    # --- find the height of the node.more subtree

    h_more = height(node.more)

    # ---- find max of node.less and node.more subtrees
    # ---- add 1 to it and return the value

    if h_less > h_more:
        return h_less + 1
    return h_more + 1        

# --------------------------------------------------------------------
# ---- traverse a tree in order - return array(s)
# ---- (call function recursively)
# --------------------------------------------------------------------

def traverse_tree_in_order(node,nodes,values):

    if node is None: return
    traverse_tree_in_order(node.less,nodes,values)
    if nodes  is not None: nodes.append(node)
    if values is not None: values.append(node.value)
    traverse_tree_in_order(node.more,nodes,values)


# --------------------------------------------------------------------
# ---- display tree node values in order
# --------------------------------------------------------------------

def display_node_values(root):

    # ---- get a list of tree values in order
    
    values = []
    traverse_tree_in_order(root,None,values)

    print()
    if len(values) == 0:
        print('value list is empty')
    else:
        for v in values:
            print(f'{v} ',end='')    
        print('\n')

# --------------------------------------------------------------------
# ---- balance a binary tree
# ---- built a tree from a list of ordered node values
# ---- (setup then call a recursive function)
# ---- based on:
# ---- www.geeksforgeeks.org/convert-normal-bst-balanced-bst/
# --------------------------------------------------------------------

def construct_balanced_tree(root):

    control = [0,None,None]

    def balanced_tree_util(values,start,end,control):

        if start > end: return None
        
        # ---- get middle value
        mid = (start + end)//2        
        value = values[mid]

        # ---- construct/add a new node to balanced tree
        node = Node(value)        
        if control[0] == 0:
            control[1] = node              # tree's root node  
            control[0] += 1
        else:
            add_node(node,control[1])      # add node to tree
            control[0] += 1

        # ---- construct the less and more subtree links
        v = balanced_tree_util(values,start,mid-1,control)
        if v is not None:
            node.less = Node(v)
            node.less.parent = new_node
        v = balanced_tree_util(values,mid+1,end,control)
        if v is not None:
            node.more = Node(v)
            node.more.parent = new_node

        return

    # ---- get a list of existing tree nodes in order
    values = []
    traverse_tree_in_order(root,None,values)

    # ---- create a new balanced tree
    n = len(values)    
    balanced_tree_util(values,0,n-1,control)

    # ---- return the balanced tree
    return control[1]       

# --------------------------------------------------------------------
# ---- add a node to a tree (below the parent)
# ---- (call function recursively)
# --------------------------------------------------------------------

def add_node(node,parent):

    if node.value < parent.value:
        if parent.less == None:
            parent.less = node
            node.parent = parent
            return
        parent = add_node(node,parent.less)
        return
    
    if parent.more == None:
        parent.more = node
        node.parent = parent
        return
    parent = add_node(node,parent.more)
    return

# --------------------------------------------------------------------
# ---- tree statistics - count nodes
# ---- Note: a list is used to hold the counts because it is mutable
# ----       integers, etc. are not
# ---- (setup then call recursive function)
# --------------------------------------------------------------------

def tree_statistics(root):

    def tree_stats(node,stats):
        if node is None:
            stats[1] += 1            # none count
            return
        stats[0] += 1                # node count
        tree_stats(node.less,stats)
        tree_stats(node.more,stats)        
        return

    stats = [0,0]                    # [node_count, none_count]
    tree_stats(root,stats)        
    return (stats[0],stats[1])
       
# --------------------------------------------------------------------
# ---- display tree nodes
# ---- (setup then call recursive function)
# --------------------------------------------------------------------

def list_tree(root):

    def list_nodes(node):
        display_node(node)
        if node.less is not None: list_nodes(node.less)
        if node.more is not None: list_nodes(node.more)

    print()
    ##print('------------------------------------')
    if root is None:
        print('tree is empty')
    else:
        list_nodes(root)
    ##print('------------------------------------')
    return

# --------------------------------------------------------------------
# ---- create a new node - ask the user for its value
# --------------------------------------------------------------------

def new_node():

    node = None

    while True:

        # ---- get node value from the users

        print()
        s = ui.get_user_input('Enter a tree node value [-999 to 999]: ')
        if not s: break

        tf,v = ui.is_integer(s)
        if not tf or v < -999 or v > 999:
            print()
            print(f'bad node value input ({s})')
            continue

        # ---- create a new node with value v
        
        node = Node(v)
        break

    return node

# --------------------------------------------------------------------
# ---- construct a tree from values in a CSV string
# --------------------------------------------------------------------

def construct_tree_from_csv_string():

    # ---- get cvs string

    print()
    cvs_str = ui.get_user_input('Enter CSV string: ')
    if not cvs_str:
        print()
        print('tree is unchanged')
        return None

    # ---- break  string into list

    lst = cvs_str.replace(',',' ').split()

    for i,s in enumerate(lst):

        tf,v = ui.is_integer(s)

        if not tf or v < -999 or v > 999:
            print()
            print(f'illegal node value is CSV string ({s})')
            print('exit function - nothing modified or created')
            return None

        lst[i] = v

    root = Node(lst[0])

    for v in lst[1:]:
        node = Node(v)
        add_node(node,root)

    return root

# --------------------------------------------------------------------
# ---- main
# --------------------------------------------------------------------

if __name__ == '__main__':

    menu = '''
 option  description
 ------  -------------------------------- 
    1    add node to tree
    2    list tree nodes
    3    tree statistics
    4    new tree (create only root node)
    5    create tree from CSV string

   10    construct balance tree

   20    display node values in order

   99    exit
'''

    root = Node(0)

    while True:

        # --- get and verify user option selection

        print(menu)
        s = ui.get_user_input(' select an option: ')
        if not s: break

        tf,opt = ui.is_int(s)
        if not tf:
            print(f'illegal option ({s}) - try again')
            continue

        # ---- add node to tree
        if opt == 1:            
            node = new_node()
            if Node: add_node(node,root)
            continue

        # ---- list tree nodes
        if opt == 2:
            list_tree(root)
            continue
        
        # ---- display tree statistics
        if opt == 3:
            stats = tree_statistics(root)
            print()
            print(f'node count  = {stats[0]}')
            print(f'none count  = {stats[1]}')
            print(f'tree height = {height(root)}')
            continue

        # ---- delete current tree
        # ---- initialize a new tree root node (value=0)
        if opt == 4:
            root = Node(0)
            continue

        # ---- construct a tree from CSV string values
        if opt == 5:
            x = construct_tree_from_csv_string()
            if x is not None: root = x
            continue

        # ---- construct a balanced tree from current tree
        if opt == 10:
            print()
            print(f'max height before balancing = {height(root)}') 
            x = construct_balanced_tree(root)
            if x is not None: root = x
            print(f'max height after  balancing = {height(root)}') 
            continue

        # ---- display tree node values in order
        if opt == 20:
            display_node_values(root)
            continue

        # ---- exit program
        if opt == 99:
            break

        print(f'illegal option ({opt}) - try again')