Notes of Machine Learning

CART具体实现

CART可用于分类和回归。

对数据只做二元切分。

回归树与分类树的思想类似,但叶节点的数据类型不是离散的,而是连续的。

1. 存储树的数据结构:

class treeNode():
    def __init__(self, feat, val, right, left):
        featureToSplitOn = feat     #待切分的特征
        valueOfSplit = val            #切分值
        rightBranch = right            #右子树。当不再需要切分时,也可以是单个值
        leftBranch = left            #左子树。同右子树

2. 回归树的构建:

函数createTree()的伪代码大致如下:

找到最佳的待切分特征:
    如果该节点不能再分,将该节点存为叶节点
    执行二元切分
    在右子树调用createTree()方法
    在左子树调用createTree()方法

3. 找到最佳切分(待切分特征和切分点):

函数chooseBestSplit()的伪代码大致如下:

对每一维特征:
    对每个特征值:
        将数据集切分成两份
        计算切分误差
        如果当前误差小于当前最小误差,那么更新最佳切分
返回最佳切分(特征和阈值)

4. 树剪枝(pruning)

通过降低决策树的复杂度来避免过拟合。

4.1 预剪枝(prepruning)

通过设定生长停止条件

  • 树的深度
  • 误差的容忍度tol
  • 节点中样本的个数少于用户指定的个数
  • 节点达到完全纯性

4.2 后剪枝(postpruning)

后剪枝首先允许决策树充分生长,然后自下而上逐层剪枝。

需要将数据集分成训练集和测试集,用测试集来判断两个叶节点合并后是否能降低测试误差,如果能的话,就合并。

函数pruning()的伪代码如下:

如果右子树是一棵树,则在右子树递归剪枝过程。
如果左子树是一棵树,则在左子树递归剪枝过程。
如果左右子树都是单个值,则计算合并后的误差和不合并的误差:
    如果合并会降低误差,就将叶节点合并。

5. 模型树

回归树的叶节点是单个值,这是对数据最朴素的划分。如果将其替换成分段线性函数,就得到了模型树。

线性模型的权值直接求解析解