logo头像
Snippet 博客主题

Spark随机森林遍历模型

树模型有一个重要的优点是可解释性比较好,训练之后遍历结构即可获得相应的规则。


Spark中将其抽象为Node类和Split类,Node类有两个子类型:LeafNode和InternalNode;Split也有两种子类型:ContinuousSplit和CategoricalSplit。在遍历树的过程中判断当前节点是否是LeafNode,如果是的话就获取其预测值(prediction)以及纯净度(impurity),如果不是则继续向下遍历。具体的阈值条件可以通过Split中的属性获得。


训练模型

数据处理

1
2
3
4
5
6
7
import org.apache.spark.ml.feature.VectorAssembler

val input = spark.sql("select id, fe1, fe2, fe3 from table")
val features = input.columns.tail
val feMaker = new VectorAssembler().setInputCols(features).setOutputCol("features")
val data = feMaker.transform(input)
val Array(train, test) = data.randomSplit(Array(0.9, 0.1))

模型训练

1
2
3
4
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}

val rf = new RandomForestClassifier().setLabelCol("label").setFeaturesCol("features").setNumTrees(30).setMaxDepth(4)
val model: RandomForestClassificationModel = rf.fit(train)

特征重要性

模型中保存的特征是特征的id,需要将其映射为特征名比较好理解

1
2
3
4
5
val featureInfo = features.zipWithIndex.map(x => (x._2, x._1)).toMap
val featureImportance = model.featureImportances
val indices = featureImportance.toSparse.indices
val scores = featureImportance.toSparse.values
val importance = (0 until indices.size).map(i => (featureInfo(indices(i)), scores(i).formatted("%.4f"))).sortWith(_._2 > _._2)

获取规则

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import org.apache.spark.ml.tree._

case class RuleItem(col: String, op: String, value: String)

case class RuleExpr(nodes: Array[RuleItem], prediction: Double, impurity: Double)

val rules = model.trees.map { x =>
val node: Node = x.rootNode
visit(node, featureInfo)
}.reduce(_ ++ _)

def visit(root: Node, FeatureInfo: Map[Int, String]): Array[RuleExpr] = {
val paths = ArrayBuffer[RulePaRuleExprth]()

def visitTree(root: Node, path: Array[(String, String, String)]) {
if (root.isInstanceOf[InternalNode]) {
val node = root.asInstanceOf[InternalNode]
if (node.split.isInstanceOf[ContinuousSplit]) {
val split = node.split.asInstanceOf[ContinuousSplit]
visitTree(node.leftChild, path :+ (featureInfo(split.featureIndex), "<=", split.threshold.toString))
visitTree(node.rightChild, path :+ (featureInfo(split.featureIndex), ">=", split.threshold.toString))
}
}
if (root.isInstanceOf[LeafNode]) {
val node = root.asInstanceOf[LeafNode]
paths += RuleExpr(path.map(x => RuleItem(x._1, x._2, x._3)), node.prediction, node.impurity)
}
}
visit(root, Array.empty[(String, String, String)])
paths.toArray
}

Spark源码

Node

1
2
3
4
5
/** Prediction a leaf node makes, or which an internal node would make if it were a leaf node */
def prediction: Double

/** Impurity measure at this node (for training data) */
def impurity: Double

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
private[ml] object Node {

/**
* Create a new Node from the old Node format, recursively creating child nodes as needed.
*/
def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = {
if (oldNode.isLeaf) {
// TODO: Once the implementation has been moved to this API, then include sufficient
// statistics here.
new LeafNode(prediction = oldNode.predict.predict,
impurity = oldNode.impurity, impurityStats = null)
} else {
val gain = if (oldNode.stats.nonEmpty) {
oldNode.stats.get.gain
} else {
0.0
}
new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity,
gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures),
rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null)
}
}
}

LeafNode

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
/**
* Decision tree leaf node.
* @param prediction Prediction this node makes
* @param impurity Impurity measure at this node (for training data)
*/
final class LeafNode private[ml] (
override val prediction: Double,
override val impurity: Double) extends Node {

override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)"

override private[ml] def predict(features: Vector): Double = prediction

override private[tree] def numDescendants: Int = 0

override private[tree] def subtreeToString(indentFactor: Int = 0): String = {
val prefix: String = " " * indentFactor
prefix + s"Predict: $prediction\n"
}

override private[tree] def subtreeDepth: Int = 0

override private[ml] def toOld(id: Int): OldNode = {
// NOTE: We do NOT store 'prob' in the new API currently.
new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true,
None, None, None, None)
}
}

InternalNode

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
private object InternalNode {

/**
* Helper method for [[Node.subtreeToString()]].
* @param split Split to print
* @param left Indicates whether this is the part of the split going to the left, or that going to the right.
*/
private def splitToString(split: Split, left: Boolean): String = {
val featureStr = s"feature ${split.featureIndex}"
split match {
case contSplit: ContinuousSplit =>
if (left) {
s"$featureStr <= ${contSplit.threshold}"
} else {
s"$featureStr > ${contSplit.threshold}"
}
case catSplit: CategoricalSplit =>
val categoriesStr = catSplit.leftCategories.mkString("{", ",", "}")
if (left) {
s"$featureStr in $categoriesStr"
} else {
s"$featureStr not in $categoriesStr"
}
}
}
}

Split

ContinuousSplit

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
/**
* Split which tests a categorical feature.
* @param featureIndex Index of the feature to test
* @param _leftCategories If the feature value is in this set of categories, then the split goes left. Otherwise, it goes right.
* @param numCategories Number of categories for this feature.
*/
final class CategoricalSplit private[ml] (
override val featureIndex: Int,
_leftCategories: Array[Double],
private val numCategories: Int)
extends Split {

require(_leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" +
s" (should be in range [0, $numCategories)): ${_leftCategories.mkString(",")}")

/**
* If true, then "categories" is the set of categories for splitting to the left, and vice versa.
*/
private val isLeft: Boolean = _leftCategories.length <= numCategories / 2

/** Set of categories determining the splitting rule, along with [[isLeft]]. */
private val categories: Set[Double] = {
if (isLeft) {
_leftCategories.toSet
} else {
setComplement(_leftCategories.toSet)
}
}

override private[ml] def shouldGoLeft(features: Vector): Boolean = {
if (isLeft) {
categories.contains(features(featureIndex))
} else {
!categories.contains(features(featureIndex))
}
}

override def equals(o: Any): Boolean = {
o match {
case other: CategoricalSplit => featureIndex == other.featureIndex &&
isLeft == other.isLeft && categories == other.categories
case _ => false
}
}

override private[tree] def toOld: OldSplit = {
val oldCats = if (isLeft) {
categories
} else {
setComplement(categories)
}
OldSplit(featureIndex, threshold = 0.0, OldFeatureType.Categorical, oldCats.toList)
}

/** Get sorted categories which split to the left */
def leftCategories: Array[Double] = {
val cats = if (isLeft) categories else setComplement(categories)
cats.toArray.sorted
}

/** Get sorted categories which split to the right */
def rightCategories: Array[Double] = {
val cats = if (isLeft) setComplement(categories) else categories
cats.toArray.sorted
}

/** [0, numCategories) \ cats */
private def setComplement(cats: Set[Double]): Set[Double] = {
Range(0, numCategories).map(_.toDouble).filter(cat => !cats.contains(cat)).toSet
}
}

CategoricalSplit

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/**
* Split which tests a continuous feature.
* @param featureIndex Index of the feature to test
* @param threshold If the feature value is <= this threshold, then the split goes left. Otherwise, it goes right.
*/
final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double)
extends Split {

override private[ml] def shouldGoLeft(features: Vector): Boolean = {
features(featureIndex) <= threshold
}

override def equals(o: Any): Boolean = {
o match {
case other: ContinuousSplit =>
featureIndex == other.featureIndex && threshold == other.threshold
case _ =>
false
}
}

override private[tree] def toOld: OldSplit = {
OldSplit(featureIndex, threshold, OldFeatureType.Continuous, List.empty[Double])
}
}

评论系统未开启,无法评论!