@@ -8,154 +8,109 @@ import android.os.Build
88import android.os.Bundle
99import android.util.Log
1010import android.widget.Toast
11+ import androidx.activity.enableEdgeToEdge
1112import androidx.activity.result.PickVisualMediaRequest
1213import androidx.activity.result.contract.ActivityResultContracts
1314import androidx.activity.result.contract.ActivityResultContracts.PickVisualMedia.Companion.isPhotoPickerAvailable
1415import androidx.appcompat.app.AlertDialog
1516import androidx.appcompat.app.AppCompatActivity
1617import androidx.core.app.ActivityCompat
1718import androidx.core.content.ContextCompat
19+ import androidx.lifecycle.lifecycleScope
1820import com.engineer.ai.databinding.ActivityTansStyleBinding
1921import com.engineer.ai.util.AndroidAssetsFileUtil
20- import kotlinx.coroutines.Dispatchers
22+ import com.engineer.ai.util.AsyncExecutor
23+ import com.engineer.ai.util.StyleTransferProcessor
24+ import com.engineer.ai.util.gone
25+ import com.engineer.ai.util.show
2126import kotlinx.coroutines.GlobalScope
27+ import kotlinx.coroutines.cancel
2228import kotlinx.coroutines.launch
23- import org.pytorch.IValue
2429import org.pytorch.LiteModuleLoader
2530import org.pytorch.Module
26- import org.pytorch.Tensor
27- import org.pytorch.torchvision.TensorImageUtils
2831
2932
3033class FastStyleTransActivity : AppCompatActivity () {
31- private val TAG = " FastStyleTransActivity "
34+ private val TAG = " FastStyleTransActivity_TAG "
3235 private lateinit var module: Module
3336 private val modelName = " mosaic.pt"
3437 private var currentBitmap: Bitmap ? = null
3538
3639 private lateinit var viewBinding: ActivityTansStyleBinding
3740 override fun onCreate (savedInstanceState : Bundle ? ) {
3841 super .onCreate(savedInstanceState)
42+ enableEdgeToEdge()
3943 viewBinding = ActivityTansStyleBinding .inflate(layoutInflater)
4044 setContentView(viewBinding.root)
4145 initModel()
4246 viewBinding.pickImg.setOnClickListener {
4347 pickImage()
4448 }
4549 viewBinding.gen.setOnClickListener {
46- GlobalScope .launch(Dispatchers .IO ) {
50+ refreshLoading(true )
51+ // lifecycleScope.launch {
52+ // genImage()
53+ // }
54+ GlobalScope .launch {
4755 genImage()
4856 }
4957 }
5058 }
5159
60+ override fun onDestroy () {
61+ super .onDestroy()
62+ lifecycleScope.cancel()
63+ }
64+
5265 private fun showBitmap (bitmap : Bitmap ) {
5366 viewBinding.pickResult.setImageBitmap(bitmap)
67+
68+ Log .i(TAG , " ori = ${bitmap.width} ,${bitmap.height} " )
5469 currentBitmap = bitmap
5570 }
5671
57- fun multiplyTensorBy255 (inputTensor : Tensor ): Tensor {
58- // 获取 Tensor 的浮点数组
59- val inputArray = inputTensor.dataAsFloatArray
60-
61- // 创建新数组并乘以255
62- val outputArray = FloatArray (inputArray.size) { i ->
63- inputArray[i] * 255.0f
64- }
65-
66- // 创建新的 Tensor(保持原始形状)
67- return Tensor .fromBlob(outputArray, inputTensor.shape())
72+ private fun refreshLoading (show : Boolean ) {
73+ if (show) viewBinding.loading.show() else viewBinding.loading.gone()
6874 }
6975
70- fun divTensorBy255 (inputTensor : Tensor ): Tensor {
71- // 获取 Tensor 的浮点数组
72- val inputArray = inputTensor.dataAsFloatArray
73-
74- // 创建新数组并乘以255
75- val outputArray = FloatArray (inputArray.size) { i ->
76- inputArray[i] * 255.0f
77- }
78-
79- // 创建新的 Tensor(保持原始形状)
80- return Tensor .fromBlob(outputArray, inputTensor.shape())
81- }
8276
8377 private fun genImage () {
8478
85- val inDims: IntArray = intArrayOf(224 , 224 , 3 )
86- val outDims: IntArray = intArrayOf(224 , 224 , 3 )
87- val bmp: Bitmap ? = null
88- var scaledBmp: Bitmap ? = null
89- val filePath = " "
9079 currentBitmap?.let {
91- scaledBmp = Bitmap .createScaledBitmap(it, inDims[0 ], inDims[1 ], true );
92-
93-
94- // Android更简洁的实现
95- // 转换为张量并归一化到[0,1]
96- val inputTensor: Tensor = TensorImageUtils .bitmapToFloat32Tensor(
97- currentBitmap, floatArrayOf(0f , 0f , 0f ), // 不减去均值
98- floatArrayOf(1f , 1f , 1f ) // 不除以标准差
99- )
100-
101- val tensor = multiplyTensorBy255(inputTensor);
102-
103- Log .i(TAG , " 1" )
104-
105- val resultTensor = module.forward(IValue .from(tensor)).toTensor()
106- val out = divTensorBy255(resultTensor)
107-
108- Log .i(TAG , " 2" )
80+ // AsyncExecutor.fromIO().execute {
81+ // StyleTransferProcessor.initModule(module)
82+ // StyleTransferProcessor.transferStyle(it, 1.0f)
83+ // }.awaitResult<Bitmap>(onSuccess = {
84+ // Log.i(TAG, "onSuccess")
85+ // refreshLoading(false)
86+ // Log.i(TAG, "output ${it.width},${it.height}")
87+ // viewBinding.transResult.setImageBitmap(it)
88+ // }, onError = {
89+ // refreshLoading(false)
90+ // Log.i(TAG, it.stackTraceToString())
91+ // })
92+
93+ AsyncExecutor .fromIO().execute {
94+ StyleTransferProcessor .initModule(module)
95+ // val it = StyleTransferProcessor.transferStyle(it, 1.0f)
96+ //
97+ // withContext(Dispatchers.Main) {
98+ // refreshLoading(false)
99+ // Log.i(TAG, "output ${it.width},${it.height}")
100+ // viewBinding.transResult.setImageBitmap(it)
101+ // }
109102
110- val outputArray = out .dataAsFloatArray
111- val width = outDims[0 ]
112- val height = outDims[1 ]
113- val outputBitmap = Bitmap .createBitmap(width, height, Bitmap .Config .ARGB_8888 )
103+ StyleTransferProcessor .transferStyleAsync(it, 0.5f ) {
104+ runOnUiThread {
105+ refreshLoading(false )
114106
115- // 将浮点数组转换为Bitmap (简化实现,实际可能需要更复杂的转换)
116- for (y in 0 until height) {
117- for (x in 0 until width) {
118- val r = (outputArray[y * width * 3 + x * 3 + 0 ] * 255 ).toInt().coerceIn(0 , 255 )
119- val g = (outputArray[y * width * 3 + x * 3 + 1 ] * 255 ).toInt().coerceIn(0 , 255 )
120- val b = (outputArray[y * width * 3 + x * 3 + 2 ] * 255 ).toInt().coerceIn(0 , 255 )
121- outputBitmap.setPixel(x, y, android.graphics.Color .rgb(r, g, b))
107+ viewBinding.transResult.setImageBitmap(it)
108+ }
122109 }
123110 }
124- Log .i(TAG , " 3" )
125- GlobalScope .launch(Dispatchers .Main ) {
126- viewBinding.transResult.setImageBitmap(outputBitmap)
127- }
128111 }
129112
130- // val zDim = intArrayOf(1, 100)
131- // val outDims = intArrayOf(64, 64, 3)
132- // Log.d(TAG, zDim.contentToString())
133- // val z = FloatArray(zDim[0] * zDim[1])
134- // Log.d(TAG, "z = ${z.contentToString()}")
135- // val rand = Random()
136- // // 生成高斯随机数
137- // for (c in 0 until zDim[0] * zDim[1]) {
138- // z[c] = rand.nextGaussian().toFloat()
139- // }
140- // Log.d(TAG, "z = ${z.contentToString()}")
141- // val shape = longArrayOf(1, 100)
142- // val tensor = Tensor.fromBlob(z, shape)
143- // Log.d(TAG, tensor.dataAsFloatArray.contentToString())
144- // val resultT = module.forward(IValue.from(tensor)).toTensor()
145- // val resultArray = resultT.dataAsFloatArray
146- // val resultImg = Array(outDims[0]) { Array(outDims[1]) { FloatArray(outDims[2]) { 0.0f } } }
147- // var index = 0
148- // // 根据输出的一维数组,解析生成的卡通图像
149- // for (j in 0 until outDims[2]) {
150- // for (k in 0 until outDims[0]) {
151- // for (m in 0 until outDims[1]) {
152- // resultImg[k][m][j] = resultArray[index] * 127.5f + 127.5f
153- // index++
154- // }
155- // }
156- // }
157- // val bitmap = Utils.getBitmap(resultImg, outDims)
158- // viewBinding.transResult.setImageBitmap(bitmap)
113+
159114 }
160115
161116 private fun initModel () {
0 commit comments