Machine Learning (Chapter 32): Stopping Criteria & Pruning
Machine Learning (Chapter 32): Stopping Criteria & Pruning
Stopping Criteria
In machine learning, particularly in iterative algorithms like decision trees, neural networks, and gradient-based optimization methods, stopping criteria are crucial for ensuring that the algorithm converges efficiently without overfitting or underfitting. Stopping criteria determine when to halt the training process based on certain conditions.
1. Stopping Criteria for Iterative Algorithms
For iterative algorithms, stopping criteria can include:
Maximum Iterations: The algorithm stops after a predefined number of iterations.
where is the current iteration and is the maximum number of iterations.
Convergence of Loss Function: The algorithm stops if the change in the loss function between iterations is below a threshold .
where is the loss function value at iteration .
Early Stopping: Uses a validation set to monitor performance and stop training when performance starts to degrade, indicating potential overfitting.
Example in Python (Early Stopping)
python:
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score
# Load data
data = load_iris()
X = data.data
y = data.target
# Split data
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
# Initialize classifier
clf = MLPClassifier(hidden_layer_sizes=(10,), max_iter=1, warm_start=True)
best_val_acc = 0
early_stopping_rounds = 5
no_improvement = 0
for epoch in range(100):
clf.fit(X_train, y_train)
val_preds = clf.predict(X_val)
val_acc = accuracy_score(y_val, val_preds)
if val_acc > best_val_acc:
best_val_acc = val_acc
no_improvement = 0
else:
no_improvement += 1
if no_improvement >= early_stopping_rounds:
print(f"Early stopping at epoch {epoch}")
break
Pruning
Pruning is a technique used to reduce the size of decision trees by removing nodes that provide little predictive power. This helps to prevent overfitting and improve model generalization.
2. Pruning Techniques
Cost Complexity Pruning (CCP): Also known as weakest link pruning, it involves removing branches that contribute least to the tree’s accuracy.
The cost complexity measure is defined as:
where is the classification error and is the number of nodes in the tree.
Trees are pruned by selecting values that balance error and complexity.
Reduced Error Pruning: Evaluates each node by removing it and checking if the accuracy on a validation set improves.
Example in Python (Cost Complexity Pruning)
python:
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
# Load data
data = load_iris()
X = data.data
y = data.target
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Train a decision tree
clf = DecisionTreeClassifier()
clf.fit(X_train, y_train)
# Prune the tree
path = clf.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities
# Plot the tree with different alpha values
for ccp_alpha in ccp_alphas:
pruned_clf = DecisionTreeClassifier(ccp_alpha=ccp_alpha)
pruned_clf.fit(X_train, y_train)
y_pred = pruned_clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Alpha: {ccp_alpha:.4f}, Accuracy: {accuracy:.4f}")
plt.figure(figsize=(10, 6))
plot_tree(pruned_clf, filled=True, feature_names=data.feature_names, class_names=data.target_names)
plt.title(f"Decision Tree with alpha = {ccp_alpha:.4f}")
plt.show()
Conclusion
Stopping criteria and pruning are vital techniques in machine learning to improve model efficiency and generalization. Stopping criteria ensure that algorithms converge appropriately, while pruning techniques help manage model complexity and prevent overfitting. By using these methods, you can enhance the performance and robustness of your machine learning models.
Comments
Post a Comment