A decision tree is one of the most transparent models in machine learning. Instead of fitting a black-box function, it learns a series of if-else rules that partition the feature space — rules you can read, trace, and explain to a non-technical audience. This interpretability is what makes decision trees a natural starting point for supervised learning.
In this tutorial you will build two models using Scikit-learn: a Decision Tree Regressor trained on the Diabetes dataset to predict disease progression, and a Decision Tree Classifier trained on the Iris dataset to identify flower species.
Prerequisites: Python 3.x, Scikit-learn, Pandas, NumPy, Matplotlib, Seaborn.
How Decision Trees Work
A decision tree learns by repeatedly asking: "Which feature, split at which value, best separates the data?" It starts at the root with the full dataset, selects the best split, divides the data into two subsets, and repeats the process on each subset — a technique called recursive binary splitting. This continues until a stopping condition is met, such as a maximum depth or a minimum number of samples in a node.
The steps at each node are:
- Select the best attribute using an Attribute Selection Measure (ASM) to split the records
- Make that attribute a decision node and break the dataset into smaller subsets
- Repeat recursively for each child until all samples in a node belong to the same class, no attributes remain, or no samples remain
Two popular algorithms implement recursive splitting differently:
- CART (Classification and Regression Trees) — uses Gini impurity as the splitting metric
- ID3 (Iterative Dichotomiser 3) — uses Entropy and Information Gain
For a deep theoretical treatment, see Chapter 8 of An Introduction to Statistical Learning.
Attribute Selection Measures
At every split the tree needs a metric to rank which feature and threshold produces the cleanest separation. The three most common measures are Information Gain, Gain Ratio, and Gini Index.
Information Gain
To define information gain we first need entropy — a measure of impurity from information theory. A pure node where all samples belong to one class has entropy 0; a perfectly mixed node has maximum entropy.
Where:
- — the current dataset
- — the set of classes in
- — the proportion of samples belonging to class
Information Gain then measures how much a split on attribute reduces entropy. The tree always picks the attribute that maximises information gain — minimising uncertainty in the resulting subsets.
Where:
- — entropy of the full dataset before the split
- — the subsets created by splitting on attribute
- — the proportion of samples routed to subset
- — entropy of subset
Gain Ratio
Information Gain favours attributes with many distinct values, which can lead to poor generalisation. Gain Ratio corrects this bias by penalising broad splits using a term called Split Information:
Gini Index
Gini impurity estimates the probability that a randomly chosen sample would be misclassified if labelled according to the class distribution of its node. A pure node has Gini 0; a fully mixed node has the maximum value. scikit-learn uses Gini as its default criterion because it is faster to compute than entropy.
Key Hyperparameters
Three parameters have the greatest impact on tree behaviour:
criterion— the splitting metric:'gini'(default) or'entropy'. Both produce similar accuracy;'gini'is faster to compute.splitter— split strategy:'best'always picks the optimal split;'random'picks the best among a random subset of features, adding variance but reducing compute time.max_depth— the maximum number of levels in the tree.Nonegrows the tree until leaves are pure or too small to split. Higher values risk overfitting; lower values risk underfitting.
Pruning
An unconstrained tree grows until it perfectly memorises the training data — a classic case of overfitting. Pruning reduces the tree by removing branches that use features of low importance, cutting complexity and improving generalisation.
The diagram below shows an unpruned tree vs a pruned version:

Regression: The Diabetes Dataset
The scikit-learn Diabetes dataset contains 10 physiological features (age, BMI, blood pressure, and six blood serum measurements) for 442 patients, along with a quantitative measure of disease progression one year later. The goal is to predict this continuous target value.
Start by importing all the libraries needed for this section:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn import datasets, metrics
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
Load the dataset and inspect its keys:
diabetes = datasets.load_diabetes()
diabetes.keys()
dict_keys(['data', 'target', 'frame', 'DESCR', 'feature_names', 'data_filename', 'target_filename'])
Print the full dataset description to understand what each feature represents:
print (diabetes.DESCR)
.. _diabetes_dataset:
Diabetes dataset
----------------
Ten baseline variables, age, sex, body mass index, average blood
pressure, and six blood serum measurements were obtained for each of n =
442 diabetes patients, as well as the response of interest, a
quantitative measure of disease progression one year after baseline.
**Data Set Characteristics:**
:Number of Instances: 442
:Number of Attributes: First 10 columns are numeric predictive values
:Target: Column 11 is a quantitative measure of disease progression one year after baseline
:Attribute Information:
- age age in years
- sex
- bmi body mass index
- bp average blood pressure
- s1 tc, T-Cells (a type of white blood cells)
- s2 ldl, low-density lipoproteins
- s3 hdl, high-density lipoproteins
- s4 tch, thyroid stimulating hormone
- s5 ltg, lamotrigine
- s6 glu, blood sugar level
Note: Each of these 10 feature variables have been mean centered and scaled by the standard deviation times `n_samples` (i.e. the sum of squares of each column totals 1).
Inspect the feature names and the first ten target values:
diabetes.feature_names
['age', 'sex', 'bmi', 'bp', 's1', 's2', 's3', 's4', 's5', 's6']
diabetes.target[: 10]
array([151., 75., 141., 206., 135., 97., 138., 63., 110., 310.])
Confirm the shape of the feature matrix and target vector:
X = diabetes.data
y = diabetes.target
X.shape, y.shape
((442, 10), (442,))
Build a DataFrame for easier exploration and preview the first five rows:
df = pd.DataFrame(X, columns=diabetes.feature_names)
df['target'] = y
df.head()
| age | sex | bmi | bp | s1 | s2 | s3 | s4 | s5 | s6 | target | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0.038076 | 0.050680 | 0.061696 | 0.021872 | -0.044223 | -0.034821 | -0.043401 | -0.002592 | 0.019908 | -0.017646 | 151.0 |
| 1 | -0.001882 | -0.044642 | -0.051474 | -0.026328 | -0.008449 | -0.019163 | 0.074412 | -0.039493 | -0.068330 | -0.092204 | 75.0 |
| 2 | 0.085299 | 0.050680 | 0.044451 | -0.005671 | -0.045599 | -0.034194 | -0.032356 | -0.002592 | 0.002864 | -0.025930 | 141.0 |
| 3 | -0.089063 | -0.044642 | -0.011595 | -0.036656 | 0.012191 | 0.024991 | -0.036038 | 0.034309 | 0.022692 | -0.009362 | 206.0 |
| 4 | 0.005383 | -0.044642 | -0.036385 | 0.021872 | 0.003935 | 0.015596 | 0.008142 | -0.002592 | -0.031991 | -0.046641 | 135.0 |
A pair plot shows the pairwise relationships and distributions across all features, giving a quick sense of which variables correlate most strongly with the target:
sns.pairplot(df)
plt.show()

Training the Decision Tree Regressor
Split the data into 80 % training and 20 % test sets, then fit the regressor and generate predictions:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 42)
regressor = DecisionTreeRegressor(random_state=42)
regressor.fit(X_train, y_train)
y_pred = regressor.predict(X_test)
Plot the predicted values against the true test values to see how closely the regressor tracks actual disease progression scores:
plt.figure(figsize=(16, 4))
plt.plot(y_pred, label='y_pred')
plt.plot(y_test, label='y_test')
plt.xlabel('X_test', fontsize=14)
plt.ylabel('Value of y(pred , test)', fontsize=14)
plt.title('Comparing predicted values and true values')
plt.legend(title='Parameter where:')
plt.show()
Calculate the Root Mean Squared Error (RMSE) to measure average prediction error in the same units as the target:
np.sqrt(metrics.mean_squared_error(y_test, y_pred))
70.61829663921893
Compare the RMSE against the standard deviation of the test targets — an RMSE close to the standard deviation suggests the model is barely better than predicting the mean:
y_test.std()
72.78840394263774
An RMSE of 70.6 against a standard deviation of 72.8 shows the unpruned tree is overfitting — it memorises the training data but struggles to generalise. Setting a max_depth limit would significantly close this gap.
Classification: Iris Flower Recognition
The Iris dataset contains 150 samples across three flower species — setosa, versicolor, and virginica — each described by four measurements: sepal length, sepal width, petal length, and petal width. The goal is to classify each sample into the correct species.
from sklearn.tree import DecisionTreeClassifier
Load the dataset and inspect the class names and feature names:
iris = datasets.load_iris()
iris.target_names
array(['setosa', 'versicolor', 'virginica'], dtype='<U10')
iris.feature_names
['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
Build a DataFrame and preview the first five rows:
X = iris.data
y = iris.target
df = pd.DataFrame(X, columns=iris.feature_names)
df['target'] = y
df.head()
| sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
|---|---|---|---|---|---|
| 0 | 5.1 | 3.5 | 1.4 | 0.2 | 0 |
| 1 | 4.9 | 3.0 | 1.4 | 0.2 | 0 |
| 2 | 4.7 | 3.2 | 1.3 | 0.2 | 0 |
| 3 | 4.6 | 3.1 | 1.5 | 0.2 | 0 |
| 4 | 5.0 | 3.6 | 1.4 | 0.2 | 0 |
The pair plot below reveals that petal length and petal width are the most separable features — setosa is visually distinct from the other two species:
sns.pairplot(df)
plt.show()
Train the classifier using Gini impurity on a stratified split, then evaluate accuracy on the test set:
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 1, test_size = 0.2, stratify = y)
clf = DecisionTreeClassifier(criterion='gini', random_state=1)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
print('Accuracy: ', metrics.accuracy_score(y_test, y_pred))
Accuracy: 0.9666666666666667
The classifier achieves 96.7 % accuracy — only one misclassification across 30 test samples. Visualise the full breakdown with a confusion matrix, where each cell shows the count of actual vs. predicted class pairs:
from mlxtend.evaluate import confusion_matrix
from mlxtend.plotting import plot_confusion_matrix
print('Confusion Matrix')
cm = confusion_matrix(y_test, y_pred)
fig, ax = plot_confusion_matrix(conf_mat=cm)
plt.title('Relative ratios between actual class and predicted class ')
plt.show()
Confusion Matrix
Classification Report
The classification report breaks down precision, recall, and F1-score per class alongside overall accuracy. Print it to see exactly where the classifier succeeds and where it struggles:
print(metrics.classification_report(y_test, y_pred))
precision recall f1-score support
0 1.00 1.00 1.00 10
1 0.91 1.00 0.95 10
2 1.00 0.90 0.95 10
accuracy 0.97 30
macro avg 0.97 0.97 0.97 30
weighted avg 0.97 0.97 0.97 30
Setosa is classified perfectly. The single error occurs between versicolor and virginica, whose petal dimensions overlap slightly — a well-known characteristic of this dataset.
Conclusion
In this tutorial you trained a Decision Tree Regressor on the Diabetes dataset and a Decision Tree Classifier on the Iris dataset using Scikit-learn. The classifier achieved 96.7 % accuracy on a 30-sample test set, while the regressor demonstrated the cost of overfitting — an RMSE nearly equal to the target's standard deviation.
Key takeaways:
- Decision trees split data by choosing the feature and threshold that maximises information gain or minimises Gini impurity at each node.
- An unpruned tree memorises training data — always constrain growth with
max_depthor cost-complexity pruning viaccp_alpha. - Trees are fully interpretable: every prediction traces back to a readable sequence of feature conditions.
- The
criterionparameter ('gini'vs'entropy') rarely changes accuracy significantly —'gini'is the faster default.
Next steps:
- Explore Random Forest to see how an ensemble of trees overcomes the single-tree overfitting problem.
- Read Ensemble Learning for a broader look at bagging and boosting techniques.
- Tune
max_depthon the Diabetes regressor and observe how RMSE changes with tree depth — then usesklearn.tree.plot_tree()to inspect the learned rules.
