mac服務(wù)器 做網(wǎng)站免費(fèi)google賬號(hào)注冊(cè)入口
Scikit-Learn決策樹
- 1、決策樹分類
- 2、Scikit-Learn決策樹分類
- 2.1、Scikit-Learn決策樹API
- 2.2、Scikit-Learn決策樹初體驗(yàn)
- 2.3、Scikit-Learn決策樹實(shí)踐(葡萄酒分類)
1、決策樹分類
2、Scikit-Learn決策樹分類
2.1、Scikit-Learn決策樹API
官方文檔:https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier
中文官方文檔:https://scikit-learn.org.cn/view/784.html
2.2、Scikit-Learn決策樹初體驗(yàn)
下面我們使用Scikit-Learn提供的API制作兩個(gè)交錯(cuò)的半圓形狀數(shù)據(jù)集來(lái)演示Scikit-Learn決策樹
1)制作數(shù)據(jù)集
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets# 生成兩個(gè)交錯(cuò)的半圓形狀數(shù)據(jù)集
X, y = datasets.make_moons(noise=0.25, random_state=666)
plt.scatter(X[y == 0, 0], X[y == 0, 1])
plt.scatter(X[y == 1, 0], X[y == 1, 1])
plt.show()
2)訓(xùn)練決策樹分類模型
from sklearn.tree import DecisionTreeClassifier # 決策樹分類器# 使用CART分類樹的默認(rèn)參數(shù)
dt_clf = DecisionTreeClassifier()
# dt_clf = DecisionTreeClassifier(max_depth=2, max_leaf_nodes=4)
# 訓(xùn)練擬合
dt_clf.fit(X, y)
3)繪制決策邊界
# 繪制決策邊界
decision_boundary_fill(dt_clf, axis=[-1.5, 2.5, -1.0, 1.5])
plt.scatter(X[y == 0, 0], X[y == 0, 1])
plt.scatter(X[y == 1, 0], X[y == 1, 1])
plt.show()
其中,使用到的繪制函數(shù)詳見(jiàn)文章:傳送門
當(dāng)使用CART分類樹的默認(rèn)參數(shù)時(shí),其決策邊界如圖所示:
由圖可見(jiàn),在不加限制的情況下,一棵決策樹會(huì)生長(zhǎng)到所有的葉子都是純凈的或者或者沒(méi)有更多的特征可用為止。這樣的決策樹往往會(huì)過(guò)擬合,也就是說(shuō),它在訓(xùn)練集上表現(xiàn)的很好,而在測(cè)試集上卻表現(xiàn)的很糟糕
當(dāng)我們限制決策樹的最大深度max_depth=2
,并且最大葉子節(jié)點(diǎn)數(shù)max_leaf_nodes=4
時(shí),其決策邊界如下圖所示:
通過(guò)限制一些參數(shù),對(duì)決策樹進(jìn)行剪枝,可以讓我們的決策樹具有更好的泛化性
2.3、Scikit-Learn決策樹實(shí)踐(葡萄酒分類)
2.3.1、葡萄酒數(shù)據(jù)集
葡萄酒(Wine)數(shù)據(jù)集是來(lái)自加州大學(xué)歐文分校(UCI)的公開(kāi)數(shù)據(jù)集,這些數(shù)據(jù)是對(duì)意大利同一地區(qū)種植的葡萄酒進(jìn)行化學(xué)分析的結(jié)果。數(shù)據(jù)集共178個(gè)樣本,包括三個(gè)不同品種,每個(gè)品種的葡萄酒中含有13種成分(特征)、一個(gè)類別標(biāo)簽,分別使是0/1/2來(lái)代表葡萄酒的三個(gè)分類
數(shù)據(jù)集的屬性信息(13特征+1標(biāo)簽)如下:
from sklearn.datasets import load_winewine = load_wine()
data = pd.DataFrame(data=wine.data, columns=wine.feature_names)
data['class'] = wine.target
print(data.head().to_string())
'''alcohol malic_acid ash alcalinity_of_ash magnesium total_phenols flavanoids nonflavanoid_phenols proanthocyanins color_intensity hue od280/od315_of_diluted_wines proline class
0 14.23 1.71 2.43 15.6 127.0 2.80 3.06 0.28 2.29 5.64 1.04 3.92 1065.0 0
1 13.20 1.78 2.14 11.2 100.0 2.65 2.76 0.26 1.28 4.38 1.05 3.40 1050.0 0
2 13.16 2.36 2.67 18.6 101.0 2.80 3.24 0.30 2.81 5.68 1.03 3.17 1185.0 0
3 14.37 1.95 2.50 16.8 113.0 3.85 3.49 0.24 2.18 7.80 0.86 3.45 1480.0 0
4 13.24 2.59 2.87 21.0 118.0 2.80 2.69 0.39 1.82 4.32 1.04 2.93 735.0 0
'''
屬性/標(biāo)簽 | 說(shuō)明 |
---|---|
alcohol | 酒精含量(百分比) |
malic_acid | 蘋果酸含量(克/升) |
ash | 灰分含量(克/升) |
alcalinity_of_ash | 灰分堿度(mEq/L) |
magnesium | 鎂含量(毫克/升) |
total_phenols | 總酚含量(毫克/升) |
flavanoids | 類黃酮含量(毫克/升) |
nonflavanoid_phenols | 非黃酮酚含量(毫克/升) |
proanthocyanins | 原花青素含量(毫克/升) |
color_intensity | 顏色強(qiáng)度(單位absorbance) |
hue | 色調(diào)(在1至10之間的一個(gè)數(shù)字) |
od280/od315_of_diluted_wines | 稀釋葡萄酒樣品的光密度比值,用于測(cè)量葡萄酒中各種化合物的濃度 |
proline | 脯氨酸含量(毫克/升) |
class | 分類標(biāo)簽(class_0(59)、class_1(71)、class_2(48)) |
數(shù)據(jù)集的概要信息如下:
# 數(shù)據(jù)集大小
print(wine.data.shape) # (178, 13)
# 標(biāo)簽名稱
print(wine.target_names) # ['class_0' 'class_1' 'class_2']
# 分類標(biāo)簽
print(data.groupby('class')['class'].count())
'''
class
0 59
1 71
2 48
Name: class, dtype: int64
'''
數(shù)據(jù)集的缺失值情況:
# 缺失值:無(wú)缺失值
print(data.isnull().sum())
2.3.2、決策樹實(shí)踐(葡萄酒分類)
未完待續(xù)…