Files
cas-pml/ML/aufgaben/kmeans/kmeans_digits.py
T

61 lines
2.2 KiB
Python

"""
Use k-means to try to match handwritten digits and see if changing the parameters
results in better recognition.
- This is an example of an unsupervised ML algorithm
- it has no labels on the training data
- it discovers the structure on its own
- thus the cluster numbers are arbitrary and do not correspond to the class labels
"""
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.cluster import KMeans
from sklearn import metrics
from sklearn.decomposition import PCA
# get the digits dataset
digits = datasets.load_digits()
# 100 samples pro ziffer
# 64 pixel pro zahl
print(digits.data.shape)
# ausprobieren verschiedener parameter
# kmeans = KMeans(n_clusters=10, init="random", n_init=1)
# kmeans = KMeans(n_clusters=10)
kmeans = KMeans(n_clusters=10, init="k-means++", n_init=10)
kmeans.fit(digits.data)
print(list(zip(digits.target, kmeans.labels_)))
print(metrics.homogeneity_score(digits.target, kmeans.labels_))
print(metrics.completeness_score(digits.target, kmeans.labels_))
print(metrics.adjusted_rand_score(digits.target, kmeans.labels_))
print(metrics.silhouette_score(digits.data, kmeans.labels_))
pca = PCA(n_components=2)
X2d = pca.fit_transform(digits.data)
centroids2d = pca.transform(kmeans.cluster_centers_)
plt.figure(figsize=(10, 8))
scatter = plt.scatter(X2d[:, 0], X2d[:, 1], c=kmeans.labels_, cmap='tab10', s=10, alpha=0.6)
plt.scatter(centroids2d[:, 0], centroids2d[:, 1], c='red', marker='X', s=200, edgecolors='black')
plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.1%} var)')
plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.1%} var)')
plt.title('K-Means on Digits (PCA projection)')
plt.colorbar(scatter, label='Cluster')
plt.savefig('kmeans_digits.png', dpi=150, bbox_inches='tight')
fig, axes = plt.subplots(2, 5, figsize=(10, 4))
for i, ax in enumerate(axes.flat):
ax.imshow(kmeans.cluster_centers_[i].reshape(8, 8), cmap='gray_r')
ax.set_title(f'Cluster {i}')
ax.axis('off')
fig.savefig('kmeans_digits_centroids.png', dpi=150, bbox_inches='tight')
"""
Takaway:
- Hier ist k-means nicht der richtige algorithmus, weil die Daten nicht schön kugelförmig verteilt sind und sich nicht gut clustern lassen.
"""