real time body tracking for compose multiplatform mobile apps

Compare changes

Choose any two refs to compare.

Changed files
+503 -27
posedetection
src
androidMain
kotlin
com
performancecoachlab
posedetection
com.performancecoachlab
posedetection
+3 -1
posedetection/build.gradle.kts
··· 1 1 import com.android.build.api.dsl.AaptOptions 2 2 import com.android.build.api.dsl.AndroidResources 3 3 import com.vanniktech.maven.publish.SonatypeHost 4 + import org.gradle.kotlin.dsl.implementation 4 5 import org.jetbrains.compose.ExperimentalComposeLibrary 5 6 import org.jetbrains.kotlin.gradle.plugin.KotlinSourceSetTree 6 7 ··· 90 91 implementation(libs.pose.detection) 91 92 implementation(libs.pose.detection.common) 92 93 implementation(libs.androidx.media3.common.ktx) 93 - implementation(libs.tensorflow.lite.task.vision) 94 + implementation("org.tensorflow:tensorflow-lite:2.17.0") 95 + implementation("org.tensorflow:tensorflow-lite-support:0.5.0") 94 96 } 95 97 96 98 }
+10 -11
posedetection/src/androidMain/kotlin/com/performancecoachlab/posedetection/camera/Utils.android.kt
··· 8 8 import com.google.android.gms.tasks.Tasks 9 9 import com.google.mlkit.vision.common.InputImage 10 10 import com.google.mlkit.vision.pose.PoseDetector 11 + import com.performancecoachlab.posedetection.custom.YoloTFLite 11 12 import com.performancecoachlab.posedetection.recording.AnalysisObject 12 13 import com.performancecoachlab.posedetection.recording.AnalysisResult 13 14 import com.performancecoachlab.posedetection.recording.FrameSize ··· 24 25 25 26 @OptIn(ExperimentalGetImage::class) 26 27 fun ImageProxy.process( 27 - objectDetector: org.tensorflow.lite.task.vision.detector.ObjectDetector?, 28 + objectDetector: YoloTFLite?, 28 29 poseDetector: PoseDetector?, 29 30 timestamp: Long, 30 31 focusArea: Rect?, ··· 58 59 } 59 60 60 61 fun Bitmap.process( 61 - objectDetector: org.tensorflow.lite.task.vision.detector.ObjectDetector?, 62 + objectDetector: YoloTFLite?, 62 63 poseDetector: PoseDetector, 63 64 timestamp: Long, 64 65 focusArea: Rect?, ··· 169 170 private fun process( 170 171 tensorImage: TensorImage, 171 172 mlKitImage: InputImage?, 172 - objectDetector: org.tensorflow.lite.task.vision.detector.ObjectDetector?, 173 + objectDetector: YoloTFLite?, 173 174 poseDetector: PoseDetector?, 174 175 timestamp: Long, 175 176 width: Int, ··· 177 178 bitmap: Bitmap, 178 179 onComplete: (AnalysisResult, Bitmap) -> Unit 179 180 ) { 180 - val objectsDetected = objectDetector?.detect(tensorImage)?.map { result -> 181 + val objectsDetected = objectDetector?.detect(bitmap)?.map { result -> 181 182 AnalysisObject( 182 - boundingBox = result.boundingBox.let { 183 + boundingBox = result.bbox.let { 183 184 Rect( 184 185 left = it.left, 185 186 top = it.top, ··· 188 189 ) 189 190 }, 190 191 trackingId = 0, 191 - labels = result.categories.map { category -> 192 - com.performancecoachlab.posedetection.recording.Label( 193 - category.label, 194 - category.score 195 - ) 196 - }, 192 + labels = listOf(com.performancecoachlab.posedetection.recording.Label( 193 + result.classId.toString(), 194 + result.score 195 + )), 197 196 frameSize = FrameSize( 198 197 width = width, 199 198 height = height
+489 -12
posedetection/src/androidMain/kotlin/com/performancecoachlab/posedetection/custom/CustomObjectModel.android.kt
··· 1 1 package com.performancecoachlab.posedetection.custom 2 2 3 + import android.content.Context 4 + import android.graphics.Bitmap 5 + import android.graphics.RectF 3 6 import androidx.compose.runtime.Composable 4 7 import androidx.compose.ui.platform.LocalContext 8 + import org.tensorflow.lite.Interpreter 9 + import org.tensorflow.lite.Tensor 10 + import org.tensorflow.lite.DataType 11 + import org.tensorflow.lite.support.common.FileUtil 12 + import java.nio.ByteBuffer 13 + import java.nio.ByteOrder 14 + import kotlin.math.max 15 + import kotlin.math.min 16 + import androidx.core.graphics.scale 17 + import co.touchlab.kermit.Logger 5 18 6 19 @Composable 7 20 actual fun initialiseObjectModel(modelPath: ModelPath): ObjectModel { 8 - val options = org.tensorflow.lite.task.vision.detector.ObjectDetector.ObjectDetectorOptions.builder().setMaxResults(5).setScoreThreshold(0f).build() 9 - val detector = org.tensorflow.lite.task.vision.detector.ObjectDetector.createFromFileAndOptions( 10 - LocalContext.current, 11 - modelPath.androidModelPath, 12 - options 13 - ) 21 + val detector = modelPath.androidModelPath?.let { YoloTFLite(LocalContext.current, it, confThreshold = 0.01f) } 14 22 return ObjectModel(detector) 15 23 } 16 24 17 25 actual class ObjectModel{ 18 - private var detector: org.tensorflow.lite.task.vision.detector.ObjectDetector? = null 26 + private var detector: YoloTFLite? = null 27 + 28 + constructor(detector: YoloTFLite?){ 29 + this.detector = detector } 30 + 31 + fun getDetector(): YoloTFLite? { 32 + return detector 33 + } 34 + 35 + } 36 + 37 + // ModelInfo holds the inferred layout and quantization metadata needed to build inputs and parse outputs. 38 + data class ModelInfo( 39 + val inputWidth: Int, 40 + val inputHeight: Int, 41 + val inputChannels: Int, 42 + val inputIsFloat: Boolean, 43 + val inputScale: Float?, 44 + val inputZeroPoint: Int?, 45 + val inputLayoutNHWC: Boolean, 46 + val normalizeToMinusOneOne: Boolean, 47 + val outputFeatDimIndex: Int, 48 + val outputNumBoxesIndex: Int?, 49 + val outputFeatDim: Int, 50 + val outputNumBoxes: Int, 51 + val outputIsFloat: Boolean, 52 + val outputScale: Float?, 53 + val outputZeroPoint: Int? 54 + ) 55 + 56 + class YoloTFLite( 57 + context: Context, 58 + modelPathInAssets: String, 59 + private val confThreshold: Float = 0.25f 60 + ) { 61 + private val interpreter: Interpreter 62 + private val inputTensor: Tensor 63 + private val modelInfo: ModelInfo 64 + 65 + init { 66 + val model = FileUtil.loadMappedFile(context, modelPathInAssets) 67 + interpreter = Interpreter(model, Interpreter.Options()) 68 + inputTensor = interpreter.getInputTensor(0) 69 + modelInfo = inspectModel(interpreter) 70 + Logger.d { "YoloTFLite modelInfo=$modelInfo" } 71 + } 72 + 73 + @Suppress("unused") 74 + fun close() = interpreter.close() 75 + 76 + fun detect(src: Bitmap): List<Detection> { 77 + // 1) Preprocess 78 + val resized = src.scale(modelInfo.inputWidth, modelInfo.inputHeight) 79 + val inputBuffer = makeInputBuffer(resized) 80 + 81 + // 2) Prepare output 82 + val outTensor = interpreter.getOutputTensor(0) 83 + val outShape = outTensor.shape() // flexible 84 + val featDim = modelInfo.outputFeatDim 85 + val nBoxes = modelInfo.outputNumBoxes 86 + val outIsFloat = modelInfo.outputIsFloat 87 + val outScale = modelInfo.outputScale 88 + val outZero = modelInfo.outputZeroPoint 89 + 90 + // allocate outRaw to exactly match the output tensor shape 91 + fun makeOutputArray(shape: IntArray, isFloat: Boolean): Any { 92 + fun build(dim: Int): Any { 93 + val size = shape[dim] 94 + return if (dim == shape.lastIndex) { 95 + if (isFloat) FloatArray(size) else ByteArray(size) 96 + } else { 97 + val arr = arrayOfNulls<Any>(size) 98 + for (i in 0 until size) arr[i] = build(dim + 1) 99 + arr 100 + } 101 + } 102 + return build(0) 103 + } 104 + 105 + val outRaw: Any = makeOutputArray(outShape, outIsFloat) 106 + 107 + // storage for fallback flat outputs if nested array copy fails 108 + var flatFloatOut: FloatArray? = null 109 + var flatByteOut: ByteArray? = null 110 + var usedFlatBuffer = false 111 + 112 + // 3) Inference 113 + try { 114 + interpreter.run(inputBuffer, outRaw) 115 + } catch (e: IllegalArgumentException) { 116 + // compute constructed Java array shape for logging 117 + fun constructedShape(raw: Any): List<Int> { 118 + val dims = mutableListOf<Int>() 119 + var cur: Any? = raw 120 + while (true) { 121 + when (cur) { 122 + is Array<*> -> { 123 + dims.add(cur.size) 124 + cur = if (cur.isNotEmpty()) cur[0] else break 125 + } 126 + is FloatArray -> { 127 + dims.add(cur.size) 128 + break 129 + } 130 + is ByteArray -> { 131 + dims.add(cur.size) 132 + break 133 + } 134 + else -> break 135 + } 136 + } 137 + return dims 138 + } 139 + 140 + val constructed = constructedShape(outRaw) 141 + Logger.e { "Interpreter run failed: ${e.message} -- tensorShape=${outShape.joinToString(prefix="[", postfix="]")} constructed=${constructed.joinToString(prefix="[", postfix="]")} modelInfo=$modelInfo" } 142 + 143 + // Attempt fallback: run into a direct ByteBuffer and parse the raw bytes. 144 + val totalElements = outShape.fold(1) { acc, v -> acc * v } 145 + try { 146 + if (outIsFloat) { 147 + val bb = ByteBuffer.allocateDirect(totalElements * 4).order(ByteOrder.nativeOrder()) 148 + interpreter.run(inputBuffer, bb) 149 + bb.rewind() 150 + val floats = FloatArray(totalElements) 151 + for (i in 0 until totalElements) floats[i] = bb.float 152 + flatFloatOut = floats 153 + usedFlatBuffer = true 154 + } else { 155 + val bb = ByteBuffer.allocateDirect(totalElements).order(ByteOrder.nativeOrder()) 156 + interpreter.run(inputBuffer, bb) 157 + bb.rewind() 158 + val bytes = ByteArray(totalElements) 159 + bb.get(bytes) 160 + flatByteOut = bytes 161 + usedFlatBuffer = true 162 + } 163 + } catch (e2: Exception) { 164 + Logger.e { "Fallback interpreter run also failed: ${e2.message}" } 165 + throw e // rethrow original 166 + } 167 + } 168 + 169 + // dynamic getter to read a value from nested outRaw at specified multi-dimensional indices 170 + fun getValueAt(raw: Any, indices: IntArray, isFloat: Boolean, scale: Float?, zero: Int?): Float { 171 + if (usedFlatBuffer) { 172 + val totalShape = outShape 173 + // compute linear index from multi-dimensional indices (row-major) 174 + var idx = 0 175 + var stride = 1 176 + // compute strides 177 + val rank = totalShape.size 178 + val strides = IntArray(rank) 179 + strides[rank - 1] = 1 180 + for (r in rank - 2 downTo 0) strides[r] = strides[r + 1] * totalShape[r + 1] 181 + for (r in 0 until rank) idx += indices[r] * strides[r] 182 + return if (isFloat) { 183 + flatFloatOut!![idx] 184 + } else { 185 + val v = flatByteOut!![idx].toInt() and 0xFF 186 + (v - (zero ?: 0)) * (scale ?: 1f) 187 + } 188 + } 189 + var cur: Any = raw 190 + val last = indices.lastIndex 191 + for (d in 0 until last) { 192 + @Suppress("UNCHECKED_CAST") 193 + cur = (cur as Array<*>)[indices[d]] as Any 194 + } 195 + return if (isFloat) { 196 + val arr = cur as FloatArray 197 + arr[indices[last]] 198 + } else { 199 + val arr = cur as ByteArray 200 + val v = arr[indices[last]].toInt() and 0xFF 201 + (v - (zero ?: 0)) * (scale ?: 1f) 202 + } 203 + } 204 + 205 + // helper to get a single box's feature vector (length = featDim) using configured indices 206 + fun getRow(boxIdx: Int): FloatArray { 207 + val row = FloatArray(featDim) 208 + val rank = outShape.size 209 + val indices = IntArray(rank) { 0 } 210 + // assume batch dimension is 0 and set it to 0 211 + if (rank > 0) indices[0] = 0 212 + val numBoxesIndex = modelInfo.outputNumBoxesIndex 213 + for (f in 0 until featDim) { 214 + // set feature index 215 + indices[modelInfo.outputFeatDimIndex] = f 216 + // set box index only if tensor has a separate boxes dimension 217 + if (numBoxesIndex != null) { 218 + indices[numBoxesIndex] = boxIdx 219 + } 220 + row[f] = getValueAt(outRaw, indices, outIsFloat, outScale, outZero) 221 + } 222 + return row 223 + } 224 + 225 + fun scaleAndClamp(x1: Float, y1: Float, x2: Float, y2: Float, xs: Float, ys: Float): RectF { 226 + val sx1 = (x1 * xs).coerceIn(0f, src.width.toFloat()) 227 + val sy1 = (y1 * ys).coerceIn(0f, src.height.toFloat()) 228 + val sx2 = (x2 * xs).coerceIn(0f, src.width.toFloat()) 229 + val sy2 = (y2 * ys).coerceIn(0f, src.height.toFloat()) 230 + return RectF(min(sx1, sx2), min(sy1, sy2), max(sx1, sx2), max(sy1, sy2)) 231 + } 232 + 233 + val scaleNormX = src.width.toFloat() 234 + val scaleNormY = src.height.toFloat() 235 + val scaleInputX = src.width.toFloat() / modelInfo.inputWidth 236 + val scaleInputY = src.height.toFloat() / modelInfo.inputHeight 237 + 238 + val detections = mutableListOf<Detection>() 239 + 240 + for (i in 0 until nBoxes) { 241 + val row = getRow(i) 242 + if (row.size < 6) continue 243 + val score = row[4] 244 + if (score < confThreshold) continue 245 + val clsId = row[5].toInt() 246 + 247 + val ax1 = row[0]; val ay1 = row[1]; val ax2 = row[2]; val ay2 = row[3] 248 + val cx = row[0]; val cy = row[1]; val w = row[2]; val h = row[3] 249 + val bx1 = cx - w / 2f; val by1 = cy - h / 2f; val bx2 = cx + w / 2f; val by2 = cy + h / 2f 250 + 251 + val candidates = listOf( 252 + scaleAndClamp(ax1, ay1, ax2, ay2, scaleNormX, scaleNormY) to "A_norm", 253 + scaleAndClamp(bx1, by1, bx2, by2, scaleNormX, scaleNormY) to "B_norm", 254 + scaleAndClamp(ax1, ay1, ax2, ay2, scaleInputX, scaleInputY) to "A_input", 255 + scaleAndClamp(bx1, by1, bx2, by2, scaleInputX, scaleInputY) to "B_input", 256 + scaleAndClamp(ax1, ay1, ax2, ay2, 1f, 1f) to "A_src", 257 + scaleAndClamp(bx1, by1, bx2, by2, 1f, 1f) to "B_src" 258 + ) 19 259 20 - constructor(detector: org.tensorflow.lite.task.vision.detector.ObjectDetector){ 21 - this.detector = detector 260 + val chosen = candidates.firstOrNull { (r, _) -> 261 + val wbox = r.width() 262 + val hbox = r.height() 263 + wbox > 4f && hbox > 4f 264 + }?.first 265 + 266 + if (chosen != null) { 267 + detections += Detection(chosen, score, clsId) 268 + } else { 269 + val best = candidates.maxByOrNull { it.first.width() * it.first.height() }?.first 270 + if (best != null && best.width() > 0f && best.height() > 0f) { 271 + detections += Detection(best, score, clsId) 272 + } 273 + } 274 + } 275 + Logger.d { "YoloTFLite: detected ${detections}" } 276 + return detections 22 277 } 23 278 24 - fun getDetector(): org.tensorflow.lite.task.vision.detector.ObjectDetector? { 25 - return detector 279 + private fun makeInputBuffer(bm: Bitmap): ByteBuffer { 280 + val w = modelInfo.inputWidth 281 + val h = modelInfo.inputHeight 282 + val c = modelInfo.inputChannels 283 + val pixels = IntArray(w * h) 284 + val resized = if (bm.width != w || bm.height != h) bm.scale(w, h) else bm 285 + resized.getPixels(pixels, 0, w, 0, 0, w, h) 286 + 287 + if (modelInfo.inputIsFloat) { 288 + val buf = ByteBuffer.allocateDirect(4 * w * h * c) 289 + buf.order(ByteOrder.nativeOrder()) 290 + if (modelInfo.inputLayoutNHWC) { 291 + var idx = 0 292 + for (y in 0 until h) { 293 + for (x in 0 until w) { 294 + val p = pixels[idx++] 295 + val r = ((p shr 16) and 0xFF) / 127.5f - 1f 296 + val g = ((p shr 8) and 0xFF) / 127.5f - 1f 297 + val b = (p and 0xFF) / 127.5f - 1f 298 + if (c == 1) { 299 + val gray = (0.2989f * r + 0.5870f * g + 0.1140f * b) 300 + buf.putFloat(gray) 301 + } else { 302 + buf.putFloat(r) 303 + buf.putFloat(g) 304 + buf.putFloat(b) 305 + } 306 + } 307 + } 308 + } else { 309 + // NCHW: write channel planes: R plane, G plane, B plane 310 + if (c == 1) { 311 + var idx = 0 312 + for (i in 0 until w * h) { 313 + val p = pixels[idx++] 314 + val r = ((p shr 16) and 0xFF) / 127.5f - 1f 315 + buf.putFloat(r) 316 + } 317 + } else { 318 + val rPlane = FloatArray(w * h) 319 + val gPlane = FloatArray(w * h) 320 + val bPlane = FloatArray(w * h) 321 + var idx = 0 322 + for (y in 0 until h) { 323 + for (x in 0 until w) { 324 + val p = pixels[idx] 325 + rPlane[idx] = ((p shr 16) and 0xFF) / 127.5f - 1f 326 + gPlane[idx] = ((p shr 8) and 0xFF) / 127.5f - 1f 327 + bPlane[idx] = (p and 0xFF) / 127.5f - 1f 328 + idx++ 329 + } 330 + } 331 + for (i in 0 until w * h) buf.putFloat(rPlane[i]) 332 + for (i in 0 until w * h) buf.putFloat(gPlane[i]) 333 + for (i in 0 until w * h) buf.putFloat(bPlane[i]) 334 + } 335 + } 336 + buf.rewind() 337 + return buf 338 + } else { 339 + // quantized input 340 + val buf = ByteBuffer.allocateDirect(w * h * c) 341 + buf.order(ByteOrder.nativeOrder()) 342 + val scale = modelInfo.inputScale ?: 1f 343 + val zero = modelInfo.inputZeroPoint ?: 0 344 + if (modelInfo.inputLayoutNHWC) { 345 + var idx = 0 346 + for (y in 0 until h) { 347 + for (x in 0 until w) { 348 + val p = pixels[idx++] 349 + val r = (((p shr 16) and 0xFF) / 127.5f - 1f) 350 + val g = (((p shr 8) and 0xFF) / 127.5f - 1f) 351 + val b = ((p and 0xFF) / 127.5f - 1f) 352 + if (c == 1) { 353 + val gray = (0.2989f * r + 0.5870f * g + 0.1140f * b) 354 + buf.put(floatToQuantByte(gray, scale, zero)) 355 + } else { 356 + buf.put(floatToQuantByte(r, scale, zero)) 357 + buf.put(floatToQuantByte(g, scale, zero)) 358 + buf.put(floatToQuantByte(b, scale, zero)) 359 + } 360 + } 361 + } 362 + } else { 363 + // NCHW: build channel planes then write 364 + if (c == 1) { 365 + var idx = 0 366 + for (i in 0 until w * h) { 367 + val p = pixels[idx++] 368 + val r = (((p shr 16) and 0xFF) / 127.5f - 1f) 369 + buf.put(floatToQuantByte(r, scale, zero)) 370 + } 371 + } else { 372 + val rPlane = ByteArray(w * h) 373 + val gPlane = ByteArray(w * h) 374 + val bPlane = ByteArray(w * h) 375 + var idx = 0 376 + for (y in 0 until h) { 377 + for (x in 0 until w) { 378 + val p = pixels[idx] 379 + rPlane[idx] = floatToQuantByte(((p shr 16) and 0xFF) / 127.5f - 1f, scale, zero) 380 + gPlane[idx] = floatToQuantByte(((p shr 8) and 0xFF) / 127.5f - 1f, scale, zero) 381 + bPlane[idx] = floatToQuantByte((p and 0xFF) / 127.5f - 1f, scale, zero) 382 + idx++ 383 + } 384 + } 385 + for (i in 0 until w * h) buf.put(rPlane[i]) 386 + for (i in 0 until w * h) buf.put(gPlane[i]) 387 + for (i in 0 until w * h) buf.put(bPlane[i]) 388 + } 389 + } 390 + buf.rewind() 391 + return buf 392 + } 26 393 } 27 - } 394 + 395 + private fun floatToQuantByte(f: Float, scale: Float, zero: Int): Byte { 396 + val q = (f / scale + zero).toInt() 397 + val clamped = max(0, min(255, q)) 398 + return (clamped and 0xFF).toByte() 399 + } 400 + 401 + private fun inspectModel(interpreter: Interpreter): ModelInfo { 402 + val inT = interpreter.getInputTensor(0) 403 + val inShape = inT.shape() // e.g. [1,H,W,3] or [1,3,H,W] 404 + val inDtype = inT.dataType() 405 + val inputIsFloat = inDtype == DataType.FLOAT32 406 + 407 + val inputChannels: Int 408 + val inputWidth: Int 409 + val inputHeight: Int 410 + val inputLayoutNHWC: Boolean 411 + 412 + if (inShape.size == 4) { 413 + // detect NHWC vs NCHW 414 + if (inShape[1] == 3 || inShape[1] == 1) { 415 + // NCHW 416 + inputLayoutNHWC = false 417 + inputChannels = inShape[1] 418 + inputHeight = inShape[2] 419 + inputWidth = inShape[3] 420 + } else { 421 + // NHWC 422 + inputLayoutNHWC = true 423 + inputHeight = inShape[1] 424 + inputWidth = inShape[2] 425 + inputChannels = inShape[3] 426 + } 427 + } else if (inShape.size == 3) { 428 + inputLayoutNHWC = true 429 + inputHeight = inShape[1] 430 + inputWidth = inShape[2] 431 + inputChannels = 1 432 + } else { 433 + // fallback 434 + inputLayoutNHWC = true 435 + inputHeight = if (inShape.size > 1) inShape[1] else 1 436 + inputWidth = if (inShape.size > 2) inShape[2] else 1 437 + inputChannels = if (inShape.size > 3) inShape[3] else 3 438 + } 439 + 440 + val (inScale, inZero) = if (!inputIsFloat) { 441 + val qp = inT.quantizationParams() 442 + Pair(qp.scale, qp.zeroPoint) 443 + } else Pair(null, null) 444 + 445 + // Determine output layout and quant params 446 + val outT = interpreter.getOutputTensor(0) 447 + val outShape = outT.shape() 448 + // If the output is 2D e.g. [1,25], treat it as batch x features (no boxes axis) 449 + val featDimCandidateIndex: Int 450 + val numBoxesIndexNullable: Int? 451 + val featDim: Int 452 + val numBoxes: Int 453 + if (outShape.size == 2) { 454 + // [batch, features] 455 + featDimCandidateIndex = 1 456 + numBoxesIndexNullable = null 457 + featDim = outShape[1] 458 + numBoxes = 1 459 + } else { 460 + // Heuristic: small dim like 8 is feature dim 461 + featDimCandidateIndex = when { 462 + outShape.size >= 3 && (outShape[1] <= 32) -> 1 463 + outShape.size >= 3 && (outShape.last() <= 32) -> outShape.size - 1 464 + else -> 1 465 + } 466 + val numBoxesIndex = if (featDimCandidateIndex == 1) outShape.size - 1 else 1 467 + numBoxesIndexNullable = numBoxesIndex 468 + featDim = outShape[featDimCandidateIndex] 469 + numBoxes = outShape[numBoxesIndex] 470 + } 471 + 472 + val outIsFloat = outT.dataType() == DataType.FLOAT32 473 + val outQ = outT.quantizationParams() 474 + val outScale = if (!outIsFloat) outQ.scale else null 475 + val outZero = if (!outIsFloat) outQ.zeroPoint else null 476 + 477 + // default normalization: many custom models use [-1,1] 478 + val normalizeToMinusOneOne = true 479 + 480 + return ModelInfo( 481 + inputWidth = inputWidth, 482 + inputHeight = inputHeight, 483 + inputChannels = inputChannels, 484 + inputIsFloat = inputIsFloat, 485 + inputScale = inScale, 486 + inputZeroPoint = inZero, 487 + inputLayoutNHWC = inputLayoutNHWC, 488 + normalizeToMinusOneOne = normalizeToMinusOneOne, 489 + outputFeatDimIndex = featDimCandidateIndex, 490 + outputNumBoxesIndex = numBoxesIndexNullable, 491 + outputFeatDim = featDim, 492 + outputNumBoxes = numBoxes, 493 + outputIsFloat = outIsFloat, 494 + outputScale = outScale, 495 + outputZeroPoint = outZero 496 + ) 497 + } 498 + } 499 + 500 + data class Detection( 501 + val bbox: android.graphics.RectF, // in original image coordinates 502 + val score: Float, 503 + val classId: Int 504 + )
+1 -3
posedetection/src/androidMain/kotlin/com.performancecoachlab/posedetection/camera/CameraView.android.kt
··· 53 53 import com.performancecoachlab.posedetection.recording.AnalysisObject 54 54 import com.performancecoachlab.posedetection.recording.AnalysisResult 55 55 import kotlinx.coroutines.launch 56 - import org.tensorflow.lite.support.image.TensorImage 57 - import org.tensorflow.lite.task.vision.detector.ObjectDetector 58 56 59 57 // Data class to hold recording state for each recording ID 60 58 data class RecordingSlot( ··· 264 262 modifier = Modifier 265 263 .fillMaxSize() 266 264 .scale(if (frontCamera) -1f else 1f, 1f), 267 - contentScale = ContentScale.Crop 265 + contentScale = ContentScale.Fit 268 266 ) 269 267 } 270 268 }