2022.10.07

Android上で実行する転移学習実装方法(by Kotlin)

次世代システム研究室のT.Sです。Google Tensorのような機械学習最適化チップが登場するなどモバイル端末上でも機械学習が実行しやすくなるなかで、MLを用いたモバイルアプリというのもかなり普及が進んだように感じます。アプリ実装面でもFirebaseMLのように簡単に機械学習を組み込める仕組みがでてきており、エンジニアとしても利用面ではかなり身近になってきたなあと感じます。
しかしその一方でモバイル端末上で[学習]を行う点については、まだちょっと身近にはなってきていないかとも感じています。そこで今回はTensorflowLiteを利用して、モバイル端末上で転移学習を実施するコードをご紹介したいと思います。


前提: 参考にしたコード


今回ご紹介するコードはTensorflow LiteのGithubで公開されている転移学習のコードをベースにご紹介します。ただこのコード自体はJavaで記載されており読みづらい部分があるのと、Androidアプリ自体が近年Kotlin主体で、非同期処理もCoroutinesで制御している中ではちょっと使いにくいと思う点があったためKotlinで全面的に書き直しました。そのためコード的におかしな部分もあるかと思いますがご容赦ください

実装


サーバサイド: モデル学習


今回はAndorid上で転移学習を実行するということなので、その基礎となるモデルが必要になります。こちらについてはAndroidとは全く関係なく、従来どおりPythonなどでTensorflowを利用して実装・モデル保存していただく形なります。ここはAndroidとは基本切り離された世界ですし、慣れた技術がそのまま使えるため非常に助かりますね。

ただ2点ほど特殊なものが必要となってきます。まず1点目としてモデルを保存する際にはTensorflow Lite形式に変換する必要があります。Android上ではこれまでのモデルファイルそのままではなく、モバイル用に最適化(量子化)されたモデルを利用するため、この作業が必要となります。(量子化については以前書いたこちらのブログも是非ご参照ください)

ただ変換自体はさほど難しい作業ではありません。tf.lite.TFLiteConverterに各種Convert関数が用意されているため、これを利用するのみとなります。例えばKerasモデルを変換する際は以下のようになります。

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

2点目の特殊なものは、@tf.functionデコレータになります。

この先Android上で転移学習や、モデルに基づいた推論を実施する必要があります。そのときに「TensorflowをJavaで書かなくちゃいけないのか… 慣れてないから嫌だなあと…」というのが頭によぎった方も多いのではないかと思います(私もその一人でした)。

しかし今回はその心配はいりません!そういったコードをすべてPython側で書いてしまいましょう!ただしその関数の頭にすべて@tf.functionデコレータをつけるようにしてください。これですべての悩みはなくなることになります!

どうやってそれをAndroid上で実行するの?という疑問の答えについてはこのあとご紹介しますが、これがTensorflowLiteが非常に使いやすいなと思える仕組みの一つでもあります

# 例1: Train
@tf.function(input_signature=[
    tf.TensorSpec([None, NUM_FEATURES], tf.float32),
    tf.TensorSpec([None, NUM_CLASSES], tf.float32),
])
def train(self, bottleneck, label):
....

従来Pythonで記述しているTensorflowの転移学習コードと異なる主な点は以上2点となります。これまで培った技術と大きな差はなく、それがそのまま使えるといったポジティブな印象を受けるのではないでしょうか?

Android: モデル読み込み


さてこれで利用するモデルファイルができたので、まずはこれをAndroid上で読み込んでみましょう
まず作成したファイルはアプリケーションプロジェクト配下のassetsフォルダ内に格納してください。


あとはAssetManagerを利用して該当ファイルをこんな感じで読み込むことになります

/**
 * 指定したfile pathからmodel(tflite)を読み込む
 */
fun loadMappedFile(filePath: String): MappedByteBuffer {
    val fileDescriptor = assetManager.openFd("$directoryName/$filePath")

    val modelByteBuffer = fileDescriptor.run {
        val fileChannel = FileInputStream(this.fileDescriptor).channel
        val startOffset = startOffset
        val declaredLength = declaredLength

        fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
    }

    return modelByteBuffer
}

assetsを取り扱うAssetManagerはContextから取得できるのでJetpackなどでは

val contex. = LocaContext.current
val assetManager: AssetManager = context.assets

こんな感じでとっていただけると。
まあModelLoader自体はアプリから切り出されてAARなどで外部に持ったほうがつかやすいし、実際はassetManagerを引数にとる関数を作ってあげるのがよいかともおもいます(context自体を持ち回るのは色々危険そうなのでassetManagerを渡す方がいいですかね?)

Android: TF Liteラッパークラス(Python定義関数呼び出し)


ボトルネック特徴量


さて実際の学習するためのコードを記載する前にTensorflowLiteのモデルファイルを取り扱うためのラッパー(ヘルパー)クラスを作成しましょう。

何をするためのラッパーなのでしょう?これがPythonで仕込んだ@tf.functionデコレータを取り扱うものになります。転移学習の訓練をするコードはPython側に記載して@tf.functionデコレータをつけよう!と前述したかと思いますが、まさにこれを取り扱うものになります。

# 例1: Train(再掲)
@tf.function(input_signature=[
    tf.TensorSpec([None, NUM_FEATURES], tf.float32),
    tf.TensorSpec([None, NUM_CLASSES], tf.float32),
])
def train(self, bottleneck, label):
....

今回はボトルネック特徴量/学習/推論の3つが@tf.functionデコレータをつけて定義されているので、これらを呼び出すラッパークラスを作成することになります。

まずボトルネック特徴量からです。転移学習ではまずFC層を除いて変換した特徴量を抽出し、これをInputとして学習/識別させることになりますが、この特徴量をボトルネック特徴量と呼んだりします。コードとしては以下のような感じになります

private val interpreter: Interpreter = Interpreter(tfLiteModel)

    fun loadBottleneck(
        image: Array<Array<FloatArray>>,
    ): FloatArray {
        // 入力は画像そのもの
        val inputs =
            mutableMapOf<String, Any>()
                .apply {
                    this["feature"] = arrayOf(image)
                }

        // Outputはボトルネック特徴量だがTFLiteに合わせ<String, Any>のMapで渡す
        val bottleneck = Array(1) { FloatArray(BOTTLENECK_SIZE) }
        val outputs =
            mutableMapOf<String, Any>()
                .apply {
                    this["bottleneck"] = bottleneck
                }

        // runSignatureはTFLite特有の概念
        // python側に事前に処理を記載し、これをrunSignature経由でCallできる
        // 何をしているかはpython側参照
        interpreter.runSignature(inputs, outputs, "load")

        return bottleneck[0]
    }

ここで重要なのが以下ですね.このrunSignatureになります。Python側でload関数で作成されているので、この名前を第3引数に与え実行すると、該当関数が実行されるということになります。便利!

interpreter.runSignature(inputs, outputs, "load")

ちなみにInput値は@tf.functionに以下のように記載されて言います。この場合はFloatの配列になっていますね

tf.TensorSpec([None, IMG_SIZE, IMG_SIZE, 3], tf.float32),

そのためKotlin側では以下のように変換して、Float配列に変換与えるようになっています

val inputs =
    mutableMapOf<String, Any>()
        .apply {
            this["feature"] = arrayOf(image)
        }

学習/推論


学習/推論も同様ですinputを要求される形に変換し、Outputを格納する箱を用意して runSignature を実行するという流れになります。参考までにコードをいかに記載いたします。

    /**
     * ボトルネック特徴量と正解ラベルを用いて学習させる(バッチ単位)
     *
     * @param bottlenecks バッチ単位(BATCH_SIZE, BOTTLENECK_SIZE)
     * @param labels バッチ単位(BATCH_SIZE, NUM_CLASSES)
     */
    fun runTraining(
        bottlenecks: Array<FloatArray>,
        labels: Array<FloatArray>,
    ): Float {
        // 入力は学習に必要なボトルネック特徴量とラベル
        val inputs =
            mutableMapOf<String, Any>()
                .apply {
                    this["bottleneck"] = bottlenecks
                    this["label"] = labels
                }

        // OutputはLoss値
        val loss = FloatBuffer.allocate(1)
        val outputs =
            mutableMapOf<String, Any>()
                .apply {
                    this["loss"] = loss
                }

        Log.d(" runSignature Start", "DEBUG")
        Log.d(" input $labels ", "DEBUG")
        interpreter.runSignature(inputs, outputs, "train")
        Log.d(" runSignature End", "DEBUG")

        return loss[0]
    }

    /**
     * 画像を与え、ラベル識別を実行する
     *
     * @param testImage カラー画像(IMG_SIZE, IMG_SIZE, 3)
     */
    fun runInference(
        testImage: Array<Array<FloatArray>>,
    ): FloatArray {
        // inputは画像そのもの
        val inputs =
            mutableMapOf<String, Any>()
                .apply {
                    this["feature"] = arrayOf(testImage)
                }

        // outputは正解ラベル
        val output = Array(1) { FloatArray(numClasses) }
        val outputs =
            mutableMapOf<String, Any>()
                .apply {
                    this["output"] = output
                }

        interpreter.runSignature(inputs, outputs, "infer")

        return output[0]
    }

完全に余談ですがKotlinにはスコープ関数(let/run/with/apply/also)というのが用意されているのですが、これらを使うとコードが美しくなるので非常に気に入っています。今回もinputで使っていますが、宣言と値格納が一つの命令行で収まるのが割と地味に綺麗さをキープしてくれます。

Android: 転移学習


さてこれで準備は整いましたので、実際に学習していくコードを書いていきましょう。ただコード全体はちょっと長くなるため、末尾に添付した上で、重要な部分のみ解説することにいたします

サンプルデータ追加


サンプルデータ自体はTrainingSampleというSampleを格納するためのdata classに格納し、これをtrainingSamples: List に格納するだけのものになっています。

ただここで重要なのはsuspend関数とwithContextによるDispather(スレッド)指定になります。これを使いたくてKotlinに書き直したと言っても過言ではありません。従来のJavaでThread処理を扱うのであれば、Thread{..}.start()のような構文で書くことが有るかと思います。

通常であればそれでもまあ事足りるのではありますが、Androidでは一画面一アプリが専用する中でこれをうまく取り扱うために様々なライフサイクル、スレッド、Activityなどが存在し、これらのなかでどうやってアプリの動きを制御するかが肝になります。Google自体をこれを制御するためにKotlin-Coroutine等の仕組みを使い、色々提供しています。その内容については今回のブログの範囲を大きく逸脱するので割愛しますが、今回の機械学習のような重い処理を扱う際にはやはりこの仕組を使いたくKotlinで実装し直したというわけになります。

    /**
     * サンプルデータ追加
     */
    suspend fun addSample(
        image: Array<Array<FloatArray>>,
        className: String,
    ) = withContext(Dispatchers.Default) {
        classIndexes
            .takeIf { it.containsKey(className) }
            ?: throw IllegalArgumentException("Class $className is not one of the classes recognized by the model")

        oneHotEncodedClass[className]?.let {
            val bottleneck = model.loadBottleneck(image)
            trainingSamples = trainingSamples + TrainingSample(bottleneck, it)
        }
            ?: throw IllegalArgumentException("Class $className is not one of the oneHotEncodedClass")
    }

学習


こちらも基本思想は変わらず、Trainingにあたる部分をDispatchers.DefaultでUIとは異なるスレッドに流し実行しています。

ただここはちょっと実装を変えた方がいいのかなあという気がしています。学習自体は転移学習とは言え実行時間はそれなりにかかることが予想されます。その際に、このようにスレッドで動かした状態で、違うアプリに遷移すると、該当Activityが裏でKillされて実行が止まるのではという心配があり。。。WorkManagerに逃がすとか他の手段を考える必要があるのかなあと漠然と考えていはいますが、実際どうするかは来月以降に検証しようかと考えています。

そういう意味でこのコードは仮だとお考えいただければ嬉しいです。(TrainigSamplesも本当はIterationにしたほうがよいとか改善点も後々修正する必要はあるかとも考えています)

        withContext(Dispatchers.Default) {
            // エポック数分実行する
            (0..numEpochs).forEach { epoch ->
                var totalLoss = 0.0f
                var numBatchesProcessed = 0

                // バッチ単位で取得できるようWrapperクラスを利用する
                val trainingSamples =
                    TrainingSamples(
                        data = trainingSamples,
                        trainBatchSize = trainBatchSize,
                    )

                // バッチ単位の学習データがある分だけ実行する
                while (trainingSamples.hasNext()) {

                    val batch = trainingSamples.next()

                    // 学習データ一つずつを詰めていく
                    batch.forEachIndexed { i, sample ->
                        trainingBatchBottlenecks[i] = sample.bottleneck
                        trainingBatchLabels[i] = sample.label
                    }

                    val loss: Float =
                        model.runTraining(trainingBatchBottlenecks, trainingBatchLabels)

                    totalLoss += loss
                    numBatchesProcessed++
                }

                val avgLoss = totalLoss / numBatchesProcessed
                lossConsumer.onLoss(epoch, avgLoss)
                Log.d("$epoch all End" , "DEBUG")
            }


参考:コード全体


最後にコード全体を記載いたします。先程言及したようにまだ仮な部分はありますので参考程度に見ていただければと

class TransferLearningModel(
    modelLoader: ModelLoader,
    private var classes: Collection<String>,
    private var classNum: Int = 4,
) : Closeable {

    data class Prediction(
        val className: String,
        val confidence: Float,
    )

    data class TrainingSample (
        var bottleneck: FloatArray,
        var label: FloatArray,
    )

    class TrainingSamples constructor(
        private var nextIndex: Int = 0,
        private var trainBatchSize: Int,
        private var data: List<TrainingSample>,
    ) {
        init {
            data = data.shuffled()
        }

        fun hasNext(): Boolean {
            return nextIndex < data.size
        }

        fun next(): List<TrainingSample> {
            val fromIndex = nextIndex
            val toIndex =
                (nextIndex + trainBatchSize)
                    .also {
                        nextIndex = it
                    }

            if (toIndex >= data.size) {
                return data.subList(
                    data.size - trainBatchSize, data.size
                )
            }
            return data.subList(fromIndex, toIndex)
        }
    }

    interface LossConsumer {
        fun onLoss(epoch: Int, loss: Float)
    }

    private var model: LiteMultipleSignatureModel

    private val classIndexes = TreeMap<String, Int>() // {className to Index}
    private val oneHotEncodedClass =
        mutableMapOf<String, FloatArray>() // {className, Array(該当Classのみ1,その他0)}

    private var trainingSamples: List<TrainingSample> = mutableListOf()

    init {
        // model load
        model =
            runCatching {
                LiteMultipleSignatureModel(
                    modelLoader.loadMappedFile(MODEL_FILE_NAME),
                    classes.size,
                )
            }.getOrNull()
                ?: throw RuntimeException("Couldn't read underlying model for TransferLearningModel")

        // classesをone hotエンコーディングに変更
        classes.forEachIndexed { index, s ->
            classIndexes[s] = index
            oneHotEncodedClass[s] = oneHotEncoding(index)
        }
    }

    /**
     * サンプルデータ追加
     */
    suspend fun addSample(
        image: Array<Array<FloatArray>>,
        className: String,
    ) = withContext(Dispatchers.Default) {
        classIndexes
            .takeIf { it.containsKey(className) }
            ?: throw IllegalArgumentException("Class $className is not one of the classes recognized by the model")

        oneHotEncodedClass[className]?.let {
            val bottleneck = model.loadBottleneck(image)
            trainingSamples = trainingSamples + TrainingSample(bottleneck, it)
        }
            ?: throw IllegalArgumentException("Class $className is not one of the oneHotEncodedClass")
    }

    suspend fun train(
        numEpochs: Int = 1,
        lossConsumer: LossConsumer
    ) {
        // batch size は最低1とする
        val trainBatchSize =
            getTrainBatchSize()
                .also {
                    if (trainingSamples.size < it) {
                        throw java.lang.RuntimeException("Too few samples to start training: need $it, got $trainingSamples.size")
                    }
                }

        // 学習に必要なボトルネック特徴量とラベルを取得する
        val trainingBatchBottlenecks =
            Array(trainBatchSize) { FloatArray(model.getNumBottleneckFeatures()) }
        val trainingBatchLabels = Array(trainBatchSize) { FloatArray(classIndexes.size) }

        withContext(Dispatchers.Default) {
            // エポック数分実行する
            (0..numEpochs).forEach { epoch ->
                var totalLoss = 0.0f
                var numBatchesProcessed = 0

                // バッチ単位で取得できるようWrapperクラスを利用する
                val trainingSamples =
                    TrainingSamples(
                        data = trainingSamples,
                        trainBatchSize = trainBatchSize,
                    )

                // バッチ単位の学習データがある分だけ実行する
                while (trainingSamples.hasNext()) {

                    val batch = trainingSamples.next()

                    // 学習データ一つずつを詰めていく
                    batch.forEachIndexed { i, sample ->
                        trainingBatchBottlenecks[i] = sample.bottleneck
                        trainingBatchLabels[i] = sample.label
                    }

                    Log.d("$epoch training Start" , "DEBUG")
                    val loss: Float =
                        model.runTraining(trainingBatchBottlenecks, trainingBatchLabels)
                    Log.d("$epoch training End" , "DEBUG")

                    totalLoss += loss
                    numBatchesProcessed++
                }

                val avgLoss = totalLoss / numBatchesProcessed
                lossConsumer.onLoss(epoch, avgLoss)
                Log.d("$epoch all End" , "DEBUG")
            }
        }
    }

    fun predict(
        image: Array<Array<FloatArray>>
    ): Array<Prediction> {

        /* TODO */
    }

    fun getTrainBatchSize(): Int {
        return minOf(
            maxOf(trainingSamples.size, 1),
            model.getExpectedBatchSize()
        )
    }

    private fun oneHotEncoding(classIdx: Int): FloatArray =
        FloatArray(classNum)
            .apply { this[classIdx] = 1.0f }

    override fun close() {
        model.close()
    }

    companion object {
        private const val MODEL_FILE_NAME = "model.tflite"
    }

Android: 実行


さてここまでお膳だてできていれば、あとは TransferLearningModel#addSample/Train を呼び出すだけになります。これ自体はsuspend関数なので、通常通り viewModelScope.launchなどなどで実行してもらえればと思います。

最後に


今回のブログではモバイル端末上で転移学習をKotlinで実装したコードのご紹介をしました。次回以降はこれを踏まえた上で、実際にAndroid上で機械学習を実行して、どういった動作をするか(主にメモリや実行時間における検証)を行っていきたいと思います。


グループ研究開発本部 AI研究開発室では、データサイエンティスト/機械学習エンジニアを募集しています。ビッグデータの解析業務などAI研究開発室にご興味を持って頂ける方がいらっしゃいましたら、ぜひ 募集職種一覧 からご応募をお願いします。皆さんのご応募をお待ちしています。

  • Twitter
  • Facebook
  • はてなブックマークに追加

グループ研究開発本部の最新情報をTwitterで配信中です。ぜひフォローください。

 
  • AI研究開発室
  • 大阪研究開発グループ

関連記事