Spark Rを使ってみたら、どんな問題がでてきたか
こんにちは、次世代システム研究室のA.Zです。
現在、広告解析のプロジェクトで、SparkRを利用しています。SparkRを利用することにあたって、発見した問題または制限と解決方法を紹介したいと思います。
はじめに
前回のブログでは簡単に紹介したSparkRについて、本格的にプロジェクトで利用することになりました。実際のプロジェクトに利用するときに出てきた問題や、SparkRの制約などの解決方法を簡単に紹介したいと思います。初めてSparkRを触った方や、実際のプロジェクトにSparkRの導入を検討している方に役に立つと思います。
発見した問題また制限
サポートしているMLlibのアルゴリズムが少ない
SparkRが現在、標準でサポートしている機械学習アルゴリズムはglmのみです。それ以外のMLlibアルゴリズムを使いたいときはカスタマイズが必要です。SparkRは、Javaまたはscalaのメソッドを呼び出すことができます。標準でサポートしていないアルゴリズムを使いたいときはWrapper classを作成することで、SparkRからMLlibのアルゴリズムを使うことができます。Wrapper Classの例:
package com.foo.bar; public class Wrappers { public static PipelineModel trainRandomForest(DataFrame dataFrame, String label, String[] features, int ntrees, int maxdepth) { // change label to string indexer StringIndexerModel labelIndexer = new StringIndexer() .setInputCol(label) .setOutputCol("label") .fit(dataFrame); VectorAssembler featureAssembler = new VectorAssembler() .setInputCols(features.toArray(new String[0])) .setOutputCol("features"); // RandomForest model. RandomForestClassifier rfc = new RandomForestClassifier() .setLabelCol(labelIndexer.getOutputCol()) .setFeaturesCol(featureAssembler.getOutputCol()) .setMaxDepth(maxdepth) .setNumTrees(ntrees); List<PipelineStage> pipelineStages = new ArrayList(); pipelineStages.add(labelIndexer); pipelineStages.add(featureAssembler); pipelineStages.add(rfc); Pipeline pipeline = new Pipeline() .setStages(pipelineStages.toArray(new PipelineStage[0])); PipelineModel model = pipeline.fit(dataFrame); return model; }wrapper classをjarに固め、SparkR init するときに読み込みます。
jarPath<-"/path/to/wrapper-jar" sc <- sparkR.init(sparkJars = paste0(jarPath))SparkR から、以下のように呼び出すことができます。
rf.model <- SparkR:::callJStatic( "com.foo.bar.Wrappers", "trainRandomForest", train_data@sdf, label, features, numTree, treeDepth ) rf.model <- new("PipelineModel", model = rf.model)
R and spark 接続タイムアウトの調整
以下のスライドによると、SparkRはsocket connectionを利用して、RとSparkのJVMのコミュニケーションを行うようです。http://www.slideshare.net/SparkSummit/07-venkataraman-sun
しかし、以下のsocket connectionのmax aliveは最大6000sに設定されており、解析処理には短すぎます(現在の解析処理は10時間以上かかります)。
source1
source2
以下の方法で、SparkR のconnection timeoutを変更することができます。
#最大24時間設定する sparkr_conn_timeout <- 86400 #because sparkR connection default timeout is 6000s, we need to increase it for long time process port <- get("backendPort",envir=SparkR:::.sparkREnv) conn <- get(".sparkRCon",envir=SparkR:::.sparkREnv) #close the connection first conn <- close(conn) # recreate connection with longer timeout conn <- socketConnection(host = "localhost", port = port, server = FALSE, blocking = TRUE, open = "wb", timeout = sparkr_conn_timeout) #re-assign to sparkEnv assign(".sparkRCon", conn, envir = SparkR:::.sparkREnv)
PipelineModel 保存する問題について
SparkのPipelineModelを利用することで、解析処理のworkflowを簡単に管理・保存・ロードすることができます。http://spark.apache.org/docs/1.6.1/ml-guide.html#main-concepts-in-pipelines
標準のSparkRで、modelまたはpipeline modelの保存する機能がありませんが、JavaまたはScala wrapperを利用すれば、SparkRからでもSpark modelを保存することができます。
しかし、pipelineの中に一つでもsaveできないstageがあった場合、pipeline全体が保存できなくなります。SparkはJavaObjectを保存できる機能を持つので、これとpipelineのstage save機能と組合せることで、全てのステージが保存できない場合でも全体のpipelineは保存・ロードできるようになります。
保存する処理の例:
int counter = 0; //こちらのmodelはPipelineModelです。 for (Transformer tr : model.stages()) { //保存先のファイル名に、stageのindexとstageのクラス名を記載する。ロードするときに、stage順番とクラスの生成に使う String outPath = path + "/" + STAGE_FILE_PREFIX + STAGE_FILE_SEPARATOR + counter + STAGE_FILE_SEPARATOR + tr.getClass().getSimpleName(); //save できないpipeline stageをobjectファイルとして保存する if (tr.getClass().getSimpleName().equalsIgnoreCase("OneVsRestModel")) { OneVsRestModel newModel= (OneVsRestModel) tr; List<OneVsRestModel> rfcList=new ArrayList(); rfcList.add(newModel); jsc.parallelize(rfcList,1).saveAsObjectFile(outPath); } //save できるpipeline stageをsave methodを呼び出す else { ((MLWritable) tr).write().overwrite().save(outPath); } counter++; }ロードする処理の例:
List<Transformer> stages = new ArrayList<>(); // object factory PipelineStageFactory psf = new PipelineStageFactory(); //保存先のファイルリストをloopする for (String fileFullPath: paths) { //ファイル名をparseする String fname = fileFullPath.replace(basePath,""); String[] fnameSplit = fname.split(STAGE_FILE_SEPARATOR); if(!fnameSplit[0].equalsIgnoreCase(STAGE_FILE_PREFIX)) continue; //stage indexとクラス名を探す int stageNum = Integer.parseInt(fnameSplit[1]); String featureType = fnameSplit[2]; //saveしたtransformer内容から、オブジェクトを作成する Transformer transformer = psf.create(featureType, fileFullPath, sc); if (transformer != null) { stages.add(stageNum,transformer); } } PipelineModel pm = new PipelineModel(uid, stages.toArray(new Transformer[0]));
DataframeのVector datatypeについての問題
MLlibの分類アルゴリズム(RandomForest, LogisticRegressionなど)の結果はdataframeに収められます。dataframeで確率が格納されているカラムは、基本的にVectorUDTのタイプになっておりSparkRでは直接アクセスできません。SparkRでアクセスするための一つの方法は、VectorUDTのタイプをArrayに変換することです。変換するためのtransformerサンプルコードは以下で公開しています。https://gist.github.com/zufri/3a5d23afe8dd1c3952e17c8325b6f425
こちらのtransformerをpipeline stageに登録すれば、確率配列のカラムが作成されます。そして、SparkRでは以下のように確率カラムにアクセスすることができます。
#dfはmodelのpredictionした結果のdataframe prob_df <- select(df,c(expr("prob_col_name[0] as prob_1"),expr("prob_col_name[1] as prob_2")))
まとめ
現時点で、SparkRは他のSparkの言語インターフェース(scala, python)に比べて、まだまだ未熟だと思います。Scalaやpythonで標準機能でできることでも、SparkRでは拡張が必要なことが多いです。今後のSpark 2.0、SparkRの機能やサポートがより充実することを期待しています。https://databricks.com/blog/2016/05/11/apache-spark-2-0-technical-preview-easier-faster-and-smarter.html
最後に
次世代システム研究室では、アプリケーション開発や設計を行うアーキテクトを募集しています。アプリケーション開発者の方、次世代システム研究室にご興味を持って頂ける方がいらっしゃいましたら、ぜひ 募集職種一覧 からご応募をお願いします。皆さんのご応募をお待ちしています。