feature(decisiontree): add digits example with decision tree
This commit is contained in:
Binary file not shown.
|
After Width: | Height: | Size: 834 KiB |
@@ -0,0 +1,44 @@
|
|||||||
|
"""
|
||||||
|
Use a decisiontree classifier to predict handwritten digits
|
||||||
|
|
||||||
|
- This is an example of a supervised ML algorithm
|
||||||
|
- it has labels on the training data
|
||||||
|
- you tell the model: this is class X during training
|
||||||
|
"""
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from sklearn import datasets
|
||||||
|
from sklearn.tree import DecisionTreeClassifier
|
||||||
|
from sklearn import tree
|
||||||
|
|
||||||
|
# load the digits dataset
|
||||||
|
digits = datasets.load_digits()
|
||||||
|
# get a feel for it
|
||||||
|
print(digits.data.size)
|
||||||
|
print(digits.target.size)
|
||||||
|
print(digits.feature_names)
|
||||||
|
print(digits.target_names)
|
||||||
|
|
||||||
|
# use a decision tree classifier
|
||||||
|
# set max_depth to 5 otherwise the tree will get huge
|
||||||
|
classifier = DecisionTreeClassifier(max_depth=5)
|
||||||
|
# use all but the last sample for training
|
||||||
|
classifier.fit(digits.data[:-1], digits.target[:-1])
|
||||||
|
|
||||||
|
# use the model to predict the last data sample
|
||||||
|
last_sample = digits.data[-1:]
|
||||||
|
last_target = digits.target[-1:]
|
||||||
|
print(f"predicted: {classifier.predict(last_sample)} vs real: {last_target}")
|
||||||
|
|
||||||
|
# print the tree for visual inspection
|
||||||
|
fig, ax = plt.subplots(figsize=(20, 10))
|
||||||
|
tree.plot_tree(
|
||||||
|
classifier,
|
||||||
|
feature_names=digits.feature_names,
|
||||||
|
class_names=[str(i) for i in digits.target_names],
|
||||||
|
filled=True,
|
||||||
|
rounded=True,
|
||||||
|
ax=ax,
|
||||||
|
)
|
||||||
|
fig.savefig("decisiontree_digits.png", dpi=500, bbox_inches="tight")
|
||||||
@@ -19,7 +19,7 @@ print(iris.target.size)
|
|||||||
print(iris.feature_names)
|
print(iris.feature_names)
|
||||||
print(iris.target_names)
|
print(iris.target_names)
|
||||||
|
|
||||||
# use a decition tree classifier
|
# use a decision tree classifier
|
||||||
classifier = DecisionTreeClassifier()
|
classifier = DecisionTreeClassifier()
|
||||||
# use all but the last sample for training
|
# use all but the last sample for training
|
||||||
classifier.fit(iris.data[:-1], iris.target[:-1])
|
classifier.fit(iris.data[:-1], iris.target[:-1])
|
||||||
|
|||||||
Reference in New Issue
Block a user