决策树是一种比较常用的分类算法,所谓决策树分类就是用决策条件构成的一个树状预测模型,通过这个模型,我们可以对未知类别的数据进行分类。决策树是一种树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。
决策树的构造过程不依赖领域知识,它使用属性选择度量来选择将元组最好地划分成不同的类的属性。所谓决策树的构造就是进行属性选择度量确定各个特征属性之间的拓扑结构。
构造决策树的关键步骤是分裂属性。所谓分裂属性就是在某个节点处按照某一特征属性的不同划分构造不同的分支,其目标是让各个分裂子集尽可能地“纯”。尽可能“纯”就是尽量让一个分裂子集中待分类项属于同一类别。分裂属性分为三种不同的情况:
1. 属性是离散值且不要求生成二叉决策树。此时用属性的每一个划分作为一个分支。
2. 属性是离散值且要求生成二叉决策树。此时使用属性划分的一个子集进行测试,按照“属于此子集”和“不属于此子集”分成两个分支。
3. 属性是连续值。此时确定一个值作为分裂点split point,按照>split point和<=split point生成两个分支。
构造决策树的关键性内容是进行属性选择度量,属性选择度量是一种选择分裂准则,是将给定了类标记的训练集合划分,“最好”地分成个体类的启发式方法,它决定了拓扑结构及分裂点split point的选择。
属性选择度量算法有很多,一般使用自顶向下递归分治法,并采用不回溯的贪心策略,常用的算法有ID3和C4.5。
在实际构造决策树时,通常要进行剪枝,这是为了处理由于数据中的噪声和离群点导致的过分拟合问题。剪枝有两种:
1. 先剪枝——在构造过程中,当某个节点满足剪枝条件,则直接停止此分支的构造。
2. 后剪枝——先构造完成完整的决策树,再通过某些条件遍历树进行剪枝。
因为在实际的训练中,训练的结果对于训练集的拟合程度通常还是挺好的(初试条件敏感),但是对于训练集之外的数据的拟合程度通常就不那么令人满意了。因此我们通常并不会把所有的数据集都拿来训练,而是分出一部分来(这一部分不参加训练)对训练集生成的参数进行测试,相对客观的判断这些参数对训练集之外的数据的符合程度。这种思想就称为交叉验证。
train_test_split函数
train_test_split来自sklearn.model_selection,是交叉验证中常用的函数,它能从样本中按比例随机选取训练集和测试集。其用法如下:
X_train, X_test, y_train, y_test = cross_validation.train_test_split(train_data, train_target, test_size=0.25, random_state=None)
参数解释:
1. train_data: 所要划分的样本特征集。
2. train_target: 所要划分的样本结果。
3. test_size: 样本占比,如果是整数的话就是样本的数量。
4. random_state: 是随机数的种子。
tree.DecisionTreeClassifier函数
DecisionTreeClassifier函数用于创建决策树分类器。其用法如下:
clf = tree.DecisionTreeClassifier()
常用参数解释:
1. criterion: string类型,可选(默认为"gini")。指定使用哪种方法衡量分类的质量。支持的标准有"gini"代表的是Gini impurity(不纯度)与"entropy"代表的是information gain(信息增益)。
2. splitter: string类型,可选(默认为"best")。指定在节点中选择分类的策略。支持的策略有"best",选择最好的分类,"random"选择最好的随机分类。
3. max_depth: int or None,可选(默认为"None")。表示树的最大深度。
4. min_samples_split: int,float,可选(默认为2)。一个内部节点需要的最少的样本数。
5. max_features: int,float,string or None类型,可选(默认为None)。在进行分类时需要考虑的特征数。
6. random_state: 可为int类型,RandomState 实例或None,可选(默认为"None")。如果是int,random_state
是随机数字发生器的种子;如果是RandomState,random_state是随机数字发生器,如果是None,随机数字发生器是np.random使用的RandomState instance.
iris数据集是一个经典的用于多类分类的数据集。通过sklearn.datasets.load_iris()函数可导入。sklearn中的iris数据集有5个key,分别如下:
1. target_names: 类别名称,分别为setosa、versicolor和virginica;
2. data:特征集,5列,150行;
3. target: 样本类别值;
4. DESCR:关于数据的描述信息;
5. feature_names:特征名称,分别为sepal length (cm),sepal width (cm),petal length (cm)和petal width (cm)。
iris.data前五条数据如下:
>>> from sklearn.datasets import load_iris
>>> dataset = load_iris()
>>> dataset.data[0:5]
array([[ 5.1, 3.5, 1.4, 0.2],
[ 4.9, 3. , 1.4, 0.2],
[ 4.7, 3.2, 1.3, 0.2],
[ 4.6, 3.1, 1.5, 0.2],
[ 5. , 3.6, 1.4, 0.2]])
iris.target前五条数据如下:
>>> dataset.target[0:5]
array([0, 0, 0, 0, 0])