51 lines
1.5 KiB
Python
51 lines
1.5 KiB
Python
"""
|
|
Use the random forest classifier to classify the digits data set.
|
|
|
|
- This is an example of a supervised ML algorithm
|
|
- it has labels on the training data
|
|
- you tell the model: this is class X during training
|
|
"""
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
from sklearn import datasets
|
|
from sklearn import tree
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
|
|
digits = datasets.load_digits()
|
|
print(digits.data.shape)
|
|
|
|
# split into training and test data
|
|
x_train, x_test, y_train, y_test = train_test_split(
|
|
digits.data, digits.target, test_size=0.2, random_state=0
|
|
)
|
|
|
|
# use a random forest classifier
|
|
classifier = RandomForestClassifier(n_estimators=100, random_state=42)
|
|
# train on the split data
|
|
classifier.fit(x_train, y_train)
|
|
# test the model and print it's accurecy
|
|
score = classifier.score(x_test, y_test)
|
|
print(score)
|
|
|
|
# get the first tree and turn it into an image
|
|
fig, ax = plt.subplots(figsize=(20, 10))
|
|
tree.plot_tree(
|
|
classifier.estimators_[0],
|
|
feature_names=digits.feature_names,
|
|
class_names=[str(n) for n in digits.target_names],
|
|
filled=True,
|
|
rounded=True,
|
|
ax=ax,
|
|
)
|
|
fig.savefig("randomforest_digits_tree_0.png", dpi=150, bbox_inches="tight")
|
|
|
|
# for digits, plot feature importance as 8x8 heatmap
|
|
importances = classifier.feature_importances_
|
|
plt.figure()
|
|
plt.imshow(importances.reshape(8, 8), cmap='hot')
|
|
plt.title('Random Forest: Feature Importance')
|
|
plt.colorbar()
|
|
plt.savefig('randomforest_digits_feature_importance.png', dpi=150, bbox_inches='tight')
|