Classification Decision Trees: An Intro

What’s a Decision Tree?

Rowais Hanna
6 min readOct 24, 2020

A decision tree is a popular machine learning algorithm used for predictive modeling. Used in a variety of fields, decision trees are good tools for making data-driven decisions even with little or no domain knowledge. It operates, as its name suggests, like a tree structure that mimics a flowchart of yes/no decisions where each “branch” represents an outcome of the test and each “leaf” denotes a class label. The approach decision trees use to make predictions is by dividing a big decision into smaller and smaller sub-decisions (think branches) until a “final” decision is reached.

Photo by Greg Rosenke on Unsplash

What is a Decision Tree useful for?

There are many types of decision trees that have different use cases, but they all boil down to 2 general use cases:

  1. Classification Predictions— predicting what class or category a hypothetical set of data would fall under.

An example of classification predictions would be determining if a hypothetical customer would make a purchase (yes) or not make a purchase (no), given a set of data. Any situation where the prediction is categorical or discreet can be considered a classification problem.

2. Regression Predictions — predicting a target value given a set of observations.

An example of regression predictions would be determining the sale price of a house, given a set of houses that have been sold, and data about them. Any situation where the prediction is a continuous variable can be considered a regression problem.

For the purposes of this article, I’ll focus on explaining decision trees through the lens of classification predictions, but making regression predictions work in a very similar fashion.

source: mc.ai

How Does a Decision Tree Work?

Definitions

Before jumping into the intricacies of how a decision tree can be used to classify a hypothetical observation, we need to cover a few key terms:

  1. Root (parent) Node: This attribute is used for dividing the data into two or more sets. The feature attribute in this node is selected based on Attribute Selection Techniques.
  2. Branch or Sub-Tree: A part of the entire decision tree is called branch or sub-tree.
  3. Splitting: Dividing a node into two or more sub-nodes based on if-else conditions.
  4. Decision (child) Node: After splitting the sub-nodes into further sub-nodes, then it is called as the decision node.
  5. Leaf or Terminal Node: This is the end of the decision tree where it cannot be split into further sub-nodes.
  6. Pruning: Removing a sub-node from the tree is called pruning.
source: https://www.kdnuggets.com/

7. Gini Impurity: the measure of the degree of probability of a particular variable being incorrectly classified when it is randomly selected. This is derived mathematically as follows:

Pi = probability of an object being classified into a particular class

The more decisive a split is, the closer to 0 the Gini Impurity would be. A perfect split of of the data would result in a Gini impurity of 0, whereas the worst case split would result in a 50/50 split and a Gini impurity of 0.5 (assuming the classification problem has 2 possible outcomes).

Process

Now that we have the lingo down, let’s look at the process. Every decision tree algorithm works by splitting the data from a parent node into child nodes based on optimal thresholds in the data. This process continues until 1 of 2 things happens: either a pre-determined maximum depth is reached or the Gini Impurity of the child node is not less than the Gini impurity of the parent node. Said another way, the goal of a decision tree is to minimize the Gini impurity for each node, until it is no longer able to.

Let’s take a look at the practical steps taken with the algorithm to achieve this:

Step 1 — Calculate Gini Impurity

Starting with a particular feature, a decision tree algorithm determines how pure or homogenous a feature in question is in our data. Number of observations, number of unique classes, and the number of occurrences in each class are counted, then the Gini impurity is calculated using the above formula.

Step 2 — Identify split threshold

Go through the data and identify the point that will yield the lowest Gini impurity. The thresholds can be identified as the halfway points between the values of the feature. For example, if the data contains feature values of 0, 3, 5, 7, and 9 your thresholds would be the halfway points between those values: 1.5, 4, 6, and 8. Calculate the Gini Impurities for each of those thresholds, and the winner will be the lowest number; i.e. the number that returns the lowest homogeneity of the resulting classes.

Step 3 — Split the data

With the original data split based on the threshold identified in Step 2, calculate the Gini impurity of each individual node.

Step 4 — Iterate through the data and calculate Gini for each split, and the corresponding children nodes

Thanks to computers, we no longer have to do this manually. So if you’re running code for this, take a quick moment of appreciation for your CPU and apologize if you complained it’s slow.

Step 5 — Build a decision tree based on the lowest Gini values

Similar to the simple animal example above, the idea is to have each split of the tree be as simple as a yes/no question, where the answers are either YES! or NO! Steps 1–4 is the long and arduous process of doing that mathematically. Once that’s done, a decision tree can be built based on the splits that yield the most yes/no branches, i.e. the best segregation of the dataset.

Step 6 — Repeat Steps 1–5 for each feature

All this work was for just 1 feature. For each feature in a dataset, this process needs to be repeated. If a new feature has a lower Gini impurity value than a previous feature, that should be the feature to use to split the data. Remember the goal is to segregate the data into the most yes/no set possible.

Step 7 —make a prediction

Once we’re sure the dataset is segregated as much as possible, by mathematically deriving the lowest possible gini impurity value based on all the features we’re considering, it’s time to put it to the test. Predictions are made by applying the majority class in the leaf node as a starting point. If the input data meets the thresholds that lead to that leaf node, then we have our prediction. If the input data doesn’t meet the threshold, we predict the opposite class. In other words, we follow the tree like a flowchart, or a decision-map.

Bringing it to life

I’ve created my own Decision Tree Classifier class using Python which you can find here. The best way to learn and understand something is to break it down and rebuild it yourself, so give it a try!

Better yet, work with it on your own dataset and see if you can make some useful predictions using Scikit learn’s Decision Tree Classifier.

--

--

No responses yet