Skip to content
Toggle navigation
P
Projects
G
Groups
S
Snippets
Help
王雷
/
detection
This project
Loading...
Sign in
Toggle navigation
Go to a project
Project
Repository
Issues
0
Merge Requests
0
Pipelines
Wiki
Snippets
Settings
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Commit 6609763d
authored
May 18, 2024
by
wanglei
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
统一yolov5和v8
1 parent
81d77638
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
557 additions
and
538 deletions
app/src/main/java/com/agenew/detection/CameraConnectionFragment.java
app/src/main/java/com/agenew/detection/DetectionManager.java
app/src/main/java/com/agenew/detection/MainActivity.java
app/src/main/java/com/agenew/detection/customview/OverlayView.java
app/src/main/java/com/agenew/detection/env/Utils.java
app/src/main/java/com/agenew/detection/tflite/DetectorFactory.java
app/src/main/java/com/agenew/detection/tflite/YoloClassifier.java
app/src/main/java/com/agenew/detection/tflite/YoloV5Classifier.java
app/src/main/java/com/agenew/detection/tflite/YoloV8Classifier.java
app/src/main/res/layout/tfe_od_camera_connection_fragment_tracking.xml
gradle/libs.versions.toml
gradle/wrapper/gradle-wrapper.properties
app/src/main/java/com/agenew/detection/CameraConnectionFragment.java
View file @
6609763
...
...
@@ -531,7 +531,7 @@ public class CameraConnectionFragment extends Fragment {
@Override
public
int
compare
(
final
Size
lhs
,
final
Size
rhs
)
{
// We cast here to ensure the multiplications won't overflow
return
Long
.
signum
((
long
)
lhs
.
getWidth
()
*
lhs
.
getHeight
()
-
(
long
)
rhs
.
getWidth
()
*
r
hs
.
getHeight
());
return
Long
.
signum
((
long
)
rhs
.
getWidth
()
*
rhs
.
getHeight
()
-
(
long
)
lhs
.
getWidth
()
*
l
hs
.
getHeight
());
}
}
...
...
app/src/main/java/com/agenew/detection/DetectionManager.java
0 → 100755
View file @
6609763
package
com
.
agenew
.
detection
;
import
android.app.Activity
;
import
android.content.Context
;
import
android.graphics.Bitmap
;
import
android.graphics.Canvas
;
import
android.graphics.Matrix
;
import
android.graphics.RectF
;
import
android.graphics.Typeface
;
import
android.os.SystemClock
;
import
android.util.Log
;
import
android.util.TypedValue
;
import
android.view.ViewGroup
;
import
android.widget.Toast
;
import
com.agenew.detection.customview.OverlayView
;
import
com.agenew.detection.env.BorderedText
;
import
com.agenew.detection.env.ImageUtils
;
import
com.agenew.detection.env.Utils
;
import
com.agenew.detection.tflite.Classifier
;
import
com.agenew.detection.tflite.YoloClassifier
;
import
com.agenew.detection.tracking.MultiBoxTracker
;
import
java.io.IOException
;
import
java.util.LinkedList
;
import
java.util.List
;
public
class
DetectionManager
{
private
static
final
String
TAG
=
"DetectionManager"
;
private
static
final
float
TEXT_SIZE_DIP
=
10
;
private
static
final
boolean
MAINTAIN_ASPECT
=
true
;
private
static
final
boolean
SAVE_PREVIEW_BITMAP
=
false
;
private
final
Context
mContext
;
private
final
OverlayView
mOverlayView
;
private
final
MultiBoxTracker
mTracker
;
private
YoloClassifier
mDetector
;
private
Bitmap
mCroppedBitmap
=
null
;
private
Matrix
mFrameToCropTransform
;
private
Matrix
mCropToFrameTransform
;
private
int
mWidth
;
private
int
mHeight
;
private
int
mSensorOrientation
;
public
DetectionManager
(
Context
context
,
OverlayView
overlayView
)
{
mContext
=
context
;
final
float
textSizePx
=
TypedValue
.
applyDimension
(
TypedValue
.
COMPLEX_UNIT_DIP
,
TEXT_SIZE_DIP
,
context
.
getResources
().
getDisplayMetrics
());
BorderedText
borderedText
=
new
BorderedText
(
textSizePx
);
borderedText
.
setTypeface
(
Typeface
.
MONOSPACE
);
mTracker
=
new
MultiBoxTracker
(
context
);
mOverlayView
=
overlayView
;
mOverlayView
.
addCallback
(
mTracker:
:
draw
);
}
public
void
updateDetector
(
final
String
modelString
,
String
device
)
{
try
{
mDetector
=
YoloClassifier
.
create
(
mContext
.
getAssets
(),
modelString
);
mWidth
=
0
;
}
catch
(
final
IOException
e
)
{
Log
.
e
(
TAG
,
"Exception initializing classifier! e = "
+
e
,
e
);
Toast
toast
=
Toast
.
makeText
(
mContext
.
getApplicationContext
(),
"Classifier could not be initialized"
,
Toast
.
LENGTH_SHORT
);
toast
.
show
();
if
(
mContext
instanceof
Activity
)
{
Activity
activity
=
(
Activity
)
mContext
;
activity
.
finish
();
}
}
switch
(
device
)
{
case
"CPU"
:
mDetector
.
useCPU
();
break
;
case
"GPU"
:
mDetector
.
useGpu
();
break
;
case
"NNAPI"
:
mDetector
.
useNNAPI
();
break
;
}
}
public
void
updateTracker
(
final
int
width
,
final
int
height
,
final
int
sensorOrientation
)
{
if
(
mWidth
!=
width
||
mHeight
!=
height
||
mSensorOrientation
!=
sensorOrientation
)
{
mWidth
=
width
;
mHeight
=
height
;
mSensorOrientation
=
sensorOrientation
;
int
cropSize
=
mDetector
.
getInputSize
();
mCroppedBitmap
=
Bitmap
.
createBitmap
(
cropSize
,
cropSize
,
Bitmap
.
Config
.
ARGB_8888
);
mFrameToCropTransform
=
ImageUtils
.
getTransformationMatrix
(
width
,
height
,
cropSize
,
cropSize
,
sensorOrientation
,
MAINTAIN_ASPECT
);
mCropToFrameTransform
=
new
Matrix
();
mFrameToCropTransform
.
invert
(
mCropToFrameTransform
);
mTracker
.
setFrameConfiguration
(
width
,
height
,
sensorOrientation
);
int
overlayViewHeight
=
mOverlayView
.
getHeight
();
int
overlayViewWidth
=
width
*
overlayViewHeight
/
height
;
ViewGroup
.
LayoutParams
layoutParams
=
mOverlayView
.
getLayoutParams
();
layoutParams
.
width
=
overlayViewWidth
;
}
}
public
void
closeDetector
()
{
if
(
mDetector
!=
null
)
{
mDetector
.
close
();
mDetector
=
null
;
}
}
public
void
setNumThreads
(
int
num_threads
)
{
mDetector
.
setNumThreads
(
num_threads
);
}
public
void
postInvalidate
()
{
mOverlayView
.
postInvalidate
();
}
public
void
drawBitmap
(
Bitmap
rgbFrameBitmap
)
{
final
Canvas
canvas
=
new
Canvas
(
mCroppedBitmap
);
canvas
.
drawBitmap
(
rgbFrameBitmap
,
mFrameToCropTransform
,
null
);
// For examining the actual TF input.
if
(
SAVE_PREVIEW_BITMAP
)
{
ImageUtils
.
saveBitmap
(
mCroppedBitmap
);
}
}
public
long
recognizeImage
(
final
long
currTimestamp
)
{
final
long
startTime
=
SystemClock
.
uptimeMillis
();
final
List
<
Classifier
.
Recognition
>
results
=
mDetector
.
recognizeImage
(
mCroppedBitmap
);
long
lastProcessingTimeMs
=
SystemClock
.
uptimeMillis
()
-
startTime
;
float
minimumConfidence
=
Utils
.
MINIMUM_CONFIDENCE_TF_OD_API
;
final
List
<
Classifier
.
Recognition
>
mappedRecognitions
=
new
LinkedList
<>();
for
(
final
Classifier
.
Recognition
result
:
results
)
{
final
RectF
location
=
result
.
getLocation
();
if
(
location
!=
null
&&
result
.
getConfidence
()
>=
minimumConfidence
)
{
mCropToFrameTransform
.
mapRect
(
location
);
result
.
setLocation
(
location
);
mappedRecognitions
.
add
(
result
);
}
}
mTracker
.
trackResults
(
mappedRecognitions
,
currTimestamp
);
mOverlayView
.
postInvalidate
();
return
lastProcessingTimeMs
;
}
public
String
getcropInfo
()
{
return
mCroppedBitmap
.
getWidth
()
+
"x"
+
mCroppedBitmap
.
getHeight
();
}
}
\ No newline at end of file
app/src/main/java/com/agenew/detection/MainActivity.java
View file @
6609763
...
...
@@ -17,33 +17,11 @@
package
com
.
agenew
.
detection
;
import
android.graphics.Bitmap
;
import
android.graphics.Bitmap.Config
;
import
android.graphics.Canvas
;
import
android.graphics.Color
;
import
android.graphics.Matrix
;
import
android.graphics.Paint
;
import
android.graphics.Paint.Style
;
import
android.graphics.RectF
;
import
android.graphics.Typeface
;
import
android.media.ImageReader.OnImageAvailableListener
;
import
android.os.SystemClock
;
import
android.util.Log
;
import
android.util.Size
;
import
android.util.TypedValue
;
import
android.widget.Toast
;
import
java.io.IOException
;
import
java.util.LinkedList
;
import
java.util.List
;
import
com.agenew.detection.customview.OverlayView
;
import
com.agenew.detection.env.BorderedText
;
import
com.agenew.detection.env.ImageUtils
;
import
com.agenew.detection.env.Logger
;
import
com.agenew.detection.tflite.Classifier
;
import
com.agenew.detection.tflite.DetectorFactory
;
import
com.agenew.detection.tflite.YoloV5Classifier
;
import
com.agenew.detection.tracking.MultiBoxTracker
;
/**
* An activity that uses a TensorFlowMultiBoxDetector and ObjectTracker to
...
...
@@ -52,65 +30,27 @@ import com.agenew.detection.tracking.MultiBoxTracker;
public
class
MainActivity
extends
CameraActivity
implements
OnImageAvailableListener
{
private
static
final
Logger
LOGGER
=
new
Logger
();
private
static
final
DetectorMode
MODE
=
DetectorMode
.
TF_OD_API
;
public
static
final
float
MINIMUM_CONFIDENCE_TF_OD_API
=
0.3f
;
private
static
final
boolean
MAINTAIN_ASPECT
=
true
;
private
static
final
Size
DESIRED_PREVIEW_SIZE
=
new
Size
(
640
,
640
);
private
static
final
boolean
SAVE_PREVIEW_BITMAP
=
false
;
private
static
final
float
TEXT_SIZE_DIP
=
10
;
OverlayView
trackingOverlay
;
private
Integer
sensorOrientation
;
private
YoloV5Classifier
detector
;
private
long
lastProcessingTimeMs
;
private
Bitmap
rgbFrameBitmap
=
null
;
private
Bitmap
croppedBitmap
=
null
;
private
Bitmap
cropCopyBitmap
=
null
;
private
boolean
computingDetection
=
false
;
private
long
timestamp
=
0
;
private
Bitmap
rgbFrameBitmap
=
null
;
private
Matrix
frameToCropTransform
;
private
Matrix
cropToFrameTransform
;
private
MultiBoxTracker
tracker
;
private
BorderedText
borderedText
;
private
DetectionManager
mDetectionManager
;
@Override
public
void
onPreviewSizeChosen
(
final
Size
size
,
final
int
rotation
)
{
final
float
textSizePx
=
TypedValue
.
applyDimension
(
TypedValue
.
COMPLEX_UNIT_DIP
,
TEXT_SIZE_DIP
,
getResources
().
getDisplayMetrics
());
borderedText
=
new
BorderedText
(
textSizePx
);
borderedText
.
setTypeface
(
Typeface
.
MONOSPACE
);
tracker
=
new
MultiBoxTracker
(
this
);
final
int
modelIndex
=
modelView
.
getCheckedItemPosition
();
final
String
modelString
=
modelStrings
.
get
(
modelIndex
);
final
int
deviceIndex
=
deviceView
.
getCheckedItemPosition
();
String
device
=
deviceStrings
.
get
(
deviceIndex
);
try
{
detector
=
DetectorFactory
.
getDetector
(
getAssets
(),
modelString
);
}
catch
(
final
IOException
e
)
{
e
.
printStackTrace
();
LOGGER
.
e
(
e
,
"Exception initializing classifier!"
);
Toast
toast
=
Toast
.
makeText
(
getApplicationContext
(),
"Classifier could not be initialized"
,
Toast
.
LENGTH_SHORT
);
toast
.
show
();
finish
();
}
if
(
device
.
equals
(
"CPU"
))
{
detector
.
useCPU
();
}
else
if
(
device
.
equals
(
"GPU"
))
{
detector
.
useGpu
();
}
else
if
(
device
.
equals
(
"NNAPI"
))
{
detector
.
useNNAPI
();
}
int
cropSize
=
detector
.
getInputSize
();
previewWidth
=
size
.
getWidth
();
previewHeight
=
size
.
getHeight
();
...
...
@@ -119,24 +59,12 @@ public class MainActivity extends CameraActivity implements OnImageAvailableList
LOGGER
.
i
(
"Camera orientation relative to screen canvas: %d"
,
sensorOrientation
);
LOGGER
.
i
(
"Initializing at size %dx%d"
,
previewWidth
,
previewHeight
);
rgbFrameBitmap
=
Bitmap
.
createBitmap
(
previewWidth
,
previewHeight
,
Config
.
ARGB_8888
);
croppedBitmap
=
Bitmap
.
createBitmap
(
cropSize
,
cropSize
,
Config
.
ARGB_8888
);
frameToCropTransform
=
ImageUtils
.
getTransformationMatrix
(
previewWidth
,
previewHeight
,
cropSize
,
cropSize
,
sensorOrientation
,
MAINTAIN_ASPECT
);
cropToFrameTransform
=
new
Matrix
();
frameToCropTransform
.
invert
(
cropToFrameTransform
);
trackingOverlay
=
findViewById
(
R
.
id
.
tracking_overlay
);
trackingOverlay
.
addCallback
(
canvas
->
{
tracker
.
draw
(
canvas
);
if
(
isDebug
())
{
tracker
.
drawDebug
(
canvas
);
}
});
tracker
.
setFrameConfiguration
(
previewWidth
,
previewHeight
,
sensorOrientation
);
OverlayView
trackingOverlay
=
findViewById
(
R
.
id
.
tracking_overlay
);
mDetectionManager
=
new
DetectionManager
(
this
,
trackingOverlay
);
rgbFrameBitmap
=
Bitmap
.
createBitmap
(
previewWidth
,
previewHeight
,
Bitmap
.
Config
.
ARGB_8888
);
mDetectionManager
.
updateDetector
(
modelString
,
device
);
mDetectionManager
.
updateTracker
(
previewWidth
,
previewHeight
,
sensorOrientation
);
}
protected
void
updateActiveModel
()
{
...
...
@@ -155,51 +83,17 @@ public class MainActivity extends CameraActivity implements OnImageAvailableList
currentNumThreads
=
numThreads
;
// Disable classifier while updating
if
(
detector
!=
null
)
{
detector
.
close
();
detector
=
null
;
}
mDetectionManager
.
closeDetector
();
// Lookup names of parameters.
String
modelString
=
modelStrings
.
get
(
modelIndex
);
String
device
=
deviceStrings
.
get
(
deviceIndex
);
LOGGER
.
i
(
"Changing model to "
+
modelString
+
" device "
+
device
);
// Try to load model.
try
{
detector
=
DetectorFactory
.
getDetector
(
getAssets
(),
modelString
);
// Customize the interpreter to the type of device we want to use.
if
(
detector
==
null
)
{
return
;
}
}
catch
(
IOException
e
)
{
e
.
printStackTrace
();
LOGGER
.
e
(
e
,
"Exception in updateActiveModel()"
);
Toast
toast
=
Toast
.
makeText
(
getApplicationContext
(),
"Classifier could not be initialized"
,
Toast
.
LENGTH_SHORT
);
toast
.
show
();
finish
();
}
if
(
device
.
equals
(
"CPU"
))
{
detector
.
useCPU
();
}
else
if
(
device
.
equals
(
"GPU"
))
{
detector
.
useGpu
();
}
else
if
(
device
.
equals
(
"NNAPI"
))
{
detector
.
useNNAPI
();
}
detector
.
setNumThreads
(
numThreads
);
int
cropSize
=
detector
.
getInputSize
();
croppedBitmap
=
Bitmap
.
createBitmap
(
cropSize
,
cropSize
,
Config
.
ARGB_8888
);
frameToCropTransform
=
ImageUtils
.
getTransformationMatrix
(
previewWidth
,
previewHeight
,
cropSize
,
cropSize
,
sensorOrientation
,
MAINTAIN_ASPECT
);
cropToFrameTransform
=
new
Matrix
();
frameToCropTransform
.
invert
(
cropToFrameTransform
);
rgbFrameBitmap
=
Bitmap
.
createBitmap
(
previewWidth
,
previewHeight
,
Bitmap
.
Config
.
ARGB_8888
);
mDetectionManager
.
updateDetector
(
modelString
,
device
);
mDetectionManager
.
updateTracker
(
previewWidth
,
previewHeight
,
sensorOrientation
);
mDetectionManager
.
setNumThreads
(
numThreads
);
});
}
...
...
@@ -207,7 +101,7 @@ public class MainActivity extends CameraActivity implements OnImageAvailableList
protected
void
processImage
()
{
++
timestamp
;
final
long
currTimestamp
=
timestamp
;
trackingOverlay
.
postInvalidate
();
mDetectionManager
.
postInvalidate
();
// No mutex needed as this method is not reentrant.
if
(
computingDetection
)
{
...
...
@@ -215,63 +109,18 @@ public class MainActivity extends CameraActivity implements OnImageAvailableList
return
;
}
computingDetection
=
true
;
LOGGER
.
i
(
"Preparing image "
+
currTimestamp
+
" for detection in bg thread."
);
rgbFrameBitmap
.
setPixels
(
getRgbBytes
(),
0
,
previewWidth
,
0
,
0
,
previewWidth
,
previewHeight
);
LOGGER
.
i
(
"processImage Preparing image "
+
currTimestamp
+
" for detection in bg thread."
);
setPixels
(
getRgbBytes
(),
previewWidth
,
previewHeight
);
readyForNextImage
();
final
Canvas
canvas
=
new
Canvas
(
croppedBitmap
);
canvas
.
drawBitmap
(
rgbFrameBitmap
,
frameToCropTransform
,
null
);
// For examining the actual TF input.
if
(
SAVE_PREVIEW_BITMAP
)
{
ImageUtils
.
saveBitmap
(
croppedBitmap
);
}
mDetectionManager
.
drawBitmap
(
rgbFrameBitmap
);
runInBackground
(()
->
{
LOGGER
.
i
(
"Running detection on image "
+
currTimestamp
);
final
long
startTime
=
SystemClock
.
uptimeMillis
();
final
List
<
Classifier
.
Recognition
>
results
=
detector
.
recognizeImage
(
croppedBitmap
);
lastProcessingTimeMs
=
SystemClock
.
uptimeMillis
()
-
startTime
;
Log
.
e
(
"CHECK"
,
"run: "
+
results
.
size
());
cropCopyBitmap
=
Bitmap
.
createBitmap
(
croppedBitmap
);
final
Canvas
canvas1
=
new
Canvas
(
cropCopyBitmap
);
final
Paint
paint
=
new
Paint
();
paint
.
setColor
(
Color
.
RED
);
paint
.
setStyle
(
Style
.
STROKE
);
paint
.
setStrokeWidth
(
2.0f
);
float
minimumConfidence
=
MINIMUM_CONFIDENCE_TF_OD_API
;
switch
(
MODE
)
{
case
TF_OD_API:
minimumConfidence
=
MINIMUM_CONFIDENCE_TF_OD_API
;
break
;
}
final
List
<
Classifier
.
Recognition
>
mappedRecognitions
=
new
LinkedList
<>();
for
(
final
Classifier
.
Recognition
result
:
results
)
{
final
RectF
location
=
result
.
getLocation
();
if
(
location
!=
null
&&
result
.
getConfidence
()
>=
minimumConfidence
)
{
canvas1
.
drawRect
(
location
,
paint
);
cropToFrameTransform
.
mapRect
(
location
);
result
.
setLocation
(
location
);
mappedRecognitions
.
add
(
result
);
}
}
tracker
.
trackResults
(
mappedRecognitions
,
currTimestamp
);
trackingOverlay
.
postInvalidate
();
LOGGER
.
i
(
"processImage Running detection on image "
+
currTimestamp
);
lastProcessingTimeMs
=
mDetectionManager
.
recognizeImage
(
currTimestamp
);
computingDetection
=
false
;
runOnUiThread
(()
->
{
showFrameInfo
(
previewWidth
+
"x"
+
previewHeight
);
showCropInfo
(
cropCopyBitmap
.
getWidth
()
+
"x"
+
cropCopyBitmap
.
getHeight
());
showCropInfo
(
mDetectionManager
.
getcropInfo
());
showInference
(
lastProcessingTimeMs
+
"ms"
);
});
});
...
...
@@ -290,12 +139,13 @@ public class MainActivity extends CameraActivity implements OnImageAvailableList
// Which detection model to use: by default uses Tensorflow Object Detection API
// frozen
// checkpoints.
private
enum
DetectorMode
{
TF_OD_API
;
}
@Override
protected
void
setNumThreads
(
final
int
numThreads
)
{
runInBackground
(()
->
detector
.
setNumThreads
(
numThreads
));
runInBackground
(()
->
mDetectionManager
.
setNumThreads
(
numThreads
));
}
private
void
setPixels
(
int
[]
pixels
,
int
width
,
int
height
)
{
rgbFrameBitmap
.
setPixels
(
pixels
,
0
,
width
,
0
,
0
,
width
,
height
);
}
}
app/src/main/java/com/agenew/detection/customview/OverlayView.java
View file @
6609763
...
...
@@ -26,6 +26,10 @@ import java.util.List;
public
class
OverlayView
extends
View
{
private
final
List
<
DrawCallback
>
callbacks
=
new
LinkedList
<>();
public
OverlayView
(
final
Context
context
)
{
super
(
context
);
}
public
OverlayView
(
final
Context
context
,
final
AttributeSet
attrs
)
{
super
(
context
,
attrs
);
}
...
...
@@ -36,6 +40,7 @@ public class OverlayView extends View {
@Override
public
synchronized
void
draw
(
final
Canvas
canvas
)
{
super
.
draw
(
canvas
);
for
(
final
DrawCallback
callback
:
callbacks
)
{
callback
.
drawCallback
(
canvas
);
}
...
...
app/src/main/java/com/agenew/detection/env/Utils.java
View file @
6609763
...
...
@@ -11,6 +11,7 @@ import java.nio.channels.FileChannel;
public
class
Utils
{
private
static
final
String
TAG
=
"Utils"
;
public
static
final
float
MINIMUM_CONFIDENCE_TF_OD_API
=
0.3f
;
/**
* Memory-map the model file in Assets.
...
...
app/src/main/java/com/agenew/detection/tflite/DetectorFactory.java
deleted
100755 → 0
View file @
81d7763
package
com
.
agenew
.
detection
.
tflite
;
import
android.content.res.AssetManager
;
import
java.io.IOException
;
public
class
DetectorFactory
{
public
static
YoloV5Classifier
getDetector
(
final
AssetManager
assetManager
,
final
String
modelFilename
)
throws
IOException
{
String
labelFilename
=
null
;
boolean
isQuantized
=
false
;
int
inputSize
=
0
;
if
(
modelFilename
.
endsWith
(
".tflite"
))
{
labelFilename
=
"file:///android_asset/class.txt"
;
isQuantized
=
modelFilename
.
endsWith
(
"-int8.tflite"
);
inputSize
=
640
;
}
return
YoloV5Classifier
.
create
(
assetManager
,
modelFilename
,
labelFilename
,
isQuantized
,
inputSize
);
}
}
app/src/main/java/com/agenew/detection/tflite/YoloClassifier.java
0 → 100755
View file @
6609763
package
com
.
agenew
.
detection
.
tflite
;
import
android.content.res.AssetManager
;
import
android.graphics.Bitmap
;
import
android.graphics.RectF
;
import
android.util.Log
;
import
com.agenew.detection.env.Utils
;
import
com.mediatek.neuropilot_S.Interpreter
;
import
com.mediatek.neuropilot_S.Tensor
;
import
com.mediatek.neuropilot_S.nnapi.NnApiDelegate
;
import
java.io.BufferedReader
;
import
java.io.IOException
;
import
java.io.InputStream
;
import
java.io.InputStreamReader
;
import
java.nio.ByteBuffer
;
import
java.nio.ByteOrder
;
import
java.nio.MappedByteBuffer
;
import
java.util.ArrayList
;
import
java.util.Arrays
;
import
java.util.HashMap
;
import
java.util.Map
;
import
java.util.PriorityQueue
;
import
java.util.Vector
;
public
abstract
class
YoloClassifier
implements
Classifier
{
private
static
final
String
TAG
=
"YoloClassifier"
;
private
static
final
String
LABEL_FILENAME
=
"class_dump.txt"
;
private
static
final
int
NUM_THREADS
=
1
;
private
final
int
mInputSize
;
protected
int
mOutputBox
;
protected
final
boolean
mIsModelQuantized
;
/** holds a gpu delegate */
// GpuDelegate gpuDelegate = null;
/** holds an nnapi delegate */
private
NnApiDelegate
mNnapiDelegate
=
null
;
private
MappedByteBuffer
mTfliteModel
;
private
final
Interpreter
.
Options
mTfliteOptions
=
new
Interpreter
.
Options
();
protected
final
Vector
<
String
>
mLabels
=
new
Vector
<>();
private
final
int
[]
mIntValues
;
private
final
ByteBuffer
mImgData
;
protected
final
ByteBuffer
mOutData
;
private
Interpreter
mTfLite
;
private
float
mInpScale
;
private
int
mInpZeroPoint
;
protected
float
mOupScale
;
protected
int
mOupZeroPoint
;
protected
int
[]
mOutputShape
;
public
static
YoloClassifier
create
(
final
AssetManager
assetManager
,
final
String
modelFilename
)
throws
IOException
{
boolean
isYoloV8
=
modelFilename
.
contains
(
"yolov8"
);
return
isYoloV8
?
new
YoloV8Classifier
(
assetManager
,
modelFilename
)
:
new
YoloV5Classifier
(
assetManager
,
modelFilename
);
}
public
YoloClassifier
(
final
AssetManager
assetManager
,
final
String
modelFilename
)
throws
IOException
{
InputStream
labelsInput
=
assetManager
.
open
(
LABEL_FILENAME
);
BufferedReader
br
=
new
BufferedReader
(
new
InputStreamReader
(
labelsInput
));
String
line
;
while
((
line
=
br
.
readLine
())
!=
null
)
{
Log
.
w
(
TAG
,
line
);
mLabels
.
add
(
line
);
}
br
.
close
();
try
{
Interpreter
.
Options
options
=
(
new
Interpreter
.
Options
());
options
.
setNumThreads
(
NUM_THREADS
);
mTfliteModel
=
Utils
.
loadModelFile
(
assetManager
,
modelFilename
);
mTfLite
=
new
Interpreter
(
mTfliteModel
,
options
);
}
catch
(
Exception
e
)
{
throw
new
RuntimeException
(
e
);
}
final
boolean
isQuantized
=
modelFilename
.
endsWith
(
"-int8.tflite"
);
mIsModelQuantized
=
isQuantized
;
// Pre-allocate buffers.
int
numBytesPerChannel
=
isQuantized
?
1
:
4
;
Tensor
inpten
=
mTfLite
.
getInputTensor
(
0
);
mInputSize
=
inpten
.
shape
()[
1
];
mImgData
=
ByteBuffer
.
allocateDirect
(
mInputSize
*
mInputSize
*
3
*
numBytesPerChannel
);
mImgData
.
order
(
ByteOrder
.
nativeOrder
());
mIntValues
=
new
int
[
mInputSize
*
mInputSize
];
Tensor
oupten
=
mTfLite
.
getOutputTensor
(
0
);
mOutputShape
=
oupten
.
shape
();
if
(
mIsModelQuantized
)
{
mInpScale
=
inpten
.
quantizationParams
().
getScale
();
mInpZeroPoint
=
inpten
.
quantizationParams
().
getZeroPoint
();
mOupScale
=
oupten
.
quantizationParams
().
getScale
();
mOupZeroPoint
=
oupten
.
quantizationParams
().
getZeroPoint
();
Log
.
i
(
TAG
,
"WL_DEBUG YoloClassifier mInpScale = "
+
mInpScale
+
", mInpZeroPoint = "
+
mInpZeroPoint
+
", mOupScale = "
+
mOupScale
+
", mOupZeroPoint = "
+
mOupZeroPoint
+
", mInputSize = "
+
mInputSize
);
}
int
outDataCapacity
=
mOutputShape
[
1
]
*
mOutputShape
[
2
]
*
numBytesPerChannel
;
Log
.
i
(
TAG
,
"WL_DEBUG create mOutputShape = "
+
Arrays
.
toString
(
mOutputShape
)
+
", outDataCapacity = "
+
outDataCapacity
);
mOutData
=
ByteBuffer
.
allocateDirect
(
outDataCapacity
);
mOutData
.
order
(
ByteOrder
.
nativeOrder
());
}
public
int
getInputSize
()
{
return
mInputSize
;
}
@Override
public
void
close
()
{
mTfLite
.
close
();
mTfLite
=
null
;
/*
* if (gpuDelegate != null) { gpuDelegate.close(); gpuDelegate = null; }
*/
if
(
mNnapiDelegate
!=
null
)
{
mNnapiDelegate
.
close
();
mNnapiDelegate
=
null
;
}
mTfliteModel
=
null
;
}
public
void
setNumThreads
(
int
num_threads
)
{
if
(
mTfLite
!=
null
)
mTfLite
.
setNumThreads
(
num_threads
);
}
private
void
recreateInterpreter
()
{
if
(
mTfLite
!=
null
)
{
mTfLite
.
close
();
mTfLite
=
new
Interpreter
(
mTfliteModel
,
mTfliteOptions
);
}
}
public
void
useGpu
()
{
/*
* if (gpuDelegate == null) { gpuDelegate = new GpuDelegate();
* tfliteOptions.addDelegate(gpuDelegate); recreateInterpreter(); }
*/
}
public
void
useCPU
()
{
recreateInterpreter
();
}
public
void
useNNAPI
()
{
mNnapiDelegate
=
new
NnApiDelegate
();
mTfliteOptions
.
addDelegate
(
mNnapiDelegate
);
recreateInterpreter
();
}
@Override
public
float
getObjThresh
()
{
return
Utils
.
MINIMUM_CONFIDENCE_TF_OD_API
;
}
// non maximum suppression
private
ArrayList
<
Recognition
>
nms
(
ArrayList
<
Recognition
>
list
)
{
ArrayList
<
Recognition
>
nmsList
=
new
ArrayList
<>();
for
(
int
k
=
0
;
k
<
mLabels
.
size
();
k
++)
{
// 1.find max confidence per class
PriorityQueue
<
Recognition
>
pq
=
new
PriorityQueue
<>(
50
,
(
lhs
,
rhs
)
->
{
// Intentionally reversed to put high confidence at the head of the queue.
return
Float
.
compare
(
rhs
.
getConfidence
(),
lhs
.
getConfidence
());
});
for
(
int
i
=
0
;
i
<
list
.
size
();
++
i
)
{
if
(
list
.
get
(
i
).
getDetectedClass
()
==
k
)
{
pq
.
add
(
list
.
get
(
i
));
}
}
// 2.do non maximum suppression
while
(!
pq
.
isEmpty
())
{
// insert detection with max confidence
Recognition
[]
a
=
new
Recognition
[
pq
.
size
()];
Recognition
[]
detections
=
pq
.
toArray
(
a
);
Recognition
max
=
detections
[
0
];
nmsList
.
add
(
max
);
pq
.
clear
();
for
(
int
j
=
1
;
j
<
detections
.
length
;
j
++)
{
Recognition
detection
=
detections
[
j
];
RectF
b
=
detection
.
getLocation
();
float
mNmsThresh
=
0.6f
;
if
(
box_iou
(
max
.
getLocation
(),
b
)
<
mNmsThresh
)
{
pq
.
add
(
detection
);
}
}
}
}
return
nmsList
;
}
private
float
box_iou
(
RectF
a
,
RectF
b
)
{
return
box_intersection
(
a
,
b
)
/
box_union
(
a
,
b
);
}
private
float
box_intersection
(
RectF
a
,
RectF
b
)
{
float
w
=
overlap
((
a
.
left
+
a
.
right
)
/
2
,
a
.
right
-
a
.
left
,
(
b
.
left
+
b
.
right
)
/
2
,
b
.
right
-
b
.
left
);
float
h
=
overlap
((
a
.
top
+
a
.
bottom
)
/
2
,
a
.
bottom
-
a
.
top
,
(
b
.
top
+
b
.
bottom
)
/
2
,
b
.
bottom
-
b
.
top
);
if
(
w
<
0
||
h
<
0
)
return
0
;
return
w
*
h
;
}
private
float
box_union
(
RectF
a
,
RectF
b
)
{
float
i
=
box_intersection
(
a
,
b
);
return
(
a
.
right
-
a
.
left
)
*
(
a
.
bottom
-
a
.
top
)
+
(
b
.
right
-
b
.
left
)
*
(
b
.
bottom
-
b
.
top
)
-
i
;
}
private
float
overlap
(
float
x1
,
float
w1
,
float
x2
,
float
w2
)
{
float
l1
=
x1
-
w1
/
2
;
float
l2
=
x2
-
w2
/
2
;
float
left
=
Math
.
max
(
l1
,
l2
);
float
r1
=
x1
+
w1
/
2
;
float
r2
=
x2
+
w2
/
2
;
float
right
=
Math
.
min
(
r1
,
r2
);
return
right
-
left
;
}
/**
* Writes Image data into a {@code ByteBuffer}.
*/
private
void
convertBitmapToByteBuffer
(
Bitmap
bitmap
)
{
bitmap
.
getPixels
(
mIntValues
,
0
,
bitmap
.
getWidth
(),
0
,
0
,
bitmap
.
getWidth
(),
bitmap
.
getHeight
());
mImgData
.
rewind
();
for
(
int
i
=
0
;
i
<
mInputSize
;
++
i
)
{
for
(
int
j
=
0
;
j
<
mInputSize
;
++
j
)
{
int
pixelValue
=
mIntValues
[
i
*
mInputSize
+
j
];
// Float model
float
IMAGE_MEAN
=
0
;
float
IMAGE_STD
=
255.0f
;
if
(
mIsModelQuantized
)
{
// Quantized model
mImgData
.
put
((
byte
)
((((
pixelValue
>>
16
)
&
0xFF
)
-
IMAGE_MEAN
)
/
IMAGE_STD
/
mInpScale
+
mInpZeroPoint
));
mImgData
.
put
((
byte
)
((((
pixelValue
>>
8
)
&
0xFF
)
-
IMAGE_MEAN
)
/
IMAGE_STD
/
mInpScale
+
mInpZeroPoint
));
mImgData
.
put
((
byte
)
(((
pixelValue
&
0xFF
)
-
IMAGE_MEAN
)
/
IMAGE_STD
/
mInpScale
+
mInpZeroPoint
));
}
else
{
// Float model
mImgData
.
putFloat
((((
pixelValue
>>
16
)
&
0xFF
)
-
IMAGE_MEAN
)
/
IMAGE_STD
);
mImgData
.
putFloat
((((
pixelValue
>>
8
)
&
0xFF
)
-
IMAGE_MEAN
)
/
IMAGE_STD
);
mImgData
.
putFloat
(((
pixelValue
&
0xFF
)
-
IMAGE_MEAN
)
/
IMAGE_STD
);
}
}
}
}
public
ArrayList
<
Recognition
>
recognizeImage
(
Bitmap
bitmap
)
{
convertBitmapToByteBuffer
(
bitmap
);
Map
<
Integer
,
Object
>
outputMap
=
new
HashMap
<>();
mOutData
.
rewind
();
outputMap
.
put
(
0
,
mOutData
);
Log
.
d
(
TAG
,
"mObjThresh: "
+
getObjThresh
()
+
", mOutData.capacity() = "
+
mOutData
.
capacity
());
Object
[]
inputArray
=
{
mImgData
};
try
{
mTfLite
.
runForMultipleInputsOutputs
(
inputArray
,
outputMap
);
}
catch
(
Exception
e
)
{
Log
.
e
(
TAG
,
"WL_DEBUG recognizeImage e = "
+
e
,
e
);
return
new
ArrayList
<>();
}
mOutData
.
rewind
();
ArrayList
<
Recognition
>
detections
=
new
ArrayList
<>();
int
outputShape1
=
mOutputShape
[
1
];
int
outputShape2
=
mOutputShape
[
2
];
float
[][][]
out
=
new
float
[
1
][
outputShape1
][
outputShape2
];
Log
.
d
(
TAG
,
"out[0] detect start"
);
buildDetections
(
outputShape1
,
outputShape2
,
out
,
bitmap
,
detections
);
Log
.
d
(
TAG
,
"detect end"
);
return
nms
(
detections
);
}
protected
abstract
void
buildDetections
(
int
outputShape1
,
int
outputShape2
,
float
[][][]
out
,
Bitmap
bitmap
,
ArrayList
<
Recognition
>
detections
);
}
app/src/main/java/com/agenew/detection/tflite/YoloV5Classifier.java
View file @
6609763
...
...
@@ -20,350 +20,42 @@ import android.graphics.Bitmap;
import
android.graphics.RectF
;
import
android.util.Log
;
//import org.tensorflow.lite.Interpreter;
import
com.mediatek.neuropilot_S.Interpreter
;
//import org.tensorflow.lite.Tensor;
import
com.mediatek.neuropilot_S.Tensor
;
import
com.agenew.detection.MainActivity
;
import
com.agenew.detection.env.Logger
;
import
com.agenew.detection.env.Utils
;
//import org.tensorflow.lite.gpu.GpuDelegate;
//import org.tensorflow.lite.nnapi.NnApiDelegate;
import
com.mediatek.neuropilot_S.nnapi.NnApiDelegate
;
import
java.io.BufferedReader
;
import
java.io.IOException
;
import
java.io.InputStream
;
import
java.io.InputStreamReader
;
import
java.nio.ByteBuffer
;
import
java.nio.ByteOrder
;
import
java.nio.MappedByteBuffer
;
import
java.util.ArrayList
;
import
java.util.Comparator
;
import
java.util.HashMap
;
import
java.util.Map
;
import
java.util.PriorityQueue
;
import
java.util.Vector
;
/**
* Wrapper for frozen detection models trained using the Tensorflow Object
* Detection API: -
* https://github.com/tensorflow/models/tree/master/research/object_detection
* where you can find the training code.
* <p>
* To use pretrained models in the API or convert to TF Lite models, please see
* docs for details: -
* https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
* -
* https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_on_mobile_tensorflowlite.md#running-our-model-on-android
*/
public
class
YoloV5Classifier
implements
Classifier
{
public
class
YoloV5Classifier
extends
YoloClassifier
{
private
static
final
String
TAG
=
"YoloV5Classifier"
;
/**
* Initializes a native TensorFlow session for classifying images.
*
* @param assetManager The asset manager to be used to load assets.
* @param modelFilename The filepath of the model GraphDef protocol buffer.
* @param labelFilename The filepath of label file for classes.
* @param isQuantized Boolean representing model is quantized or not
*/
public
static
YoloV5Classifier
create
(
final
AssetManager
assetManager
,
final
String
modelFilename
,
final
String
labelFilename
,
final
boolean
isQuantized
,
final
int
inputSize
)
throws
IOException
{
final
YoloV5Classifier
d
=
new
YoloV5Classifier
();
String
actualFilename
=
labelFilename
.
split
(
"file:///android_asset/"
)[
1
];
InputStream
labelsInput
=
assetManager
.
open
(
actualFilename
);
BufferedReader
br
=
new
BufferedReader
(
new
InputStreamReader
(
labelsInput
));
String
line
;
while
((
line
=
br
.
readLine
())
!=
null
)
{
LOGGER
.
w
(
line
);
d
.
labels
.
add
(
line
);
}
br
.
close
();
try
{
Interpreter
.
Options
options
=
(
new
Interpreter
.
Options
());
options
.
setNumThreads
(
NUM_THREADS
);
d
.
tfliteModel
=
Utils
.
loadModelFile
(
assetManager
,
modelFilename
);
d
.
tfLite
=
new
Interpreter
(
d
.
tfliteModel
,
options
);
}
catch
(
Exception
e
)
{
throw
new
RuntimeException
(
e
);
}
d
.
isModelQuantized
=
isQuantized
;
// Pre-allocate buffers.
int
numBytesPerChannel
;
if
(
isQuantized
)
{
numBytesPerChannel
=
1
;
// Quantized
}
else
{
numBytesPerChannel
=
4
;
// Floating point
}
d
.
INPUT_SIZE
=
inputSize
;
d
.
imgData
=
ByteBuffer
.
allocateDirect
(
1
*
d
.
INPUT_SIZE
*
d
.
INPUT_SIZE
*
3
*
numBytesPerChannel
);
d
.
imgData
.
order
(
ByteOrder
.
nativeOrder
());
d
.
intValues
=
new
int
[
d
.
INPUT_SIZE
*
d
.
INPUT_SIZE
];
d
.
output_box
=
(
int
)
((
Math
.
pow
((
inputSize
/
32
),
2
)
+
Math
.
pow
((
inputSize
/
16
),
2
)
+
Math
.
pow
((
inputSize
/
8
),
2
))
*
3
);
if
(
d
.
isModelQuantized
)
{
Tensor
inpten
=
d
.
tfLite
.
getInputTensor
(
0
);
d
.
inp_scale
=
inpten
.
quantizationParams
().
getScale
();
d
.
inp_zero_point
=
inpten
.
quantizationParams
().
getZeroPoint
();
Tensor
oupten
=
d
.
tfLite
.
getOutputTensor
(
0
);
d
.
output_box
=
oupten
.
shape
()[
1
];
d
.
oup_scale
=
oupten
.
quantizationParams
().
getScale
();
d
.
oup_zero_point
=
oupten
.
quantizationParams
().
getZeroPoint
();
}
int
[]
shape
=
d
.
tfLite
.
getOutputTensor
(
0
).
shape
();
int
numClass
=
shape
[
shape
.
length
-
1
]
-
5
;
d
.
numClass
=
numClass
;
d
.
outData
=
ByteBuffer
.
allocateDirect
(
d
.
output_box
*
(
numClass
+
5
)
*
numBytesPerChannel
);
d
.
outData
.
order
(
ByteOrder
.
nativeOrder
());
return
d
;
}
public
int
getInputSize
()
{
return
INPUT_SIZE
;
}
@Override
public
void
close
()
{
tfLite
.
close
();
tfLite
=
null
;
/*
* if (gpuDelegate != null) { gpuDelegate.close(); gpuDelegate = null; }
*/
if
(
nnapiDelegate
!=
null
)
{
nnapiDelegate
.
close
();
nnapiDelegate
=
null
;
}
tfliteModel
=
null
;
}
public
void
setNumThreads
(
int
num_threads
)
{
if
(
tfLite
!=
null
)
tfLite
.
setNumThreads
(
num_threads
);
}
private
void
recreateInterpreter
()
{
if
(
tfLite
!=
null
)
{
tfLite
.
close
();
tfLite
=
new
Interpreter
(
tfliteModel
,
tfliteOptions
);
}
}
public
void
useGpu
()
{
/*
* if (gpuDelegate == null) { gpuDelegate = new GpuDelegate();
* tfliteOptions.addDelegate(gpuDelegate); recreateInterpreter(); }
*/
}
public
void
useCPU
()
{
recreateInterpreter
();
}
public
void
useNNAPI
()
{
nnapiDelegate
=
new
NnApiDelegate
();
tfliteOptions
.
addDelegate
(
nnapiDelegate
);
recreateInterpreter
();
public
YoloV5Classifier
(
final
AssetManager
assetManager
,
final
String
modelFilename
)
throws
IOException
{
super
(
assetManager
,
modelFilename
);
mOutputBox
=
mOutputShape
[
1
];
}
@Override
public
float
getObjThresh
()
{
return
MainActivity
.
MINIMUM_CONFIDENCE_TF_OD_API
;
}
private
static
final
Logger
LOGGER
=
new
Logger
();
// Float model
private
final
float
IMAGE_MEAN
=
0
;
private
final
float
IMAGE_STD
=
255.0f
;
// config yolo
private
int
INPUT_SIZE
=
-
1
;
private
int
output_box
;
// Number of threads in the java app
private
static
final
int
NUM_THREADS
=
1
;
private
boolean
isModelQuantized
;
/** holds a gpu delegate */
// GpuDelegate gpuDelegate = null;
/** holds an nnapi delegate */
NnApiDelegate
nnapiDelegate
=
null
;
/** The loaded TensorFlow Lite model. */
private
MappedByteBuffer
tfliteModel
;
/** Options for configuring the Interpreter. */
private
final
Interpreter
.
Options
tfliteOptions
=
new
Interpreter
.
Options
();
// Config values.
// Pre-allocated buffers.
private
Vector
<
String
>
labels
=
new
Vector
<
String
>();
private
int
[]
intValues
;
private
ByteBuffer
imgData
;
private
ByteBuffer
outData
;
private
Interpreter
tfLite
;
private
float
inp_scale
;
private
int
inp_zero_point
;
private
float
oup_scale
;
private
int
oup_zero_point
;
private
int
numClass
;
private
YoloV5Classifier
()
{
}
// non maximum suppression
protected
ArrayList
<
Recognition
>
nms
(
ArrayList
<
Recognition
>
list
)
{
ArrayList
<
Recognition
>
nmsList
=
new
ArrayList
<
Recognition
>();
for
(
int
k
=
0
;
k
<
labels
.
size
();
k
++)
{
// 1.find max confidence per class
PriorityQueue
<
Recognition
>
pq
=
new
PriorityQueue
<
Recognition
>(
50
,
new
Comparator
<
Recognition
>()
{
@Override
public
int
compare
(
final
Recognition
lhs
,
final
Recognition
rhs
)
{
// Intentionally reversed to put high confidence at the head of the queue.
return
Float
.
compare
(
rhs
.
getConfidence
(),
lhs
.
getConfidence
());
}
});
for
(
int
i
=
0
;
i
<
list
.
size
();
++
i
)
{
if
(
list
.
get
(
i
).
getDetectedClass
()
==
k
)
{
pq
.
add
(
list
.
get
(
i
));
}
}
// 2.do non maximum suppression
while
(
pq
.
size
()
>
0
)
{
// insert detection with max confidence
Recognition
[]
a
=
new
Recognition
[
pq
.
size
()];
Recognition
[]
detections
=
pq
.
toArray
(
a
);
Recognition
max
=
detections
[
0
];
nmsList
.
add
(
max
);
pq
.
clear
();
for
(
int
j
=
1
;
j
<
detections
.
length
;
j
++)
{
Recognition
detection
=
detections
[
j
];
RectF
b
=
detection
.
getLocation
();
if
(
box_iou
(
max
.
getLocation
(),
b
)
<
mNmsThresh
)
{
pq
.
add
(
detection
);
}
}
}
}
return
nmsList
;
}
protected
float
mNmsThresh
=
0.6f
;
protected
float
box_iou
(
RectF
a
,
RectF
b
)
{
return
box_intersection
(
a
,
b
)
/
box_union
(
a
,
b
);
}
protected
float
box_intersection
(
RectF
a
,
RectF
b
)
{
float
w
=
overlap
((
a
.
left
+
a
.
right
)
/
2
,
a
.
right
-
a
.
left
,
(
b
.
left
+
b
.
right
)
/
2
,
b
.
right
-
b
.
left
);
float
h
=
overlap
((
a
.
top
+
a
.
bottom
)
/
2
,
a
.
bottom
-
a
.
top
,
(
b
.
top
+
b
.
bottom
)
/
2
,
b
.
bottom
-
b
.
top
);
if
(
w
<
0
||
h
<
0
)
return
0
;
float
area
=
w
*
h
;
return
area
;
}
protected
float
box_union
(
RectF
a
,
RectF
b
)
{
float
i
=
box_intersection
(
a
,
b
);
float
u
=
(
a
.
right
-
a
.
left
)
*
(
a
.
bottom
-
a
.
top
)
+
(
b
.
right
-
b
.
left
)
*
(
b
.
bottom
-
b
.
top
)
-
i
;
return
u
;
}
protected
float
overlap
(
float
x1
,
float
w1
,
float
x2
,
float
w2
)
{
float
l1
=
x1
-
w1
/
2
;
float
l2
=
x2
-
w2
/
2
;
float
left
=
l1
>
l2
?
l1
:
l2
;
float
r1
=
x1
+
w1
/
2
;
float
r2
=
x2
+
w2
/
2
;
float
right
=
r1
<
r2
?
r1
:
r2
;
return
right
-
left
;
}
/**
* Writes Image data into a {@code ByteBuffer}.
*/
protected
ByteBuffer
convertBitmapToByteBuffer
(
Bitmap
bitmap
)
{
bitmap
.
getPixels
(
intValues
,
0
,
bitmap
.
getWidth
(),
0
,
0
,
bitmap
.
getWidth
(),
bitmap
.
getHeight
());
imgData
.
rewind
();
for
(
int
i
=
0
;
i
<
INPUT_SIZE
;
++
i
)
{
for
(
int
j
=
0
;
j
<
INPUT_SIZE
;
++
j
)
{
int
pixelValue
=
intValues
[
i
*
INPUT_SIZE
+
j
];
if
(
isModelQuantized
)
{
// Quantized model
imgData
.
put
((
byte
)
((((
pixelValue
>>
16
)
&
0xFF
)
-
IMAGE_MEAN
)
/
IMAGE_STD
/
inp_scale
+
inp_zero_point
));
imgData
.
put
((
byte
)
((((
pixelValue
>>
8
)
&
0xFF
)
-
IMAGE_MEAN
)
/
IMAGE_STD
/
inp_scale
+
inp_zero_point
));
imgData
.
put
((
byte
)
(((
pixelValue
&
0xFF
)
-
IMAGE_MEAN
)
/
IMAGE_STD
/
inp_scale
+
inp_zero_point
));
}
else
{
// Float model
imgData
.
putFloat
((((
pixelValue
>>
16
)
&
0xFF
)
-
IMAGE_MEAN
)
/
IMAGE_STD
);
imgData
.
putFloat
((((
pixelValue
>>
8
)
&
0xFF
)
-
IMAGE_MEAN
)
/
IMAGE_STD
);
imgData
.
putFloat
(((
pixelValue
&
0xFF
)
-
IMAGE_MEAN
)
/
IMAGE_STD
);
}
}
}
return
imgData
;
}
public
ArrayList
<
Recognition
>
recognizeImage
(
Bitmap
bitmap
)
{
convertBitmapToByteBuffer
(
bitmap
);
Map
<
Integer
,
Object
>
outputMap
=
new
HashMap
<
Integer
,
Object
>();
outData
.
rewind
();
outputMap
.
put
(
0
,
outData
);
Log
.
d
(
"YoloV5Classifier"
,
"mObjThresh: "
+
getObjThresh
());
Object
[]
inputArray
=
{
imgData
};
tfLite
.
runForMultipleInputsOutputs
(
inputArray
,
outputMap
);
ByteBuffer
byteBuffer
=
(
ByteBuffer
)
outputMap
.
get
(
0
);
byteBuffer
.
rewind
();
ArrayList
<
Recognition
>
detections
=
new
ArrayList
<
Recognition
>();
float
[][][]
out
=
new
float
[
1
][
output_box
][
numClass
+
5
];
Log
.
d
(
"YoloV5Classifier"
,
"out[0] detect start"
);
for
(
int
i
=
0
;
i
<
output_box
;
++
i
)
{
for
(
int
j
=
0
;
j
<
numClass
+
5
;
++
j
)
{
if
(
isModelQuantized
)
{
out
[
0
][
i
][
j
]
=
oup_scale
*
(((
int
)
byteBuffer
.
get
()
&
0xFF
)
-
oup_zero_point
);
protected
void
buildDetections
(
int
outputShape1
,
int
outputShape2
,
float
[][][]
out
,
Bitmap
bitmap
,
ArrayList
<
Recognition
>
detections
)
{
for
(
int
i
=
0
;
i
<
outputShape1
;
++
i
)
{
for
(
int
j
=
0
;
j
<
outputShape2
;
++
j
)
{
if
(
mIsModelQuantized
)
{
out
[
0
][
i
][
j
]
=
mOupScale
*
((((
int
)
mOutData
.
get
()
&
0xFF
)
-
mOupZeroPoint
)&
0xFF
);
}
else
{
out
[
0
][
i
][
j
]
=
byteBuffer
.
getFloat
();
out
[
0
][
i
][
j
]
=
mOutData
.
getFloat
();
}
if
(
j
<
4
)
{
out
[
0
][
i
][
j
]
*=
getInputSize
();
}
}
// Denormalize xywh
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
out
[
0
][
i
][
j
]
*=
getInputSize
();
}
}
for
(
int
i
=
0
;
i
<
output_b
ox
;
++
i
)
{
for
(
int
i
=
0
;
i
<
mOutputB
ox
;
++
i
)
{
final
int
offset
=
0
;
final
float
confidence
=
out
[
0
][
i
][
4
];
int
detectedClass
=
-
1
;
float
maxClass
=
0
;
final
float
[]
classes
=
new
float
[
l
abels
.
size
()];
for
(
int
c
=
0
;
c
<
labels
.
size
(
);
++
c
)
{
final
float
[]
classes
=
new
float
[
mL
abels
.
size
()];
for
(
int
c
=
0
;
c
<
mLabels
.
size
()
&&
c
<
(
out
[
0
][
i
].
length
-
5
);
++
c
)
{
classes
[
c
]
=
out
[
0
][
i
][
5
+
c
];
}
for
(
int
c
=
0
;
c
<
l
abels
.
size
();
++
c
)
{
for
(
int
c
=
0
;
c
<
mL
abels
.
size
();
++
c
)
{
if
(
classes
[
c
]
>
maxClass
)
{
detectedClass
=
c
;
maxClass
=
classes
[
c
];
...
...
@@ -377,17 +69,13 @@ public class YoloV5Classifier implements Classifier {
final
float
w
=
out
[
0
][
i
][
2
];
final
float
h
=
out
[
0
][
i
][
3
];
Log
.
d
(
"YoloV5Classifier"
,
Float
.
toString
(
xPos
)
+
','
+
yPos
+
','
+
w
+
','
+
h
);
Log
.
d
(
TAG
,
Float
.
toString
(
xPos
)
+
','
+
yPos
+
','
+
w
+
','
+
h
);
final
RectF
rect
=
new
RectF
(
Math
.
max
(
0
,
xPos
-
w
/
2
),
Math
.
max
(
0
,
yPos
-
h
/
2
),
Math
.
min
(
bitmap
.
getWidth
()
-
1
,
xPos
+
w
/
2
),
Math
.
min
(
bitmap
.
getHeight
()
-
1
,
yPos
+
h
/
2
));
detections
.
add
(
new
Recognition
(
""
+
offset
,
l
abels
.
get
(
detectedClass
),
confidenceInClass
,
rect
,
detections
.
add
(
new
Recognition
(
""
+
offset
,
mL
abels
.
get
(
detectedClass
),
confidenceInClass
,
rect
,
detectedClass
));
}
}
Log
.
d
(
TAG
,
"detect end"
);
final
ArrayList
<
Recognition
>
recognitions
=
nms
(
detections
);
return
recognitions
;
}
}
app/src/main/java/com/agenew/detection/tflite/YoloV8Classifier.java
0 → 100755
View file @
6609763
package
com
.
agenew
.
detection
.
tflite
;
import
android.content.res.AssetManager
;
import
android.graphics.Bitmap
;
import
android.graphics.RectF
;
import
android.util.Log
;
import
java.io.IOException
;
import
java.util.ArrayList
;
public
class
YoloV8Classifier
extends
YoloClassifier
{
private
static
final
String
TAG
=
"YoloV8Classifier"
;
public
YoloV8Classifier
(
final
AssetManager
assetManager
,
final
String
modelFilename
)
throws
IOException
{
super
(
assetManager
,
modelFilename
);
mOutputBox
=
mOutputShape
[
2
];
}
@Override
protected
void
buildDetections
(
int
outputShape1
,
int
outputShape2
,
float
[][][]
out
,
Bitmap
bitmap
,
ArrayList
<
Recognition
>
detections
)
{
for
(
int
i
=
0
;
i
<
outputShape1
;
++
i
)
{
for
(
int
j
=
0
;
j
<
outputShape2
;
++
j
)
{
if
(
mIsModelQuantized
)
{
out
[
0
][
i
][
j
]
=
mOupScale
*
((((
int
)
mOutData
.
get
()
&
0xFF
)
-
mOupZeroPoint
)&
0xFF
);
}
else
{
out
[
0
][
i
][
j
]
=
mOutData
.
getFloat
();
}
if
(
i
<
4
)
{
out
[
0
][
i
][
j
]
*=
getInputSize
();
}
}
}
for
(
int
i
=
0
;
i
<
mOutputBox
;
++
i
)
{
final
int
offset
=
0
;
int
detectedClass
=
-
1
;
float
maxClass
=
0
;
final
float
[]
classes
=
new
float
[
mLabels
.
size
()];
for
(
int
c
=
0
;
c
<
mLabels
.
size
()
&&
c
<
(
out
[
0
].
length
-
4
);
++
c
)
{
classes
[
c
]
=
out
[
0
][
4
+
c
][
i
];
}
for
(
int
c
=
0
;
c
<
mLabels
.
size
();
++
c
)
{
if
(
classes
[
c
]
>
maxClass
)
{
detectedClass
=
c
;
maxClass
=
classes
[
c
];
}
}
final
float
confidenceInClass
=
maxClass
;
if
(
confidenceInClass
>
getObjThresh
())
{
final
float
xPos
=
out
[
0
][
0
][
i
];
final
float
yPos
=
out
[
0
][
1
][
i
];
final
float
w
=
out
[
0
][
2
][
i
];
final
float
h
=
out
[
0
][
3
][
i
];
Log
.
d
(
TAG
,
Float
.
toString
(
xPos
)
+
','
+
yPos
+
','
+
w
+
','
+
h
);
final
RectF
rect
=
new
RectF
(
Math
.
max
(
0
,
xPos
-
w
/
2
),
Math
.
max
(
0
,
yPos
-
h
/
2
),
Math
.
min
(
bitmap
.
getWidth
()
-
1
,
xPos
+
w
/
2
),
Math
.
min
(
bitmap
.
getHeight
()
-
1
,
yPos
+
h
/
2
));
detections
.
add
(
new
Recognition
(
""
+
offset
,
mLabels
.
get
(
detectedClass
),
confidenceInClass
,
rect
,
detectedClass
));
}
}
}
}
app/src/main/res/layout/tfe_od_camera_connection_fragment_tracking.xml
View file @
6609763
...
...
@@ -13,18 +13,19 @@
See the License for the specific language governing permissions and
limitations under the License.
-->
<
Fram
eLayout
xmlns:android=
"http://schemas.android.com/apk/res/android"
<
Relativ
eLayout
xmlns:android=
"http://schemas.android.com/apk/res/android"
android:layout_width=
"match_parent"
android:layout_height=
"match_parent"
>
<com.agenew.detection.customview.AutoFitTextureView
android:id=
"@+id/texture"
android:layout_width=
"wrap_content"
android:layout_height=
"wrap_content"
/>
android:layout_height=
"wrap_content"
android:layout_centerInParent=
"true"
/>
<com.agenew.detection.customview.OverlayView
android:id=
"@+id/tracking_overlay"
android:layout_width=
"match_parent"
android:layout_height=
"match_parent"
/>
android:layout_height=
"480dp"
android:layout_centerInParent=
"true"
/>
</
Fram
eLayout>
</
Relativ
eLayout>
gradle/libs.versions.toml
View file @
6609763
[versions]
agp
=
"8.
3.2
"
agp
=
"8.
4.0
"
junit
=
"4.13.2"
junitVersion
=
"1.1.5"
espressoCore
=
"3.5.1"
...
...
gradle/wrapper/gradle-wrapper.properties
View file @
6609763
#Wed Apr 10 09:00:02 CST 2024
distributionBase
=
GRADLE_USER_HOME
distributionPath
=
wrapper/dists
distributionUrl
=
https
\:
//services.gradle.org/distributions/gradle-8.
4
-bin.zip
distributionUrl
=
https
\:
//services.gradle.org/distributions/gradle-8.
6
-bin.zip
zipStoreBase
=
GRADLE_USER_HOME
zipStorePath
=
wrapper/dists
Write
Preview
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment