refactor: update kmeans examples with better plots
This commit is contained in:
@@ -30,7 +30,7 @@ kmeans = KMeans(n_clusters=3)
|
||||
# fit auf daten
|
||||
kmeans.fit(iris.data)
|
||||
|
||||
# gegenüberstellung gold standard vs prediction
|
||||
# Gegenüberstellung gold standard vs prediction
|
||||
print("gold standard vs. prediction")
|
||||
for target_label, predicted_label in zip(iris.target, kmeans.labels_):
|
||||
print(f"{target_label} -> {predicted_label}")
|
||||
@@ -41,12 +41,13 @@ print(metrics.completeness_score(iris.target, kmeans.labels_))
|
||||
print(metrics.adjusted_rand_score(iris.target, kmeans.labels_))
|
||||
print(metrics.silhouette_score(iris.data, kmeans.labels_))
|
||||
|
||||
# plot vorbereiten
|
||||
# plot vorbereiten (Idee von kmeans digits)
|
||||
# Transformation 4D nach 2D via Projektionsfit
|
||||
pca = PCA(n_components=2)
|
||||
X2d = pca.fit_transform(iris.data)
|
||||
centroids2d = pca.transform(kmeans.cluster_centers_)
|
||||
|
||||
# plot
|
||||
# plotten der Punktewolke und einzeichnen der Centroiden
|
||||
plt.scatter(X2d[:, 0], X2d[:, 1], c=kmeans.labels_, cmap="viridis", s=30, alpha=0.7)
|
||||
plt.scatter(
|
||||
centroids2d[:, 0], centroids2d[:, 1], c="red", marker="X", s=200, edgecolors="black"
|
||||
|
||||
Reference in New Issue
Block a user