In this article, we will implement the algorithm to find the distance between any two nodes in a binary tree using Java. The basic of this algorithm is to first find the lowest common ancestor of given nodes and then find the individual distance(d1, d2) to these nodes from the LCA node and the sum of d1 and d2 would be our final result.

Below is the sample binary tree that we will be using for this purpose and if we assume the nodes to be `30 and 25`

then the distance between these two nodes would be `3`

.

As discussed our first task is to find the LCA and the next would be to find the distance from each node to LCA.

## Binary Tree Node Implementation

Below is our representation of the binary tree node in Java.

public class BinaryTreeNode { private int data; private BinaryTreeNode leftNode; private BinaryTreeNode rightNode; //setters and getters }

Below is the template of our DistanceBetweenNodes class implementation.

public class DistanceBetweenNodes { public static int findDistanceBetweenNodes(BinaryTreeNode root, BinaryTreeNode n1, BinaryTreeNode n2){ } private static int findDistanceBetween(BinaryTreeNode root, BinaryTreeNode n1, int distance) { } private static BinaryTreeNode findLCA(BinaryTreeNode root, BinaryTreeNode n1, BinaryTreeNode n2) { } public static void main(String[] args){ } }

## Find LCA of Two Nodes

In my last article, we discussed the algorithm to find the LCA of two nodes. You can visit that article for a detailed explanation.

LCA of two nodes is the first common ancestor node of given nodes. The algorithm recursively searches for the nodes and if any of the nodes are found then the node is returned or else null is returned. Hence, for a node to be the common ancestor its left and the right child must return a non-null node. You can visualize it with below image.

Let us write the code directly as we have already discussed its algorithm in the last article.

private static BinaryTreeNode findLCA(BinaryTreeNode root, BinaryTreeNode n1, BinaryTreeNode n2) { if (root == null){ return null; } if (root == n1 || root == n2){ return root; } BinaryTreeNode leftNode = findLCA(root.getLeftNode(), n1, n2); BinaryTreeNode rightNode = findLCA(root.getRightNode(), n1, n2); if (leftNode != null && rightNode != null){ return root; } if (leftNode == null && rightNode == null){ return null; } return leftNode == null ? rightNode: leftNode; }

## Find Distance between LCA and Given Node

Now, once we have the LCA we can need to find the distance between the LCA and given node one by one.

This algorithm is same as finding distance between any two given nodes. There are 2 base conditions - one is once we reach to the leaf node and the second is once we found the node that we are looking for. After each resursion level, we will increment the distance variable by 1 and once we find the node that we are looking for, the distance variable value will be returned and further processing of child nodes will be halted.

private static int findDistanceBetween(BinaryTreeNode root, BinaryTreeNode n1, int distance) { if (root == null){ return -1; } if (root == n1){ return distance; } int d = findDistanceBetween(root.getLeftNode(), n1, distance + 1); if (d != -1){ return d; } d = findDistanceBetween(root.getRightNode(), n1, distance + 1); return d; }

Now, let us define the public method that will be invoked by our main method. This will first invoke the findLCA() method and then the method to find distance between LCA and the given node. Later the sum of distances would be our final result.

public static int findDistanceBetweenNodes(BinaryTreeNode root, BinaryTreeNode n1, BinaryTreeNode n2){ BinaryTreeNode lca = findLCA(root, n1, n2); int distance1 = findDistanceBetween(lca, n1, 0); int distance2 = findDistanceBetween(lca, n2, 0); System.out.println(distance1 + distance2); return distance1 + distance2; }

## Binary Tree Runner

Let us construct above binary tree and invoke `findDistanceBetweenNodes()`

.

public static void main(String[] args){ BinaryTreeNode root = new BinaryTreeNode(5); root.setLeftNode(new BinaryTreeNode(7)); root.setRightNode(new BinaryTreeNode(10)); root.getLeftNode().setLeftNode(new BinaryTreeNode(14)); root.getLeftNode().setRightNode(new BinaryTreeNode(19)); root.getRightNode().setLeftNode(new BinaryTreeNode(30)); root.getRightNode().setRightNode(new BinaryTreeNode(15)); root.getRightNode().getRightNode().setLeftNode(new BinaryTreeNode(25)); DistanceBetweenNodes.findDistanceBetweenNodes(root, root.getRightNode().getLeftNode(), root.getRightNode().getRightNode().getLeftNode()); }

## Conclusion

In this article, we will implemented the algorithm to find the distance between any two nodes in a binary tree using Java.