Files
cas-pml/ML/aufgaben/decisiontree/decisiontree_iris.py
T

43 lines
1.2 KiB
Python

"""
Use a decisiontree classifier to predict flowers based on sepal and petal features
- This is an example of a supervised ML algorithm
- it has labels on the training data
- you tell the model: this is class X during training
"""
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
# load the iris data set and look at its dimensions
iris = datasets.load_iris()
print(iris.data.size)
print(iris.target.size)
print(iris.feature_names)
print(iris.target_names)
# use a decision tree classifier
classifier = DecisionTreeClassifier()
# use all but the last sample for training
classifier.fit(iris.data[:-1], iris.target[:-1])
# use the model to predict the last data sample
last_sample = iris.data[-1:]
last_target = iris.target[-1:]
print(f"predicted: {last_sample} vs real: {last_target}")
# print the tree for visual inspection
fig, ax = plt.subplots(figsize=(20, 10))
tree.plot_tree(
classifier,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True,
rounded=True,
ax=ax,
)
fig.savefig("decisiontree_iris.png", dpi=150, bbox_inches="tight")