diff --git a/ML/aufgaben/decisiontree/decisiontree_digits.png b/ML/aufgaben/decisiontree/decisiontree_digits.png new file mode 100644 index 0000000..94ff5f8 Binary files /dev/null and b/ML/aufgaben/decisiontree/decisiontree_digits.png differ diff --git a/ML/aufgaben/decisiontree/decisiontree_digits.py b/ML/aufgaben/decisiontree/decisiontree_digits.py new file mode 100644 index 0000000..37da4cd --- /dev/null +++ b/ML/aufgaben/decisiontree/decisiontree_digits.py @@ -0,0 +1,44 @@ +""" +Use a decisiontree classifier to predict handwritten digits + +- This is an example of a supervised ML algorithm + - it has labels on the training data + - you tell the model: this is class X during training +""" + +import matplotlib.pyplot as plt + +from sklearn import datasets +from sklearn.tree import DecisionTreeClassifier +from sklearn import tree + +# load the digits dataset +digits = datasets.load_digits() +# get a feel for it +print(digits.data.size) +print(digits.target.size) +print(digits.feature_names) +print(digits.target_names) + +# use a decision tree classifier +# set max_depth to 5 otherwise the tree will get huge +classifier = DecisionTreeClassifier(max_depth=5) +# use all but the last sample for training +classifier.fit(digits.data[:-1], digits.target[:-1]) + +# use the model to predict the last data sample +last_sample = digits.data[-1:] +last_target = digits.target[-1:] +print(f"predicted: {classifier.predict(last_sample)} vs real: {last_target}") + +# print the tree for visual inspection +fig, ax = plt.subplots(figsize=(20, 10)) +tree.plot_tree( + classifier, + feature_names=digits.feature_names, + class_names=[str(i) for i in digits.target_names], + filled=True, + rounded=True, + ax=ax, +) +fig.savefig("decisiontree_digits.png", dpi=500, bbox_inches="tight") diff --git a/ML/aufgaben/decisiontree/decisiontree_iris.py b/ML/aufgaben/decisiontree/decisiontree_iris.py index d361715..95118a8 100644 --- a/ML/aufgaben/decisiontree/decisiontree_iris.py +++ b/ML/aufgaben/decisiontree/decisiontree_iris.py @@ -19,7 +19,7 @@ print(iris.target.size) print(iris.feature_names) print(iris.target_names) -# use a decition tree classifier +# use a decision tree classifier classifier = DecisionTreeClassifier() # use all but the last sample for training classifier.fit(iris.data[:-1], iris.target[:-1])