@@ -14,22 +14,14 @@ limitations under the License.
1414package com.engineer.ai.util
1515
1616import android.content.Context
17- import android.content.res.AssetManager
1817import android.graphics.Bitmap
1918import android.util.Log
19+ import androidx.core.graphics.scale
2020import com.google.android.gms.tasks.Task
2121import com.google.android.gms.tasks.TaskCompletionSource
22- import com.google.android.gms.tflite.java.TfLite
2322import org.tensorflow.lite.InterpreterApi
24- import org.tensorflow.lite.TensorFlowLite
25- import java.io.FileInputStream
26- import java.io.IOException
27- import java.nio.ByteBuffer
28- import java.nio.ByteOrder
29- import java.nio.channels.FileChannel
3023import java.util.concurrent.ExecutorService
3124import java.util.concurrent.Executors
32- import androidx.core.graphics.scale
3325
3426class DigitClassifier (private val context : Context ) {
3527
@@ -45,60 +37,33 @@ class DigitClassifier(private val context: Context) {
4537 private var inputImageHeight: Int = 0 // will be inferred from TF Lite model.
4638 private var modelInputSize: Int = 0 // will be inferred from TF Lite model.
4739
48- private val initializeTask: Task <Void > by lazy { TfLite .initialize(context) }
4940 private var interpreter: InterpreterApi ? = null
5041
5142 fun initialize (cb : (Boolean ) -> Unit ) {
52- val assetManager = context.assets
53- val model = loadModelFile(assetManager, " mnist.tflite" )
54-
55- initializeTask.addOnSuccessListener {
56- val interpreterOption =
57- InterpreterApi .Options ().setRuntime(InterpreterApi .Options .TfLiteRuntime .FROM_SYSTEM_ONLY )
58- interpreter = InterpreterApi .create(model, interpreterOption)
59-
60- Log .d(TAG , " ver ${TensorFlowLite .schemaVersion(InterpreterApi .Options .TfLiteRuntime .FROM_SYSTEM_ONLY )} " )
61- Log .d(TAG , " ver ${TensorFlowLite .runtimeVersion(InterpreterApi .Options .TfLiteRuntime .FROM_SYSTEM_ONLY )} " )
62- // Read input shape from model file
63- interpreter?.let {
64- val inputShape = it.getInputTensor(0 ).shape()
65- inputImageWidth = inputShape[1 ]
66- inputImageHeight = inputShape[2 ]
67- modelInputSize = FLOAT_TYPE_SIZE * inputImageWidth * inputImageHeight * PIXEL_SIZE
68-
69-
70- isInitialized = true
71- Log .d(TAG , " Initialized TFLite interpreter." )
72- cb(true )
73- } ? : run {
74- Log .d(TAG , " Initialized TFLite fail." )
43+ TensorFlowLiteHelper .init (context) {
44+ cb(it)
45+ if (it) {
46+ interpreter = TensorFlowLiteHelper .createInterpreterApi(context, " mnist.tflite" )
47+ // Read input shape from model file
48+ interpreter?.let { inter ->
49+ val inputShape = inter.getInputTensor(0 ).shape()
50+ inputImageWidth = inputShape[1 ]
51+ inputImageHeight = inputShape[2 ]
52+ modelInputSize = FLOAT_TYPE_SIZE * inputImageWidth * inputImageHeight * PIXEL_SIZE
53+ isInitialized = true
54+ }
7555 }
76-
77- }.addOnFailureListener { e ->
78- cb(false )
79- Log .e(TAG , " Cannot initialize interpreter" , e)
8056 }
81-
82-
8357 }
8458
85- @Throws(IOException ::class )
86- private fun loadModelFile (assetManager : AssetManager , filename : String ): ByteBuffer {
87- val fileDescriptor = assetManager.openFd(filename)
88- val inputStream = FileInputStream (fileDescriptor.fileDescriptor)
89- val fileChannel = inputStream.channel
90- val startOffset = fileDescriptor.startOffset
91- val declaredLength = fileDescriptor.declaredLength
92- return fileChannel.map(FileChannel .MapMode .READ_ONLY , startOffset, declaredLength)
93- }
9459
9560 private fun classify (bitmap : Bitmap ): String {
9661 check(isInitialized) { " TF Lite Interpreter is not initialized yet." }
97-
98-
9962 // Preprocessing: resize the input image to match the model input shape.
10063 val resizedImage = bitmap.scale(inputImageWidth, inputImageHeight)
101- val byteBuffer = convertBitmapToByteBuffer(resizedImage)
64+ val config = TensorFlowLiteHelper .Config (modelInputSize, inputImageWidth, inputImageHeight)
65+ val byteBuffer = TensorFlowLiteHelper .convertBitmapToByteBuffer(config, resizedImage)
66+
10267 // Define an array to store the model output.
10368 val output = Array (1 ) { FloatArray (OUTPUT_CLASSES_COUNT ) }
10469
@@ -131,25 +96,6 @@ class DigitClassifier(private val context: Context) {
13196 }
13297 }
13398
134- private fun convertBitmapToByteBuffer (bitmap : Bitmap ): ByteBuffer {
135- val byteBuffer = ByteBuffer .allocateDirect(modelInputSize)
136- byteBuffer.order(ByteOrder .nativeOrder())
137-
138- val pixels = IntArray (inputImageWidth * inputImageHeight)
139- bitmap.getPixels(pixels, 0 , bitmap.width, 0 , 0 , bitmap.width, bitmap.height)
140-
141- for (pixelValue in pixels) {
142- val r = (pixelValue shr 16 and 0xFF )
143- val g = (pixelValue shr 8 and 0xFF )
144- val b = (pixelValue and 0xFF )
145-
146- // Convert RGB to grayscale and normalize pixel value to [0..1].
147- val normalizedPixelValue = (r + g + b) / 3.0f / 255.0f
148- byteBuffer.putFloat(normalizedPixelValue)
149- }
150-
151- return byteBuffer
152- }
15399
154100 companion object {
155101 private const val TAG = " DigitClassifier"
0 commit comments