refactor: update kmeans examples with better plots

This commit is contained in:
2026-05-01 11:04:31 +02:00
parent 959e53b7b3
commit d5258a6edf
5 changed files with 20 additions and 10 deletions
+16 -7
View File
@@ -3,16 +3,20 @@ Use k-means to try to match handwritten digits and see if changing the parameter
results in better recognition.
- This is an example of an unsupervised ML algorithm
- it has no labels on the training data
- it has no labels in the training data
- it discovers the structure on its own
- thus the cluster numbers are arbitrary and do not correspond to the class labels
Takaway:
- Hier ist k-means nicht der beste algorithmus, weil die Daten nicht in schön kugelförmig
verteilten Clustern angeordnet sind und k-means Mühe hat die Centroiden sauber zu bestimmen.
"""
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.cluster import KMeans
from sklearn import metrics
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
# get the digits dataset
@@ -34,10 +38,18 @@ print(metrics.completeness_score(digits.target, kmeans.labels_))
print(metrics.adjusted_rand_score(digits.target, kmeans.labels_))
print(metrics.silhouette_score(digits.data, kmeans.labels_))
# Wikipedia:
# PCA (Principal Component Analysis) finds the directions in your data with
# the most variance and projects everything onto those axes.
#
# Irgendwas mit Eigenvektoren und Kovarianz Matrix, TODO: Anschauen
# Transformiert 64 dimensionalen Vektor möglichst gut in eine 2D Projektion
pca = PCA(n_components=2)
X2d = pca.fit_transform(digits.data)
centroids2d = pca.transform(kmeans.cluster_centers_)
# Punktewolke plotten und Centroiden einzeichnen, tab10 gibt 10 versch. Farben für die Legende
# Hier sieht man grosse überlappung zwischen den Clustern -> ein Hinweis, das K-Means nicht optimal ist?
plt.figure(figsize=(10, 8))
scatter = plt.scatter(X2d[:, 0], X2d[:, 1], c=kmeans.labels_, cmap='tab10', s=10, alpha=0.6)
plt.scatter(centroids2d[:, 0], centroids2d[:, 1], c='red', marker='X', s=200, edgecolors='black')
@@ -47,14 +59,11 @@ plt.title('K-Means on Digits (PCA projection)')
plt.colorbar(scatter, label='Cluster')
plt.savefig('kmeans_digits.png', dpi=150, bbox_inches='tight')
# Centroiden als 8x8 Bild darstellen, indem man das "durchschnittliche zeichen" um das Zentrum plottet
# Dieser plot zeigt was die K-Means "gelernt" hat, man sieht die Zuweisung von Cluster zu Zahl sofort
fig, axes = plt.subplots(2, 5, figsize=(10, 4))
for i, ax in enumerate(axes.flat):
ax.imshow(kmeans.cluster_centers_[i].reshape(8, 8), cmap='gray_r')
ax.set_title(f'Cluster {i}')
ax.axis('off')
fig.savefig('kmeans_digits_centroids.png', dpi=150, bbox_inches='tight')
"""
Takaway:
- Hier ist k-means nicht der richtige algorithmus, weil die Daten nicht schön kugelförmig verteilt sind und sich nicht gut clustern lassen.
"""