spark改写 心血管疾病预测
python版传送门:https://www.kesci.com/home/project/5da974e9c83fb400420f77d3
package dataclear/*** @CreateUser: eshter* @CreateDate: 2019/10/23* @UpdateUser:*/import utils.session.IgnoreErrorAndINFO
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.ml.classification.{LogisticRegression}
import org.apache.spark.ml.feature.{StandardScaler, VectorAssembler, _}
import utils.metrics.Metrics
import org.apache.spark.ml.Pipeline
object cardioTrainLr {/*注意:1、label =cardio2、StandardScaler 只支持输入向量(org.spark.ml.linalg.Vector)的数据3、数据的连续型变量为Array("age","height","weight","ap_hi","ap_lo")4、数据的离散型变量为Array("gender","cholesterol","gluc","smoke","alco")*/new IgnoreErrorAndINFO().ignoreErrorAndInfo()def splitData(df:DataFrame,splitRate:Double)={val dfTmp = df.randomSplit(Array(splitRate,1-splitRate),seed=2)List(dfTmp(0),dfTmp(1))}def featureHandleTest(dfTrain:DataFrame,dfValid:DataFrame,featureCols:Array[String])={val scale_col=Array("age","height","weight","ap_hi","ap_lo")val onehot_col=Array("gender","cholesterol","gluc","smoke","alco")val onehot_colToInt=onehot_col.map(col=>col+"ToInt")val standardIndex=onehot_col.map{line=>new StringIndexer().setInputCol(line).setOutputCol(line+"ToInt")}val vectorScale = new VectorAssembler().setInputCols(scale_col).setOutputCol("feaScale")val scale=new StandardScaler().setInputCol("feaScale").setOutputCol("sfea")val pipeline = new Pipeline().setStages(Array(vectorScale,scale))val model = pipeline.fit(dfTrain)val scaledfTrain=model.transform(dfTrain)val scaleDfTest = model.transform(dfValid)val vectorAssembler = new VectorAssembler().setInputCols(onehot_colToInt++Array("sfea")).setOutputCol("features")val pipelineFinal = new Pipeline().setStages(standardIndex++Array(vectorAssembler))val modelFinal = pipelineFinal.fit(scaledfTrain)val scaledfTrain1=modelFinal.transform(scaledfTrain)val scaleDfTest1 = modelFinal.transform(scaleDfTest)List(scaledfTrain1,scaleDfTest1)}def modelTrainLr(dfTrain:DataFrame,dfValid:DataFrame,featureCol:String,label:String): Unit ={val lr = new LogisticRegression().setLabelCol(label).setFeaturesCol(featureCol).setMaxIter(50)val LrModel = lr.fit(dfTrain)val predTrain=LrModel.transform(dfTrain)val mer = new Metrics(predTrain,label,"prediction")mer.metricFunc()val predTest=LrModel.transform(dfTrain)val mert = new Metrics(predTest,label,"prediction")mert.metricFunc()}def main(args: Array[String]): Unit = {//cardio_train.csvval spark=SparkSession.builder().master("local[2]").appName("cardio_train").getOrCreate()var src_train=spark.read.format("csv").option("header",true).option("inferSchema",true)//.option("multiLine",true).option("delimiter",";").load("/Users/eshter/Desktop/cc_data/cardio_train.csv")println(src_train.show(2))val label_col ="cardio"//删除id 列src_train=src_train.drop("id")//打印分布情况src_train.summary().show()val splitRate=0.85val df=splitData(src_train,splitRate)val dfTrain=df(0)val dfValid =df(1)// println(dfTrain.show(100))println(dfTrain.stat.corr("gender",label_col,"pearson").toString)val featCols = dfTrain.columns.filter(dfTrain.stat.corr(_,label_col,"pearson").abs > 0.1).filter(_!="cardio")val scaleData= featureHandleTest(dfTrain,dfValid,featCols)modelTrainLr(scaleData(0),scaleData(1),"features",label_col)}}
本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!
