feature: add a comparison between all algorithms for each dataset to see which performs best
This commit is contained in:
@@ -0,0 +1,113 @@
|
|||||||
|
|
||||||
|
============================================================
|
||||||
|
IRIS
|
||||||
|
============================================================
|
||||||
|
|
||||||
|
--- Decision Tree ---
|
||||||
|
Accuracy: 1.000
|
||||||
|
Adj. Rand: 1.000
|
||||||
|
precision recall f1-score support
|
||||||
|
|
||||||
|
setosa 1.00 1.00 1.00 19
|
||||||
|
versicolor 1.00 1.00 1.00 13
|
||||||
|
virginica 1.00 1.00 1.00 13
|
||||||
|
|
||||||
|
accuracy 1.00 45
|
||||||
|
macro avg 1.00 1.00 1.00 45
|
||||||
|
weighted avg 1.00 1.00 1.00 45
|
||||||
|
|
||||||
|
|
||||||
|
--- Naive Bayes ---
|
||||||
|
Accuracy: 0.978
|
||||||
|
Adj. Rand: 0.943
|
||||||
|
precision recall f1-score support
|
||||||
|
|
||||||
|
setosa 1.00 1.00 1.00 19
|
||||||
|
versicolor 1.00 0.92 0.96 13
|
||||||
|
virginica 0.93 1.00 0.96 13
|
||||||
|
|
||||||
|
accuracy 0.98 45
|
||||||
|
macro avg 0.98 0.97 0.97 45
|
||||||
|
weighted avg 0.98 0.98 0.98 45
|
||||||
|
|
||||||
|
|
||||||
|
--- K-Means (mapped) ---
|
||||||
|
Accuracy: 0.893
|
||||||
|
Adj. Rand: 0.730
|
||||||
|
precision recall f1-score support
|
||||||
|
|
||||||
|
setosa 1.00 1.00 1.00 50
|
||||||
|
versicolor 0.77 0.96 0.86 50
|
||||||
|
virginica 0.95 0.72 0.82 50
|
||||||
|
|
||||||
|
accuracy 0.89 150
|
||||||
|
macro avg 0.91 0.89 0.89 150
|
||||||
|
weighted avg 0.91 0.89 0.89 150
|
||||||
|
|
||||||
|
|
||||||
|
============================================================
|
||||||
|
DIGITS
|
||||||
|
============================================================
|
||||||
|
|
||||||
|
--- Decision Tree ---
|
||||||
|
Accuracy: 0.843
|
||||||
|
Adj. Rand: 0.685
|
||||||
|
precision recall f1-score support
|
||||||
|
|
||||||
|
0 0.92 0.91 0.91 53
|
||||||
|
1 0.74 0.78 0.76 50
|
||||||
|
2 0.83 0.74 0.79 47
|
||||||
|
3 0.78 0.85 0.81 54
|
||||||
|
4 0.81 0.85 0.83 60
|
||||||
|
5 0.92 0.86 0.89 66
|
||||||
|
6 0.93 0.94 0.93 53
|
||||||
|
7 0.85 0.84 0.84 55
|
||||||
|
8 0.89 0.77 0.82 43
|
||||||
|
9 0.78 0.85 0.81 59
|
||||||
|
|
||||||
|
accuracy 0.84 540
|
||||||
|
macro avg 0.85 0.84 0.84 540
|
||||||
|
weighted avg 0.85 0.84 0.84 540
|
||||||
|
|
||||||
|
|
||||||
|
--- Naive Bayes ---
|
||||||
|
Accuracy: 0.852
|
||||||
|
Adj. Rand: 0.710
|
||||||
|
precision recall f1-score support
|
||||||
|
|
||||||
|
0 1.00 0.98 0.99 53
|
||||||
|
1 0.86 0.74 0.80 50
|
||||||
|
2 0.86 0.66 0.75 47
|
||||||
|
3 0.95 0.76 0.85 54
|
||||||
|
4 0.98 0.85 0.91 60
|
||||||
|
5 0.94 0.94 0.94 66
|
||||||
|
6 0.89 0.96 0.93 53
|
||||||
|
7 0.72 0.98 0.83 55
|
||||||
|
8 0.57 0.91 0.70 43
|
||||||
|
9 0.89 0.71 0.79 59
|
||||||
|
|
||||||
|
accuracy 0.85 540
|
||||||
|
macro avg 0.87 0.85 0.85 540
|
||||||
|
weighted avg 0.88 0.85 0.85 540
|
||||||
|
|
||||||
|
|
||||||
|
--- K-Means (mapped) ---
|
||||||
|
Accuracy: 0.794
|
||||||
|
Adj. Rand: 0.667
|
||||||
|
precision recall f1-score support
|
||||||
|
|
||||||
|
0 0.99 0.99 0.99 178
|
||||||
|
1 0.62 0.30 0.41 182
|
||||||
|
2 0.84 0.84 0.84 177
|
||||||
|
3 0.86 0.85 0.85 183
|
||||||
|
4 0.99 0.92 0.95 181
|
||||||
|
5 0.87 0.75 0.81 182
|
||||||
|
6 0.97 0.98 0.98 181
|
||||||
|
7 0.86 0.95 0.90 179
|
||||||
|
8 0.45 0.59 0.51 174
|
||||||
|
9 0.58 0.77 0.66 180
|
||||||
|
|
||||||
|
accuracy 0.79 1797
|
||||||
|
macro avg 0.80 0.79 0.79 1797
|
||||||
|
weighted avg 0.80 0.79 0.79 1797
|
||||||
|
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
"""
|
||||||
|
Compare Decision Tree, Naive Bayes (supervised) and K-Means (unsupervised)
|
||||||
|
on the Iris and Digits datasets using the same metrics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sklearn import datasets
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from sklearn.tree import DecisionTreeClassifier
|
||||||
|
from sklearn.naive_bayes import GaussianNB
|
||||||
|
from sklearn.cluster import KMeans
|
||||||
|
from sklearn.metrics import accuracy_score, classification_report, adjusted_rand_score
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def kmeans_accuracy(X, y, n_classes):
|
||||||
|
"""Map each cluster to its majority true label, then compute accuracy."""
|
||||||
|
kmeans = KMeans(n_clusters=n_classes, init="k-means++", n_init=10, random_state=42)
|
||||||
|
kmeans.fit(X)
|
||||||
|
labels = np.zeros_like(kmeans.labels_)
|
||||||
|
for i in range(n_classes):
|
||||||
|
mask = kmeans.labels_ == i
|
||||||
|
if mask.sum() > 0:
|
||||||
|
labels[mask] = np.bincount(y[mask]).argmax()
|
||||||
|
return labels, kmeans
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(name, dataset, target_names):
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f" {name}")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(
|
||||||
|
dataset.data, dataset.target, test_size=0.3, random_state=42
|
||||||
|
)
|
||||||
|
|
||||||
|
# supervised
|
||||||
|
for clf_name, clf in [("Decision Tree", DecisionTreeClassifier(random_state=42)),
|
||||||
|
("Naive Bayes", GaussianNB())]:
|
||||||
|
clf.fit(X_train, y_train)
|
||||||
|
y_pred = clf.predict(X_test)
|
||||||
|
print(f"\n--- {clf_name} ---")
|
||||||
|
print(f"Accuracy: {accuracy_score(y_test, y_pred):.3f}")
|
||||||
|
print(f"Adj. Rand: {adjusted_rand_score(y_test, y_pred):.3f}")
|
||||||
|
print(classification_report(y_test, y_pred, target_names=target_names))
|
||||||
|
|
||||||
|
# unsupervised (evaluated on full dataset)
|
||||||
|
n_classes = len(target_names)
|
||||||
|
mapped_labels, kmeans = kmeans_accuracy(dataset.data, dataset.target, n_classes)
|
||||||
|
print(f"\n--- K-Means (mapped) ---")
|
||||||
|
print(f"Accuracy: {accuracy_score(dataset.target, mapped_labels):.3f}")
|
||||||
|
print(f"Adj. Rand: {adjusted_rand_score(dataset.target, kmeans.labels_):.3f}")
|
||||||
|
print(classification_report(dataset.target, mapped_labels, target_names=target_names))
|
||||||
|
|
||||||
|
|
||||||
|
iris = datasets.load_iris()
|
||||||
|
evaluate("IRIS", iris, list(iris.target_names))
|
||||||
|
|
||||||
|
digits = datasets.load_digits()
|
||||||
|
evaluate("DIGITS", digits, [str(n) for n in digits.target_names])
|
||||||
Reference in New Issue
Block a user