""" Use the naive bayes classifier to classify the digits data set. - This is an example of a supervised ML algorithm - it has labels on the training data - you tell the model: this is class X during training """ import matplotlib.pyplot as plt import numpy as np from sklearn import datasets from sklearn.model_selection import train_test_split from sklearn.naive_bayes import GaussianNB digits = datasets.load_digits() print(digits.data.shape) # split into training and test data x_train, x_test, y_train, y_test = train_test_split( digits.data, digits.target, test_size=0.2, random_state=0 ) # use a gaussian NB classifier classifier = GaussianNB() # train on the split data classifier.fit(x_train, y_train) # test the model and print it's accurecy score = classifier.score(x_test, y_test) print(score) # visualizing the learned means as 8x8 images fig, axes = plt.subplots(2, 5, figsize=(12, 5)) for i, ax in enumerate(axes.flat): ax.imshow(classifier.theta_[i].reshape(8, 8), cmap='gray_r') ax.set_title(f'Class {i}') ax.axis('off') fig.suptitle('NB: Mean pixel intensity per class') fig.savefig('naivebayes_digits_means.png', dpi=150, bbox_inches='tight') # The variance plot shows where pixels vary most within a class: # - high variance (bright) means that pixel isn't reliable for classification # - low variance (dark) means it's consistent. fig, axes = plt.subplots(2, 5, figsize=(12, 5)) for i, ax in enumerate(axes.flat): ax.imshow(classifier.var_[i].reshape(8, 8), cmap='hot') ax.set_title(f'Class {i}') ax.axis('off') fig.suptitle('NB: Pixel variance per class') fig.savefig('naivebayes_digits_variance.png', dpi=150, bbox_inches='tight') # plot the variance difference between two commonly confused digits like 3 and 8 to see # on which pixels nb relies to tell them apart plt.figure() diff = abs(classifier.var_[3] - classifier.var_[8]) plt.imshow(diff.reshape(8, 8), cmap='hot') plt.title('Variance difference: 3 vs 8') plt.colorbar() plt.savefig('naivebayes_3v8_variance.png', dpi=150, bbox_inches='tight')