SparkMlib 之决策树及其案例
创始人
2024-02-29 13:31:11
0

文章目录

    • 什么是决策树?
    • 决策树的优缺点
    • 决策树示例——鸢尾花分类

什么是决策树?

决策树及其集成是分类和回归机器学习任务的流行方法。决策树被广泛使用,因为它们易于解释,处理分类特征,扩展到多类分类设置,不需要特征缩放,并且能够捕获非线性和特征相互作用。随机森林和增强算法等树集成算法在分类和回归任务中表现最佳。

常应用于以下类型的场景:

  1. 预测用户贷款是否能够按时还款;
  2. 预测邮件是否是垃圾邮件;
  3. 预测用户是否会购买某件商品等等

官网:分类和回归

决策树的优缺点

优点:

  1. 决策树算法易理解,机理解释起来简单。

  2. 决策树算法可以用于小数据集。

  3. 决策树算法的时间复杂度较小,为用于训练决策树的数据点的对数。

  4. 相比于其他算法智能分析一种类型变量,决策树算法可处理数字和数据的类别。

  5. 能够处理多输出的问题。

  6. 对缺失值不敏感。

  7. 可以处理不相关特征数据。

  8. 效率高,决策树只需要一次构建,反复使用,每一次预测的最大计算次数不超过决策树的深度。

缺点:

  1. 对连续性的字段比较难预测。

  2. 容易出现过拟合。

  3. 当类别太多时,错误可能就会增加的比较快。

  4. 在处理特征关联性比较强的数据时表现得不是太好。

  5. 对于各类别样本数量不一致的数据,在决策树当中,信息增益的结果偏向于那些具有更多数值的特征。

参考博客:决策树算法优缺点

决策树示例——鸢尾花分类

数据集下载:

链接:
https://pan.baidu.com/s/1AshgNxx1wOWhLgKxgjrZww?pwd=lz3l 提取码:
lz3l

数据集介绍:

iris.data 数据集中共有五个字段,逗号分隔,前四个为特征字段,最后一个为标签字段。

标签字段列一共有三种值,分别是:Iris-setosaIris-versicolorIris-virginica

将数据集中的随机百分之70作为训练集,剩余的作为测试集。

需求实现:

import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SparkSession}object Iris {// TODO 鸢尾花种类判断def main(args: Array[String]): Unit = {val sc: SparkSession = SparkSession.builder().appName("Iris").master("local[*]").getOrCreate()// 1.加载鸢尾花数据val train_data: RDD[String] = sc.read.textFile("iris.data").rdd// 2.将随机百分之70的数据设置为训练集,其余为测试集val data: Array[RDD[String]] = train_data.randomSplit(Array(0.7, 0.3))// 3.向量转换import sc.implicits._val trainDF: DataFrame = data(0).map(lines => {val arr: Array[String] = lines.split(",")LabeledPoint(if (arr(4).equals("Iris-setosa")) {1D} else if (arr(4).equals("Iris-versicolor")) {2D} else {3D},Vectors.dense(arr.take(4).map(_.toDouble)))}).toDF("label", "features")// 4.创建决策树对象val classifier = new DecisionTreeClassifier()// 设置最大深度、分支、质量、特征列classifier.setMaxDepth(5).setMaxBins(32).setImpurity("gini").setFeaturesCol("features")// 5.训练模型val model: DecisionTreeClassificationModel = classifier.fit(trainDF)// 打印模型println(model.toDebugString)// 6.将测试集转换成向量val testDF: DataFrame = data(1).map(lines => {val arr: Array[String] = lines.split(",")LabeledPoint(if (arr(4).equals("Iris-setosa")) {1D} else if (arr(4).equals("Iris-versicolor")) {2D} else {3D},Vectors.dense(arr.take(4).map(_.toDouble)))}).toDF("label", "features")// 7.模型预测val result: DataFrame = model.transform(testDF.select("label", "features"))// 8.模型预测评估result.select("label", "features","prediction").show(100)// 9.计算错误率val error: Double = result.where("label = prediction").count.toDouble/result.countprintln("错误率为:"+(1-error))}}

相关内容

热门资讯

监控摄像头接入GB28181平... 流程简介将监控摄像头的视频在网站和APP中直播,要解决的几个问题是:1&...
Windows10添加群晖磁盘... 在使用群晖NAS时,我们需要通过本地映射的方式把NAS映射成本地的一块磁盘使用。 通过...
protocol buffer... 目录 目录 什么是protocol buffer 1.protobuf 1.1安装  1.2使用...
Fluent中创建监测点 1 概述某些仿真问题,需要创建监测点,用于获取空间定点的数据࿰...
educoder数据结构与算法...                                                   ...
MySQL下载和安装(Wind... 前言:刚换了一台电脑,里面所有东西都需要重新配置,习惯了所...
MFC文件操作  MFC提供了一个文件操作的基类CFile,这个类提供了一个没有缓存的二进制格式的磁盘...
在Word、WPS中插入AxM... 引言 我最近需要写一些文章,在排版时发现AxMath插入的公式竟然会导致行间距异常&#...
有效的括号 一、题目 给定一个只包括 '(',')','{','}'...
【Ctfer训练计划】——(三... 作者名:Demo不是emo  主页面链接:主页传送门 创作初心ÿ...