博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Spark快速获得CrossValidator的最佳模型参数
阅读量:5214 次
发布时间:2019-06-14

本文共 1296 字,大约阅读时间需要 4 分钟。

Spark提供了便利的Pipeline模型,可以轻松的创建自己的学习模型。

但是大部分模型都是需要提供参数的,如果不提供就是默认参数,那么怎么选择参数就是一个比较常见的问题。Spark提供在org.apache.spark.ml.tuning包下提供了模型选择器,可以替换参数然后比较模型输出。

目前有CrossValidator和TrainValidationSplit两种,比如一个文本情感预测模型。

Pipeline只有三步,第一步切词,第二步HashingTF,第三步NB分类

Pipeline pipeline = new Pipeline()                .setStages(new PipelineStage[]{tokenizer, hashingTF, naiveBayes});ParamMap[] paramMaps = new ParamGridBuilder()                .addGrid(hashingTF.numFeatures(), new int[]{
10000, 100000, 500000, 1000000}) .build();CrossValidator cv = new CrossValidator() .setEstimator(pipeline) .setEvaluator(new BinaryClassificationEvaluator()) .setEstimatorParamMaps(paramMaps);

其中HashingTF的参数选择非常重要,我们这里就随便尝试几种,然后放在CrossValidator中去。

最后我们会获得一个CrossValidatorModel类,这里有两种选择。

第一种是自己手动获取其中的参数,因为bestModel的参数就是我们最后选择的参数

Pipeline bestPipeline = (Pipeline) model.bestModel().parent();PipelineStage stage = bestPipeline.getStages()[1];stage.extractParamMap().get(stage.getParam("numFeatures"));

这种方法可以获得值,但是需要根据你模型情况修改获取的位置。

如果你只是想知道最佳参数是多少,并不是需要在上下文中使用,那还有一个更简单的方法。

修改log4j的配置,添加

log4j.logger.org.apache.spark.ml.tuning.TrainValidationSplit=INFOlog4j.logger.org.apache.spark.ml.tuning.CrossValidator=INFO

效果如下:

 

转载于:https://www.cnblogs.com/itboys/p/9827567.html

你可能感兴趣的文章
知识不是来炫耀的,而是来分享的-----现在的人们却…似乎开始变味了…
查看>>
CSS背景颜色、背景图片、平铺、定位、固定
查看>>
口胡:[HNOI2011]数学作业
查看>>
我的第一个python web开发框架(29)——定制ORM(五)
查看>>
Combination Sum III -- leetcode
查看>>
中国剩余定理
查看>>
基础笔记一
查看>>
uva 10137 The trip
查看>>
spring 解决中文乱码问题
查看>>
hdu 4268
查看>>
启动tomcat时cmd窗口一闪而过
查看>>
两个有序数列,求中间值 Median of Two Sorted Arrays
查看>>
vue路由的实现原理
查看>>
Java核心技术:Java异常处理
查看>>
Python 学习笔记一
查看>>
引入列表,将对话分类添加到对应列表中
查看>>
回文子串
查看>>
Count Numbers
查看>>
React——JSX
查看>>
编写高质量代码改善C#程序的157个建议——建议110:用类来代替enum
查看>>