Commit 6609763d by wanglei

统一yolov5和v8

1 parent 81d77638
......@@ -531,7 +531,7 @@ public class CameraConnectionFragment extends Fragment {
@Override
public int compare(final Size lhs, final Size rhs) {
// We cast here to ensure the multiplications won't overflow
return Long.signum((long) lhs.getWidth() * lhs.getHeight() - (long) rhs.getWidth() * rhs.getHeight());
return Long.signum((long) rhs.getWidth() * rhs.getHeight() - (long) lhs.getWidth() * lhs.getHeight());
}
}
......
package com.agenew.detection;
import android.app.Activity;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Matrix;
import android.graphics.RectF;
import android.graphics.Typeface;
import android.os.SystemClock;
import android.util.Log;
import android.util.TypedValue;
import android.view.ViewGroup;
import android.widget.Toast;
import com.agenew.detection.customview.OverlayView;
import com.agenew.detection.env.BorderedText;
import com.agenew.detection.env.ImageUtils;
import com.agenew.detection.env.Utils;
import com.agenew.detection.tflite.Classifier;
import com.agenew.detection.tflite.YoloClassifier;
import com.agenew.detection.tracking.MultiBoxTracker;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
public class DetectionManager {
private static final String TAG = "DetectionManager";
private static final float TEXT_SIZE_DIP = 10;
private static final boolean MAINTAIN_ASPECT = true;
private static final boolean SAVE_PREVIEW_BITMAP = false;
private final Context mContext;
private final OverlayView mOverlayView;
private final MultiBoxTracker mTracker;
private YoloClassifier mDetector;
private Bitmap mCroppedBitmap = null;
private Matrix mFrameToCropTransform;
private Matrix mCropToFrameTransform;
private int mWidth;
private int mHeight;
private int mSensorOrientation;
public DetectionManager(Context context, OverlayView overlayView) {
mContext = context;
final float textSizePx = TypedValue.applyDimension(TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP,
context.getResources().getDisplayMetrics());
BorderedText borderedText = new BorderedText(textSizePx);
borderedText.setTypeface(Typeface.MONOSPACE);
mTracker = new MultiBoxTracker(context);
mOverlayView = overlayView;
mOverlayView.addCallback(mTracker::draw);
}
public void updateDetector(final String modelString, String device) {
try {
mDetector = YoloClassifier.create(mContext.getAssets(), modelString);
mWidth = 0;
} catch (final IOException e) {
Log.e(TAG, "Exception initializing classifier! e = " + e, e);
Toast toast = Toast.makeText(mContext.getApplicationContext(), "Classifier could not be initialized",
Toast.LENGTH_SHORT);
toast.show();
if (mContext instanceof Activity) {
Activity activity = (Activity) mContext;
activity.finish();
}
}
switch (device) {
case "CPU":
mDetector.useCPU();
break;
case "GPU":
mDetector.useGpu();
break;
case "NNAPI":
mDetector.useNNAPI();
break;
}
}
public void updateTracker(final int width, final int height, final int sensorOrientation) {
if (mWidth != width || mHeight != height || mSensorOrientation != sensorOrientation) {
mWidth = width;
mHeight = height;
mSensorOrientation = sensorOrientation;
int cropSize = mDetector.getInputSize();
mCroppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Bitmap.Config.ARGB_8888);
mFrameToCropTransform = ImageUtils.getTransformationMatrix(width, height, cropSize, cropSize,
sensorOrientation, MAINTAIN_ASPECT);
mCropToFrameTransform = new Matrix();
mFrameToCropTransform.invert(mCropToFrameTransform);
mTracker.setFrameConfiguration(width, height, sensorOrientation);
int overlayViewHeight = mOverlayView.getHeight();
int overlayViewWidth = width*overlayViewHeight/height;
ViewGroup.LayoutParams layoutParams = mOverlayView.getLayoutParams();
layoutParams.width = overlayViewWidth;
}
}
public void closeDetector() {
if (mDetector != null) {
mDetector.close();
mDetector = null;
}
}
public void setNumThreads(int num_threads) {
mDetector.setNumThreads(num_threads);
}
public void postInvalidate() {
mOverlayView.postInvalidate();
}
public void drawBitmap(Bitmap rgbFrameBitmap) {
final Canvas canvas = new Canvas(mCroppedBitmap);
canvas.drawBitmap(rgbFrameBitmap, mFrameToCropTransform, null);
// For examining the actual TF input.
if (SAVE_PREVIEW_BITMAP) {
ImageUtils.saveBitmap(mCroppedBitmap);
}
}
public long recognizeImage(final long currTimestamp) {
final long startTime = SystemClock.uptimeMillis();
final List<Classifier.Recognition> results = mDetector.recognizeImage(mCroppedBitmap);
long lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
float minimumConfidence = Utils.MINIMUM_CONFIDENCE_TF_OD_API;
final List<Classifier.Recognition> mappedRecognitions = new LinkedList<>();
for (final Classifier.Recognition result : results) {
final RectF location = result.getLocation();
if (location != null && result.getConfidence() >= minimumConfidence) {
mCropToFrameTransform.mapRect(location);
result.setLocation(location);
mappedRecognitions.add(result);
}
}
mTracker.trackResults(mappedRecognitions, currTimestamp);
mOverlayView.postInvalidate();
return lastProcessingTimeMs;
}
public String getcropInfo() {
return mCroppedBitmap.getWidth() + "x" + mCroppedBitmap.getHeight();
}
}
\ No newline at end of file
......@@ -26,6 +26,10 @@ import java.util.List;
public class OverlayView extends View {
private final List<DrawCallback> callbacks = new LinkedList<>();
public OverlayView(final Context context) {
super(context);
}
public OverlayView(final Context context, final AttributeSet attrs) {
super(context, attrs);
}
......@@ -36,6 +40,7 @@ public class OverlayView extends View {
@Override
public synchronized void draw(final Canvas canvas) {
super.draw(canvas);
for (final DrawCallback callback : callbacks) {
callback.drawCallback(canvas);
}
......
......@@ -11,6 +11,7 @@ import java.nio.channels.FileChannel;
public class Utils {
private static final String TAG = "Utils";
public static final float MINIMUM_CONFIDENCE_TF_OD_API = 0.3f;
/**
* Memory-map the model file in Assets.
......
package com.agenew.detection.tflite;
import android.content.res.AssetManager;
import java.io.IOException;
public class DetectorFactory {
public static YoloV5Classifier getDetector(final AssetManager assetManager, final String modelFilename)
throws IOException {
String labelFilename = null;
boolean isQuantized = false;
int inputSize = 0;
if (modelFilename.endsWith(".tflite")) {
labelFilename = "file:///android_asset/class.txt";
isQuantized = modelFilename.endsWith("-int8.tflite");
inputSize = 640;
}
return YoloV5Classifier.create(assetManager, modelFilename, labelFilename, isQuantized, inputSize);
}
}
package com.agenew.detection.tflite;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.RectF;
import android.util.Log;
import com.agenew.detection.env.Utils;
import com.mediatek.neuropilot_S.Interpreter;
import com.mediatek.neuropilot_S.Tensor;
import com.mediatek.neuropilot_S.nnapi.NnApiDelegate;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Vector;
public abstract class YoloClassifier implements Classifier {
private static final String TAG = "YoloClassifier";
private static final String LABEL_FILENAME = "class_dump.txt";
private static final int NUM_THREADS = 1;
private final int mInputSize;
protected int mOutputBox;
protected final boolean mIsModelQuantized;
/** holds a gpu delegate */
// GpuDelegate gpuDelegate = null;
/** holds an nnapi delegate */
private NnApiDelegate mNnapiDelegate = null;
private MappedByteBuffer mTfliteModel;
private final Interpreter.Options mTfliteOptions = new Interpreter.Options();
protected final Vector<String> mLabels = new Vector<>();
private final int[] mIntValues;
private final ByteBuffer mImgData;
protected final ByteBuffer mOutData;
private Interpreter mTfLite;
private float mInpScale;
private int mInpZeroPoint;
protected float mOupScale;
protected int mOupZeroPoint;
protected int[] mOutputShape;
public static YoloClassifier create(final AssetManager assetManager, final String modelFilename) throws IOException {
boolean isYoloV8 = modelFilename.contains("yolov8");
return isYoloV8 ? new YoloV8Classifier(assetManager, modelFilename) : new YoloV5Classifier(assetManager, modelFilename);
}
public YoloClassifier(final AssetManager assetManager, final String modelFilename) throws IOException {
InputStream labelsInput = assetManager.open(LABEL_FILENAME);
BufferedReader br = new BufferedReader(new InputStreamReader(labelsInput));
String line;
while ((line = br.readLine()) != null) {
Log.w(TAG, line);
mLabels.add(line);
}
br.close();
try {
Interpreter.Options options = (new Interpreter.Options());
options.setNumThreads(NUM_THREADS);
mTfliteModel = Utils.loadModelFile(assetManager, modelFilename);
mTfLite = new Interpreter(mTfliteModel, options);
} catch (Exception e) {
throw new RuntimeException(e);
}
final boolean isQuantized = modelFilename.endsWith("-int8.tflite");
mIsModelQuantized = isQuantized;
// Pre-allocate buffers.
int numBytesPerChannel = isQuantized ? 1 : 4;
Tensor inpten = mTfLite.getInputTensor(0);
mInputSize = inpten.shape()[1];
mImgData = ByteBuffer.allocateDirect(mInputSize * mInputSize * 3 * numBytesPerChannel);
mImgData.order(ByteOrder.nativeOrder());
mIntValues = new int[mInputSize * mInputSize];
Tensor oupten = mTfLite.getOutputTensor(0);
mOutputShape = oupten.shape();
if (mIsModelQuantized) {
mInpScale = inpten.quantizationParams().getScale();
mInpZeroPoint = inpten.quantizationParams().getZeroPoint();
mOupScale = oupten.quantizationParams().getScale();
mOupZeroPoint = oupten.quantizationParams().getZeroPoint();
Log.i(TAG, "WL_DEBUG YoloClassifier mInpScale = " + mInpScale + ", mInpZeroPoint = " + mInpZeroPoint + ", mOupScale = " + mOupScale + ", mOupZeroPoint = " + mOupZeroPoint + ", mInputSize = " + mInputSize);
}
int outDataCapacity = mOutputShape[1] * mOutputShape[2] * numBytesPerChannel;
Log.i(TAG, "WL_DEBUG create mOutputShape = " + Arrays.toString(mOutputShape) + ", outDataCapacity = " + outDataCapacity);
mOutData = ByteBuffer.allocateDirect(outDataCapacity);
mOutData.order(ByteOrder.nativeOrder());
}
public int getInputSize() {
return mInputSize;
}
@Override
public void close() {
mTfLite.close();
mTfLite = null;
/*
* if (gpuDelegate != null) { gpuDelegate.close(); gpuDelegate = null; }
*/
if (mNnapiDelegate != null) {
mNnapiDelegate.close();
mNnapiDelegate = null;
}
mTfliteModel = null;
}
public void setNumThreads(int num_threads) {
if (mTfLite != null)
mTfLite.setNumThreads(num_threads);
}
private void recreateInterpreter() {
if (mTfLite != null) {
mTfLite.close();
mTfLite = new Interpreter(mTfliteModel, mTfliteOptions);
}
}
public void useGpu() {
/*
* if (gpuDelegate == null) { gpuDelegate = new GpuDelegate();
* tfliteOptions.addDelegate(gpuDelegate); recreateInterpreter(); }
*/
}
public void useCPU() {
recreateInterpreter();
}
public void useNNAPI() {
mNnapiDelegate = new NnApiDelegate();
mTfliteOptions.addDelegate(mNnapiDelegate);
recreateInterpreter();
}
@Override
public float getObjThresh() {
return Utils.MINIMUM_CONFIDENCE_TF_OD_API;
}
// non maximum suppression
private ArrayList<Recognition> nms(ArrayList<Recognition> list) {
ArrayList<Recognition> nmsList = new ArrayList<>();
for (int k = 0; k < mLabels.size(); k++) {
// 1.find max confidence per class
PriorityQueue<Recognition> pq = new PriorityQueue<>(50, (lhs, rhs) -> {
// Intentionally reversed to put high confidence at the head of the queue.
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
});
for (int i = 0; i < list.size(); ++i) {
if (list.get(i).getDetectedClass() == k) {
pq.add(list.get(i));
}
}
// 2.do non maximum suppression
while (!pq.isEmpty()) {
// insert detection with max confidence
Recognition[] a = new Recognition[pq.size()];
Recognition[] detections = pq.toArray(a);
Recognition max = detections[0];
nmsList.add(max);
pq.clear();
for (int j = 1; j < detections.length; j++) {
Recognition detection = detections[j];
RectF b = detection.getLocation();
float mNmsThresh = 0.6f;
if (box_iou(max.getLocation(), b) < mNmsThresh) {
pq.add(detection);
}
}
}
}
return nmsList;
}
private float box_iou(RectF a, RectF b) {
return box_intersection(a, b) / box_union(a, b);
}
private float box_intersection(RectF a, RectF b) {
float w = overlap((a.left + a.right) / 2, a.right - a.left, (b.left + b.right) / 2, b.right - b.left);
float h = overlap((a.top + a.bottom) / 2, a.bottom - a.top, (b.top + b.bottom) / 2, b.bottom - b.top);
if (w < 0 || h < 0)
return 0;
return w * h;
}
private float box_union(RectF a, RectF b) {
float i = box_intersection(a, b);
return (a.right - a.left) * (a.bottom - a.top) + (b.right - b.left) * (b.bottom - b.top) - i;
}
private float overlap(float x1, float w1, float x2, float w2) {
float l1 = x1 - w1 / 2;
float l2 = x2 - w2 / 2;
float left = Math.max(l1, l2);
float r1 = x1 + w1 / 2;
float r2 = x2 + w2 / 2;
float right = Math.min(r1, r2);
return right - left;
}
/**
* Writes Image data into a {@code ByteBuffer}.
*/
private void convertBitmapToByteBuffer(Bitmap bitmap) {
bitmap.getPixels(mIntValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
mImgData.rewind();
for (int i = 0; i < mInputSize; ++i) {
for (int j = 0; j < mInputSize; ++j) {
int pixelValue = mIntValues[i * mInputSize + j];
// Float model
float IMAGE_MEAN = 0;
float IMAGE_STD = 255.0f;
if (mIsModelQuantized) {
// Quantized model
mImgData.put((byte) ((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD / mInpScale
+ mInpZeroPoint));
mImgData.put((byte) ((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD / mInpScale
+ mInpZeroPoint));
mImgData.put((byte) (((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD / mInpScale + mInpZeroPoint));
} else { // Float model
mImgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
mImgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
mImgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
}
}
}
}
public ArrayList<Recognition> recognizeImage(Bitmap bitmap) {
convertBitmapToByteBuffer(bitmap);
Map<Integer, Object> outputMap = new HashMap<>();
mOutData.rewind();
outputMap.put(0, mOutData);
Log.d(TAG, "mObjThresh: " + getObjThresh() + ", mOutData.capacity() = " + mOutData.capacity());
Object[] inputArray = { mImgData };
try {
mTfLite.runForMultipleInputsOutputs(inputArray, outputMap);
}catch (Exception e) {
Log.e(TAG, "WL_DEBUG recognizeImage e = " + e, e);
return new ArrayList<>();
}
mOutData.rewind();
ArrayList<Recognition> detections = new ArrayList<>();
int outputShape1 = mOutputShape[1];
int outputShape2 = mOutputShape[2];
float[][][] out = new float[1][outputShape1][outputShape2];
Log.d(TAG, "out[0] detect start");
buildDetections(outputShape1, outputShape2, out, bitmap, detections);
Log.d(TAG, "detect end");
return nms(detections);
}
protected abstract void buildDetections(int outputShape1, int outputShape2, float[][][] out, Bitmap bitmap, ArrayList<Recognition> detections);
}
package com.agenew.detection.tflite;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.RectF;
import android.util.Log;
import java.io.IOException;
import java.util.ArrayList;
public class YoloV8Classifier extends YoloClassifier {
private static final String TAG = "YoloV8Classifier";
public YoloV8Classifier(final AssetManager assetManager, final String modelFilename) throws IOException {
super(assetManager, modelFilename);
mOutputBox = mOutputShape[2];
}
@Override
protected void buildDetections(int outputShape1, int outputShape2, float[][][] out, Bitmap bitmap, ArrayList<Recognition> detections) {
for (int i = 0; i < outputShape1; ++i) {
for (int j = 0; j < outputShape2; ++j) {
if (mIsModelQuantized) {
out[0][i][j] = mOupScale * ((((int) mOutData.get() & 0xFF) - mOupZeroPoint)&0xFF);
} else {
out[0][i][j] = mOutData.getFloat();
}
if (i < 4) {
out[0][i][j] *= getInputSize();
}
}
}
for (int i = 0; i < mOutputBox; ++i) {
final int offset = 0;
int detectedClass = -1;
float maxClass = 0;
final float[] classes = new float[mLabels.size()];
for (int c = 0; c < mLabels.size() && c < (out[0].length-4); ++c) {
classes[c] = out[0][4+c][i];
}
for (int c = 0; c < mLabels.size(); ++c) {
if (classes[c] > maxClass) {
detectedClass = c;
maxClass = classes[c];
}
}
final float confidenceInClass = maxClass;
if (confidenceInClass > getObjThresh()) {
final float xPos = out[0][0][i];
final float yPos = out[0][1][i];
final float w = out[0][2][i];
final float h = out[0][3][i];
Log.d(TAG, Float.toString(xPos) + ',' + yPos + ',' + w + ',' + h);
final RectF rect = new RectF(Math.max(0, xPos - w / 2), Math.max(0, yPos - h / 2),
Math.min(bitmap.getWidth() - 1, xPos + w / 2), Math.min(bitmap.getHeight() - 1, yPos + h / 2));
detections.add(new Recognition("" + offset, mLabels.get(detectedClass), confidenceInClass, rect,
detectedClass));
}
}
}
}
......@@ -13,18 +13,19 @@
See the License for the specific language governing permissions and
limitations under the License.
-->
<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="match_parent">
<com.agenew.detection.customview.AutoFitTextureView
android:id="@+id/texture"
android:layout_width="wrap_content"
android:layout_height="wrap_content" />
android:layout_height="wrap_content"
android:layout_centerInParent="true"/>
<com.agenew.detection.customview.OverlayView
android:id="@+id/tracking_overlay"
android:layout_width="match_parent"
android:layout_height="match_parent" />
android:layout_height="480dp"
android:layout_centerInParent="true"/>
</FrameLayout>
</RelativeLayout>
[versions]
agp = "8.3.2"
agp = "8.4.0"
junit = "4.13.2"
junitVersion = "1.1.5"
espressoCore = "3.5.1"
......
#Wed Apr 10 09:00:02 CST 2024
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.4-bin.zip
distributionUrl=https\://services.gradle.org/distributions/gradle-8.6-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!