Introduction to Decision Trees¶
Decision Trees are a form of supervised learning used for both classification and regression tasks. They work by learning decision rules inferred from the features of the data. The model resembles a tree structure, where each internal node represents a "test" or "decision" on a feature, each branch corresponds to the outcome of the test, and each leaf node represents a class label (in classification) or a numerical value (in regression). What makes decision trees especially appealing is their interpretability. Unlike black-box models such as neural networks, decision trees provide a transparent view of how decisions are made based on feature splits. This makes them popular in domains where understanding the model’s reasoning is crucial, such as healthcare, finance, and policy-making.
Decision trees are widely used in the fields of operations research, decision analysis, and machine learning to model decisions and their possible outcomes. They serve as a visual and analytical tool that helps in evaluating different strategies or courses of action, especially in situations involving uncertainty and multiple alternatives. By providing a structured way to break down a decision-making problem, decision trees enable individuals and organizations to identify the most promising strategy that leads to a desired goal.
Historical Background and Development¶
The concept of decision-making trees has its roots in statistical analysis and pattern recognition from the mid-20th century. One of the earliest formal methods for constructing decision trees was the AID (Automatic Interaction Detection) algorithm introduced in 1963. This method laid the groundwork by recursively partitioning the data to find patterns.
The real breakthrough came with the development of the CART (Classification and Regression Trees) algorithm by Breiman et al. in 1986. CART provided a robust, theoretically sound framework that supported both classification and regression tasks. It introduced the binary splitting mechanism and formalized the use of impurity metrics like Gini and variance reduction. Around the same time, algorithms like ID3 and later C4.5 by Quinlan popularized entropy-based splitting, especially in classification tasks.
Understanding the concept of Tree¶
Before diving into the concept of decision trees, it is important to understand what a tree means in the context of programming and computer science. A tree is a hierarchical data structure used to represent relationships between different pieces of data. This structure is made up of elements called nodes. Each node contains some form of data - this could be a number, a character, a string, or any other meaningful value.
Connecting these nodes are lines known as edges or branches. These edges define the relationship between nodes and serve as links from one node to another. In most tree structures, these connections only occur between nodes at adjacent levels - that is, from a node to its direct descendants, never skipping levels or looping back upward. As this pattern of nodes and edges repeats, it forms the entire tree structure.
One of the defining features of a tree is its clear hierarchy, which is always viewed from top to bottom. The topmost node in a tree is known as the root node. It is the origin point from which all other nodes descend. As we move down the structure, we come across internal nodes that play the role of parents - they connect to one or more child nodes below them. A node that is directly connected beneath another is called a child node, and the node from which it originates is its parent. This parent-child relationship helps maintain the structure and clarity of the tree’s layout.
Eventually, we reach the end points of the tree - nodes from which no further branches extend. These are called leaf nodes. Leaf nodes do not have any children and often represent the final outcome or classification in many tree-based algorithms. Depending on the application, trees can be broad and complex, with multiple children stemming from each node, or they can be more restrictive and structured.
A particularly well-known variant of the tree is the binary tree. In a binary tree, each node is allowed to have at most two children. This creates a left and right child for every parent node, providing a simple and efficient framework for many algorithmic tasks. Binary trees are a foundational concept in computer science, especially because of their use in searching and sorting. A notable example is the binary search tree (BST), which organizes data in such a way that searching can be performed in logarithmic time. In a BST, for every node, all the elements in the left subtree are less than the node, and all elements in the right subtree are greater. This structure allows for efficient insertion, deletion, and lookup operations.
Another important concept associated with trees is the height of the tree. The height is determined by the number of levels it contains, starting from the root at level one and counting downward to the deepest leaf node. The height has a significant impact on the performance of tree-based operations. A balanced tree - one where the height is kept to a minimum — allows for faster computation, while an unbalanced tree may result in inefficiencies similar to a linear list. There are many other types of trees in computer science beyond binary and binary search trees. These include AVL trees, red-black trees, heaps, B-trees, and others - each designed to address specific computational problems or optimize particular operations. Some trees are used in databases, others in compiler design, and many in artificial intelligence.
In this context, we will now explore a specific type of tree used in the field of machine learning - the decision tree. This structure builds upon the fundamentals of tree data structures but applies them to the task of making predictions or decisions based on input data. Unlike general-purpose binary search trees, decision trees are typically built by analyzing a dataset and finding the best way to split it in order to classify or estimate outcomes. Understanding the foundational structure of trees - with their nodes, edges, levels, and hierarchy - will make it easier to grasp how decision trees function and why they are such a powerful tool in data science and artificial intelligence.
Structure of a Decision Tree¶
The structure of a decision tree is hierarchical. It starts at the root node, which represents the entire dataset. At each level, the dataset is split into subsets based on a feature value that best separates the data with respect to the target.
- Internal nodes: These represent decisions based on features.
- Branches: Outcomes of a decision or test.
- Leaf nodes: Final output predictions.
The tree grows by choosing the best possible feature and threshold to split the data at every node, recursively.
The construction of a decision tree follows a greedy approach. At each step, the algorithm selects the feature and threshold that result in the best split. For classification tasks, this typically means maximizing the information gain or minimizing impurity, while for regression, it involves minimizing the mean squared error (MSE) or variance (We will go in detail about these shortly).
The general steps are:
- Start with the entire dataset at the root.
- For each feature, evaluate all possible splits.
- Choose the split that provides the best improvement according to a criterion.
- Split the data into two parts and repeat the process for each part.
- Continue until a stopping condition is met (e.g., maximum depth, minimum samples per leaf, or no improvement).
Let's take an example to understand how Decision Trees works¶
Let's say, the tree begins with a single question at the top: "Weather?" This is the root node of the tree and represents the first and most important decision. There are three possible weather conditions shown:
- Sunny
- Cloudy
- Rainy (represented by a dark cloud)
Each of these conditions leads to a different path.
Path 1: Sunny Weather¶
If the weather is sunny, the next question becomes: "Time?" This node checks whether the available time is more than 30 minutes or less than 30 minutes.
- If the available time is more than 30 minutes, the outcome is "Walk".
- If the available time is less than 30 minutes, the outcome is "Bus".
This path represents a logical prioritization: if the weather is nice and you have ample time, walking is a healthy and cost-effective option. If time is limited, a bus is more efficient.
Path 2: Cloudy Weather¶
If the weather is cloudy, the decision shifts from time to physical condition, asking: "Hungry?"
- If the person is hungry, the outcome is "Walk".
- If the person is not hungry, the outcome is "Bus".
This decision might seem counterintuitive at first, but it could represent a scenario where walking to grab food is preferred if one is hungry. If not hungry, perhaps the person prefers saving energy or reaching the destination quickly by bus.
Path 3: Rainy Weather¶
If the weather is rainy (dark cloud), the decision is made directly without further questions. The outcome is "Bus". This is a straightforward case assuming no one wants to walk in the rain.
Key Characteristics Demonstrated by the Diagram¶
This tree visually exhibits the following key aspects of decision trees:
- Root Node: The first decision point (weather) that guides all other decisions.
- Internal Nodes: Represent conditional questions like time and hunger.
- Branches: Show the possible outcomes of these decisions.
- Leaf Nodes: Final decisions such as "Walk" or "Bus".
Each path from the root to a leaf is a unique decision rule or logic sequence. By following these paths, one can make an informed decision based on the current situation.
Using Decision Trees to Solve Real-World Problems¶
Once we understand the structure and components of a decision tree, the next step is to see how it can be applied to practical problem-solving. Decision trees act like a series of questions leading to an answer. If you provide answers to all the questions asked by a decision tree - in other words, if you provide all the feature values (inputs) - and follow the correct path from the top to a leaf node, you will reach a specific result or decision (output).
When a decision tree is used as a machine learning model, the idea is to allow the algorithm to build the tree on its own using historical data. The tree is no longer manually drawn by a person; instead, it is automatically created by learning from the dataset. This means the machine examines patterns, correlations, and splits in the data to decide which features to use at each step.
Imagine we are dealing with just two features - for instance, "age" and "income" - and we want to predict whether someone will purchase a product. Although there are only two features, the way the tree splits based on values and the order of the splits can result in a large number of possible trees. One tree might first ask about income and then about age, while another might reverse the order. Depending on how the splits are made and which thresholds are chosen, the resulting trees - and their predictions - can be very different. Moreover, one answer in a tree might lead to another question, and that to another, forming a deep structure of nested decisions. This nesting continues until a stopping condition is reached, such as when further splits no longer improve the model significantly, or a maximum depth is reached.
A key concept in decision trees is that of hierarchy. Questions at the top of the tree (closer to the root node) carry more weight in determining the final prediction. These are the first splits and therefore affect a larger portion of the data. As a result, it becomes important for the algorithm to choose wisely when deciding what question or feature to test at the top. The algorithm evaluates all available features and chooses the one that provides the best "information gain" or reduction in impurity. For classification tasks, this could mean reducing Gini impurity or entropy; for regression, it could mean minimizing variance or mean squared error. Choosing the wrong feature at the top could lead to a suboptimal tree with poor predictive power. As the number of features increases, the number of potential trees increases exponentially. This makes the task of finding the best tree quite complex. The goal of the algorithm is to search through these possibilities and choose the tree that best fits the data, without overfitting.
One of the strengths of decision trees is their ability to handle both numerical and categorical data. For numerical features, the tree can choose a threshold value - such as "is age > 40?" - to split the data. For categorical features, it can create branches for each category - for instance, "if occupation is teacher, engineer, or doctor, go this way." This flexibility allows decision trees to be applied in a wide range of problems, from predicting loan defaults (using numeric data like credit score and income) to classifying customer types (using categorical variables like region or product preferences).
Moreover, decision trees can be used for both classification and regression. In classification, the goal is to assign a class label, such as "Yes" or "No." In regression, the goal is to predict a continuous value, such as the price of a house or the temperature on a given day.
How the Algorithm Chooses the Tree¶
The algorithms used to create decision trees are designed to search for the best way to split the data at each node. They do this using statistical criteria that quantify the effectiveness of each possible split. For example, the CART algorithm evaluates every feature and every possible threshold to find the one that minimizes the total impurity of the child nodes after a split. This process is recursive - once the first split is made, the algorithm repeats the same analysis on each resulting subset of the data. It continues to split and split again, building the tree from the top down, until some stopping rule is met. This could be:
- A minimum number of samples in a node
- A maximum depth
- No further improvement in split quality
Despite the apparent simplicity of the tree’s final form, the underlying algorithm that creates it is quite complex and computationally demanding, especially for large datasets.
When Decision Tree Regression May Not Be Suitable¶
While decision trees can be adapted for regression problems (using DecisionTreeRegressor), there are scenarios where they may not perform well. One key weakness is that regression trees tend to produce piecewise constant predictions. This means that within each region of the input space, the predicted value is simply the average of the training samples in that region, without any smooth transitions between intervals.
Let us consider an example:
Suppose you are trying to model the relationship between years of experience and salary in a company. The actual relationship is smooth and linear - salary tends to increase steadily with experience. However, if you use a decision tree regressor, the model will break the data into intervals like "experience < 3 years," "3 ≤ experience < 7 years," and so on, and assign an average salary for each bucket. The resulting prediction curve will look like a staircase rather than a smooth upward trend. This step-wise prediction is unrealistic for many real-world regression problems, especially when the underlying function is continuous. This makes decision trees less suitable for regression tasks where continuity and smooth trends are important. In such cases, more sophisticated models like linear regression, support vector regression, or ensemble methods such as Random Forest Regressor or Gradient Boosting Regressor tend to perform better.
Understanding Splitting Criteria¶
For Classification:¶
Gini Impurity: Measures how often a randomly chosen element would be incorrectly labeled.
$$ Gini = 1 - \sum_{i=1}^{k} p_i^2 $$
Entropy and Information Gain: Based on Shannon entropy.
$$ Entropy = - \sum_{i=1}^{k} p_i \log_2(p_i) $$
Information Gain is the decrease in entropy after a dataset is split on an attribute.
For Regression:¶
Mean Squared Error (MSE):
$$ MSE = \frac{1}{n} \sum_{i=1}^{n}(y_i - \bar{y})^2 $$
Mean Absolute Error (MAE) and Reduction in Variance can also be used.
These criteria guide the algorithm in finding splits that create the most homogeneous child nodes.
Pruning: Preventing Overfitting¶
Decision trees can easily overfit, especially when allowed to grow deep. Overfitting occurs when the tree captures noise in the training data rather than general patterns. To prevent this, pruning techniques are employed.
Pre-Pruning (Early Stopping): Stops the tree from growing beyond a certain point based on user-defined limits like maximum depth or minimum samples per split.
Post-Pruning: First allows the tree to grow fully, and then removes branches that have little importance using cross-validation to determine their utility.
Pruning helps in simplifying the tree and improving its generalization to unseen data.
Understanding the Metrics Behind Decision Tree Splitting: A Detailed Look at Gini Impurity¶
When you train a decision tree model and visualize it, each node in the tree typically displays four important pieces of information (you will see this in the code section at the end). These values help in interpreting the decisions made at each stage of the model. The first value is the condition — a threshold involving a specific feature, such as "Age < 30", which tells us how the data is being split at that node.
Next, you will notice a metric named Gini, which represents the Gini impurity score for that node. It quantifies how "pure" or "impure" the node is in terms of class distribution. The samples count indicates how many data points are being considered at that particular node. Finally, the value shows the distribution of samples from different classes, for example: [42, 58]
might mean 42 belong to class 0 and 58 to class 1. In classification problems, we can also see the predicted class, which is the class with the majority count among the samples.
Let us now dig deeper into the metric that drives these splits: Gini impurity. When building a decision tree, the algorithm does not construct the entire tree all at once. Instead, it works locally, node by node, deciding the best way to split the data at each point. For each node (or split), the algorithm must evaluate how good that split is - how well it separates the classes and reduces uncertainty.
To do this, it uses metrics that quantify the impurity (or disorder) at each node. The goal is to make each child node as pure as possible - ideally, each node should contain data points from only one class. There are two widely used impurity metrics:
- Gini impurity
- Entropy (used in Information Gain)
Gini Impurity: The Concept¶
Gini impurity tells us how mixed the classes are in a node of a decision tree. It measures how often we would misclassify a data point if we randomly labeled it based on the class distribution in that node.
- If all data points in a node belong to the same class, the node is pure, and Gini impurity is 0.
- If the classes are evenly split (like 50% class A and 50% class B in binary classification), Gini impurity is high, close to 0.5.
The more mixed the classes, the higher the impurity. The goal of decision trees is to split the data in a way that reduces this impurity as much as possible. Gini impurity helps the tree decide which feature and split to choose to make the nodes purer.
Let us say we are working with a small dataset that consists of two classes:
- Red (Class A)
- Blue (Class B)
Suppose we have a node containing 10 data points: 4 Red and 6 Blue. We want to compute the Gini impurity for this node.
Step 1: Define the Class Probabilities¶
We first calculate the probability of selecting each class randomly:
Probability of Red (Class A):
$$ P_A = \frac{4}{10} = 0.4 $$
Probability of Blue (Class B):
$$ P_B = \frac{6}{10} = 0.6 $$
Step 2: Apply the Gini Formula¶
The formula for Gini impurity for a node with $k$ classes is:
$$ Gini = 1 - \sum_{i=1}^{k} p_i^2 $$
Where $p_i$ is the probability of class $i$ in the node.
In our case with two classes:
$$ Gini = 1 - (P_A^2 + P_B^2) $$
$$ = 1 - (0.4^2 + 0.6^2) $$
$$ = 1 - (0.16 + 0.36) $$
$$ = 1 - 0.52 = 0.48 $$
So the Gini impurity of this node is 0.48, meaning the node is somewhat impure because it contains a mix of both classes. The Gini score tells us how likely we are to misclassify a randomly chosen point from this node if we assign labels randomly according to the node’s class distribution. In this case, the chance of misclassification is 48 percent.
This value is not terrible, but it is far from pure. The algorithm will now try to split this node in such a way that the Gini impurity of the resulting child nodes is lower than the current impurity. It continues to do this at each stage, trying to produce the cleanest splits possible.
Gini Impurity with Perfect Purity Example¶
Suppose another node contains 10 data points, and all of them are Blue.
Probability of Red:
$$ P_A = 0 $$
Probability of Blue:
$$ P_B = 1 $$
Gini Impurity:
$$ Gini = 1 - (0^2 + 1^2) = 1 - 1 = 0 $$
This means the node is completely pure, and there is no risk of misclassification. This is the ideal situation that the decision tree algorithm tries to achieve at every node.
Why Gini Prefers Balanced and Informative Splits¶
When the data is balanced - meaning all classes are represented more equally - the Gini impurity is higher. But if a split separates the classes more cleanly, the Gini impurity for the child nodes becomes lower. The decision tree algorithm evaluates all possible splits and computes the weighted average Gini of the resulting child nodes. It then chooses the split that minimizes this combined impurity.
This is why having a balanced dataset is important - the algorithm relies heavily on the distribution of classes to determine useful splits. If the data is skewed toward one class, the tree may not split effectively and could become biased.
Understanding the Metrics Behind Decision Tree Splitting: A Detailed Look at Information Gain¶
In addition to Gini impurity, another widely used metric for evaluating the quality of a split in decision trees is Information Gain. Information Gain is rooted in information theory and relies on the concept of entropy, which measures the amount of disorder or impurity in a set of data. While Gini impurity quantifies how often a randomly chosen element would be misclassified, Information Gain measures how much uncertainty is reduced after a split.
Let us explore this concept deeply with formulas, examples, and interpretation.
What is Entropy?¶
Entropy is a measure of randomness or unpredictability in the data. If a node contains elements from only one class, it is completely pure, and its entropy is zero. If the classes are evenly mixed, the entropy is maximum, indicating high disorder.
The formula for entropy $H(S)$ of a dataset $S$ with $k$ classes is:
$$ H(S) = - \sum_{i=1}^{k} p_i \log_2(p_i) $$
Where:
- $p_i$ is the proportion of elements in class $i$ in the dataset $S$.
- $\log_2(p_i)$ represents the information content (in bits) of choosing class $i$.
What is Information Gain?¶
Information Gain (IG) measures the reduction in entropy achieved by partitioning the dataset based on a feature. It tells us how much “information” we gained by knowing the outcome of a specific split.
The formula for Information Gain is:
$$ IG(S, A) = H(S) - \sum_{v \in \text{Values}(A)} \frac{|S_v|}{|S|} H(S_v) $$
Where:
- $S$ is the parent dataset before splitting.
- $A$ is the attribute (feature) used for splitting.
- $\text{Values}(A)$ are the possible values of attribute $A$.
- $S_v$ is the subset of data where attribute $A$ takes value $v$.
- $\frac{|S_v|}{|S|}$ is the proportion of the dataset that falls into subset $S_v$.
So, Information Gain = Entropy before split - Weighted average entropy after split The higher the information gain, the better the feature is at classifying the data.
Let us go through an example to illustrate the concept of entropy and information gain. You have a dataset of 14 samples about whether people Play Tennis or not. One of the features is Weather, with values: Sunny, Overcast, and Rain.
Here is the class distribution:
- 9 samples: Play = Yes
- 5 samples: Play = No
Let’s first calculate the entropy of the parent node (before any split):
Step 1: Entropy of the Parent Dataset¶
$$ p_{yes} = \frac{9}{14}, \quad p_{no} = \frac{5}{14} $$
$$ H(S) = -\left( \frac{9}{14} \log_2 \frac{9}{14} + \frac{5}{14} \log_2 \frac{5}{14} \right) $$
$$ = -\left( 0.643 \log_2(0.643) + 0.357 \log_2(0.357) \right) $$
$$ = - (0.643 \times -0.643 + 0.357 \times -1.485) $$
$$ = 0.940 \text{ bits} $$
So the entropy of the whole dataset is 0.940.
Step 2: Splitting on the Feature "Weather"¶
Let’s say the data distribution after splitting on Weather is:
- Sunny: 5 samples → 2 Yes, 3 No
- Overcast: 4 samples → 4 Yes, 0 No
- Rain: 5 samples → 3 Yes, 2 No
Now we calculate entropy for each subset.
Sunny Subset:¶
$$ H(Sunny) = -\left( \frac{2}{5} \log_2 \frac{2}{5} + \frac{3}{5} \log_2 \frac{3}{5} \right) $$
$$ = - (0.4 \log_2(0.4) + 0.6 \log_2(0.6)) = - (0.4 \times -1.322 + 0.6 \times -0.737) \approx 0.971 $$
Overcast Subset:¶
All samples are "Yes" ⇒ pure node.
$$ H(Overcast) = 0 $$
Rain Subset:¶
$$ H(Rain) = -\left( \frac{3}{5} \log_2 \frac{3}{5} + \frac{2}{5} \log_2 \frac{2}{5} \right) \approx 0.971 $$
Step 3: Weighted Average Entropy After Split¶
Now compute the weighted entropy after the split:
$$ H_{after} = \frac{5}{14} \times 0.971 + \frac{4}{14} \times 0 + \frac{5}{14} \times 0.971 $$
$$ = \frac{10}{14} \times 0.971 = 0.694 $$
Step 4: Calculate Information Gain¶
$$ IG(S, \text{Weather}) = H(S) - H_{after} = 0.940 - 0.694 = 0.246 $$
So, by splitting on Weather, the entropy is reduced by 0.246 bits, which is the Information Gain. This tells us that Weather is somewhat informative but may not be the best split compared to other features.
Key Points to Remember¶
- Entropy measures uncertainty: more mixed classes = higher entropy.
- Information Gain measures reduction in entropy from a split.
- Higher Information Gain means a feature does a better job of separating the classes.
- The decision tree algorithm tries all possible splits and selects the one with the highest Information Gain (or lowest impurity).
Other Splitting Criteria¶
Besides Gini impurity and Information Gain (based on entropy), there are other splitting criteria that decision trees and similar algorithms can use to measure the quality of splits. Two such alternatives are:
1. Classification Error (Misclassification Rate)¶
Classification error, sometimes referred to as misclassification rate, is a simple impurity measure that looks at how often the most frequent class label in a node does not match the actual labels of the samples in that node. Unlike entropy or Gini, which give a more nuanced picture of class distribution, classification error only cares about the dominant class and how many samples do not belong to it. This measure is often used as a diagnostic tool or for pruning the decision tree, rather than for building it, because it is less sensitive to changes in the underlying distribution compared to Gini or entropy.
For a node $S$ with $k$ classes and class probabilities $p_1, p_2, ..., p_k$, the classification error is defined as:
$$ \text{Classification Error} = 1 - \max_{i} (p_i) $$
Where $\max(p_i)$ is the probability of the majority class in the node.
Suppose we have a node with 10 samples:
- 7 samples are Class A
- 3 samples are Class B
Then:
- $p_A = \frac{7}{10} = 0.7$
- $p_B = \frac{3}{10} = 0.3$
So:
$$ \text{Classification Error} = 1 - \max(0.7, 0.3) = 1 - 0.7 = 0.3 $$
This tells us that 30% of the samples would be misclassified if we used the majority class (Class A) as the predicted label for this node.
- Minimum value: 0 (when all data belongs to one class).
- Maximum value: Approaches 0.5 (in the binary case, with a perfectly balanced node).
- Lower values mean purer nodes, as fewer samples are misclassified.
Because classification error is not sensitive to class proportions beyond the majority, it is less informative during tree building. However, it’s often used:
- During pruning, where the goal is to simplify the tree without significantly sacrificing accuracy.
- In model evaluation, when a coarse measure of impurity is sufficient.
2. Variance Reduction (for Regression Trees)¶
While Gini, entropy, and classification error are used for classification tasks, regression trees use a different impurity metric since their targets are continuous values, not classes. In such cases, we use variance reduction to evaluate splits. The idea is to measure how much the spread (variance) of target values decreases after a split. A good split groups similar values together, resulting in low variance within each child node.
Let:
- $S$ be the parent dataset.
- $S_1$ and $S_2$ be the child subsets after the split.
The Variance of a node $S$ is:
$$ Var(S) = \frac{1}{|S|} \sum_{i=1}^{|S|} (y_i - \bar{y})^2 $$
Where $\bar{y}$ is the mean of the target variable in set $S$.
The Variance Reduction (VR) is:
$$ VR = Var(S) - \left( \frac{|S_1|}{|S|} Var(S_1) + \frac{|S_2|}{|S|} Var(S_2) \right) $$
The larger the variance reduction, the better the split.
Imagine you have a regression node with 6 values:
- $y = [10, 12, 14, 13, 11, 12]$
Step 1: Compute Parent Variance¶
Mean = $\bar{y} = \frac{10 + 12 + 14 + 13 + 11 + 12}{6} = 12$
$$ Var(S) = \frac{(10 - 12)^2 + (12 - 12)^2 + (14 - 12)^2 + (13 - 12)^2 + (11 - 12)^2 + (12 - 12)^2}{6} = \frac{4 + 0 + 4 + 1 + 1 + 0}{6} = \frac{10}{6} \approx 1.667 $$
Step 2: Split the Data¶
Split into:
- Left (S1): $y = [10, 11, 12]$ → Mean = 11
- Right (S2): $y = [12, 13, 14]$ → Mean = 13
Compute variances:
- $Var(S_1) = \frac{(10-11)^2 + (11-11)^2 + (12-11)^2}{3} = \frac{1 + 0 + 1}{3} = 0.667$
- $Var(S_2) = \frac{(12-13)^2 + (13-13)^2 + (14-13)^2}{3} = \frac{1 + 0 + 1}{3} = 0.667$
Step 3: Compute Variance Reduction¶
$$ VR = 1.667 - \left( \frac{3}{6} \times 0.667 + \frac{3}{6} \times 0.667 \right) = 1.667 - 0.667 = 1.0 $$
So the variance reduction is 1.0, which is significant and indicates a good split.
- Variance reduction is only used in regression trees, like in the CART algorithm for regression.
- The goal is to reduce the spread of the target values within each child node.
- A higher reduction implies better prediction consistency in child nodes.
Pruning in Decision Trees: An Effective Strategy Against Overfitting¶
As powerful and intuitive as decision trees are, they come with a significant drawback: overfitting. A fully grown decision tree tends to model not just the underlying patterns in the data, but also the noise - random fluctuations or rare cases that do not represent general trends. This makes the model perform well on the training data but poorly on unseen test data, indicating poor generalization. This is where pruning becomes a vital technique.
Pruning is the process of reducing the size of a decision tree by removing parts of the tree that do not provide significant power in predicting target variables. These are usually the deeper branches of the tree that are highly specific to the training data and contribute minimally (or negatively) to predictive accuracy. Pruning can be thought of as the process of simplifying the tree structure to make it more robust, general, and interpretable, while preserving or even improving its predictive performance on new data.
Why Overfitting Happens in Decision Trees¶
Before diving into pruning methods, it’s important to understand why overfitting happens in decision trees:
- Greedy Splitting: Decision trees grow recursively by splitting nodes using the best possible split at each level (based on impurity reduction), without looking ahead. This can result in splits that only marginally improve purity but lead to a large, deep tree.
- Low Bias, High Variance: Trees have low bias, meaning they can fit very complex patterns. But this comes with high variance, meaning small changes in the training data can lead to very different trees.
- Capturing Noise: If not controlled, a tree can create branches that specifically model rare outliers or noise points in the training data.
Types of Pruning¶
There are two major approaches to pruning:
1. Pre-Pruning (Early Stopping)¶
This approach limits the growth of the tree during training itself by placing constraints on the tree’s expansion.
Common pre-pruning techniques include:
- Maximum depth: Limit how deep the tree can go.
- Minimum samples per leaf: Require a minimum number of data points to form a leaf node.
- Minimum information gain: Stop a split if the improvement in impurity is too small.
- Maximum number of leaf nodes: Restrict the total number of leaf nodes.
Advantages:
- Simple and fast to implement.
- Prevents the tree from becoming overly complex.
Disadvantages:
- Might stop too early and miss important patterns that appear in deeper levels.
2. Post-Pruning (Pruning After Full Tree is Grown)¶
In post-pruning, a tree is allowed to grow to its full depth, potentially overfitting the training data. Then, subtrees are evaluated and removed if they don’t contribute meaningfully to predictive power.
Common post-pruning methods:
Cost Complexity Pruning (a.k.a. Minimal Cost-Complexity Pruning):
This is the standard method used in CART (Classification and Regression Trees).
For a tree $T$, the cost function is defined as:
$$ C_\alpha(T) = R(T) + \alpha \cdot |T| $$
Where:
- $R(T)$ is the misclassification rate (or mean squared error in regression).
- $|T|$ is the number of terminal nodes (leaves) in the tree.
- $\alpha$ is a complexity parameter that penalizes larger trees.
The goal is to find a subtree that minimizes this cost. As $\alpha$ increases, we prefer simpler trees more.
Reduced Error Pruning: Removes nodes only if their removal does not degrade accurac* on a separate validation set. This method requires holding out part of the training data as validation data.
Pessimistic Pruning: Makes use of estimated errors and confidence intervals to prune without needing a separate validation set.
Benefits of Pruning¶
- Reduces Overfitting: Prevents the model from learning noise or outliers in the training data.
- Improves Generalization: Simplified trees often perform better on unseen data.
- Enhances Interpretability: Smaller trees are easier to visualize and explain.
- Reduces Complexity and Storage: Smaller trees consume less memory and are faster to use.
- Makes Models More Stable: Reduces sensitivity to small variations in the training dataset.
Visual effect of Pruning is present at the code section near the end of the notebook.
Let's understand in detail about Minimal Cost-Complexity Pruning (MCCP)¶
Minimal Cost-Complexity Pruning is the standard and most widely used post-pruning technique in decision trees, especially in the CART algorithm. It balances the trade-off between the tree's complexity and its accuracy on the training data. The main idea is to prevent the tree from overfitting by pruning it back to a simpler form, while minimizing the loss in predictive power. When a decision tree is fully grown, it may have many nodes that provide very small gains in impurity reduction but lead to greater overfitting. MCCP systematically removes those parts of the tree that do not offer significant improvement, according to a cost function that penalizes complexity.
The Cost-Complexity Function¶
Let’s define the cost-complexity function used to evaluate the quality of a tree:
$$ C_\alpha(T) = R(T) + \alpha \cdot |T| $$
Where:
- $T$ is a subtree of the original decision tree.
- $R(T)$ is the total misclassification error (or mean squared error in case of regression) of the tree $T$ on the training data.
- $|T|$ is the number of leaf nodes (terminal nodes) in the subtree $T$.
- $\alpha$ is the complexity parameter, which controls the penalty for having a large tree.
Meaning of the Alpha ($\alpha$) Parameter¶
- The alpha parameter $\alpha$ balances between accuracy and complexity.
- It determines how much we penalize additional leaf nodes in the tree.
- A larger $\alpha$ puts more emphasis on keeping the tree small, potentially sacrificing some accuracy.
- A smaller $\alpha$ favors keeping the tree large, minimizing error on training data but increasing the risk of overfitting.
Range of Alpha Values:
$\alpha \geq 0$
In practice:
- When $\alpha = 0$, pruning does not penalize complexity at all. The algorithm retains the full tree (possibly overfitting).
- As $\alpha$ increases, the algorithm starts pruning more aggressively.
- For very high $\alpha$, the tree may reduce to a single node (the root), which always predicts the majority class.
How MCCP Works: Step-by-Step¶
- Grow the full tree: Build a deep tree that may overfit the training data.
- Evaluate subtrees: Compute the cost-complexity function $C_\alpha(T)$ for various subtrees of the full tree.
- Generate a sequence of pruned trees:
- For each increasing value of $\alpha$, find the subtree $T_\alpha$ that minimizes the cost.
- This creates a sequence of increasingly smaller trees.
- Select the optimal subtree:
- Evaluate all candidate trees (from the pruning sequence) on a validation set or using cross-validation.
- Choose the tree with the best validation performance (lowest error).
Example (Conceptual)¶
Suppose we grow a tree with 50 leaf nodes and a very low training error. However, we observe poor performance on validation data. We then apply cost-complexity pruning with different values of $\alpha$:
Alpha ($\alpha$) | Tree Size (Leaf Nodes) | Training Error | Validation Error |
---|---|---|---|
0.00 | 50 | 5% | 18% |
0.01 | 30 | 6% | 14% |
0.05 | 15 | 8% | 10% |
0.10 | 8 | 12% | 10% |
0.20 | 1 (root only) | 30% | 20% |
From this, we may select the tree corresponding to $\alpha = 0.05$, as it gives the best generalization performance.
Regression Trees: The Precursor to Decision Trees¶
Regression trees are a subtype of decision trees designed specifically for predicting continuous variables. The idea is to partition the dataset into regions where the target variable exhibits minimal variance. For example, imagine a dataset with housing prices based on features like size, location, and age. A regression tree might first split the data based on size, then further divide by location, eventually predicting the average house price in each resulting subset. Unlike classification trees that rely on entropy or Gini, regression trees use metrics such as mean squared error or variance reduction to determine splits.
The algorithm is similar to classification trees:
- Find a feature and threshold that minimizes the MSE.
- Split the data accordingly.
- Repeat the process recursively.
Regression trees formed the conceptual backbone for classification trees. Researchers realized that the recursive partitioning strategy could be applied not only to continuous outputs but also to categorical labels, giving rise to the modern decision tree.
Strengths and Weaknesses of Decision Trees¶
Advantages:
- Simple to understand and interpret.
- No need for feature scaling or normalization.
- Can handle both numerical and categorical data.
- Non-parametric: no assumptions about data distribution.
Disadvantages:
- Prone to overfitting, especially deep trees.
- High variance: small changes in data can result in very different trees.
- Not as accurate as ensemble models or deep learning on complex tasks.
- Biased towards features with more levels if not handled carefully.
These limitations led to the development of more robust ensemble methods like Random Forests.
Implementing Basic Project in Decision Tree using sklearn¶
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
import numpy as np
import matplotlib.pyplot as plt
Load Dataset¶
iris = load_iris()
# Separating features and target
X, y = iris.data, iris.target
X[3], X[60], X[130]
(array([4.6, 3.1, 1.5, 0.2]), array([5. , 2. , 3.5, 1. ]), array([7.4, 2.8, 6.1, 1.9]))
y[3], y[60], y[130]
(0, 1, 2)
np.shape(X) #150 records with 4 features each
(150, 4)
Create the Decision Tree model¶
clf = DecisionTreeClassifier() # Instantiating the model
clf = clf.fit(X, y) # To let the model know about data - Training
clf
DecisionTreeClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier()
#Let's try predict giving random input with 4 features
clf.predict([[2.3, 1.2, 4.3, 2.3]])
array([2])
Gives the class 1, which means the input features resembles most with class 1 as per our decision tree model.
Visualize the tree¶
plt.figure(figsize = (10, 6))
plot_tree(clf, filled = True,
feature_names= iris.feature_names,
class_names=iris.target_names)
plt.show()
clf = DecisionTreeClassifier(ccp_alpha=0.1) # Instantiating the model again with Pruning
clf = clf.fit(X, y)
plt.figure(figsize = (10, 6))
plot_tree(clf, filled = True,
feature_names= iris.feature_names,
class_names=iris.target_names)
plt.show()