+3
-1
posedetection/build.gradle.kts
+3
-1
posedetection/build.gradle.kts
···
1
1
import com.android.build.api.dsl.AaptOptions
2
2
import com.android.build.api.dsl.AndroidResources
3
3
import com.vanniktech.maven.publish.SonatypeHost
4
+
import org.gradle.kotlin.dsl.implementation
4
5
import org.jetbrains.compose.ExperimentalComposeLibrary
5
6
import org.jetbrains.kotlin.gradle.plugin.KotlinSourceSetTree
6
7
···
90
91
implementation(libs.pose.detection)
91
92
implementation(libs.pose.detection.common)
92
93
implementation(libs.androidx.media3.common.ktx)
93
-
implementation(libs.tensorflow.lite.task.vision)
94
+
implementation("org.tensorflow:tensorflow-lite:2.17.0")
95
+
implementation("org.tensorflow:tensorflow-lite-support:0.5.0")
94
96
}
95
97
96
98
}
+10
-11
posedetection/src/androidMain/kotlin/com/performancecoachlab/posedetection/camera/Utils.android.kt
+10
-11
posedetection/src/androidMain/kotlin/com/performancecoachlab/posedetection/camera/Utils.android.kt
···
8
8
import com.google.android.gms.tasks.Tasks
9
9
import com.google.mlkit.vision.common.InputImage
10
10
import com.google.mlkit.vision.pose.PoseDetector
11
+
import com.performancecoachlab.posedetection.custom.YoloTFLite
11
12
import com.performancecoachlab.posedetection.recording.AnalysisObject
12
13
import com.performancecoachlab.posedetection.recording.AnalysisResult
13
14
import com.performancecoachlab.posedetection.recording.FrameSize
···
24
25
25
26
@OptIn(ExperimentalGetImage::class)
26
27
fun ImageProxy.process(
27
-
objectDetector: org.tensorflow.lite.task.vision.detector.ObjectDetector?,
28
+
objectDetector: YoloTFLite?,
28
29
poseDetector: PoseDetector?,
29
30
timestamp: Long,
30
31
focusArea: Rect?,
···
58
59
}
59
60
60
61
fun Bitmap.process(
61
-
objectDetector: org.tensorflow.lite.task.vision.detector.ObjectDetector?,
62
+
objectDetector: YoloTFLite?,
62
63
poseDetector: PoseDetector,
63
64
timestamp: Long,
64
65
focusArea: Rect?,
···
169
170
private fun process(
170
171
tensorImage: TensorImage,
171
172
mlKitImage: InputImage?,
172
-
objectDetector: org.tensorflow.lite.task.vision.detector.ObjectDetector?,
173
+
objectDetector: YoloTFLite?,
173
174
poseDetector: PoseDetector?,
174
175
timestamp: Long,
175
176
width: Int,
···
177
178
bitmap: Bitmap,
178
179
onComplete: (AnalysisResult, Bitmap) -> Unit
179
180
) {
180
-
val objectsDetected = objectDetector?.detect(tensorImage)?.map { result ->
181
+
val objectsDetected = objectDetector?.detect(bitmap)?.map { result ->
181
182
AnalysisObject(
182
-
boundingBox = result.boundingBox.let {
183
+
boundingBox = result.bbox.let {
183
184
Rect(
184
185
left = it.left,
185
186
top = it.top,
···
188
189
)
189
190
},
190
191
trackingId = 0,
191
-
labels = result.categories.map { category ->
192
-
com.performancecoachlab.posedetection.recording.Label(
193
-
category.label,
194
-
category.score
195
-
)
196
-
},
192
+
labels = listOf(com.performancecoachlab.posedetection.recording.Label(
193
+
result.classId.toString(),
194
+
result.score
195
+
)),
197
196
frameSize = FrameSize(
198
197
width = width,
199
198
height = height
+489
-12
posedetection/src/androidMain/kotlin/com/performancecoachlab/posedetection/custom/CustomObjectModel.android.kt
+489
-12
posedetection/src/androidMain/kotlin/com/performancecoachlab/posedetection/custom/CustomObjectModel.android.kt
···
1
1
package com.performancecoachlab.posedetection.custom
2
2
3
+
import android.content.Context
4
+
import android.graphics.Bitmap
5
+
import android.graphics.RectF
3
6
import androidx.compose.runtime.Composable
4
7
import androidx.compose.ui.platform.LocalContext
8
+
import org.tensorflow.lite.Interpreter
9
+
import org.tensorflow.lite.Tensor
10
+
import org.tensorflow.lite.DataType
11
+
import org.tensorflow.lite.support.common.FileUtil
12
+
import java.nio.ByteBuffer
13
+
import java.nio.ByteOrder
14
+
import kotlin.math.max
15
+
import kotlin.math.min
16
+
import androidx.core.graphics.scale
17
+
import co.touchlab.kermit.Logger
5
18
6
19
@Composable
7
20
actual fun initialiseObjectModel(modelPath: ModelPath): ObjectModel {
8
-
val options = org.tensorflow.lite.task.vision.detector.ObjectDetector.ObjectDetectorOptions.builder().setMaxResults(5).setScoreThreshold(0f).build()
9
-
val detector = org.tensorflow.lite.task.vision.detector.ObjectDetector.createFromFileAndOptions(
10
-
LocalContext.current,
11
-
modelPath.androidModelPath,
12
-
options
13
-
)
21
+
val detector = modelPath.androidModelPath?.let { YoloTFLite(LocalContext.current, it, confThreshold = 0.01f) }
14
22
return ObjectModel(detector)
15
23
}
16
24
17
25
actual class ObjectModel{
18
-
private var detector: org.tensorflow.lite.task.vision.detector.ObjectDetector? = null
26
+
private var detector: YoloTFLite? = null
27
+
28
+
constructor(detector: YoloTFLite?){
29
+
this.detector = detector }
30
+
31
+
fun getDetector(): YoloTFLite? {
32
+
return detector
33
+
}
34
+
35
+
}
36
+
37
+
// ModelInfo holds the inferred layout and quantization metadata needed to build inputs and parse outputs.
38
+
data class ModelInfo(
39
+
val inputWidth: Int,
40
+
val inputHeight: Int,
41
+
val inputChannels: Int,
42
+
val inputIsFloat: Boolean,
43
+
val inputScale: Float?,
44
+
val inputZeroPoint: Int?,
45
+
val inputLayoutNHWC: Boolean,
46
+
val normalizeToMinusOneOne: Boolean,
47
+
val outputFeatDimIndex: Int,
48
+
val outputNumBoxesIndex: Int?,
49
+
val outputFeatDim: Int,
50
+
val outputNumBoxes: Int,
51
+
val outputIsFloat: Boolean,
52
+
val outputScale: Float?,
53
+
val outputZeroPoint: Int?
54
+
)
55
+
56
+
class YoloTFLite(
57
+
context: Context,
58
+
modelPathInAssets: String,
59
+
private val confThreshold: Float = 0.25f
60
+
) {
61
+
private val interpreter: Interpreter
62
+
private val inputTensor: Tensor
63
+
private val modelInfo: ModelInfo
64
+
65
+
init {
66
+
val model = FileUtil.loadMappedFile(context, modelPathInAssets)
67
+
interpreter = Interpreter(model, Interpreter.Options())
68
+
inputTensor = interpreter.getInputTensor(0)
69
+
modelInfo = inspectModel(interpreter)
70
+
Logger.d { "YoloTFLite modelInfo=$modelInfo" }
71
+
}
72
+
73
+
@Suppress("unused")
74
+
fun close() = interpreter.close()
75
+
76
+
fun detect(src: Bitmap): List<Detection> {
77
+
// 1) Preprocess
78
+
val resized = src.scale(modelInfo.inputWidth, modelInfo.inputHeight)
79
+
val inputBuffer = makeInputBuffer(resized)
80
+
81
+
// 2) Prepare output
82
+
val outTensor = interpreter.getOutputTensor(0)
83
+
val outShape = outTensor.shape() // flexible
84
+
val featDim = modelInfo.outputFeatDim
85
+
val nBoxes = modelInfo.outputNumBoxes
86
+
val outIsFloat = modelInfo.outputIsFloat
87
+
val outScale = modelInfo.outputScale
88
+
val outZero = modelInfo.outputZeroPoint
89
+
90
+
// allocate outRaw to exactly match the output tensor shape
91
+
fun makeOutputArray(shape: IntArray, isFloat: Boolean): Any {
92
+
fun build(dim: Int): Any {
93
+
val size = shape[dim]
94
+
return if (dim == shape.lastIndex) {
95
+
if (isFloat) FloatArray(size) else ByteArray(size)
96
+
} else {
97
+
val arr = arrayOfNulls<Any>(size)
98
+
for (i in 0 until size) arr[i] = build(dim + 1)
99
+
arr
100
+
}
101
+
}
102
+
return build(0)
103
+
}
104
+
105
+
val outRaw: Any = makeOutputArray(outShape, outIsFloat)
106
+
107
+
// storage for fallback flat outputs if nested array copy fails
108
+
var flatFloatOut: FloatArray? = null
109
+
var flatByteOut: ByteArray? = null
110
+
var usedFlatBuffer = false
111
+
112
+
// 3) Inference
113
+
try {
114
+
interpreter.run(inputBuffer, outRaw)
115
+
} catch (e: IllegalArgumentException) {
116
+
// compute constructed Java array shape for logging
117
+
fun constructedShape(raw: Any): List<Int> {
118
+
val dims = mutableListOf<Int>()
119
+
var cur: Any? = raw
120
+
while (true) {
121
+
when (cur) {
122
+
is Array<*> -> {
123
+
dims.add(cur.size)
124
+
cur = if (cur.isNotEmpty()) cur[0] else break
125
+
}
126
+
is FloatArray -> {
127
+
dims.add(cur.size)
128
+
break
129
+
}
130
+
is ByteArray -> {
131
+
dims.add(cur.size)
132
+
break
133
+
}
134
+
else -> break
135
+
}
136
+
}
137
+
return dims
138
+
}
139
+
140
+
val constructed = constructedShape(outRaw)
141
+
Logger.e { "Interpreter run failed: ${e.message} -- tensorShape=${outShape.joinToString(prefix="[", postfix="]")} constructed=${constructed.joinToString(prefix="[", postfix="]")} modelInfo=$modelInfo" }
142
+
143
+
// Attempt fallback: run into a direct ByteBuffer and parse the raw bytes.
144
+
val totalElements = outShape.fold(1) { acc, v -> acc * v }
145
+
try {
146
+
if (outIsFloat) {
147
+
val bb = ByteBuffer.allocateDirect(totalElements * 4).order(ByteOrder.nativeOrder())
148
+
interpreter.run(inputBuffer, bb)
149
+
bb.rewind()
150
+
val floats = FloatArray(totalElements)
151
+
for (i in 0 until totalElements) floats[i] = bb.float
152
+
flatFloatOut = floats
153
+
usedFlatBuffer = true
154
+
} else {
155
+
val bb = ByteBuffer.allocateDirect(totalElements).order(ByteOrder.nativeOrder())
156
+
interpreter.run(inputBuffer, bb)
157
+
bb.rewind()
158
+
val bytes = ByteArray(totalElements)
159
+
bb.get(bytes)
160
+
flatByteOut = bytes
161
+
usedFlatBuffer = true
162
+
}
163
+
} catch (e2: Exception) {
164
+
Logger.e { "Fallback interpreter run also failed: ${e2.message}" }
165
+
throw e // rethrow original
166
+
}
167
+
}
168
+
169
+
// dynamic getter to read a value from nested outRaw at specified multi-dimensional indices
170
+
fun getValueAt(raw: Any, indices: IntArray, isFloat: Boolean, scale: Float?, zero: Int?): Float {
171
+
if (usedFlatBuffer) {
172
+
val totalShape = outShape
173
+
// compute linear index from multi-dimensional indices (row-major)
174
+
var idx = 0
175
+
var stride = 1
176
+
// compute strides
177
+
val rank = totalShape.size
178
+
val strides = IntArray(rank)
179
+
strides[rank - 1] = 1
180
+
for (r in rank - 2 downTo 0) strides[r] = strides[r + 1] * totalShape[r + 1]
181
+
for (r in 0 until rank) idx += indices[r] * strides[r]
182
+
return if (isFloat) {
183
+
flatFloatOut!![idx]
184
+
} else {
185
+
val v = flatByteOut!![idx].toInt() and 0xFF
186
+
(v - (zero ?: 0)) * (scale ?: 1f)
187
+
}
188
+
}
189
+
var cur: Any = raw
190
+
val last = indices.lastIndex
191
+
for (d in 0 until last) {
192
+
@Suppress("UNCHECKED_CAST")
193
+
cur = (cur as Array<*>)[indices[d]] as Any
194
+
}
195
+
return if (isFloat) {
196
+
val arr = cur as FloatArray
197
+
arr[indices[last]]
198
+
} else {
199
+
val arr = cur as ByteArray
200
+
val v = arr[indices[last]].toInt() and 0xFF
201
+
(v - (zero ?: 0)) * (scale ?: 1f)
202
+
}
203
+
}
204
+
205
+
// helper to get a single box's feature vector (length = featDim) using configured indices
206
+
fun getRow(boxIdx: Int): FloatArray {
207
+
val row = FloatArray(featDim)
208
+
val rank = outShape.size
209
+
val indices = IntArray(rank) { 0 }
210
+
// assume batch dimension is 0 and set it to 0
211
+
if (rank > 0) indices[0] = 0
212
+
val numBoxesIndex = modelInfo.outputNumBoxesIndex
213
+
for (f in 0 until featDim) {
214
+
// set feature index
215
+
indices[modelInfo.outputFeatDimIndex] = f
216
+
// set box index only if tensor has a separate boxes dimension
217
+
if (numBoxesIndex != null) {
218
+
indices[numBoxesIndex] = boxIdx
219
+
}
220
+
row[f] = getValueAt(outRaw, indices, outIsFloat, outScale, outZero)
221
+
}
222
+
return row
223
+
}
224
+
225
+
fun scaleAndClamp(x1: Float, y1: Float, x2: Float, y2: Float, xs: Float, ys: Float): RectF {
226
+
val sx1 = (x1 * xs).coerceIn(0f, src.width.toFloat())
227
+
val sy1 = (y1 * ys).coerceIn(0f, src.height.toFloat())
228
+
val sx2 = (x2 * xs).coerceIn(0f, src.width.toFloat())
229
+
val sy2 = (y2 * ys).coerceIn(0f, src.height.toFloat())
230
+
return RectF(min(sx1, sx2), min(sy1, sy2), max(sx1, sx2), max(sy1, sy2))
231
+
}
232
+
233
+
val scaleNormX = src.width.toFloat()
234
+
val scaleNormY = src.height.toFloat()
235
+
val scaleInputX = src.width.toFloat() / modelInfo.inputWidth
236
+
val scaleInputY = src.height.toFloat() / modelInfo.inputHeight
237
+
238
+
val detections = mutableListOf<Detection>()
239
+
240
+
for (i in 0 until nBoxes) {
241
+
val row = getRow(i)
242
+
if (row.size < 6) continue
243
+
val score = row[4]
244
+
if (score < confThreshold) continue
245
+
val clsId = row[5].toInt()
246
+
247
+
val ax1 = row[0]; val ay1 = row[1]; val ax2 = row[2]; val ay2 = row[3]
248
+
val cx = row[0]; val cy = row[1]; val w = row[2]; val h = row[3]
249
+
val bx1 = cx - w / 2f; val by1 = cy - h / 2f; val bx2 = cx + w / 2f; val by2 = cy + h / 2f
250
+
251
+
val candidates = listOf(
252
+
scaleAndClamp(ax1, ay1, ax2, ay2, scaleNormX, scaleNormY) to "A_norm",
253
+
scaleAndClamp(bx1, by1, bx2, by2, scaleNormX, scaleNormY) to "B_norm",
254
+
scaleAndClamp(ax1, ay1, ax2, ay2, scaleInputX, scaleInputY) to "A_input",
255
+
scaleAndClamp(bx1, by1, bx2, by2, scaleInputX, scaleInputY) to "B_input",
256
+
scaleAndClamp(ax1, ay1, ax2, ay2, 1f, 1f) to "A_src",
257
+
scaleAndClamp(bx1, by1, bx2, by2, 1f, 1f) to "B_src"
258
+
)
19
259
20
-
constructor(detector: org.tensorflow.lite.task.vision.detector.ObjectDetector){
21
-
this.detector = detector
260
+
val chosen = candidates.firstOrNull { (r, _) ->
261
+
val wbox = r.width()
262
+
val hbox = r.height()
263
+
wbox > 4f && hbox > 4f
264
+
}?.first
265
+
266
+
if (chosen != null) {
267
+
detections += Detection(chosen, score, clsId)
268
+
} else {
269
+
val best = candidates.maxByOrNull { it.first.width() * it.first.height() }?.first
270
+
if (best != null && best.width() > 0f && best.height() > 0f) {
271
+
detections += Detection(best, score, clsId)
272
+
}
273
+
}
274
+
}
275
+
Logger.d { "YoloTFLite: detected ${detections}" }
276
+
return detections
22
277
}
23
278
24
-
fun getDetector(): org.tensorflow.lite.task.vision.detector.ObjectDetector? {
25
-
return detector
279
+
private fun makeInputBuffer(bm: Bitmap): ByteBuffer {
280
+
val w = modelInfo.inputWidth
281
+
val h = modelInfo.inputHeight
282
+
val c = modelInfo.inputChannels
283
+
val pixels = IntArray(w * h)
284
+
val resized = if (bm.width != w || bm.height != h) bm.scale(w, h) else bm
285
+
resized.getPixels(pixels, 0, w, 0, 0, w, h)
286
+
287
+
if (modelInfo.inputIsFloat) {
288
+
val buf = ByteBuffer.allocateDirect(4 * w * h * c)
289
+
buf.order(ByteOrder.nativeOrder())
290
+
if (modelInfo.inputLayoutNHWC) {
291
+
var idx = 0
292
+
for (y in 0 until h) {
293
+
for (x in 0 until w) {
294
+
val p = pixels[idx++]
295
+
val r = ((p shr 16) and 0xFF) / 127.5f - 1f
296
+
val g = ((p shr 8) and 0xFF) / 127.5f - 1f
297
+
val b = (p and 0xFF) / 127.5f - 1f
298
+
if (c == 1) {
299
+
val gray = (0.2989f * r + 0.5870f * g + 0.1140f * b)
300
+
buf.putFloat(gray)
301
+
} else {
302
+
buf.putFloat(r)
303
+
buf.putFloat(g)
304
+
buf.putFloat(b)
305
+
}
306
+
}
307
+
}
308
+
} else {
309
+
// NCHW: write channel planes: R plane, G plane, B plane
310
+
if (c == 1) {
311
+
var idx = 0
312
+
for (i in 0 until w * h) {
313
+
val p = pixels[idx++]
314
+
val r = ((p shr 16) and 0xFF) / 127.5f - 1f
315
+
buf.putFloat(r)
316
+
}
317
+
} else {
318
+
val rPlane = FloatArray(w * h)
319
+
val gPlane = FloatArray(w * h)
320
+
val bPlane = FloatArray(w * h)
321
+
var idx = 0
322
+
for (y in 0 until h) {
323
+
for (x in 0 until w) {
324
+
val p = pixels[idx]
325
+
rPlane[idx] = ((p shr 16) and 0xFF) / 127.5f - 1f
326
+
gPlane[idx] = ((p shr 8) and 0xFF) / 127.5f - 1f
327
+
bPlane[idx] = (p and 0xFF) / 127.5f - 1f
328
+
idx++
329
+
}
330
+
}
331
+
for (i in 0 until w * h) buf.putFloat(rPlane[i])
332
+
for (i in 0 until w * h) buf.putFloat(gPlane[i])
333
+
for (i in 0 until w * h) buf.putFloat(bPlane[i])
334
+
}
335
+
}
336
+
buf.rewind()
337
+
return buf
338
+
} else {
339
+
// quantized input
340
+
val buf = ByteBuffer.allocateDirect(w * h * c)
341
+
buf.order(ByteOrder.nativeOrder())
342
+
val scale = modelInfo.inputScale ?: 1f
343
+
val zero = modelInfo.inputZeroPoint ?: 0
344
+
if (modelInfo.inputLayoutNHWC) {
345
+
var idx = 0
346
+
for (y in 0 until h) {
347
+
for (x in 0 until w) {
348
+
val p = pixels[idx++]
349
+
val r = (((p shr 16) and 0xFF) / 127.5f - 1f)
350
+
val g = (((p shr 8) and 0xFF) / 127.5f - 1f)
351
+
val b = ((p and 0xFF) / 127.5f - 1f)
352
+
if (c == 1) {
353
+
val gray = (0.2989f * r + 0.5870f * g + 0.1140f * b)
354
+
buf.put(floatToQuantByte(gray, scale, zero))
355
+
} else {
356
+
buf.put(floatToQuantByte(r, scale, zero))
357
+
buf.put(floatToQuantByte(g, scale, zero))
358
+
buf.put(floatToQuantByte(b, scale, zero))
359
+
}
360
+
}
361
+
}
362
+
} else {
363
+
// NCHW: build channel planes then write
364
+
if (c == 1) {
365
+
var idx = 0
366
+
for (i in 0 until w * h) {
367
+
val p = pixels[idx++]
368
+
val r = (((p shr 16) and 0xFF) / 127.5f - 1f)
369
+
buf.put(floatToQuantByte(r, scale, zero))
370
+
}
371
+
} else {
372
+
val rPlane = ByteArray(w * h)
373
+
val gPlane = ByteArray(w * h)
374
+
val bPlane = ByteArray(w * h)
375
+
var idx = 0
376
+
for (y in 0 until h) {
377
+
for (x in 0 until w) {
378
+
val p = pixels[idx]
379
+
rPlane[idx] = floatToQuantByte(((p shr 16) and 0xFF) / 127.5f - 1f, scale, zero)
380
+
gPlane[idx] = floatToQuantByte(((p shr 8) and 0xFF) / 127.5f - 1f, scale, zero)
381
+
bPlane[idx] = floatToQuantByte((p and 0xFF) / 127.5f - 1f, scale, zero)
382
+
idx++
383
+
}
384
+
}
385
+
for (i in 0 until w * h) buf.put(rPlane[i])
386
+
for (i in 0 until w * h) buf.put(gPlane[i])
387
+
for (i in 0 until w * h) buf.put(bPlane[i])
388
+
}
389
+
}
390
+
buf.rewind()
391
+
return buf
392
+
}
26
393
}
27
-
}
394
+
395
+
private fun floatToQuantByte(f: Float, scale: Float, zero: Int): Byte {
396
+
val q = (f / scale + zero).toInt()
397
+
val clamped = max(0, min(255, q))
398
+
return (clamped and 0xFF).toByte()
399
+
}
400
+
401
+
private fun inspectModel(interpreter: Interpreter): ModelInfo {
402
+
val inT = interpreter.getInputTensor(0)
403
+
val inShape = inT.shape() // e.g. [1,H,W,3] or [1,3,H,W]
404
+
val inDtype = inT.dataType()
405
+
val inputIsFloat = inDtype == DataType.FLOAT32
406
+
407
+
val inputChannels: Int
408
+
val inputWidth: Int
409
+
val inputHeight: Int
410
+
val inputLayoutNHWC: Boolean
411
+
412
+
if (inShape.size == 4) {
413
+
// detect NHWC vs NCHW
414
+
if (inShape[1] == 3 || inShape[1] == 1) {
415
+
// NCHW
416
+
inputLayoutNHWC = false
417
+
inputChannels = inShape[1]
418
+
inputHeight = inShape[2]
419
+
inputWidth = inShape[3]
420
+
} else {
421
+
// NHWC
422
+
inputLayoutNHWC = true
423
+
inputHeight = inShape[1]
424
+
inputWidth = inShape[2]
425
+
inputChannels = inShape[3]
426
+
}
427
+
} else if (inShape.size == 3) {
428
+
inputLayoutNHWC = true
429
+
inputHeight = inShape[1]
430
+
inputWidth = inShape[2]
431
+
inputChannels = 1
432
+
} else {
433
+
// fallback
434
+
inputLayoutNHWC = true
435
+
inputHeight = if (inShape.size > 1) inShape[1] else 1
436
+
inputWidth = if (inShape.size > 2) inShape[2] else 1
437
+
inputChannels = if (inShape.size > 3) inShape[3] else 3
438
+
}
439
+
440
+
val (inScale, inZero) = if (!inputIsFloat) {
441
+
val qp = inT.quantizationParams()
442
+
Pair(qp.scale, qp.zeroPoint)
443
+
} else Pair(null, null)
444
+
445
+
// Determine output layout and quant params
446
+
val outT = interpreter.getOutputTensor(0)
447
+
val outShape = outT.shape()
448
+
// If the output is 2D e.g. [1,25], treat it as batch x features (no boxes axis)
449
+
val featDimCandidateIndex: Int
450
+
val numBoxesIndexNullable: Int?
451
+
val featDim: Int
452
+
val numBoxes: Int
453
+
if (outShape.size == 2) {
454
+
// [batch, features]
455
+
featDimCandidateIndex = 1
456
+
numBoxesIndexNullable = null
457
+
featDim = outShape[1]
458
+
numBoxes = 1
459
+
} else {
460
+
// Heuristic: small dim like 8 is feature dim
461
+
featDimCandidateIndex = when {
462
+
outShape.size >= 3 && (outShape[1] <= 32) -> 1
463
+
outShape.size >= 3 && (outShape.last() <= 32) -> outShape.size - 1
464
+
else -> 1
465
+
}
466
+
val numBoxesIndex = if (featDimCandidateIndex == 1) outShape.size - 1 else 1
467
+
numBoxesIndexNullable = numBoxesIndex
468
+
featDim = outShape[featDimCandidateIndex]
469
+
numBoxes = outShape[numBoxesIndex]
470
+
}
471
+
472
+
val outIsFloat = outT.dataType() == DataType.FLOAT32
473
+
val outQ = outT.quantizationParams()
474
+
val outScale = if (!outIsFloat) outQ.scale else null
475
+
val outZero = if (!outIsFloat) outQ.zeroPoint else null
476
+
477
+
// default normalization: many custom models use [-1,1]
478
+
val normalizeToMinusOneOne = true
479
+
480
+
return ModelInfo(
481
+
inputWidth = inputWidth,
482
+
inputHeight = inputHeight,
483
+
inputChannels = inputChannels,
484
+
inputIsFloat = inputIsFloat,
485
+
inputScale = inScale,
486
+
inputZeroPoint = inZero,
487
+
inputLayoutNHWC = inputLayoutNHWC,
488
+
normalizeToMinusOneOne = normalizeToMinusOneOne,
489
+
outputFeatDimIndex = featDimCandidateIndex,
490
+
outputNumBoxesIndex = numBoxesIndexNullable,
491
+
outputFeatDim = featDim,
492
+
outputNumBoxes = numBoxes,
493
+
outputIsFloat = outIsFloat,
494
+
outputScale = outScale,
495
+
outputZeroPoint = outZero
496
+
)
497
+
}
498
+
}
499
+
500
+
data class Detection(
501
+
val bbox: android.graphics.RectF, // in original image coordinates
502
+
val score: Float,
503
+
val classId: Int
504
+
)
+1
-3
posedetection/src/androidMain/kotlin/com.performancecoachlab/posedetection/camera/CameraView.android.kt
+1
-3
posedetection/src/androidMain/kotlin/com.performancecoachlab/posedetection/camera/CameraView.android.kt
···
53
53
import com.performancecoachlab.posedetection.recording.AnalysisObject
54
54
import com.performancecoachlab.posedetection.recording.AnalysisResult
55
55
import kotlinx.coroutines.launch
56
-
import org.tensorflow.lite.support.image.TensorImage
57
-
import org.tensorflow.lite.task.vision.detector.ObjectDetector
58
56
59
57
// Data class to hold recording state for each recording ID
60
58
data class RecordingSlot(
···
264
262
modifier = Modifier
265
263
.fillMaxSize()
266
264
.scale(if (frontCamera) -1f else 1f, 1f),
267
-
contentScale = ContentScale.Crop
265
+
contentScale = ContentScale.Fit
268
266
)
269
267
}
270
268
}