""" Use the random forest classifier to classify the digits data set. - This is an example of a supervised ML algorithm - it has labels on the training data - you tell the model: this is class X during training """ import matplotlib.pyplot as plt from sklearn import datasets from sklearn import tree from sklearn.model_selection import train_test_split from sklearn.ensemble import RandomForestClassifier digits = datasets.load_digits() print(digits.data.shape) # split into training and test data x_train, x_test, y_train, y_test = train_test_split( digits.data, digits.target, test_size=0.2, random_state=0 ) # use a random forest classifier classifier = RandomForestClassifier(n_estimators=100, random_state=42) # train on the split data classifier.fit(x_train, y_train) # test the model and print it's accurecy score = classifier.score(x_test, y_test) print(score) # get the first tree and turn it into an image fig, ax = plt.subplots(figsize=(20, 10)) tree.plot_tree( classifier.estimators_[0], feature_names=digits.feature_names, class_names=[str(n) for n in digits.target_names], filled=True, rounded=True, ax=ax, ) fig.savefig("randomforest_digits_tree_0.png", dpi=150, bbox_inches="tight") # for digits, plot feature importance as 8x8 heatmap importances = classifier.feature_importances_ plt.figure() plt.imshow(importances.reshape(8, 8), cmap='hot') plt.title('Random Forest: Feature Importance') plt.colorbar() plt.savefig('randomforest_digits_feature_importance.png', dpi=150, bbox_inches='tight')