模式识别合集-决策树(附ID3算法matlab代码)
模式识别合集-决策树(附ID3算法matlab代码)
- 决策树
- 基本方法
- ID3方法(交互式二分法 Interactive Dichotomizer-3)
- C4.5算法
- 过学习与决策树的剪枝
- matlab代码
- 测试结果
决策树
- 利用一定的训练样本,从数据中“学习”出决策规则,自动构造出决策树
基本方法
ID3方法(交互式二分法 Interactive Dichotomizer-3)
- 适用于每个节点下划分多个子节点的情况
- 通过选择有 辨别力的特征 对数据进行划分,直到每个叶节点上只包含单一类型的数据为止
- 基础:信息论中的熵(entropy)
- 一个事件有 k k k种可能结果,每个结果对应概率为 P i P_i Pi,则对事件结果观察后得到的信息量定义:
I = − ( P 1 l o g 2 P 1 + P 2 l o g 2 P 2 + . . . + P k l o g 2 P k ) = − ∑ i = 1 k P i l o g 2 P i I=-(P_1log_2P_1+P_2log_2P_2+...+P_klog_2P_k)=-\sum_{i=1}^{k}P_ilog_2P_i I=−(P1log2P1+P2log2P2+...+Pklog2Pk)=−i=1∑kPilog2Pi - 对某个节点上的样本,称为 熵不纯度 ,反映节点上特征对样本分类的 不纯度
- 希望引入特征后,熵不纯度能尽量减少
- 总的熵不纯度是所有样本计算的不纯度按照样本比例加权求和
- 如果特征把 N N N个样本划分为 m m m组,每组 N m N_m Nm个样本,则 不纯度减少量 的计算公式:
Δ I ( N ) = I ( N ) − ( P 1 I ( N 1 ) + P 2 I ( N 2 ) + . . . P m ( N m ) ) \Delta I(N)=I(N)-(P_1I(N_1)+P_2I(N_2)+...P_m(N_m)) ΔI(N)=I(N)−(P1I(N1)+P2I(N2)+...Pm(Nm))
- 一个事件有 k k k种可能结果,每个结果对应概率为 P i P_i Pi,则对事件结果观察后得到的信息量定义:
- 算法具体流程
- 计算当前节点包含的所有样本的熵不纯度
- 比较采用不同特征值进行分枝将得到的信息增益(不纯度减少量)
- 选取具有最大信息增益的特征赋予当前节点
- 构造决策树的下一层节点,需要分别考察两组样本上采用不同特征所得到的不纯度减少,采用较大不纯度减少量的特征构建下一层
- 若后续节点只包含一类样本,停止该枝生长;若还包含不同类样本,再重复以上步骤
- 除了香农熵,还有其他度量,比如
- Gini不纯度度量(方差不纯度)
- 误差不纯度
C4.5算法
- 采用信息增益率代替信息增益
Δ I R ( N ) = I ( N ) I ( N ) \Delta I_R(N)=\frac{I(N)}{I(N)} ΔIR(N)=I(N)I(N)
过学习与决策树的剪枝
- 推广性问题:准确率
- 过学习:测试数据或新数据的表现与训练数据差别很大
- 主要手段
- 控制决策树生成算法的终止条件
- 对决策树进行剪枝
- 主要手段
- 先剪枝
- 控制决策树的生长,生长过程中决定某节点是否需要继续分支还是直接作为叶节点
- 判断决策树何时停止的方法
- 数据划分法:数据划分为训练样本和测试样本,在训练集上生长决策树,在测试集上分类错误率达到最小时停止生长
- 阈值法:设定一个信息增益阈值,当信息增益小于设定阈值时停止树的生长
- 统计显著性分析:对已有节点获得的所有信息增益统计其分布,若继续生长得到的信息增益与该分布相比不显著则停止树的生长,通常用卡方检验
matlab代码
使用西瓜数据集进行ID3决策树分类
- decisionTree.m
clear;% 西瓜数据集
data=["青绿","蜷缩","浊响","清晰","凹陷","硬滑","是";"乌黑","蜷缩","沉闷","清晰","凹陷","硬滑","是";"乌黑","蜷缩","浊响","清晰","凹陷","硬滑","是";"青绿","蜷缩","沉闷","清晰","凹陷","硬滑","是";"浅白","蜷缩","浊响","清晰","凹陷","硬滑","是";"青绿","稍蜷","浊响","清晰","稍凹","软粘","是";"乌黑","稍蜷","浊响","稍糊","稍凹","软粘","是";"乌黑","稍蜷","浊响","清晰","稍凹","硬滑","是";"乌黑","稍蜷","沉闷","稍糊","稍凹","硬滑","否";"青绿","硬挺","清脆","清晰","平坦","软粘","否";"浅白","硬挺","清脆","模糊","平坦","硬滑","否";"浅白","蜷缩","浊响","模糊","平坦","软粘","否";"青绿","稍蜷","浊响","稍糊","凹陷","硬滑","否";"浅白","稍蜷","沉闷","稍糊","凹陷","硬滑","否";"乌黑","稍蜷","浊响","清晰","稍凹","软粘","否";"浅白","蜷缩","浊响","模糊","平坦","硬滑","否";"青绿","蜷缩","沉闷","稍糊","稍凹","硬滑","否"];label = ["色泽","根蒂","敲声","纹理","脐部","触感","好瓜"];% 参数预定义
datasetRate = 1;
dataSize = size(data);% 数据预处理
index = randperm(17,round(datasetRate*(dataSize(1,1)-1)));
trainSet = data(index,:);
testSet = data;
testSet(index,:) = [];% 所有标签
deepth = ones(1,dataSize(1,2)-1);
% 生成树
rootNode = makeTree(label,trainSet,deepth,'null');
% 画出决策树
drawTree(rootNode);
- calculateImpurity.m
% 计算熵不纯度
function res = calculateImpurity(examples_)P1 = 0;P2 = 0;[m_,n_] = size(examples_);P1 = sum(examples_(:,n_) == '是');P2 = sum(examples_(:,n_) == '否');P1 = P1 / m_;P2 = P2 / m_;if P1 == 1 || P1 == 0res = 0;elseres = -(P1*log2(P1)+P2*log2(P2));end
end
- getBestlabel.m
% 决策过程 获取信息增量最大的分类标准
function label = getBestlabel(impurity_,features_,samples_)% impurity_:划分前的熵不纯度% features_:当前可供分类的标签 是01矩阵% samples_:当前需要分类的样本[m,n]=size(samples_);delta_impurity = zeros(1,n-1);% 遍历每个特征 每个特征把m个样本分为t组 每组m_t个样本 计算每个特征的不纯度减少量delta_impurity(i)% 输入样本为m行n列矩阵 特征总数量为n-1for i = 1:n-1% 存放分类结果count = 1;grouping_res = strings;sample_nums = [];grouped_impurity = [];% 分类结果按分组计算熵不纯度grouped_P = [];% 如果features_(i)为1 说明该分支上该标签还未用于分类if features_(i) == 1% 分组for j = 1:mpos = grouping_res == samples_(j,i);if sum(pos)% 分类样本 计算同一标签类别的样本数量sample_nums(pos) = sample_nums(pos) + 1;else % 将标签的类别添加到统计结果sample_nums = [sample_nums 1];grouping_res(count) = samples_(j,i);count = count + 1;endend% 计算该分类结果的不纯度减少量% 按分组计算熵不纯度for k = grouping_ressub_sample = samples_(samples_(:,i)==k,:);grouped_impurity = [grouped_impurity calculateImpurity(sub_sample)];grouped_P = [grouped_P sum(sub_sample(:,n)=='是')/sum(samples_(:,i)==k)];enddelta_impurity(i) = impurity_ - sum(grouped_P.*grouped_impurity);endend% 返回的label是索引数组temp = delta_impurity==max(delta_impurity);% 如果存在多个结果一样的标签 则使用第一个label = find(temp,1);
end
- makeTree.m
% 生成决策树
function node = makeTree(features,examples,deepth,branch)% feature:样本分类依据的所有标签% examples:样本% deepth:树的深度,每被分类一次与分类标签对应的值置零% value:分类结果,若为null则表示该节点是分支节点% label:节点划分标签% branch:分支值% children:子节点node = struct('value','null','label',[],'branch',branch,'children',[]);[m,n] = size(examples);sample = examples(1,n);check_res = true;for i = 1:mif sample ~= examples(i,n)check_res = false;endend% 若样本中全为同一分类结果 则作为叶节点if check_resnode.value = examples(1,n);return;end% 计算熵不纯度impurity = calculateImpurity(examples);% 选择合适的标签bestLabel = getBestlabel(impurity,deepth,examples);deepth(bestLabel) = 0;node.label = features(bestLabel);% 分类grouping_res = strings;count = 1;for i = 1:mpos = grouping_res == examples(i,bestLabel);if sum(pos)% 分类样本 计算同一标签类别的样本数量else % 将标签的类别添加到统计结果grouping_res(count) = examples(i,bestLabel);count = count + 1;endendfor k = grouping_ressub_sample = examples(examples(:,bestLabel)==k,:);node.children = [node.children makeTree(features,sub_sample,deepth,k)];endend
- drawTree.m
% 画出决策树
function [] = drawTree(node)% 遍历树nodeVec = [];nodeSpec = [];edgeSpec = [];[nodeVec,nodeSpec,edgeSpec,total] = travesing(node,0,0,nodeVec,nodeSpec,edgeSpec);treeplot(nodeVec);[x,y] = treelayout(nodeVec);[m,n] = size(nodeVec);x = x';y = y';text(x(:,1),y(:,1),nodeSpec,'VerticalAlignment','bottom','HorizontalAlignment','right');x_branch = [];y_branch = [];for i = 2:nx_branch = [x_branch; (x(i,1)+x(nodeVec(i),1))/2];y_branch = [y_branch; (y(i,1)+y(nodeVec(i),1))/2];endtext(x_branch(:,1),y_branch(:,1),edgeSpec(1,2:n),'VerticalAlignment','bottom','HorizontalAlignment','right');
end% 遍历树
function [nodeVec,nodeSpec,edgeSpec,current_count] = travesing(node,current_count,last_node,nodeVec,nodeSpec,edgeSpec)nodeVec = [nodeVec last_node];if node.value == 'null'nodeSpec = [nodeSpec node.label];elseif node.value == '是'nodeSpec = [nodeSpec '好瓜'];elsenodeSpec = [nodeSpec '坏瓜'];endendedgeSpec = [edgeSpec node.branch];current_count = current_count + 1;current_node = current_count;if node.value ~= 'null'return;endfor next_ndoe = node.children[nodeVec,nodeSpec,edgeSpec,current_count] = travesing(next_ndoe,current_count,current_node,nodeVec,nodeSpec,edgeSpec);end
end
测试结果

本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!
