refactor: add more documentation strings
This commit is contained in:
@@ -1,6 +1,13 @@
|
||||
"""
|
||||
Compare Decision Tree, Naive Bayes (supervised) and K-Means (unsupervised)
|
||||
on the Iris and Digits datasets using the same metrics.
|
||||
Compare different ML algorithms against iris & digits dataset.
|
||||
|
||||
Supervised:
|
||||
- Decision Tree
|
||||
- Naive Bayes (GaussianNB)
|
||||
Unsuprvised:
|
||||
- K-Means
|
||||
|
||||
Use metrics (classification_report) to try to evaluate the algorithms.
|
||||
"""
|
||||
|
||||
from sklearn import datasets
|
||||
@@ -9,23 +16,33 @@ from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn.naive_bayes import GaussianNB
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn.metrics import accuracy_score, classification_report, adjusted_rand_score
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def kmeans_accuracy(X, y, n_classes):
|
||||
def kmeans_true_labels(X, y, n_classes):
|
||||
"""
|
||||
Map each cluster to its majority true label, then compute accuracy.
|
||||
This function handles the cluster→label mapping via majority vote.
|
||||
Since k-means is unsupervised it comes up with its own classes, that
|
||||
do not reflect the classes from the gold standard.
|
||||
This function maps each cluster to its majority true label, to help compute accuracy.
|
||||
|
||||
Each cluster gets assigned the most common true label in it.
|
||||
"""
|
||||
# train classifier and do a fit, set the rng seed to a fixed value
|
||||
kmeans = KMeans(n_clusters=n_classes, init="k-means++", n_init=10, random_state=42)
|
||||
kmeans.fit(X)
|
||||
|
||||
# creates an empty array the same shape as the cluster assignments
|
||||
labels = np.zeros_like(kmeans.labels_)
|
||||
# for each cluster i ...
|
||||
for i in range(n_classes):
|
||||
# boolean array, true for every sample that K-Means put in cluster i
|
||||
mask = kmeans.labels_ == i
|
||||
if mask.sum() > 0:
|
||||
labels[mask] = np.bincount(y[mask]).argmax()
|
||||
# just skip empty clusters
|
||||
if mask.sum() == 0:
|
||||
continue
|
||||
# set true label as the label most prominent in this cluster
|
||||
# e.g: if cluster 2 contains a mix of digits 7, 7, 7, 3, 7 this gives you 7
|
||||
labels[mask] = np.bincount(y[mask]).argmax()
|
||||
return labels, kmeans
|
||||
|
||||
|
||||
@@ -38,25 +55,26 @@ def evaluate(name, dataset, target_names):
|
||||
print(f" {name}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
# split the dataset into train and test data and use a fixed rng seed
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
dataset.data, dataset.target, test_size=0.3, random_state=42
|
||||
)
|
||||
|
||||
# supervised
|
||||
for clf_name, clf in [
|
||||
# test the supervised leanring algorithms
|
||||
for classifier_name, classifier in [
|
||||
("Decision Tree", DecisionTreeClassifier(random_state=42)),
|
||||
("Naive Bayes", GaussianNB()),
|
||||
]:
|
||||
clf.fit(X_train, y_train)
|
||||
y_pred = clf.predict(X_test)
|
||||
print(f"\n--- {clf_name} ---")
|
||||
classifier.fit(X_train, y_train)
|
||||
y_pred = classifier.predict(X_test)
|
||||
print(f"\n--- {classifier_name} ---")
|
||||
print(f"Accuracy: {accuracy_score(y_test, y_pred):.3f}")
|
||||
print(f"Adj. Rand: {adjusted_rand_score(y_test, y_pred):.3f}")
|
||||
print(classification_report(y_test, y_pred, target_names=target_names))
|
||||
|
||||
# unsupervised (evaluated on full dataset)
|
||||
# test the unsupervised learning algorithm (evaluated on full dataset)
|
||||
n_classes = len(target_names)
|
||||
mapped_labels, kmeans = kmeans_accuracy(dataset.data, dataset.target, n_classes)
|
||||
mapped_labels, kmeans = kmeans_true_labels(dataset.data, dataset.target, n_classes)
|
||||
print(f"\n--- K-Means (mapped) ---")
|
||||
print(f"Accuracy: {accuracy_score(dataset.target, mapped_labels):.3f}")
|
||||
print(f"Adj. Rand: {adjusted_rand_score(dataset.target, kmeans.labels_):.3f}")
|
||||
@@ -64,9 +82,10 @@ def evaluate(name, dataset, target_names):
|
||||
classification_report(dataset.target, mapped_labels, target_names=target_names)
|
||||
)
|
||||
|
||||
|
||||
# evaluate all ML algorithms on iris
|
||||
iris = datasets.load_iris()
|
||||
evaluate("IRIS", iris, list(iris.target_names))
|
||||
|
||||
# evaluate all ML algorithms on digits
|
||||
digits = datasets.load_digits()
|
||||
evaluate("DIGITS", digits, [str(n) for n in digits.target_names])
|
||||
|
||||
Reference in New Issue
Block a user