Commit 6609763d by wanglei

统一yolov5和v8

1 parent 81d77638
......@@ -531,7 +531,7 @@ public class CameraConnectionFragment extends Fragment {
@Override
public int compare(final Size lhs, final Size rhs) {
// We cast here to ensure the multiplications won't overflow
return Long.signum((long) lhs.getWidth() * lhs.getHeight() - (long) rhs.getWidth() * rhs.getHeight());
return Long.signum((long) rhs.getWidth() * rhs.getHeight() - (long) lhs.getWidth() * lhs.getHeight());
}
}
......
package com.agenew.detection;
import android.app.Activity;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Matrix;
import android.graphics.RectF;
import android.graphics.Typeface;
import android.os.SystemClock;
import android.util.Log;
import android.util.TypedValue;
import android.view.ViewGroup;
import android.widget.Toast;
import com.agenew.detection.customview.OverlayView;
import com.agenew.detection.env.BorderedText;
import com.agenew.detection.env.ImageUtils;
import com.agenew.detection.env.Utils;
import com.agenew.detection.tflite.Classifier;
import com.agenew.detection.tflite.YoloClassifier;
import com.agenew.detection.tracking.MultiBoxTracker;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
public class DetectionManager {
private static final String TAG = "DetectionManager";
private static final float TEXT_SIZE_DIP = 10;
private static final boolean MAINTAIN_ASPECT = true;
private static final boolean SAVE_PREVIEW_BITMAP = false;
private final Context mContext;
private final OverlayView mOverlayView;
private final MultiBoxTracker mTracker;
private YoloClassifier mDetector;
private Bitmap mCroppedBitmap = null;
private Matrix mFrameToCropTransform;
private Matrix mCropToFrameTransform;
private int mWidth;
private int mHeight;
private int mSensorOrientation;
public DetectionManager(Context context, OverlayView overlayView) {
mContext = context;
final float textSizePx = TypedValue.applyDimension(TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP,
context.getResources().getDisplayMetrics());
BorderedText borderedText = new BorderedText(textSizePx);
borderedText.setTypeface(Typeface.MONOSPACE);
mTracker = new MultiBoxTracker(context);
mOverlayView = overlayView;
mOverlayView.addCallback(mTracker::draw);
}
public void updateDetector(final String modelString, String device) {
try {
mDetector = YoloClassifier.create(mContext.getAssets(), modelString);
mWidth = 0;
} catch (final IOException e) {
Log.e(TAG, "Exception initializing classifier! e = " + e, e);
Toast toast = Toast.makeText(mContext.getApplicationContext(), "Classifier could not be initialized",
Toast.LENGTH_SHORT);
toast.show();
if (mContext instanceof Activity) {
Activity activity = (Activity) mContext;
activity.finish();
}
}
switch (device) {
case "CPU":
mDetector.useCPU();
break;
case "GPU":
mDetector.useGpu();
break;
case "NNAPI":
mDetector.useNNAPI();
break;
}
}
public void updateTracker(final int width, final int height, final int sensorOrientation) {
if (mWidth != width || mHeight != height || mSensorOrientation != sensorOrientation) {
mWidth = width;
mHeight = height;
mSensorOrientation = sensorOrientation;
int cropSize = mDetector.getInputSize();
mCroppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Bitmap.Config.ARGB_8888);
mFrameToCropTransform = ImageUtils.getTransformationMatrix(width, height, cropSize, cropSize,
sensorOrientation, MAINTAIN_ASPECT);
mCropToFrameTransform = new Matrix();
mFrameToCropTransform.invert(mCropToFrameTransform);
mTracker.setFrameConfiguration(width, height, sensorOrientation);
int overlayViewHeight = mOverlayView.getHeight();
int overlayViewWidth = width*overlayViewHeight/height;
ViewGroup.LayoutParams layoutParams = mOverlayView.getLayoutParams();
layoutParams.width = overlayViewWidth;
}
}
public void closeDetector() {
if (mDetector != null) {
mDetector.close();
mDetector = null;
}
}
public void setNumThreads(int num_threads) {
mDetector.setNumThreads(num_threads);
}
public void postInvalidate() {
mOverlayView.postInvalidate();
}
public void drawBitmap(Bitmap rgbFrameBitmap) {
final Canvas canvas = new Canvas(mCroppedBitmap);
canvas.drawBitmap(rgbFrameBitmap, mFrameToCropTransform, null);
// For examining the actual TF input.
if (SAVE_PREVIEW_BITMAP) {
ImageUtils.saveBitmap(mCroppedBitmap);
}
}
public long recognizeImage(final long currTimestamp) {
final long startTime = SystemClock.uptimeMillis();
final List<Classifier.Recognition> results = mDetector.recognizeImage(mCroppedBitmap);
long lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
float minimumConfidence = Utils.MINIMUM_CONFIDENCE_TF_OD_API;
final List<Classifier.Recognition> mappedRecognitions = new LinkedList<>();
for (final Classifier.Recognition result : results) {
final RectF location = result.getLocation();
if (location != null && result.getConfidence() >= minimumConfidence) {
mCropToFrameTransform.mapRect(location);
result.setLocation(location);
mappedRecognitions.add(result);
}
}
mTracker.trackResults(mappedRecognitions, currTimestamp);
mOverlayView.postInvalidate();
return lastProcessingTimeMs;
}
public String getcropInfo() {
return mCroppedBitmap.getWidth() + "x" + mCroppedBitmap.getHeight();
}
}
\ No newline at end of file
......@@ -17,33 +17,11 @@
package com.agenew.detection;
import android.graphics.Bitmap;
import android.graphics.Bitmap.Config;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Matrix;
import android.graphics.Paint;
import android.graphics.Paint.Style;
import android.graphics.RectF;
import android.graphics.Typeface;
import android.media.ImageReader.OnImageAvailableListener;
import android.os.SystemClock;
import android.util.Log;
import android.util.Size;
import android.util.TypedValue;
import android.widget.Toast;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import com.agenew.detection.customview.OverlayView;
import com.agenew.detection.env.BorderedText;
import com.agenew.detection.env.ImageUtils;
import com.agenew.detection.env.Logger;
import com.agenew.detection.tflite.Classifier;
import com.agenew.detection.tflite.DetectorFactory;
import com.agenew.detection.tflite.YoloV5Classifier;
import com.agenew.detection.tracking.MultiBoxTracker;
/**
* An activity that uses a TensorFlowMultiBoxDetector and ObjectTracker to
......@@ -52,65 +30,27 @@ import com.agenew.detection.tracking.MultiBoxTracker;
public class MainActivity extends CameraActivity implements OnImageAvailableListener {
private static final Logger LOGGER = new Logger();
private static final DetectorMode MODE = DetectorMode.TF_OD_API;
public static final float MINIMUM_CONFIDENCE_TF_OD_API = 0.3f;
private static final boolean MAINTAIN_ASPECT = true;
private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 640);
private static final boolean SAVE_PREVIEW_BITMAP = false;
private static final float TEXT_SIZE_DIP = 10;
OverlayView trackingOverlay;
private Integer sensorOrientation;
private YoloV5Classifier detector;
private long lastProcessingTimeMs;
private Bitmap rgbFrameBitmap = null;
private Bitmap croppedBitmap = null;
private Bitmap cropCopyBitmap = null;
private boolean computingDetection = false;
private long timestamp = 0;
private Bitmap rgbFrameBitmap = null;
private Matrix frameToCropTransform;
private Matrix cropToFrameTransform;
private MultiBoxTracker tracker;
private BorderedText borderedText;
private DetectionManager mDetectionManager;
@Override
public void onPreviewSizeChosen(final Size size, final int rotation) {
final float textSizePx = TypedValue.applyDimension(TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP,
getResources().getDisplayMetrics());
borderedText = new BorderedText(textSizePx);
borderedText.setTypeface(Typeface.MONOSPACE);
tracker = new MultiBoxTracker(this);
final int modelIndex = modelView.getCheckedItemPosition();
final String modelString = modelStrings.get(modelIndex);
final int deviceIndex = deviceView.getCheckedItemPosition();
String device = deviceStrings.get(deviceIndex);
try {
detector = DetectorFactory.getDetector(getAssets(), modelString);
} catch (final IOException e) {
e.printStackTrace();
LOGGER.e(e, "Exception initializing classifier!");
Toast toast = Toast.makeText(getApplicationContext(), "Classifier could not be initialized",
Toast.LENGTH_SHORT);
toast.show();
finish();
}
if (device.equals("CPU")) {
detector.useCPU();
} else if (device.equals("GPU")) {
detector.useGpu();
} else if (device.equals("NNAPI")) {
detector.useNNAPI();
}
int cropSize = detector.getInputSize();
previewWidth = size.getWidth();
previewHeight = size.getHeight();
......@@ -119,24 +59,12 @@ public class MainActivity extends CameraActivity implements OnImageAvailableList
LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation);
LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
croppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Config.ARGB_8888);
frameToCropTransform = ImageUtils.getTransformationMatrix(previewWidth, previewHeight, cropSize, cropSize,
sensorOrientation, MAINTAIN_ASPECT);
cropToFrameTransform = new Matrix();
frameToCropTransform.invert(cropToFrameTransform);
trackingOverlay = findViewById(R.id.tracking_overlay);
trackingOverlay.addCallback(canvas -> {
tracker.draw(canvas);
if (isDebug()) {
tracker.drawDebug(canvas);
}
});
tracker.setFrameConfiguration(previewWidth, previewHeight, sensorOrientation);
OverlayView trackingOverlay = findViewById(R.id.tracking_overlay);
mDetectionManager = new DetectionManager(this, trackingOverlay);
rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Bitmap.Config.ARGB_8888);
mDetectionManager.updateDetector(modelString, device);
mDetectionManager.updateTracker(previewWidth, previewHeight, sensorOrientation);
}
protected void updateActiveModel() {
......@@ -155,51 +83,17 @@ public class MainActivity extends CameraActivity implements OnImageAvailableList
currentNumThreads = numThreads;
// Disable classifier while updating
if (detector != null) {
detector.close();
detector = null;
}
mDetectionManager.closeDetector();
// Lookup names of parameters.
String modelString = modelStrings.get(modelIndex);
String device = deviceStrings.get(deviceIndex);
LOGGER.i("Changing model to " + modelString + " device " + device);
// Try to load model.
try {
detector = DetectorFactory.getDetector(getAssets(), modelString);
// Customize the interpreter to the type of device we want to use.
if (detector == null) {
return;
}
} catch (IOException e) {
e.printStackTrace();
LOGGER.e(e, "Exception in updateActiveModel()");
Toast toast = Toast.makeText(getApplicationContext(), "Classifier could not be initialized",
Toast.LENGTH_SHORT);
toast.show();
finish();
}
if (device.equals("CPU")) {
detector.useCPU();
} else if (device.equals("GPU")) {
detector.useGpu();
} else if (device.equals("NNAPI")) {
detector.useNNAPI();
}
detector.setNumThreads(numThreads);
int cropSize = detector.getInputSize();
croppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Config.ARGB_8888);
frameToCropTransform = ImageUtils.getTransformationMatrix(previewWidth, previewHeight, cropSize,
cropSize, sensorOrientation, MAINTAIN_ASPECT);
cropToFrameTransform = new Matrix();
frameToCropTransform.invert(cropToFrameTransform);
rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Bitmap.Config.ARGB_8888);
mDetectionManager.updateDetector(modelString, device);
mDetectionManager.updateTracker(previewWidth, previewHeight, sensorOrientation);
mDetectionManager.setNumThreads(numThreads);
});
}
......@@ -207,7 +101,7 @@ public class MainActivity extends CameraActivity implements OnImageAvailableList
protected void processImage() {
++timestamp;
final long currTimestamp = timestamp;
trackingOverlay.postInvalidate();
mDetectionManager.postInvalidate();
// No mutex needed as this method is not reentrant.
if (computingDetection) {
......@@ -215,63 +109,18 @@ public class MainActivity extends CameraActivity implements OnImageAvailableList
return;
}
computingDetection = true;
LOGGER.i("Preparing image " + currTimestamp + " for detection in bg thread.");
rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);
LOGGER.i("processImage Preparing image " + currTimestamp + " for detection in bg thread.");
setPixels(getRgbBytes(), previewWidth, previewHeight);
readyForNextImage();
final Canvas canvas = new Canvas(croppedBitmap);
canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
// For examining the actual TF input.
if (SAVE_PREVIEW_BITMAP) {
ImageUtils.saveBitmap(croppedBitmap);
}
mDetectionManager.drawBitmap(rgbFrameBitmap);
runInBackground(() -> {
LOGGER.i("Running detection on image " + currTimestamp);
final long startTime = SystemClock.uptimeMillis();
final List<Classifier.Recognition> results = detector.recognizeImage(croppedBitmap);
lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
Log.e("CHECK", "run: " + results.size());
cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
final Canvas canvas1 = new Canvas(cropCopyBitmap);
final Paint paint = new Paint();
paint.setColor(Color.RED);
paint.setStyle(Style.STROKE);
paint.setStrokeWidth(2.0f);
float minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API;
switch (MODE) {
case TF_OD_API:
minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API;
break;
}
final List<Classifier.Recognition> mappedRecognitions = new LinkedList<>();
for (final Classifier.Recognition result : results) {
final RectF location = result.getLocation();
if (location != null && result.getConfidence() >= minimumConfidence) {
canvas1.drawRect(location, paint);
cropToFrameTransform.mapRect(location);
result.setLocation(location);
mappedRecognitions.add(result);
}
}
tracker.trackResults(mappedRecognitions, currTimestamp);
trackingOverlay.postInvalidate();
LOGGER.i("processImage Running detection on image " + currTimestamp);
lastProcessingTimeMs = mDetectionManager.recognizeImage(currTimestamp);
computingDetection = false;
runOnUiThread(() -> {
showFrameInfo(previewWidth + "x" + previewHeight);
showCropInfo(cropCopyBitmap.getWidth() + "x" + cropCopyBitmap.getHeight());
showCropInfo(mDetectionManager.getcropInfo());
showInference(lastProcessingTimeMs + "ms");
});
});
......@@ -290,12 +139,13 @@ public class MainActivity extends CameraActivity implements OnImageAvailableList
// Which detection model to use: by default uses Tensorflow Object Detection API
// frozen
// checkpoints.
private enum DetectorMode {
TF_OD_API;
}
@Override
protected void setNumThreads(final int numThreads) {
runInBackground(() -> detector.setNumThreads(numThreads));
runInBackground(() -> mDetectionManager.setNumThreads(numThreads));
}
private void setPixels(int[] pixels, int width, int height) {
rgbFrameBitmap.setPixels(pixels, 0, width, 0, 0, width, height);
}
}
......@@ -26,6 +26,10 @@ import java.util.List;
public class OverlayView extends View {
private final List<DrawCallback> callbacks = new LinkedList<>();
public OverlayView(final Context context) {
super(context);
}
public OverlayView(final Context context, final AttributeSet attrs) {
super(context, attrs);
}
......@@ -36,6 +40,7 @@ public class OverlayView extends View {
@Override
public synchronized void draw(final Canvas canvas) {
super.draw(canvas);
for (final DrawCallback callback : callbacks) {
callback.drawCallback(canvas);
}
......
......@@ -11,6 +11,7 @@ import java.nio.channels.FileChannel;
public class Utils {
private static final String TAG = "Utils";
public static final float MINIMUM_CONFIDENCE_TF_OD_API = 0.3f;
/**
* Memory-map the model file in Assets.
......
package com.agenew.detection.tflite;
import android.content.res.AssetManager;
import java.io.IOException;
public class DetectorFactory {
public static YoloV5Classifier getDetector(final AssetManager assetManager, final String modelFilename)
throws IOException {
String labelFilename = null;
boolean isQuantized = false;
int inputSize = 0;
if (modelFilename.endsWith(".tflite")) {
labelFilename = "file:///android_asset/class.txt";
isQuantized = modelFilename.endsWith("-int8.tflite");
inputSize = 640;
}
return YoloV5Classifier.create(assetManager, modelFilename, labelFilename, isQuantized, inputSize);
}
}
package com.agenew.detection.tflite;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.RectF;
import android.util.Log;
import com.agenew.detection.env.Utils;
import com.mediatek.neuropilot_S.Interpreter;
import com.mediatek.neuropilot_S.Tensor;
import com.mediatek.neuropilot_S.nnapi.NnApiDelegate;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Vector;
public abstract class YoloClassifier implements Classifier {
private static final String TAG = "YoloClassifier";
private static final String LABEL_FILENAME = "class_dump.txt";
private static final int NUM_THREADS = 1;
private final int mInputSize;
protected int mOutputBox;
protected final boolean mIsModelQuantized;
/** holds a gpu delegate */
// GpuDelegate gpuDelegate = null;
/** holds an nnapi delegate */
private NnApiDelegate mNnapiDelegate = null;
private MappedByteBuffer mTfliteModel;
private final Interpreter.Options mTfliteOptions = new Interpreter.Options();
protected final Vector<String> mLabels = new Vector<>();
private final int[] mIntValues;
private final ByteBuffer mImgData;
protected final ByteBuffer mOutData;
private Interpreter mTfLite;
private float mInpScale;
private int mInpZeroPoint;
protected float mOupScale;
protected int mOupZeroPoint;
protected int[] mOutputShape;
public static YoloClassifier create(final AssetManager assetManager, final String modelFilename) throws IOException {
boolean isYoloV8 = modelFilename.contains("yolov8");
return isYoloV8 ? new YoloV8Classifier(assetManager, modelFilename) : new YoloV5Classifier(assetManager, modelFilename);
}
public YoloClassifier(final AssetManager assetManager, final String modelFilename) throws IOException {
InputStream labelsInput = assetManager.open(LABEL_FILENAME);
BufferedReader br = new BufferedReader(new InputStreamReader(labelsInput));
String line;
while ((line = br.readLine()) != null) {
Log.w(TAG, line);
mLabels.add(line);
}
br.close();
try {
Interpreter.Options options = (new Interpreter.Options());
options.setNumThreads(NUM_THREADS);
mTfliteModel = Utils.loadModelFile(assetManager, modelFilename);
mTfLite = new Interpreter(mTfliteModel, options);
} catch (Exception e) {
throw new RuntimeException(e);
}
final boolean isQuantized = modelFilename.endsWith("-int8.tflite");
mIsModelQuantized = isQuantized;
// Pre-allocate buffers.
int numBytesPerChannel = isQuantized ? 1 : 4;
Tensor inpten = mTfLite.getInputTensor(0);
mInputSize = inpten.shape()[1];
mImgData = ByteBuffer.allocateDirect(mInputSize * mInputSize * 3 * numBytesPerChannel);
mImgData.order(ByteOrder.nativeOrder());
mIntValues = new int[mInputSize * mInputSize];
Tensor oupten = mTfLite.getOutputTensor(0);
mOutputShape = oupten.shape();
if (mIsModelQuantized) {
mInpScale = inpten.quantizationParams().getScale();
mInpZeroPoint = inpten.quantizationParams().getZeroPoint();
mOupScale = oupten.quantizationParams().getScale();
mOupZeroPoint = oupten.quantizationParams().getZeroPoint();
Log.i(TAG, "WL_DEBUG YoloClassifier mInpScale = " + mInpScale + ", mInpZeroPoint = " + mInpZeroPoint + ", mOupScale = " + mOupScale + ", mOupZeroPoint = " + mOupZeroPoint + ", mInputSize = " + mInputSize);
}
int outDataCapacity = mOutputShape[1] * mOutputShape[2] * numBytesPerChannel;
Log.i(TAG, "WL_DEBUG create mOutputShape = " + Arrays.toString(mOutputShape) + ", outDataCapacity = " + outDataCapacity);
mOutData = ByteBuffer.allocateDirect(outDataCapacity);
mOutData.order(ByteOrder.nativeOrder());
}
public int getInputSize() {
return mInputSize;
}
@Override
public void close() {
mTfLite.close();
mTfLite = null;
/*
* if (gpuDelegate != null) { gpuDelegate.close(); gpuDelegate = null; }
*/
if (mNnapiDelegate != null) {
mNnapiDelegate.close();
mNnapiDelegate = null;
}
mTfliteModel = null;
}
public void setNumThreads(int num_threads) {
if (mTfLite != null)
mTfLite.setNumThreads(num_threads);
}
private void recreateInterpreter() {
if (mTfLite != null) {
mTfLite.close();
mTfLite = new Interpreter(mTfliteModel, mTfliteOptions);
}
}
public void useGpu() {
/*
* if (gpuDelegate == null) { gpuDelegate = new GpuDelegate();
* tfliteOptions.addDelegate(gpuDelegate); recreateInterpreter(); }
*/
}
public void useCPU() {
recreateInterpreter();
}
public void useNNAPI() {
mNnapiDelegate = new NnApiDelegate();
mTfliteOptions.addDelegate(mNnapiDelegate);
recreateInterpreter();
}
@Override
public float getObjThresh() {
return Utils.MINIMUM_CONFIDENCE_TF_OD_API;
}
// non maximum suppression
private ArrayList<Recognition> nms(ArrayList<Recognition> list) {
ArrayList<Recognition> nmsList = new ArrayList<>();
for (int k = 0; k < mLabels.size(); k++) {
// 1.find max confidence per class
PriorityQueue<Recognition> pq = new PriorityQueue<>(50, (lhs, rhs) -> {
// Intentionally reversed to put high confidence at the head of the queue.
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
});
for (int i = 0; i < list.size(); ++i) {
if (list.get(i).getDetectedClass() == k) {
pq.add(list.get(i));
}
}
// 2.do non maximum suppression
while (!pq.isEmpty()) {
// insert detection with max confidence
Recognition[] a = new Recognition[pq.size()];
Recognition[] detections = pq.toArray(a);
Recognition max = detections[0];
nmsList.add(max);
pq.clear();
for (int j = 1; j < detections.length; j++) {
Recognition detection = detections[j];
RectF b = detection.getLocation();
float mNmsThresh = 0.6f;
if (box_iou(max.getLocation(), b) < mNmsThresh) {
pq.add(detection);
}
}
}
}
return nmsList;
}
private float box_iou(RectF a, RectF b) {
return box_intersection(a, b) / box_union(a, b);
}
private float box_intersection(RectF a, RectF b) {
float w = overlap((a.left + a.right) / 2, a.right - a.left, (b.left + b.right) / 2, b.right - b.left);
float h = overlap((a.top + a.bottom) / 2, a.bottom - a.top, (b.top + b.bottom) / 2, b.bottom - b.top);
if (w < 0 || h < 0)
return 0;
return w * h;
}
private float box_union(RectF a, RectF b) {
float i = box_intersection(a, b);
return (a.right - a.left) * (a.bottom - a.top) + (b.right - b.left) * (b.bottom - b.top) - i;
}
private float overlap(float x1, float w1, float x2, float w2) {
float l1 = x1 - w1 / 2;
float l2 = x2 - w2 / 2;
float left = Math.max(l1, l2);
float r1 = x1 + w1 / 2;
float r2 = x2 + w2 / 2;
float right = Math.min(r1, r2);
return right - left;
}
/**
* Writes Image data into a {@code ByteBuffer}.
*/
private void convertBitmapToByteBuffer(Bitmap bitmap) {
bitmap.getPixels(mIntValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
mImgData.rewind();
for (int i = 0; i < mInputSize; ++i) {
for (int j = 0; j < mInputSize; ++j) {
int pixelValue = mIntValues[i * mInputSize + j];
// Float model
float IMAGE_MEAN = 0;
float IMAGE_STD = 255.0f;
if (mIsModelQuantized) {
// Quantized model
mImgData.put((byte) ((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD / mInpScale
+ mInpZeroPoint));
mImgData.put((byte) ((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD / mInpScale
+ mInpZeroPoint));
mImgData.put((byte) (((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD / mInpScale + mInpZeroPoint));
} else { // Float model
mImgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
mImgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
mImgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
}
}
}
}
public ArrayList<Recognition> recognizeImage(Bitmap bitmap) {
convertBitmapToByteBuffer(bitmap);
Map<Integer, Object> outputMap = new HashMap<>();
mOutData.rewind();
outputMap.put(0, mOutData);
Log.d(TAG, "mObjThresh: " + getObjThresh() + ", mOutData.capacity() = " + mOutData.capacity());
Object[] inputArray = { mImgData };
try {
mTfLite.runForMultipleInputsOutputs(inputArray, outputMap);
}catch (Exception e) {
Log.e(TAG, "WL_DEBUG recognizeImage e = " + e, e);
return new ArrayList<>();
}
mOutData.rewind();
ArrayList<Recognition> detections = new ArrayList<>();
int outputShape1 = mOutputShape[1];
int outputShape2 = mOutputShape[2];
float[][][] out = new float[1][outputShape1][outputShape2];
Log.d(TAG, "out[0] detect start");
buildDetections(outputShape1, outputShape2, out, bitmap, detections);
Log.d(TAG, "detect end");
return nms(detections);
}
protected abstract void buildDetections(int outputShape1, int outputShape2, float[][][] out, Bitmap bitmap, ArrayList<Recognition> detections);
}
......@@ -20,350 +20,42 @@ import android.graphics.Bitmap;
import android.graphics.RectF;
import android.util.Log;
//import org.tensorflow.lite.Interpreter;
import com.mediatek.neuropilot_S.Interpreter;
//import org.tensorflow.lite.Tensor;
import com.mediatek.neuropilot_S.Tensor;
import com.agenew.detection.MainActivity;
import com.agenew.detection.env.Logger;
import com.agenew.detection.env.Utils;
//import org.tensorflow.lite.gpu.GpuDelegate;
//import org.tensorflow.lite.nnapi.NnApiDelegate;
import com.mediatek.neuropilot_S.nnapi.NnApiDelegate;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Vector;
/**
* Wrapper for frozen detection models trained using the Tensorflow Object
* Detection API: -
* https://github.com/tensorflow/models/tree/master/research/object_detection
* where you can find the training code.
* <p>
* To use pretrained models in the API or convert to TF Lite models, please see
* docs for details: -
* https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
* -
* https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_on_mobile_tensorflowlite.md#running-our-model-on-android
*/
public class YoloV5Classifier implements Classifier {
public class YoloV5Classifier extends YoloClassifier {
private static final String TAG = "YoloV5Classifier";
/**
* Initializes a native TensorFlow session for classifying images.
*
* @param assetManager The asset manager to be used to load assets.
* @param modelFilename The filepath of the model GraphDef protocol buffer.
* @param labelFilename The filepath of label file for classes.
* @param isQuantized Boolean representing model is quantized or not
*/
public static YoloV5Classifier create(final AssetManager assetManager, final String modelFilename,
final String labelFilename, final boolean isQuantized, final int inputSize) throws IOException {
final YoloV5Classifier d = new YoloV5Classifier();
String actualFilename = labelFilename.split("file:///android_asset/")[1];
InputStream labelsInput = assetManager.open(actualFilename);
BufferedReader br = new BufferedReader(new InputStreamReader(labelsInput));
String line;
while ((line = br.readLine()) != null) {
LOGGER.w(line);
d.labels.add(line);
}
br.close();
try {
Interpreter.Options options = (new Interpreter.Options());
options.setNumThreads(NUM_THREADS);
d.tfliteModel = Utils.loadModelFile(assetManager, modelFilename);
d.tfLite = new Interpreter(d.tfliteModel, options);
} catch (Exception e) {
throw new RuntimeException(e);
}
d.isModelQuantized = isQuantized;
// Pre-allocate buffers.
int numBytesPerChannel;
if (isQuantized) {
numBytesPerChannel = 1; // Quantized
} else {
numBytesPerChannel = 4; // Floating point
}
d.INPUT_SIZE = inputSize;
d.imgData = ByteBuffer.allocateDirect(1 * d.INPUT_SIZE * d.INPUT_SIZE * 3 * numBytesPerChannel);
d.imgData.order(ByteOrder.nativeOrder());
d.intValues = new int[d.INPUT_SIZE * d.INPUT_SIZE];
d.output_box = (int) ((Math.pow((inputSize / 32), 2) + Math.pow((inputSize / 16), 2)
+ Math.pow((inputSize / 8), 2)) * 3);
if (d.isModelQuantized) {
Tensor inpten = d.tfLite.getInputTensor(0);
d.inp_scale = inpten.quantizationParams().getScale();
d.inp_zero_point = inpten.quantizationParams().getZeroPoint();
Tensor oupten = d.tfLite.getOutputTensor(0);
d.output_box = oupten.shape()[1];
d.oup_scale = oupten.quantizationParams().getScale();
d.oup_zero_point = oupten.quantizationParams().getZeroPoint();
}
int[] shape = d.tfLite.getOutputTensor(0).shape();
int numClass = shape[shape.length - 1] - 5;
d.numClass = numClass;
d.outData = ByteBuffer.allocateDirect(d.output_box * (numClass + 5) * numBytesPerChannel);
d.outData.order(ByteOrder.nativeOrder());
return d;
}
public int getInputSize() {
return INPUT_SIZE;
}
@Override
public void close() {
tfLite.close();
tfLite = null;
/*
* if (gpuDelegate != null) { gpuDelegate.close(); gpuDelegate = null; }
*/
if (nnapiDelegate != null) {
nnapiDelegate.close();
nnapiDelegate = null;
}
tfliteModel = null;
}
public void setNumThreads(int num_threads) {
if (tfLite != null)
tfLite.setNumThreads(num_threads);
}
private void recreateInterpreter() {
if (tfLite != null) {
tfLite.close();
tfLite = new Interpreter(tfliteModel, tfliteOptions);
}
}
public void useGpu() {
/*
* if (gpuDelegate == null) { gpuDelegate = new GpuDelegate();
* tfliteOptions.addDelegate(gpuDelegate); recreateInterpreter(); }
*/
}
public void useCPU() {
recreateInterpreter();
}
public void useNNAPI() {
nnapiDelegate = new NnApiDelegate();
tfliteOptions.addDelegate(nnapiDelegate);
recreateInterpreter();
public YoloV5Classifier(final AssetManager assetManager, final String modelFilename) throws IOException {
super(assetManager, modelFilename);
mOutputBox = mOutputShape[1];
}
@Override
public float getObjThresh() {
return MainActivity.MINIMUM_CONFIDENCE_TF_OD_API;
}
private static final Logger LOGGER = new Logger();
// Float model
private final float IMAGE_MEAN = 0;
private final float IMAGE_STD = 255.0f;
// config yolo
private int INPUT_SIZE = -1;
private int output_box;
// Number of threads in the java app
private static final int NUM_THREADS = 1;
private boolean isModelQuantized;
/** holds a gpu delegate */
// GpuDelegate gpuDelegate = null;
/** holds an nnapi delegate */
NnApiDelegate nnapiDelegate = null;
/** The loaded TensorFlow Lite model. */
private MappedByteBuffer tfliteModel;
/** Options for configuring the Interpreter. */
private final Interpreter.Options tfliteOptions = new Interpreter.Options();
// Config values.
// Pre-allocated buffers.
private Vector<String> labels = new Vector<String>();
private int[] intValues;
private ByteBuffer imgData;
private ByteBuffer outData;
private Interpreter tfLite;
private float inp_scale;
private int inp_zero_point;
private float oup_scale;
private int oup_zero_point;
private int numClass;
private YoloV5Classifier() {
}
// non maximum suppression
protected ArrayList<Recognition> nms(ArrayList<Recognition> list) {
ArrayList<Recognition> nmsList = new ArrayList<Recognition>();
for (int k = 0; k < labels.size(); k++) {
// 1.find max confidence per class
PriorityQueue<Recognition> pq = new PriorityQueue<Recognition>(50, new Comparator<Recognition>() {
@Override
public int compare(final Recognition lhs, final Recognition rhs) {
// Intentionally reversed to put high confidence at the head of the queue.
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
}
});
for (int i = 0; i < list.size(); ++i) {
if (list.get(i).getDetectedClass() == k) {
pq.add(list.get(i));
}
}
// 2.do non maximum suppression
while (pq.size() > 0) {
// insert detection with max confidence
Recognition[] a = new Recognition[pq.size()];
Recognition[] detections = pq.toArray(a);
Recognition max = detections[0];
nmsList.add(max);
pq.clear();
for (int j = 1; j < detections.length; j++) {
Recognition detection = detections[j];
RectF b = detection.getLocation();
if (box_iou(max.getLocation(), b) < mNmsThresh) {
pq.add(detection);
}
}
}
}
return nmsList;
}
protected float mNmsThresh = 0.6f;
protected float box_iou(RectF a, RectF b) {
return box_intersection(a, b) / box_union(a, b);
}
protected float box_intersection(RectF a, RectF b) {
float w = overlap((a.left + a.right) / 2, a.right - a.left, (b.left + b.right) / 2, b.right - b.left);
float h = overlap((a.top + a.bottom) / 2, a.bottom - a.top, (b.top + b.bottom) / 2, b.bottom - b.top);
if (w < 0 || h < 0)
return 0;
float area = w * h;
return area;
}
protected float box_union(RectF a, RectF b) {
float i = box_intersection(a, b);
float u = (a.right - a.left) * (a.bottom - a.top) + (b.right - b.left) * (b.bottom - b.top) - i;
return u;
}
protected float overlap(float x1, float w1, float x2, float w2) {
float l1 = x1 - w1 / 2;
float l2 = x2 - w2 / 2;
float left = l1 > l2 ? l1 : l2;
float r1 = x1 + w1 / 2;
float r2 = x2 + w2 / 2;
float right = r1 < r2 ? r1 : r2;
return right - left;
}
/**
* Writes Image data into a {@code ByteBuffer}.
*/
protected ByteBuffer convertBitmapToByteBuffer(Bitmap bitmap) {
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
imgData.rewind();
for (int i = 0; i < INPUT_SIZE; ++i) {
for (int j = 0; j < INPUT_SIZE; ++j) {
int pixelValue = intValues[i * INPUT_SIZE + j];
if (isModelQuantized) {
// Quantized model
imgData.put((byte) ((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD / inp_scale
+ inp_zero_point));
imgData.put((byte) ((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD / inp_scale
+ inp_zero_point));
imgData.put((byte) (((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD / inp_scale + inp_zero_point));
} else { // Float model
imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
}
}
}
return imgData;
}
public ArrayList<Recognition> recognizeImage(Bitmap bitmap) {
convertBitmapToByteBuffer(bitmap);
Map<Integer, Object> outputMap = new HashMap<Integer, Object>();
outData.rewind();
outputMap.put(0, outData);
Log.d("YoloV5Classifier", "mObjThresh: " + getObjThresh());
Object[] inputArray = { imgData };
tfLite.runForMultipleInputsOutputs(inputArray, outputMap);
ByteBuffer byteBuffer = (ByteBuffer) outputMap.get(0);
byteBuffer.rewind();
ArrayList<Recognition> detections = new ArrayList<Recognition>();
float[][][] out = new float[1][output_box][numClass + 5];
Log.d("YoloV5Classifier", "out[0] detect start");
for (int i = 0; i < output_box; ++i) {
for (int j = 0; j < numClass + 5; ++j) {
if (isModelQuantized) {
out[0][i][j] = oup_scale * (((int) byteBuffer.get() & 0xFF) - oup_zero_point);
protected void buildDetections(int outputShape1, int outputShape2, float[][][] out, Bitmap bitmap, ArrayList<Recognition> detections) {
for (int i = 0; i < outputShape1; ++i) {
for (int j = 0; j < outputShape2; ++j) {
if (mIsModelQuantized) {
out[0][i][j] = mOupScale * ((((int) mOutData.get() & 0xFF) - mOupZeroPoint)&0xFF);
} else {
out[0][i][j] = byteBuffer.getFloat();
out[0][i][j] = mOutData.getFloat();
}
if (j < 4) {
out[0][i][j] *= getInputSize();
}
}
// Denormalize xywh
for (int j = 0; j < 4; ++j) {
out[0][i][j] *= getInputSize();
}
}
for (int i = 0; i < output_box; ++i) {
for (int i = 0; i < mOutputBox; ++i) {
final int offset = 0;
final float confidence = out[0][i][4];
int detectedClass = -1;
float maxClass = 0;
final float[] classes = new float[labels.size()];
for (int c = 0; c < labels.size(); ++c) {
final float[] classes = new float[mLabels.size()];
for (int c = 0; c < mLabels.size() && c < (out[0][i].length-5); ++c) {
classes[c] = out[0][i][5 + c];
}
for (int c = 0; c < labels.size(); ++c) {
for (int c = 0; c < mLabels.size(); ++c) {
if (classes[c] > maxClass) {
detectedClass = c;
maxClass = classes[c];
......@@ -377,17 +69,13 @@ public class YoloV5Classifier implements Classifier {
final float w = out[0][i][2];
final float h = out[0][i][3];
Log.d("YoloV5Classifier", Float.toString(xPos) + ',' + yPos + ',' + w + ',' + h);
Log.d(TAG, Float.toString(xPos) + ',' + yPos + ',' + w + ',' + h);
final RectF rect = new RectF(Math.max(0, xPos - w / 2), Math.max(0, yPos - h / 2),
Math.min(bitmap.getWidth() - 1, xPos + w / 2), Math.min(bitmap.getHeight() - 1, yPos + h / 2));
detections.add(new Recognition("" + offset, labels.get(detectedClass), confidenceInClass, rect,
detections.add(new Recognition("" + offset, mLabels.get(detectedClass), confidenceInClass, rect,
detectedClass));
}
}
Log.d(TAG, "detect end");
final ArrayList<Recognition> recognitions = nms(detections);
return recognitions;
}
}
package com.agenew.detection.tflite;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.RectF;
import android.util.Log;
import java.io.IOException;
import java.util.ArrayList;
public class YoloV8Classifier extends YoloClassifier {
private static final String TAG = "YoloV8Classifier";
public YoloV8Classifier(final AssetManager assetManager, final String modelFilename) throws IOException {
super(assetManager, modelFilename);
mOutputBox = mOutputShape[2];
}
@Override
protected void buildDetections(int outputShape1, int outputShape2, float[][][] out, Bitmap bitmap, ArrayList<Recognition> detections) {
for (int i = 0; i < outputShape1; ++i) {
for (int j = 0; j < outputShape2; ++j) {
if (mIsModelQuantized) {
out[0][i][j] = mOupScale * ((((int) mOutData.get() & 0xFF) - mOupZeroPoint)&0xFF);
} else {
out[0][i][j] = mOutData.getFloat();
}
if (i < 4) {
out[0][i][j] *= getInputSize();
}
}
}
for (int i = 0; i < mOutputBox; ++i) {
final int offset = 0;
int detectedClass = -1;
float maxClass = 0;
final float[] classes = new float[mLabels.size()];
for (int c = 0; c < mLabels.size() && c < (out[0].length-4); ++c) {
classes[c] = out[0][4+c][i];
}
for (int c = 0; c < mLabels.size(); ++c) {
if (classes[c] > maxClass) {
detectedClass = c;
maxClass = classes[c];
}
}
final float confidenceInClass = maxClass;
if (confidenceInClass > getObjThresh()) {
final float xPos = out[0][0][i];
final float yPos = out[0][1][i];
final float w = out[0][2][i];
final float h = out[0][3][i];
Log.d(TAG, Float.toString(xPos) + ',' + yPos + ',' + w + ',' + h);
final RectF rect = new RectF(Math.max(0, xPos - w / 2), Math.max(0, yPos - h / 2),
Math.min(bitmap.getWidth() - 1, xPos + w / 2), Math.min(bitmap.getHeight() - 1, yPos + h / 2));
detections.add(new Recognition("" + offset, mLabels.get(detectedClass), confidenceInClass, rect,
detectedClass));
}
}
}
}
......@@ -13,18 +13,19 @@
See the License for the specific language governing permissions and
limitations under the License.
-->
<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="match_parent">
<com.agenew.detection.customview.AutoFitTextureView
android:id="@+id/texture"
android:layout_width="wrap_content"
android:layout_height="wrap_content" />
android:layout_height="wrap_content"
android:layout_centerInParent="true"/>
<com.agenew.detection.customview.OverlayView
android:id="@+id/tracking_overlay"
android:layout_width="match_parent"
android:layout_height="match_parent" />
android:layout_height="480dp"
android:layout_centerInParent="true"/>
</FrameLayout>
</RelativeLayout>
[versions]
agp = "8.3.2"
agp = "8.4.0"
junit = "4.13.2"
junitVersion = "1.1.5"
espressoCore = "3.5.1"
......
#Wed Apr 10 09:00:02 CST 2024
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.4-bin.zip
distributionUrl=https\://services.gradle.org/distributions/gradle-8.6-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!