Decision Tree Practice with Python
Learning objectives
- Build a decision tree classifier using scikit-learn on well-log data
- Visualise the tree structure with plot_tree and export_text
- Extract and interpret feature importance
- Demonstrate overfitting with deep trees vs shallow trees
- Apply pruning with max_depth and cost-complexity pruning (ccp_alpha)
- Use cross-validation to find the optimal tree depth
Decision Trees in Python — Complete Workflow
In this lab, we build decision trees for lithology classification from well-log data, explore how tree depth affects overfitting, visualise the learned rules, and use cross-validation to find the optimal pruning parameters.
Step 1: Preparing Well-Log Data
We generate synthetic well-log data with four common measurements: gamma ray (GR), resistivity, bulk density, and neutron porosity. Each measurement responds differently to rock types, making them useful features for classification.
%%%python import numpy as np import pandas as pd import matplotlib.pyplot as plt from sklearn.tree import DecisionTreeClassifier, plot_tree, export_text from sklearn.model_selection import train_test_split, cross_val_score from sklearn.metrics import (accuracy_score, classification_report, confusion_matrix)
Generate synthetic well-log data for 3 lithologies
np.random.seed(42) n = 150
Gamma ray: shale=high, sandstone=moderate, limestone=low
gr = np.concatenate([np.random.normal(85, 12, 60), np.random.normal(35, 10, 50), np.random.normal(20, 7, 40)])
Resistivity: shale=low, sandstone=moderate, limestone=high
res = np.concatenate([np.random.normal(8, 3, 60), np.random.normal(60, 20, 50), np.random.normal(300, 80, 40)])
Bulk density
den = np.concatenate([np.random.normal(2.55, 0.05, 60), np.random.normal(2.30, 0.08, 50), np.random.normal(2.70, 0.04, 40)])
Neutron porosity
nphi = np.concatenate([np.random.normal(35, 5, 60), np.random.normal(20, 6, 50), np.random.normal(3, 2, 40)])
X = np.column_stack([gr, res, den, nphi]) y = np.array([0]*60 + [1]*50 + [2]*40) feature_names = ["Gamma Ray", "Resistivity", "Density", "Neutron Porosity"] class_names = ["Shale", "Sandstone", "Limestone"]
print(f"Dataset: {X.shape[0]} samples, {X.shape[1]} features") print(f"Classes: Shale={60}, Sandstone={50}, Limestone={40}")
Train/test split
X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42, stratify=y ) print(f"Train: {X_train.shape[0]}, Test: {X_test.shape[0]}") %%%
Note: decision trees do not require feature scaling because they split on individual features using thresholds. This is a key advantage over KNN and neural networks.
Step 2: Building an Unpruned Tree
First, let us build an unpruned tree (no depth limit) to see how deep it grows and whether it overfits:
%%%python
Build unpruned tree
tree_full = DecisionTreeClassifier(random_state=42) tree_full.fit(X_train, y_train)
train_acc = tree_full.score(X_train, y_train) test_acc = tree_full.score(X_test, y_test)
print(f"Unpruned tree:") print(f" Train accuracy: {train_acc:.3f}") print(f" Test accuracy: {test_acc:.3f}") print(f" Tree depth: {tree_full.get_depth()}") print(f" Number of leaves: {tree_full.get_n_leaves()}") %%%
The unpruned tree likely achieves near-perfect training accuracy but lower test accuracy — a classic sign of overfitting.
Step 3: Visualising the Tree
One of the biggest advantages of decision trees is interpretability. We can visualise the entire learned model:
%%%python
Plot the top 3 levels of the tree
plt.figure(figsize=(20, 10)) plot_tree(tree_full, feature_names=feature_names, class_names=class_names, filled=True, rounded=True, fontsize=10, max_depth=3) plt.title("Decision Tree for Lithology Classification (top 3 levels)") plt.tight_layout() plt.show() %%%
Each node shows: the split condition (e.g., "Gamma Ray <= 55.5"), the Gini impurity, the number of samples, and the dominant class. Filled colours indicate the majority class at each node.
%%%python
Print the tree as human-readable text rules
tree_rules = export_text(tree_full, feature_names=feature_names, max_depth=4) print(tree_rules) %%%
These text rules can be shared with geologists who do not use Python. They can verify that the splits make geological sense — for example, "Gamma Ray > 55 implies shale" aligns with the known high radioactivity of clay-rich shales.
Step 4: Feature Importance
Decision trees automatically calculate how important each feature is for making predictions. Feature importance is based on how much each feature reduces impurity across all splits where it is used.
%%%python importances = tree_full.feature_importances_
Bar chart
plt.figure(figsize=(8, 5)) bars = plt.barh(feature_names, importances, color="steelblue") plt.xlabel("Importance (Gini reduction)") plt.title("Feature Importance — Decision Tree") for bar, imp in zip(bars, importances): plt.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height()/2, f"{imp:.3f}", va="center") plt.tight_layout() plt.show()
Print sorted importances
print("Feature importance ranking:") for name, imp in sorted(zip(feature_names, importances), key=lambda x: x[1], reverse=True): print(f" {name:>20}: {imp:.4f}") %%%
In well-log data, gamma ray and neutron porosity are typically the most discriminating features for separating shale, sandstone, and limestone. This aligns with geological understanding: GR directly measures clay content and NPHI responds strongly to hydrogen in pore fluids.
Step 5: Overfitting Demonstration — Deep vs Shallow Trees
Let us systematically compare trees with different max_depth values to see the overfitting effect:
%%%python print(f"{"Depth":<8} {"Train":>8} {"Test":>8} " f"{"Leaves":>8} {"Gap":>8}") print("-" * 44)
for depth in [1, 2, 3, 4, 5, 7, 10, None]: dt = DecisionTreeClassifier(max_depth=depth, random_state=42) dt.fit(X_train, y_train) tr = dt.score(X_train, y_train) te = dt.score(X_test, y_test) gap = tr - te d_str = str(depth) if depth else "None" print(f"{d_str:<8} {tr:>8.3f} {te:>8.3f} " f"{dt.get_n_leaves():>8} {gap:>8.3f}") %%%
Observations:
- Very shallow trees (depth 1-2) have high bias: they cannot capture enough structure and both train and test accuracy are low.
- Moderate depth (3-5) typically gives the best test accuracy: enough complexity to capture real patterns without memorising noise.
- Deep trees (10+, None) have low bias but high variance: near-perfect training accuracy but lower test accuracy due to overfitting.
- The "Gap" column (train - test) quantifies overfitting. Larger gaps mean more overfitting.
Step 6: Visualising Overfitting
%%%python depths = range(1, 16) train_accs, test_accs = [], [] for d in depths: dt = DecisionTreeClassifier(max_depth=d, random_state=42) dt.fit(X_train, y_train) train_accs.append(dt.score(X_train, y_train)) test_accs.append(dt.score(X_test, y_test))
plt.figure(figsize=(9, 5)) plt.plot(depths, train_accs, "o-", label="Train", linewidth=2) plt.plot(depths, test_accs, "s-", label="Test", linewidth=2) plt.xlabel("Max Depth") plt.ylabel("Accuracy") plt.title("Decision Tree: Accuracy vs Max Depth") plt.legend() plt.xticks(list(depths)) plt.grid(alpha=0.3) plt.tight_layout() plt.show()
best_depth = list(depths)[np.argmax(test_accs)] print(f"Best max_depth = {best_depth} " f"(test acc = {max(test_accs):.3f})") %%%
Step 7: Cost-Complexity Pruning (ccp_alpha)
An alternative to pre-pruning with max_depth is post-pruning using the cost-complexity parameter . Larger produces simpler trees by penalising the number of leaves.
%%%python
Get the pruning path
path = tree_full.cost_complexity_pruning_path(X_train, y_train) alphas = path.ccp_alphas print(f"Number of alpha values to test: {len(alphas)}") print(f"Alpha range: [{alphas.min():.6f}, {alphas.max():.6f}]")
train_scores, test_scores = [], [] for alpha in alphas: dt = DecisionTreeClassifier(ccp_alpha=alpha, random_state=42) dt.fit(X_train, y_train) train_scores.append(dt.score(X_train, y_train)) test_scores.append(dt.score(X_test, y_test))
plt.figure(figsize=(9, 5)) plt.plot(alphas, train_scores, "o-", label="Train", markersize=4, linewidth=1.5) plt.plot(alphas, test_scores, "s-", label="Test", markersize=4, linewidth=1.5) plt.xlabel("Alpha (complexity parameter)") plt.ylabel("Accuracy") plt.title("Cost-Complexity Pruning Path") plt.legend() plt.grid(alpha=0.3) plt.tight_layout() plt.show()
best_alpha = alphas[np.argmax(test_scores)] print(f"Best alpha = {best_alpha:.6f} " f"(test acc = {max(test_scores):.3f})") %%%
Step 8: Cross-Validation for Optimal Depth
Use 5-fold cross-validation for a more robust estimate of the best tree depth:
%%%python depths = range(1, 16) cv_means, cv_stds = [], []
for d in depths: dt = DecisionTreeClassifier(max_depth=d, random_state=42) scores = cross_val_score(dt, X_train, y_train, cv=5, scoring="accuracy") cv_means.append(scores.mean()) cv_stds.append(scores.std())
cv_means = np.array(cv_means) cv_stds = np.array(cv_stds)
plt.figure(figsize=(9, 5)) plt.plot(depths, cv_means, "o-", linewidth=2) plt.fill_between(depths, cv_means - cv_stds, cv_means + cv_stds, alpha=0.2) plt.xlabel("Max Depth") plt.ylabel("5-Fold CV Accuracy") plt.title("Cross-Validation: Optimal Tree Depth") plt.xticks(list(depths)) plt.grid(alpha=0.3) plt.tight_layout() plt.show()
best_cv_depth = list(depths)[np.argmax(cv_means)] print(f"Best depth by CV = {best_cv_depth} " f"(mean acc = {cv_means.max():.3f} " f"+/- {cv_stds[np.argmax(cv_means)]:.3f})") %%%
Step 9: Final Model and Evaluation
%%%python
Build the final pruned tree
tree_final = DecisionTreeClassifier( max_depth=best_cv_depth, random_state=42 ) tree_final.fit(X_train, y_train) y_pred = tree_final.predict(X_test)
print(f"Final tree (depth={best_cv_depth}):") print(f" Accuracy: {accuracy_score(y_test, y_pred):.3f}") print(f" Leaves: {tree_final.get_n_leaves()}") print() print(classification_report(y_test, y_pred, target_names=class_names))
Confusion matrix
cm = confusion_matrix(y_test, y_pred) fig, ax = plt.subplots(figsize=(6, 5)) im = ax.imshow(cm, cmap="Blues") plt.colorbar(im) ax.set_xticks([0, 1, 2]) ax.set_xticklabels(class_names) ax.set_yticks([0, 1, 2]) ax.set_yticklabels(class_names) ax.set_xlabel("Predicted") ax.set_ylabel("Actual") ax.set_title(f"Decision Tree Confusion Matrix (depth={best_cv_depth})") for i in range(3): for j in range(3): color = "white" if cm[i, j] > cm.max() / 2 else "black" ax.text(j, i, str(cm[i, j]), ha="center", va="center", color=color, fontsize=14) plt.tight_layout() plt.show() %%%
Step 10: Comparing the Final Tree with the Unpruned Tree
%%%python
Side-by-side comparison
fig, axes = plt.subplots(1, 2, figsize=(24, 8))
Pruned tree
plot_tree(tree_final, feature_names=feature_names, class_names=class_names, filled=True, rounded=True, fontsize=9, ax=axes[0]) axes[0].set_title(f"Pruned (depth={best_cv_depth}, " f"leaves={tree_final.get_n_leaves()})")
Unpruned tree (top 4 levels)
plot_tree(tree_full, feature_names=feature_names, class_names=class_names, filled=True, rounded=True, fontsize=8, ax=axes[1], max_depth=4) axes[1].set_title(f"Unpruned (depth={tree_full.get_depth()}, " f"leaves={tree_full.get_n_leaves()})")
plt.tight_layout() plt.show()
print(f"Pruned: train={tree_final.score(X_train, y_train):.3f}, " f"test={tree_final.score(X_test, y_test):.3f}") print(f"Unpruned: train={tree_full.score(X_train, y_train):.3f}, " f"test={tree_full.score(X_test, y_test):.3f}") %%%
Step 11: Making Predictions and Explaining Them
%%%python
Predict lithology for a new well-log reading
GR=75 API, Res=15 ohm-m, Den=2.50 g/cc, NPHI=28%
new_sample = np.array([[75, 15, 2.50, 28]]) prediction = tree_final.predict(new_sample) proba = tree_final.predict_proba(new_sample)
print(f"Predicted: {class_names[prediction[0]]}") print("Probabilities:") for name, p in zip(class_names, proba[0]): print(f" {name}: {p:.3f}")
Show the decision path
path = tree_final.decision_path(new_sample) node_ids = path.indices print(f"\nDecision path through nodes: {list(node_ids)}") %%%
The decision path shows exactly which nodes the sample passed through, making the prediction fully explainable. This interpretability is valuable when presenting results to geologists or regulatory bodies.
[Library: scikit-learn DecisionTreeClassifier]
References
- Pedregosa, F., et al. (2011). Scikit-learn: Machine learning in Python. J. Mach. Learn. Res. 12, 2825–2830.
- Géron, A. (2022). Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow (3rd ed.), ch. 6 (decision trees). O’Reilly.
- James, G., Witten, D., Hastie, T., Tibshirani, R. (2021). An Introduction to Statistical Learning (2nd ed.), ch. 8 (lab: decision trees). Springer.