# ========================================================= # # From: www.youtube.com/watch?v=tNa99PG8hR8 # Flower data set: # en.wikipedia.org/wiki/Iris_flower_data_set # ========================================================= # Goal: # 1. Import dataset. # 2. Train classifier. # 3. Predict label for new flower. # 4. Visualize the tree. # ========================================================= import numpy as np from sklearn.datasets import load_iris from sklearn import tree iris = load_iris() test_idx = [0, 50, 100] # training data train_target = np.delete(iris.target, test_idx) train_data = np.delete(iris.data, test_idx, axis=0) # testing data test_target = iris.target[test_idx] test_data = iris.data[test_idx] clf = tree.DecisionTreeClassifier() clf.fit(train_data, train_target) print(test_target) print(clf.predict(test_data)) # viz code from sklearn.externals.six import StringIO import pydot #import pydotplus dot_data = StringIO() tree.export_graphviz(clf, out_file=dot_data, feature_names=iris.feature_names, class_names=iris.target_names, filled=True, rounded=True, impurity=False) graph = pydot.graph_from_dot_data(dot_data.getvalue()) graph.write_pdf('iris.pdf') print(test_data[0], test_target[0]) print(iris.feature_names, iris_target_names)