AVL Balance Binary Search Tree in C++
Implementation of AVL Balance binary search tree in C++.
Github: sourcecode/AVLbalancedtree.h
#include <iostream>
#include "queueusinglinkedlist.h"
using namespace std;
/* AVL balanced binary search tree class*/
template <typename T> class avlbalancedbst
{
private:
/* node */
template <typename T>class Node
{
public:
Node<T>* left; //left tree
Node<T>* right; //right tree
T data; //data
int height; //height of the tree
int bf; //balance factor
//constructor
Node<T>() :left(nullptr), right(nullptr)
{
}
};
//root node
Node<T>* root;
public:
//constructor
avlbalancedbst() :root(nullptr)
{
}
//add element
void addrecursive(T elem)
{
//if the element is already found, return
if (findelement(root, elem))
{
return;
}
//add element
root = addrecursive(root, elem);
}
//recursive call to add the element to the tree
Node<T>* addrecursive(Node<T>* node, T elem)
{
//if node is null, add new node and return new node
if (node == nullptr)
{
//new node creation
Node<T>* new_node = new Node<T>();
new_node->data = elem;
node = new_node;
}
else
{
//if element value is less than node value, then traverse the left sub tree
if (elem < node->data)
{
node->left = addrecursive(node->left, elem);
}
else
{
//if element value is greater than node value, then traverse the right sub tree
node->right = addrecursive(node->right, elem);
}
}
//find balance factor of the node
findbalancefactor(node);
//balance the tree and return the node
return balance(node);
}
//balance the tree
Node<T>* balance(Node<T>* node)
{
if (node == nullptr)
return nullptr;
//left side heavy tree
if (node->bf == -2)
{
//rightrotation
if (node->left->bf <= 0)
{
return leftleftrotation(node);
}
else
{
//leftright rotation
return leftrightrotation(node);
}
}
//right side heavy tree
else if (node->bf == +2)
{
//balance factor is greater or equal to zero
if (node->right->bf >= 0)
{
//leftrotation similar to rightright rotation
return rightrightrotation(node);
}
else
{
//rightleft rotation
return rightleftrotation(node);
}
}
//if bf is 0 (or) -1 (or) 1
return node;
}
//left left rotation equal to right rotation
Node<T>* leftleftrotation(Node<T>* node)
{
//right rotation
return rightrotation(node);
}
//right right rotation equal to left rotation
Node<T>* rightrightrotation(Node<T>* node)
{
//left rotation
return leftrotation(node);
}
//right left rotation
Node<T>* rightleftrotation(Node<T>* node)
{
node->right = rightrotation(node->right);
return leftrotation(node);
}
//left right rotation
Node<T>* leftrightrotation(Node<T>* node)
{
node->left = leftrotation(node->left);
return rightrotation(node);
}
//left rotation
Node<T>* leftrotation(Node<T>* node)
{
Node<T>* newparent = node->right;
node->right = newparent->left;
newparent->left = node;
//find balance factor of the node
findbalancefactor(node);
//find balance factor of new parent
findbalancefactor(newparent);
return newparent;
}
//right rotation
Node<T>* rightrotation(Node<T>* node)
{
Node<T>* newparent = node->left;
node->left = newparent->right;
newparent->right = node;
//find balance factor of the node
findbalancefactor(node);
//find balance factor of new parent
findbalancefactor(newparent);
return newparent;
}
//find the balance factor of the node
void findbalancefactor(Node<T>* node)
{
if (node == nullptr)
return;
int lh, rh;
lh = rh = 0;
//left node height
if (node->left)
{
lh = node->left->height;
}
//right node height
if (node->right)
{
rh = node->right->height;
}
//node height is maximum of left and right sub tree
node->height = std::max(lh, rh) + 1;
//balance factor is difference between the height if the right and left subtree
node->bf = rh - lh;
}
//add the element iterative method
void add(T elem)
{
Node<T>* new_node = new Node<T>();
new_node->data = elem;
//root node creation
if (root == nullptr)
{
root = new_node;
}
else //traverse the tree and add element accordingly
{
Node<T>* node = root;
//traverse till leaf node
while (node != nullptr)
{
//if the element value is less than the node's data
if (elem < node->data) //travese left sub tree
{
//if the left node is null, assign the newnode as left node
if (node->left == nullptr)
{
node->left = new_node;
return;
}
else
{
//traverse the left sub tree
node = node->left;
}
}
//if the element value is greater than the node's data
else if (elem > node->data)
{
//if the right node is null, assign the newnode as right node
if (node->right == nullptr)
{
node->right = new_node;
return;
}
else
{
//traverse the right sub tree
node = node->right;
}
}
}
}
}
//print level order traversal of the tree
void print_levelordertraversal()
{
if (root == nullptr)
{
return;
}
Node<T>* node = root;
cout << "level order traversal "<<endl;
//create queue
//Here, used custom queue created using linked list
queuelist<Node<T>*>* queue = new queuelist<Node<T>*>();
//add the node to the queue
queue->enqueue(node);
//iterate till queue becomes empty
while (!queue->isEmpty())
{
//remove the element from the queue
node = queue->dequeue();
if (node == nullptr)
{
cout << "null ";
}
else
{
cout <<"data "<< node->data << " bf = " << node->bf << " height = " << node->height << endl;
//If left node is valid
if (node->left)
{
//add left node to the queue
queue->enqueue(node->left);
}
//If right node is valid
if (node->right)
{
//add right node to the queue
queue->enqueue(node->right);
}
}
}
cout << endl;
}
//remove the element
void remove(T elem)
{
if (root == nullptr)
{
cout << "no element in bst" << endl;
return;
}
else
{
//if element is found
if (findelement(root, elem))
{
//recursive remove function
root = remove(root, elem);
}
else
{
cout << "element not found in bst" << endl;
}
}
}
//find the element from the tree
bool findelement(Node<T>* node, T elem)
{
//traverse till the node is valid
while (node)
{
//element is same node's data
if (node->data == elem)
{
return true;
}
//if the element value is less than the node data, traverse on the left tree
else if (node->data > elem)
{
node = node->left;
}
else
{
//traverse the right tree
node = node->right;
}
}
return false;
}
//recursive method to remove the element
Node<T>* remove(Node<T>* node, T elem)
{
if (node == nullptr)
{
return nullptr;
}
Node<T>* temp = nullptr;
//if the element value is less than the node data, traverse on the left tree
if (elem < node->data)
{
//left tree traversal
node->left = remove(node->left, elem);
}
else if (elem > node->data)
{
//right tree traversal
node->right = remove(node->right, elem);
}
else
{
//if both the left and right node is null, then delete the node
if (node->left == nullptr && node->right == nullptr)
{
delete node;
node = nullptr;
}
//if only right node is null
else if (node->right == nullptr)
{
//assign the node->left as node
temp = node;
node = node->left;
//delete node
delete temp;
temp = nullptr;
}
//if only left node is null
else if (node->left == nullptr)
{
//assign the node->right as node
temp = node;
node = node->right;
//delete node
delete temp;
temp = nullptr;
}
else
{
//if both left and right node is valid, find the min node from the right tree
temp = findminnode(node->right);
if (temp)
{
node->data = temp->data;
}
//traverse through the right tree to remove the data
node->right = remove(node->right, temp->data);
}
}
//find the balance factor after the remove
findbalancefactor(node);
//balance the tree
return balance(node);
}
//find min node of the right subtree
Node<T>* findminnode(Node<T>* node)
{
Node<T>* mindata = nullptr;
//if node is null
if (node == nullptr)
return mindata;
mindata = node;
//traverse till leaf node
while (node != nullptr)
{
//node data value is less than mindata value
if (mindata->data > node->data)
{
mindata = node;
}
//since finding minimum, traverse through left sub tree
node = node->left;
}
return mindata;
}
//find the maximum value by traverse through right sub tree
Node<T>* findmaxnode(Node<T>* node)
{
//node is null
if (node == nullptr)
{
return nullptr;
}
Node<T>* maxdata = node;
//traverse till leaf node
while (node != nullptr)
{
//node data value is greater than maxdata value
if (maxdata->data < node->data)
{
maxdata = node;
}
//since finding maximum, traverse through right sub tree
node = node->right;
}
return maxdata;
}
//find height of the BST
int height()
{
return height(root);
}
//recursive function to find the height of the node
int height(Node<T>* node)
{
if (node == nullptr)
{
return 0;
}
else
{
//max between right and left node
return std::max(height(node->left), height(node->right)) + 1;
}
}
//print preorder
void print_preordertraversal()
{
if (root != nullptr)
{
Node<T>* node = root;
//create stack
//here created stack using custom stack built by linkedlist
liststack<Node<T>*>* stack = new liststack<Node<T>*>();
//push the root node to the stack
stack->push(node);
cout << " pre order ";
//traverse till the stack is empty
while (!stack->isEmpty())
{
//Get the top element from the stack
node = stack->top();
//remove the top element from the stack
stack->pop();
cout << node->data << " ";
//if right node is valid, push the right node first
if (node->right)
{
stack->push(node->right);
}
//if left node is valid, push the left node
if (node->left)
{
stack->push(node->left);
}
//traverse the left tree
node = node->left;
}
delete stack;
}
cout << endl;
}
//print inorder
void print_inordertraversal()
{
if (root != nullptr)
{
Node<T>* node = root;
//create stack
//here created stack using custom stack built by linkedlist
liststack<Node<T>*>* stack = new liststack<Node<T>*>();
cout << " in order ";
//traverse till the stack is empty
while (!stack->isEmpty() || node != nullptr)
{
if (node)
{
//push the node
stack->push(node);
//traverse through the left tree
node = node->left;
}
else
{
//get the top element from the stack
node = stack->top();
//remove the top element from the stack
stack->pop();
cout << node->data << " ";
//traverse through the right tree
node = node->right;
}
}
delete stack;
}
cout << endl;
}
};
template <typename T> class testavlbalancedbst
{
public:
void testexecution()
{
avlbalancedbst<int>* bst = new avlbalancedbst<int>();
bst->addrecursive(10);
bst->addrecursive(9);
bst->addrecursive(7);
bst->addrecursive(16);
bst->addrecursive(18);
bst->addrecursive(15);
bst->addrecursive(5);
bst->addrecursive(40);
bst->addrecursive(20);
bst->print_levelordertraversal();
bst->print_inordertraversal();
bst->print_preordertraversal();
bst->print_levelordertraversal();
cout << "height of the bst === " << bst->height() << endl;
bst->remove(5);
bst->remove(9);
bst->remove(18);
bst->remove(0);
bst->print_levelordertraversal();
cout << endl << endl;
avlbalancedbst<int>* bst1 = new avlbalancedbst<int>();
cout << "bst1 height " << bst1->height() << endl;
bst1->addrecursive(10);
cout << "bst1 height " << bst1->height() << endl;
bst1->addrecursive(5);
cout << "bst1 height " << bst1->height() << endl;
bst1->addrecursive(19);
cout << "bst1 height " << bst1->height() << endl;
bst1->addrecursive(23);
bst1->addrecursive(7);
bst1->print_levelordertraversal();
}
};