Decision Tree Practice with Python

Chapter 7: Tree-Based Models I — Decision Trees

Learning objectives

  • Build a decision tree classifier using scikit-learn on well-log data
  • Visualise the tree structure with plot_tree and export_text
  • Extract and interpret feature importance
  • Demonstrate overfitting with deep trees vs shallow trees
  • Apply pruning with max_depth and cost-complexity pruning (ccp_alpha)
  • Use cross-validation to find the optimal tree depth

Decision Trees in Python — Complete Workflow

In this lab, we build decision trees for lithology classification from well-log data, explore how tree depth affects overfitting, visualise the learned rules, and use cross-validation to find the optimal pruning parameters.

Step 1: Preparing Well-Log Data

We generate synthetic well-log data with four common measurements: gamma ray (GR), resistivity, bulk density, and neutron porosity. Each measurement responds differently to rock types, making them useful features for classification.

%%%python import numpy as np import pandas as pd import matplotlib.pyplot as plt from sklearn.tree import DecisionTreeClassifier, plot_tree, export_text from sklearn.model_selection import train_test_split, cross_val_score from sklearn.metrics import (accuracy_score, classification_report, confusion_matrix)

Generate synthetic well-log data for 3 lithologies

np.random.seed(42) n = 150

Gamma ray: shale=high, sandstone=moderate, limestone=low

gr = np.concatenate([np.random.normal(85, 12, 60), np.random.normal(35, 10, 50), np.random.normal(20, 7, 40)])

Resistivity: shale=low, sandstone=moderate, limestone=high

res = np.concatenate([np.random.normal(8, 3, 60), np.random.normal(60, 20, 50), np.random.normal(300, 80, 40)])

Bulk density

den = np.concatenate([np.random.normal(2.55, 0.05, 60), np.random.normal(2.30, 0.08, 50), np.random.normal(2.70, 0.04, 40)])

Neutron porosity

nphi = np.concatenate([np.random.normal(35, 5, 60), np.random.normal(20, 6, 50), np.random.normal(3, 2, 40)])

X = np.column_stack([gr, res, den, nphi]) y = np.array([0]*60 + [1]*50 + [2]*40) feature_names = ["Gamma Ray", "Resistivity", "Density", "Neutron Porosity"] class_names = ["Shale", "Sandstone", "Limestone"]

print(f"Dataset: {X.shape[0]} samples, {X.shape[1]} features") print(f"Classes: Shale={60}, Sandstone={50}, Limestone={40}")

Train/test split

X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42, stratify=y ) print(f"Train: {X_train.shape[0]}, Test: {X_test.shape[0]}") %%%

Note: decision trees do not require feature scaling because they split on individual features using thresholds. This is a key advantage over KNN and neural networks.

Step 2: Building an Unpruned Tree

First, let us build an unpruned tree (no depth limit) to see how deep it grows and whether it overfits:

%%%python

Build unpruned tree

tree_full = DecisionTreeClassifier(random_state=42) tree_full.fit(X_train, y_train)

train_acc = tree_full.score(X_train, y_train) test_acc = tree_full.score(X_test, y_test)

print(f"Unpruned tree:") print(f" Train accuracy: {train_acc:.3f}") print(f" Test accuracy: {test_acc:.3f}") print(f" Tree depth: {tree_full.get_depth()}") print(f" Number of leaves: {tree_full.get_n_leaves()}") %%%

The unpruned tree likely achieves near-perfect training accuracy but lower test accuracy — a classic sign of overfitting.

Step 3: Visualising the Tree

One of the biggest advantages of decision trees is interpretability. We can visualise the entire learned model:

%%%python

Plot the top 3 levels of the tree

plt.figure(figsize=(20, 10)) plot_tree(tree_full, feature_names=feature_names, class_names=class_names, filled=True, rounded=True, fontsize=10, max_depth=3) plt.title("Decision Tree for Lithology Classification (top 3 levels)") plt.tight_layout() plt.show() %%%

Each node shows: the split condition (e.g., "Gamma Ray <= 55.5"), the Gini impurity, the number of samples, and the dominant class. Filled colours indicate the majority class at each node.

%%%python

Print the tree as human-readable text rules

tree_rules = export_text(tree_full, feature_names=feature_names, max_depth=4) print(tree_rules) %%%

These text rules can be shared with geologists who do not use Python. They can verify that the splits make geological sense — for example, "Gamma Ray > 55 implies shale" aligns with the known high radioactivity of clay-rich shales.

Step 4: Feature Importance

Decision trees automatically calculate how important each feature is for making predictions. Feature importance is based on how much each feature reduces impurity across all splits where it is used.

%%%python importances = tree_full.feature_importances_

Bar chart

plt.figure(figsize=(8, 5)) bars = plt.barh(feature_names, importances, color="steelblue") plt.xlabel("Importance (Gini reduction)") plt.title("Feature Importance — Decision Tree") for bar, imp in zip(bars, importances): plt.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height()/2, f"{imp:.3f}", va="center") plt.tight_layout() plt.show()

Print sorted importances

print("Feature importance ranking:") for name, imp in sorted(zip(feature_names, importances), key=lambda x: x[1], reverse=True): print(f" {name:>20}: {imp:.4f}") %%%

In well-log data, gamma ray and neutron porosity are typically the most discriminating features for separating shale, sandstone, and limestone. This aligns with geological understanding: GR directly measures clay content and NPHI responds strongly to hydrogen in pore fluids.

Step 5: Overfitting Demonstration — Deep vs Shallow Trees

Let us systematically compare trees with different max_depth values to see the overfitting effect:

%%%python print(f"{"Depth":<8} {"Train":>8} {"Test":>8} " f"{"Leaves":>8} {"Gap":>8}") print("-" * 44)

for depth in [1, 2, 3, 4, 5, 7, 10, None]: dt = DecisionTreeClassifier(max_depth=depth, random_state=42) dt.fit(X_train, y_train) tr = dt.score(X_train, y_train) te = dt.score(X_test, y_test) gap = tr - te d_str = str(depth) if depth else "None" print(f"{d_str:<8} {tr:>8.3f} {te:>8.3f} " f"{dt.get_n_leaves():>8} {gap:>8.3f}") %%%

Observations:

  • Very shallow trees (depth 1-2) have high bias: they cannot capture enough structure and both train and test accuracy are low.
  • Moderate depth (3-5) typically gives the best test accuracy: enough complexity to capture real patterns without memorising noise.
  • Deep trees (10+, None) have low bias but high variance: near-perfect training accuracy but lower test accuracy due to overfitting.
  • The "Gap" column (train - test) quantifies overfitting. Larger gaps mean more overfitting.

Step 6: Visualising Overfitting

%%%python depths = range(1, 16) train_accs, test_accs = [], [] for d in depths: dt = DecisionTreeClassifier(max_depth=d, random_state=42) dt.fit(X_train, y_train) train_accs.append(dt.score(X_train, y_train)) test_accs.append(dt.score(X_test, y_test))

plt.figure(figsize=(9, 5)) plt.plot(depths, train_accs, "o-", label="Train", linewidth=2) plt.plot(depths, test_accs, "s-", label="Test", linewidth=2) plt.xlabel("Max Depth") plt.ylabel("Accuracy") plt.title("Decision Tree: Accuracy vs Max Depth") plt.legend() plt.xticks(list(depths)) plt.grid(alpha=0.3) plt.tight_layout() plt.show()

best_depth = list(depths)[np.argmax(test_accs)] print(f"Best max_depth = {best_depth} " f"(test acc = {max(test_accs):.3f})") %%%

Step 7: Cost-Complexity Pruning (ccp_alpha)

An alternative to pre-pruning with max_depth is post-pruning using the cost-complexity parameter α\alpha. Larger α\alpha produces simpler trees by penalising the number of leaves.

%%%python

Get the pruning path

path = tree_full.cost_complexity_pruning_path(X_train, y_train) alphas = path.ccp_alphas print(f"Number of alpha values to test: {len(alphas)}") print(f"Alpha range: [{alphas.min():.6f}, {alphas.max():.6f}]")

train_scores, test_scores = [], [] for alpha in alphas: dt = DecisionTreeClassifier(ccp_alpha=alpha, random_state=42) dt.fit(X_train, y_train) train_scores.append(dt.score(X_train, y_train)) test_scores.append(dt.score(X_test, y_test))

plt.figure(figsize=(9, 5)) plt.plot(alphas, train_scores, "o-", label="Train", markersize=4, linewidth=1.5) plt.plot(alphas, test_scores, "s-", label="Test", markersize=4, linewidth=1.5) plt.xlabel("Alpha (complexity parameter)") plt.ylabel("Accuracy") plt.title("Cost-Complexity Pruning Path") plt.legend() plt.grid(alpha=0.3) plt.tight_layout() plt.show()

best_alpha = alphas[np.argmax(test_scores)] print(f"Best alpha = {best_alpha:.6f} " f"(test acc = {max(test_scores):.3f})") %%%

Step 8: Cross-Validation for Optimal Depth

Use 5-fold cross-validation for a more robust estimate of the best tree depth:

%%%python depths = range(1, 16) cv_means, cv_stds = [], []

for d in depths: dt = DecisionTreeClassifier(max_depth=d, random_state=42) scores = cross_val_score(dt, X_train, y_train, cv=5, scoring="accuracy") cv_means.append(scores.mean()) cv_stds.append(scores.std())

cv_means = np.array(cv_means) cv_stds = np.array(cv_stds)

plt.figure(figsize=(9, 5)) plt.plot(depths, cv_means, "o-", linewidth=2) plt.fill_between(depths, cv_means - cv_stds, cv_means + cv_stds, alpha=0.2) plt.xlabel("Max Depth") plt.ylabel("5-Fold CV Accuracy") plt.title("Cross-Validation: Optimal Tree Depth") plt.xticks(list(depths)) plt.grid(alpha=0.3) plt.tight_layout() plt.show()

best_cv_depth = list(depths)[np.argmax(cv_means)] print(f"Best depth by CV = {best_cv_depth} " f"(mean acc = {cv_means.max():.3f} " f"+/- {cv_stds[np.argmax(cv_means)]:.3f})") %%%

Step 9: Final Model and Evaluation

%%%python

Build the final pruned tree

tree_final = DecisionTreeClassifier( max_depth=best_cv_depth, random_state=42 ) tree_final.fit(X_train, y_train) y_pred = tree_final.predict(X_test)

print(f"Final tree (depth={best_cv_depth}):") print(f" Accuracy: {accuracy_score(y_test, y_pred):.3f}") print(f" Leaves: {tree_final.get_n_leaves()}") print() print(classification_report(y_test, y_pred, target_names=class_names))

Confusion matrix

cm = confusion_matrix(y_test, y_pred) fig, ax = plt.subplots(figsize=(6, 5)) im = ax.imshow(cm, cmap="Blues") plt.colorbar(im) ax.set_xticks([0, 1, 2]) ax.set_xticklabels(class_names) ax.set_yticks([0, 1, 2]) ax.set_yticklabels(class_names) ax.set_xlabel("Predicted") ax.set_ylabel("Actual") ax.set_title(f"Decision Tree Confusion Matrix (depth={best_cv_depth})") for i in range(3): for j in range(3): color = "white" if cm[i, j] > cm.max() / 2 else "black" ax.text(j, i, str(cm[i, j]), ha="center", va="center", color=color, fontsize=14) plt.tight_layout() plt.show() %%%

Step 10: Comparing the Final Tree with the Unpruned Tree

%%%python

Side-by-side comparison

fig, axes = plt.subplots(1, 2, figsize=(24, 8))

Pruned tree

plot_tree(tree_final, feature_names=feature_names, class_names=class_names, filled=True, rounded=True, fontsize=9, ax=axes[0]) axes[0].set_title(f"Pruned (depth={best_cv_depth}, " f"leaves={tree_final.get_n_leaves()})")

Unpruned tree (top 4 levels)

plot_tree(tree_full, feature_names=feature_names, class_names=class_names, filled=True, rounded=True, fontsize=8, ax=axes[1], max_depth=4) axes[1].set_title(f"Unpruned (depth={tree_full.get_depth()}, " f"leaves={tree_full.get_n_leaves()})")

plt.tight_layout() plt.show()

print(f"Pruned: train={tree_final.score(X_train, y_train):.3f}, " f"test={tree_final.score(X_test, y_test):.3f}") print(f"Unpruned: train={tree_full.score(X_train, y_train):.3f}, " f"test={tree_full.score(X_test, y_test):.3f}") %%%

Step 11: Making Predictions and Explaining Them

%%%python

Predict lithology for a new well-log reading

GR=75 API, Res=15 ohm-m, Den=2.50 g/cc, NPHI=28%

new_sample = np.array([[75, 15, 2.50, 28]]) prediction = tree_final.predict(new_sample) proba = tree_final.predict_proba(new_sample)

print(f"Predicted: {class_names[prediction[0]]}") print("Probabilities:") for name, p in zip(class_names, proba[0]): print(f" {name}: {p:.3f}")

Show the decision path

path = tree_final.decision_path(new_sample) node_ids = path.indices print(f"\nDecision path through nodes: {list(node_ids)}") %%%

The decision path shows exactly which nodes the sample passed through, making the prediction fully explainable. This interpretability is valuable when presenting results to geologists or regulatory bodies.

[Library: scikit-learn DecisionTreeClassifier]

References

  • Pedregosa, F., et al. (2011). Scikit-learn: Machine learning in Python. J. Mach. Learn. Res. 12, 2825–2830.
  • Géron, A. (2022). Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow (3rd ed.), ch. 6 (decision trees). O’Reilly.
  • James, G., Witten, D., Hastie, T., Tibshirani, R. (2021). An Introduction to Statistical Learning (2nd ed.), ch. 8 (lab: decision trees). Springer.

This page is prerendered for SEO and accessibility. The interactive widgets above hydrate on JavaScript load.