59 lines
2.1 KiB
Python
59 lines
2.1 KiB
Python
"""
|
|
Use the naive bayes classifier to classify the digits data set.
|
|
|
|
- This is an example of a supervised ML algorithm
|
|
- it has labels on the training data
|
|
- you tell the model: this is class X during training
|
|
"""
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
|
|
from sklearn import datasets
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.naive_bayes import GaussianNB
|
|
|
|
digits = datasets.load_digits()
|
|
print(digits.data.shape)
|
|
|
|
# split into training and test data
|
|
x_train, x_test, y_train, y_test = train_test_split(
|
|
digits.data, digits.target, test_size=0.2, random_state=0
|
|
)
|
|
|
|
# use a gaussian NB classifier
|
|
classifier = GaussianNB()
|
|
# train on the split data
|
|
classifier.fit(x_train, y_train)
|
|
# test the model and print it's accurecy
|
|
score = classifier.score(x_test, y_test)
|
|
print(score)
|
|
|
|
# visualizing the learned means as 8x8 images
|
|
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
|
|
for i, ax in enumerate(axes.flat):
|
|
ax.imshow(classifier.theta_[i].reshape(8, 8), cmap='gray_r')
|
|
ax.set_title(f'Class {i}')
|
|
ax.axis('off')
|
|
fig.suptitle('NB: Mean pixel intensity per class')
|
|
fig.savefig('naivebayes_digits_means.png', dpi=150, bbox_inches='tight')
|
|
|
|
# The variance plot shows where pixels vary most within a class:
|
|
# - high variance (bright) means that pixel isn't reliable for classification
|
|
# - low variance (dark) means it's consistent.
|
|
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
|
|
for i, ax in enumerate(axes.flat):
|
|
ax.imshow(classifier.var_[i].reshape(8, 8), cmap='hot')
|
|
ax.set_title(f'Class {i}')
|
|
ax.axis('off')
|
|
fig.suptitle('NB: Pixel variance per class')
|
|
fig.savefig('naivebayes_digits_variance.png', dpi=150, bbox_inches='tight')
|
|
|
|
# plot the variance difference between two commonly confused digits like 3 and 8 to see
|
|
# on which pixels nb relies to tell them apart
|
|
plt.figure()
|
|
diff = abs(classifier.var_[3] - classifier.var_[8])
|
|
plt.imshow(diff.reshape(8, 8), cmap='hot')
|
|
plt.title('Variance difference: 3 vs 8')
|
|
plt.colorbar()
|
|
plt.savefig('naivebayes_3v8_variance.png', dpi=150, bbox_inches='tight')
|