@@ -16,29 +16,45 @@ import java.nio.channels.FileChannel
1616
1717object TensorFlowLiteHelper {
1818 private const val TAG = " TensorFlowLiteHelper"
19- private lateinit var initializeTask: Task <Void >
20- private var interpreter: InterpreterApi ? = null
2119
20+ /* *
21+ * Try to init Google Play Services TFLite (dynamite module).
22+ *
23+ * This will fail on devices without Google Play Services (e.g. many CN ROMs).
24+ * We treat failure as non-fatal and fall back to bundled TFLite runtime.
25+ */
2226 fun init (context : Context , cb : (Boolean ) -> Unit ) {
23- initializeTask = TfLite .initialize(context)
24- initializeTask.addOnSuccessListener {
25- Log .d(TAG , " Initialized TFLite interpreter." )
26- Log .d(TAG , " ver ${TensorFlowLite .schemaVersion(InterpreterApi .Options .TfLiteRuntime .FROM_SYSTEM_ONLY )} " )
27- Log .d(TAG , " ver ${TensorFlowLite .runtimeVersion(InterpreterApi .Options .TfLiteRuntime .FROM_SYSTEM_ONLY )} " )
28- cb(true )
29- }.addOnFailureListener {
30- Log .d(TAG , " Initialized TFLite fail" )
31- cb(false )
32- Log .e(TAG , " error " , it)
33- }
27+ TfLite .initialize(context)
28+ .addOnSuccessListener {
29+ Log .d(TAG , " Initialized Play Services TFLite." )
30+ try {
31+ Log .d(
32+ TAG ,
33+ " schema=${TensorFlowLite .schemaVersion(InterpreterApi .Options .TfLiteRuntime .FROM_SYSTEM_ONLY )} runtime=${TensorFlowLite .runtimeVersion(InterpreterApi .Options .TfLiteRuntime .FROM_SYSTEM_ONLY )} "
34+ )
35+ } catch (t: Throwable ) {
36+ Log .w(TAG , " Unable to query system-only TFLite version." , t)
37+ }
38+ cb(true )
39+ }
40+ .addOnFailureListener { e ->
41+ Log .w(TAG , " Play Services TFLite init failed; will fall back to bundled runtime." , e)
42+ cb(false )
43+ }
3444 }
3545
36- fun createInterpreterApi (context : Context , modelName : String ): InterpreterApi ? {
46+ fun createInterpreterApi (context : Context , modelName : String , preferPlayServices : Boolean ): InterpreterApi {
3747 val model = loadModelFile(context.assets, modelName)
38- val interpreterOption =
39- InterpreterApi .Options ().setRuntime(InterpreterApi .Options .TfLiteRuntime .FROM_SYSTEM_ONLY )
40- interpreter = InterpreterApi .create(model, interpreterOption)
41- return interpreter
48+
49+ val runtime = if (preferPlayServices) {
50+ InterpreterApi .Options .TfLiteRuntime .FROM_SYSTEM_ONLY
51+ } else {
52+ // Bundled runtime provided by org.tensorflow:tensorflow-lite
53+ InterpreterApi .Options .TfLiteRuntime .FROM_APPLICATION_ONLY
54+ }
55+
56+ val options = InterpreterApi .Options ().setRuntime(runtime)
57+ return InterpreterApi .create(model, options)
4258 }
4359
4460 @Throws(IOException ::class )
0 commit comments