61 lines
2.2 KiB
Python
61 lines
2.2 KiB
Python
"""
|
|
Use k-means to try to match handwritten digits and see if changing the parameters
|
|
results in better recognition.
|
|
|
|
- This is an example of an unsupervised ML algorithm
|
|
- it has no labels on the training data
|
|
- it discovers the structure on its own
|
|
- thus the cluster numbers are arbitrary and do not correspond to the class labels
|
|
"""
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
from sklearn import datasets
|
|
from sklearn.cluster import KMeans
|
|
from sklearn import metrics
|
|
from sklearn.decomposition import PCA
|
|
|
|
# get the digits dataset
|
|
digits = datasets.load_digits()
|
|
|
|
# 100 samples pro ziffer
|
|
# 64 pixel pro zahl
|
|
print(digits.data.shape)
|
|
|
|
# ausprobieren verschiedener parameter
|
|
# kmeans = KMeans(n_clusters=10, init="random", n_init=1)
|
|
# kmeans = KMeans(n_clusters=10)
|
|
kmeans = KMeans(n_clusters=10, init="k-means++", n_init=10)
|
|
kmeans.fit(digits.data)
|
|
|
|
print(list(zip(digits.target, kmeans.labels_)))
|
|
print(metrics.homogeneity_score(digits.target, kmeans.labels_))
|
|
print(metrics.completeness_score(digits.target, kmeans.labels_))
|
|
print(metrics.adjusted_rand_score(digits.target, kmeans.labels_))
|
|
print(metrics.silhouette_score(digits.data, kmeans.labels_))
|
|
|
|
pca = PCA(n_components=2)
|
|
X2d = pca.fit_transform(digits.data)
|
|
centroids2d = pca.transform(kmeans.cluster_centers_)
|
|
|
|
plt.figure(figsize=(10, 8))
|
|
scatter = plt.scatter(X2d[:, 0], X2d[:, 1], c=kmeans.labels_, cmap='tab10', s=10, alpha=0.6)
|
|
plt.scatter(centroids2d[:, 0], centroids2d[:, 1], c='red', marker='X', s=200, edgecolors='black')
|
|
plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.1%} var)')
|
|
plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.1%} var)')
|
|
plt.title('K-Means on Digits (PCA projection)')
|
|
plt.colorbar(scatter, label='Cluster')
|
|
plt.savefig('kmeans_digits.png', dpi=150, bbox_inches='tight')
|
|
|
|
fig, axes = plt.subplots(2, 5, figsize=(10, 4))
|
|
for i, ax in enumerate(axes.flat):
|
|
ax.imshow(kmeans.cluster_centers_[i].reshape(8, 8), cmap='gray_r')
|
|
ax.set_title(f'Cluster {i}')
|
|
ax.axis('off')
|
|
fig.savefig('kmeans_digits_centroids.png', dpi=150, bbox_inches='tight')
|
|
|
|
"""
|
|
Takaway:
|
|
- Hier ist k-means nicht der richtige algorithmus, weil die Daten nicht schön kugelförmig verteilt sind und sich nicht gut clustern lassen.
|
|
"""
|