refactor: make it easier to add other sl algorithms
This commit is contained in:
@@ -20,6 +20,7 @@ from sklearn.metrics import accuracy_score, classification_report, adjusted_rand
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def kmeans_true_labels(X, y, n_classes):
|
||||
"""
|
||||
Since k-means is unsupervised it comes up with its own classes, that
|
||||
@@ -61,12 +62,15 @@ def evaluate(name, dataset, target_names):
|
||||
dataset.data, dataset.target, test_size=0.3, random_state=42
|
||||
)
|
||||
|
||||
# test the supervised leanring algorithms
|
||||
for classifier_name, classifier in [
|
||||
# all the supervised leanring algorithms to test
|
||||
algorithms = [
|
||||
("Decision Tree", DecisionTreeClassifier(random_state=42)),
|
||||
("Naive Bayes", GaussianNB()),
|
||||
("Random Forest", RandomForestClassifier(n_estimators=100, random_state=42)),
|
||||
]:
|
||||
]
|
||||
|
||||
# do the test on the supervised learning algorithms
|
||||
for classifier_name, classifier in algorithms:
|
||||
classifier.fit(X_train, y_train)
|
||||
y_pred = classifier.predict(X_test)
|
||||
print(f"\n--- {classifier_name} ---")
|
||||
@@ -74,15 +78,14 @@ def evaluate(name, dataset, target_names):
|
||||
print(f"Adj. Rand: {adjusted_rand_score(y_test, y_pred):.3f}")
|
||||
print(classification_report(y_test, y_pred, target_names=target_names))
|
||||
|
||||
# test the unsupervised learning algorithm (evaluated on full dataset)
|
||||
# tdo the test on the unsupervised learning algorithm (evaluated on full dataset)
|
||||
n_classes = len(target_names)
|
||||
mapped_labels, kmeans = kmeans_true_labels(dataset.data, dataset.target, n_classes)
|
||||
print(f"\n--- K-Means (mapped) ---")
|
||||
print(f"Accuracy: {accuracy_score(dataset.target, mapped_labels):.3f}")
|
||||
print(f"Adj. Rand: {adjusted_rand_score(dataset.target, kmeans.labels_):.3f}")
|
||||
print(
|
||||
classification_report(dataset.target, mapped_labels, target_names=target_names)
|
||||
)
|
||||
print(classification_report(dataset.target, mapped_labels, target_names=target_names))
|
||||
|
||||
|
||||
# evaluate all ML algorithms on iris
|
||||
iris = datasets.load_iris()
|
||||
|
||||
Reference in New Issue
Block a user