1+ package com.engineer.ai
2+
3+ import android.os.Bundle
4+ import android.util.Log
5+ import androidx.appcompat.app.AppCompatActivity
6+ import com.engineer.ai.databinding.ActivityGanBinding
7+ import com.engineer.ai.databinding.ActivityTansStyleBinding
8+ import com.engineer.ai.util.AndroidAssetsFileUtil
9+ import com.engineer.ai.util.Utils
10+ import org.pytorch.IValue
11+ import org.pytorch.LiteModuleLoader
12+ import org.pytorch.Module
13+ import org.pytorch.Tensor
14+ import java.util.Random
15+
16+
17+ class FastStyleTransActivity : AppCompatActivity () {
18+ private val TAG = " FastStyleTransActivity"
19+ private lateinit var module: Module
20+ private val modelName = " dcgan.pt"
21+
22+ private lateinit var viewBinding: ActivityTansStyleBinding
23+ override fun onCreate (savedInstanceState : Bundle ? ) {
24+ super .onCreate(savedInstanceState)
25+ viewBinding = ActivityTansStyleBinding .inflate(layoutInflater)
26+ setContentView(viewBinding.root)
27+ initModel()
28+ viewBinding.gen.setOnClickListener {
29+ genImage()
30+ }
31+ }
32+
33+ private fun genImage () {
34+ val zDim = intArrayOf(1 , 100 )
35+ val outDims = intArrayOf(64 , 64 , 3 )
36+ Log .d(TAG , zDim.contentToString())
37+ val z = FloatArray (zDim[0 ] * zDim[1 ])
38+ Log .d(TAG , " z = ${z.contentToString()} " )
39+ val rand = Random ()
40+ // 生成高斯随机数
41+ for (c in 0 until zDim[0 ] * zDim[1 ]) {
42+ z[c] = rand.nextGaussian().toFloat()
43+ }
44+ Log .d(TAG , " z = ${z.contentToString()} " )
45+ val shape = longArrayOf(1 , 100 )
46+ val tensor = Tensor .fromBlob(z, shape)
47+ Log .d(TAG , tensor.dataAsFloatArray.contentToString())
48+ val resultT = module.forward(IValue .from(tensor)).toTensor()
49+ val resultArray = resultT.dataAsFloatArray
50+ val resultImg = Array (outDims[0 ]) { Array (outDims[1 ]) { FloatArray (outDims[2 ]) { 0.0f } } }
51+ var index = 0
52+ // 根据输出的一维数组,解析生成的卡通图像
53+ for (j in 0 until outDims[2 ]) {
54+ for (k in 0 until outDims[0 ]) {
55+ for (m in 0 until outDims[1 ]) {
56+ resultImg[k][m][j] = resultArray[index] * 127.5f + 127.5f
57+ index++
58+ }
59+ }
60+ }
61+ val bitmap = Utils .getBitmap(resultImg, outDims)
62+ viewBinding.transResult.setImageBitmap(bitmap)
63+ }
64+
65+ private fun initModel () {
66+ module = LiteModuleLoader .load(AndroidAssetsFileUtil .assetFilePath(this , modelName))
67+ }
68+
69+
70+ }
0 commit comments