#decision tree#classifier#regressor#scikit-learn#python#pruning

Decision Tree in Python

Train decision tree classifiers and regressors in Python with scikit-learn. Covers splitting criteria, key hyperparameters, pruning, and model evaluation.

May 22, 2026 at 11:15 AM10 min readFollowFollow (Hindi)

Topics You Will Master

How decision trees split data using information gain, gain ratio, and Gini impurity
Training a DecisionTreeRegressor on the scikit-learn Diabetes dataset
Training a DecisionTreeClassifier on the Iris dataset and evaluating with a confusion matrix
Controlling tree complexity with max_depth and pruning to prevent overfitting
Best For

Python developers who understand basic supervised learning and want a hands-on introduction to interpretable tree-based models.

Expected Outcome

Two working decision tree models — a regressor and a classifier — trained, evaluated, and visualised on real datasets, with a clear understanding of how the tree makes its decisions.

A decision tree is one of the most transparent models in machine learning. Instead of fitting a black-box function, it learns a series of if-else rules that partition the feature space — rules you can read, trace, and explain to a non-technical audience. This interpretability is what makes decision trees a natural starting point for supervised learning.

In this tutorial you will build two models using Scikit-learn: a Decision Tree Regressor trained on the Diabetes dataset to predict disease progression, and a Decision Tree Classifier trained on the Iris dataset to identify flower species.

Prerequisites: Python 3.x, Scikit-learn, Pandas, NumPy, Matplotlib, Seaborn.

How Decision Trees Work

A decision tree learns by repeatedly asking: "Which feature, split at which value, best separates the data?" It starts at the root with the full dataset, selects the best split, divides the data into two subsets, and repeats the process on each subset — a technique called recursive binary splitting. This continues until a stopping condition is met, such as a maximum depth or a minimum number of samples in a node.

The steps at each node are:

  • Select the best attribute using an Attribute Selection Measure (ASM) to split the records
  • Make that attribute a decision node and break the dataset into smaller subsets
  • Repeat recursively for each child until all samples in a node belong to the same class, no attributes remain, or no samples remain

Two popular algorithms implement recursive splitting differently:

  • CART (Classification and Regression Trees) — uses Gini impurity as the splitting metric
  • ID3 (Iterative Dichotomiser 3) — uses Entropy and Information Gain

For a deep theoretical treatment, see Chapter 8 of An Introduction to Statistical Learning.

Attribute Selection Measures

At every split the tree needs a metric to rank which feature and threshold produces the cleanest separation. The three most common measures are Information Gain, Gain Ratio, and Gini Index.

Information Gain

To define information gain we first need entropy — a measure of impurity from information theory. A pure node where all samples belong to one class has entropy 0; a perfectly mixed node has maximum entropy.

Where:

  • — the current dataset
  • — the set of classes in
  • — the proportion of samples belonging to class

Information Gain then measures how much a split on attribute reduces entropy. The tree always picks the attribute that maximises information gain — minimising uncertainty in the resulting subsets.

Where:

  • — entropy of the full dataset before the split
  • — the subsets created by splitting on attribute
  • — the proportion of samples routed to subset
  • — entropy of subset

Gain Ratio

Information Gain favours attributes with many distinct values, which can lead to poor generalisation. Gain Ratio corrects this bias by penalising broad splits using a term called Split Information:

Gini Index

Gini impurity estimates the probability that a randomly chosen sample would be misclassified if labelled according to the class distribution of its node. A pure node has Gini 0; a fully mixed node has the maximum value. scikit-learn uses Gini as its default criterion because it is faster to compute than entropy.

Key Hyperparameters

Three parameters have the greatest impact on tree behaviour:

  • criterion — the splitting metric: 'gini' (default) or 'entropy'. Both produce similar accuracy; 'gini' is faster to compute.
  • splitter — split strategy: 'best' always picks the optimal split; 'random' picks the best among a random subset of features, adding variance but reducing compute time.
  • max_depth — the maximum number of levels in the tree. None grows the tree until leaves are pure or too small to split. Higher values risk overfitting; lower values risk underfitting.

Pruning

An unconstrained tree grows until it perfectly memorises the training data — a classic case of overfitting. Pruning reduces the tree by removing branches that use features of low importance, cutting complexity and improving generalisation.

The diagram below shows an unpruned tree vs a pruned version:

Pruned decision tree with reduced depth showing cleaner structure and improved generalisation

Regression: The Diabetes Dataset

The scikit-learn Diabetes dataset contains 10 physiological features (age, BMI, blood pressure, and six blood serum measurements) for 442 patients, along with a quantitative measure of disease progression one year later. The goal is to predict this continuous target value.

Start by importing all the libraries needed for this section:

PYTHON
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn import datasets, metrics
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor

Load the dataset and inspect its keys:

PYTHON
diabetes  = datasets.load_diabetes()
diabetes.keys()
PYTHON
dict_keys(['data', 'target', 'frame', 'DESCR', 'feature_names', 'data_filename', 'target_filename'])

Print the full dataset description to understand what each feature represents:

PYTHON
print (diabetes.DESCR)
PYTHON
.. _diabetes_dataset:

Diabetes dataset
----------------

Ten baseline variables, age, sex, body mass index, average blood
pressure, and six blood serum measurements were obtained for each of n =
442 diabetes patients, as well as the response of interest, a
quantitative measure of disease progression one year after baseline.

**Data Set Characteristics:**

  :Number of Instances: 442

  :Number of Attributes: First 10 columns are numeric predictive values

  :Target: Column 11 is a quantitative measure of disease progression one year after baseline

  :Attribute Information:
      - age     age in years
      - sex
      - bmi     body mass index
      - bp      average blood pressure
      - s1      tc, T-Cells (a type of white blood cells)
      - s2      ldl, low-density lipoproteins
      - s3      hdl, high-density lipoproteins
      - s4      tch, thyroid stimulating hormone
      - s5      ltg, lamotrigine
      - s6      glu, blood sugar level

Note: Each of these 10 feature variables have been mean centered and scaled by the standard deviation times `n_samples` (i.e. the sum of squares of each column totals 1).

Inspect the feature names and the first ten target values:

PYTHON
diabetes.feature_names
OUTPUT
['age', 'sex', 'bmi', 'bp', 's1', 's2', 's3', 's4', 's5', 's6']
PYTHON
diabetes.target[: 10]
OUTPUT
array([151.,  75., 141., 206., 135.,  97., 138.,  63., 110., 310.])

Confirm the shape of the feature matrix and target vector:

PYTHON
X = diabetes.data
y = diabetes.target
X.shape, y.shape
OUTPUT
((442, 10), (442,))

Build a DataFrame for easier exploration and preview the first five rows:

PYTHON
df = pd.DataFrame(X, columns=diabetes.feature_names)
df['target'] = y
df.head()
OUTPUT
agesexbmibps1s2s3s4s5s6target
00.0380760.0506800.0616960.021872-0.044223-0.034821-0.043401-0.0025920.019908-0.017646151.0
1-0.001882-0.044642-0.051474-0.026328-0.008449-0.0191630.074412-0.039493-0.068330-0.09220475.0
20.0852990.0506800.044451-0.005671-0.045599-0.034194-0.032356-0.0025920.002864-0.025930141.0
3-0.089063-0.044642-0.011595-0.0366560.0121910.024991-0.0360380.0343090.022692-0.009362206.0
40.005383-0.044642-0.0363850.0218720.0039350.0155960.008142-0.002592-0.031991-0.046641135.0

A pair plot shows the pairwise relationships and distributions across all features, giving a quick sense of which variables correlate most strongly with the target:

PYTHON
sns.pairplot(df)
plt.show()

Seaborn pairplot showing pairwise relationships between all Diabetes dataset features

Training the Decision Tree Regressor

Split the data into 80 % training and 20 % test sets, then fit the regressor and generate predictions:

PYTHON
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 42)
PYTHON
regressor = DecisionTreeRegressor(random_state=42)
regressor.fit(X_train, y_train)
y_pred = regressor.predict(X_test)

Plot the predicted values against the true test values to see how closely the regressor tracks actual disease progression scores:

PYTHON
plt.figure(figsize=(16, 4))
plt.plot(y_pred,  label='y_pred')
plt.plot(y_test, label='y_test')
plt.xlabel('X_test', fontsize=14)
plt.ylabel('Value of y(pred , test)', fontsize=14)
plt.title('Comparing predicted values and true values')
plt.legend(title='Parameter where:')
plt.show()

Calculate the Root Mean Squared Error (RMSE) to measure average prediction error in the same units as the target:

PYTHON
np.sqrt(metrics.mean_squared_error(y_test, y_pred))
OUTPUT
70.61829663921893

Compare the RMSE against the standard deviation of the test targets — an RMSE close to the standard deviation suggests the model is barely better than predicting the mean:

PYTHON
y_test.std()
OUTPUT
72.78840394263774

An RMSE of 70.6 against a standard deviation of 72.8 shows the unpruned tree is overfitting — it memorises the training data but struggles to generalise. Setting a max_depth limit would significantly close this gap.

Classification: Iris Flower Recognition

The Iris dataset contains 150 samples across three flower species — setosa, versicolor, and virginica — each described by four measurements: sepal length, sepal width, petal length, and petal width. The goal is to classify each sample into the correct species.

PYTHON
from sklearn.tree import DecisionTreeClassifier

Load the dataset and inspect the class names and feature names:

PYTHON
iris = datasets.load_iris()
iris.target_names
OUTPUT
array(['setosa', 'versicolor', 'virginica'], dtype='<U10')
PYTHON
iris.feature_names
OUTPUT
['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']

Build a DataFrame and preview the first five rows:

PYTHON
X = iris.data
y = iris.target
df = pd.DataFrame(X, columns=iris.feature_names)
df['target'] = y
df.head()
OUTPUT
sepal length (cm)sepal width (cm)petal length (cm)petal width (cm)target
05.13.51.40.20
14.93.01.40.20
24.73.21.30.20
34.63.11.50.20
45.03.61.40.20

The pair plot below reveals that petal length and petal width are the most separable features — setosa is visually distinct from the other two species:

PYTHON
sns.pairplot(df)
plt.show()

Train the classifier using Gini impurity on a stratified split, then evaluate accuracy on the test set:

PYTHON
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 1, test_size = 0.2, stratify = y)
clf = DecisionTreeClassifier(criterion='gini', random_state=1)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
print('Accuracy: ', metrics.accuracy_score(y_test, y_pred))
OUTPUT
Accuracy:  0.9666666666666667

The classifier achieves 96.7 % accuracy — only one misclassification across 30 test samples. Visualise the full breakdown with a confusion matrix, where each cell shows the count of actual vs. predicted class pairs:

PYTHON
from mlxtend.evaluate import confusion_matrix
from mlxtend.plotting import plot_confusion_matrix
print('Confusion Matrix')
cm = confusion_matrix(y_test, y_pred)
fig, ax = plot_confusion_matrix(conf_mat=cm)
plt.title('Relative  ratios between  actual class and predicted class ')
plt.show()
OUTPUT
Confusion Matrix

Classification Report

The classification report breaks down precision, recall, and F1-score per class alongside overall accuracy. Print it to see exactly where the classifier succeeds and where it struggles:

PYTHON
print(metrics.classification_report(y_test, y_pred))
OUTPUT
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        10
           1       0.91      1.00      0.95        10
           2       1.00      0.90      0.95        10

    accuracy                           0.97        30
   macro avg       0.97      0.97      0.97        30
weighted avg       0.97      0.97      0.97        30

Setosa is classified perfectly. The single error occurs between versicolor and virginica, whose petal dimensions overlap slightly — a well-known characteristic of this dataset.

Conclusion

In this tutorial you trained a Decision Tree Regressor on the Diabetes dataset and a Decision Tree Classifier on the Iris dataset using Scikit-learn. The classifier achieved 96.7 % accuracy on a 30-sample test set, while the regressor demonstrated the cost of overfitting — an RMSE nearly equal to the target's standard deviation.

Key takeaways:

  • Decision trees split data by choosing the feature and threshold that maximises information gain or minimises Gini impurity at each node.
  • An unpruned tree memorises training data — always constrain growth with max_depth or cost-complexity pruning via ccp_alpha.
  • Trees are fully interpretable: every prediction traces back to a readable sequence of feature conditions.
  • The criterion parameter ('gini' vs 'entropy') rarely changes accuracy significantly — 'gini' is the faster default.

Next steps:

  • Explore Random Forest to see how an ensemble of trees overcomes the single-tree overfitting problem.
  • Read Ensemble Learning for a broader look at bagging and boosting techniques.
  • Tune max_depth on the Diabetes regressor and observe how RMSE changes with tree depth — then use sklearn.tree.plot_tree() to inspect the learned rules.

Find this tutorial useful?

Subscribe to our YouTube channels for more practical production walk-throughs.

Discussion & Comments