1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
| import org.apache.spark.ml.tree._
case class RuleItem(col: String, op: String, value: String)
case class RuleExpr(nodes: Array[RuleItem], prediction: Double, impurity: Double)
val rules = model.trees.map { x => val node: Node = x.rootNode visit(node, featureInfo) }.reduce(_ ++ _)
def visit(root: Node, FeatureInfo: Map[Int, String]): Array[RuleExpr] = { val paths = ArrayBuffer[RulePaRuleExprth]()
def visitTree(root: Node, path: Array[(String, String, String)]) { if (root.isInstanceOf[InternalNode]) { val node = root.asInstanceOf[InternalNode] if (node.split.isInstanceOf[ContinuousSplit]) { val split = node.split.asInstanceOf[ContinuousSplit] visitTree(node.leftChild, path :+ (featureInfo(split.featureIndex), "<=", split.threshold.toString)) visitTree(node.rightChild, path :+ (featureInfo(split.featureIndex), ">=", split.threshold.toString)) } } if (root.isInstanceOf[LeafNode]) { val node = root.asInstanceOf[LeafNode] paths += RuleExpr(path.map(x => RuleItem(x._1, x._2, x._3)), node.prediction, node.impurity) } } visit(root, Array.empty[(String, String, String)]) paths.toArray }
|
评论系统未开启,无法评论!