Commit 23a0f7ce by wanglei

init

0 parents
*.iml
.gradle
/local.properties
/.idea/caches
/.idea/libraries
/.idea/modules.xml
/.idea/workspace.xml
/.idea/navEditor.xml
/.idea/assetWizardSettings.xml
.DS_Store
/build
/captures
.externalNativeBuild
.cxx
local.properties
# TensorFlow Keras训练Mnist示例
## 模型的训练
运行python文件夹下的digit_classifier.py,生成tflite模型,得到mnist.tflite
## 使用模型
将训练生成的模型文件mnist.tflite拷贝到assets文件夹,供android读取
/build
\ No newline at end of file
plugins {
id 'com.android.application'
}
android {
namespace 'com.agenew.mnist'
compileSdk 34
defaultConfig {
applicationId "com.agenew.mnist"
minSdk 24
targetSdk 34
versionCode 1
versionName "1.0"
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
}
}
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
}
dependencies {
implementation fileTree(include: ['*.jar', '*.aar'], dir: 'libs')
}
\ No newline at end of file
No preview for this file type
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html
# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}
# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile
\ No newline at end of file
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools">
<application
android:allowBackup="true"
android:icon="@drawable/ic_launcher"
android:label="@string/app_name" >
<activity
android:name=".MainActivity"
android:exported="true" >
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
</manifest>
\ No newline at end of file
No preview for this file type
package com.agenew.mnist;
import android.annotation.SuppressLint;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Paint;
import android.graphics.Path;
import android.graphics.drawable.Drawable;
import android.util.AttributeSet;
import android.view.MotionEvent;
import android.view.View;
public class FingerPaintView extends View {
private Path mPath;
private Bitmap mDrawingBitmap;
private Canvas mDrawingCanvas;
private Paint drawingPaint;
private float mPenX = 0.0f;
private float mPenY = 0.0f;
private Paint mPaint;
private boolean mIsEmpty = true;
public FingerPaintView(Context context) {
super(context);
init();
}
public FingerPaintView(Context context, AttributeSet attrs) {
super(context, attrs);
init();
}
public FingerPaintView(Context context, AttributeSet attrs, int defStyleAttr) {
super(context, attrs, defStyleAttr);
init();
}
private void init() {
drawingPaint = new Paint(Paint.DITHER_FLAG);
mPath = new Path();
mPaint = new Paint();
mPaint.setAntiAlias(true);
mPaint.setDither(true);
mPaint.setColor(Color.BLACK);
mPaint.setStyle(Paint.Style.STROKE);
mPaint.setStrokeCap(Paint.Cap.ROUND);
mPaint.setStrokeJoin(Paint.Join.ROUND);
mPaint.setStrokeWidth(36f);
}
@Override
protected void onSizeChanged(int w, int h, int oldw, int oldh) {
super.onSizeChanged(w, h, oldw, oldh);
mDrawingBitmap = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888);
mPaint.setStrokeWidth(w * 2 / 28f);// mPaint.setStrokeWidth(w*3/28f);
mDrawingCanvas = new Canvas(mDrawingBitmap);
}
@Override
protected void onDraw(Canvas canvas) {
super.onDraw(canvas);
if (canvas != null) {
canvas.drawBitmap(mDrawingBitmap, 0f, 0f, drawingPaint);
canvas.drawPath(mPath, mPaint);
}
}
@SuppressLint("ClickableViewAccessibility")
@Override
public boolean onTouchEvent(MotionEvent event) {
if (event == null)
return false;
mIsEmpty = false;
float x = event.getX();
float y = event.getY();
switch (event.getAction()) {
case MotionEvent.ACTION_DOWN:
mPath.reset();
mPath.moveTo(x, y);
mPenX = x;
mPenY = y;
invalidate();
break;
case MotionEvent.ACTION_MOVE:
float dx = Math.abs(x - mPenX);
float dy = Math.abs(y - mPenY);
float touchTolerance = 4f;
if (dx >= touchTolerance || dy >= touchTolerance) {
mPath.quadTo(mPenX, mPenY, (x + mPenX) / 2, (y + mPenY) / 2);
mPenX = x;
mPenY = y;
}
invalidate();
break;
case MotionEvent.ACTION_UP:
mPath.lineTo(mPenX, mPenY);
mDrawingCanvas.drawPath(mPath, mPaint);
mPath.reset();
performClick();
invalidate();
break;
}
super.onTouchEvent(event);
return true;
}
public void clear() {
mPath.reset();
mDrawingBitmap = Bitmap.createBitmap(mDrawingBitmap.getWidth(), mDrawingBitmap.getHeight(),
Bitmap.Config.ARGB_8888);
mDrawingCanvas = new Canvas(mDrawingBitmap);
mIsEmpty = true;
invalidate();
}
public boolean isEmpty() {
return mIsEmpty;
}
public Bitmap exportToBitmap(int width, int height) {
Bitmap rawBitmap = Bitmap.createBitmap(getWidth(), getHeight(), Bitmap.Config.ARGB_8888);
Canvas canvas = new Canvas(rawBitmap);
Drawable bgDrawable = getBackground();
if (bgDrawable != null) {
bgDrawable.draw(canvas);
} else {
canvas.drawColor(Color.WHITE);
}
draw(canvas);
Bitmap scaledBitmap = Bitmap.createScaledBitmap(rawBitmap, width, height, true);
rawBitmap.recycle();
return scaledBitmap;
}
}
package com.agenew.mnist;
import android.content.Context;
import android.graphics.Bitmap;
import java.io.IOException;
import com.mediatek.neuropilot_S.Interpreter;
public class KerasTFLite {
private Interpreter mInterpreter;
public KerasTFLite(Context context) throws IOException {
mInterpreter = new Interpreter(Utils.loadModelFile(context));
}
public int run(Bitmap bitmap) {
float pixels[] = getPixelData(bitmap);
// should be same format with train
for (int i = 0; i < pixels.length; i++) {
pixels[i] = pixels[i] / 255;
}
float[][] labelProbArray = new float[1][10];
mInterpreter.run(new float[][] { pixels }, labelProbArray);
return getMax(labelProbArray[0]);
}
public void release() {
mInterpreter.close();
}
private int getMax(float[] results) {
int maxID = 0;
float maxValue = results[maxID];
for (int i = 1; i < results.length; i++) {
if (results[i] > maxValue) {
maxID = i;
maxValue = results[maxID];
}
}
return maxID;
}
/**
* Get 28x28 pixel data for tensorflow input.
*/
private float[] getPixelData(Bitmap bitmap) {
if (bitmap == null) {
return null;
}
int width = bitmap.getWidth();
int height = bitmap.getHeight();
// Get 28x28 pixel data from bitmap
int[] pixels = new int[width * height];
bitmap.getPixels(pixels, 0, width, 0, 0, width, height);
float[] retPixels = new float[pixels.length];
for (int i = 0; i < pixels.length; ++i) {
// Set 0 for white and 255 for black pixel
int pix = pixels[i];
int b = pix & 0xff;
retPixels[i] = 0xff - b;
}
return retPixels;
}
}
package com.agenew.mnist;
import android.app.Activity;
import android.graphics.Bitmap;
import android.os.Bundle;
import android.view.View;
import android.widget.TextView;
import android.widget.Toast;
import java.io.IOException;
public class MainActivity extends Activity implements View.OnClickListener {
private KerasTFLite mTFLite;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
findViewById(R.id.buttonDetect).setOnClickListener(this);
findViewById(R.id.buttonClear).setOnClickListener(this);
}
@Override
protected void onResume() {
super.onResume();
if (mTFLite == null) {
try {
mTFLite = new KerasTFLite(this);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
findViewById(R.id.buttonDetect).setVisibility(View.VISIBLE);
}
@Override
protected void onPause() {
if (mTFLite != null) {
mTFLite.release();
mTFLite = null;
}
super.onPause();
}
@Override
public void onClick(View v) {
final int id = v.getId();
if (id == R.id.buttonDetect) {
onDetectClicked();
} else if (id == R.id.buttonClear) {
onClearClicked();
}
}
private void onDetectClicked() {
FingerPaintView fingerPaintView = findViewById(R.id.finger_paint_view);
if (fingerPaintView.isEmpty()) {
Toast.makeText(this, R.string.toast, Toast.LENGTH_SHORT).show();
return;
}
Bitmap bitmap = fingerPaintView.exportToBitmap(Utils.PIXEL_SIZE, Utils.PIXEL_SIZE);
int result = mTFLite.run(bitmap);
String value = getString(R.string.text_pref) + result;
TextView resultText = findViewById(R.id.textResult);
resultText.setText(value);
}
private void onClearClicked() {
FingerPaintView fingerPaintView = findViewById(R.id.finger_paint_view);
fingerPaintView.clear();
TextView resultText = findViewById(R.id.textResult);
resultText.setText("");
}
}
package com.agenew.mnist;
import android.content.Context;
import android.content.res.AssetFileDescriptor;
import android.content.res.AssetManager;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
public class Utils {
public static final int PIXEL_SIZE = 28;
private static final String MODEL_FILE = "mnist.tflite";
public static MappedByteBuffer loadModelFile(Context context) throws IOException {
AssetManager assets = context.getAssets();
AssetFileDescriptor fileDescriptor = assets.openFd(MODEL_FILE);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
}
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical"
android:padding="16dp"
tools:context=".MainActivity" >
<com.agenew.mnist.FingerPaintView
android:id="@+id/finger_paint_view"
android:layout_width="200dp"
android:layout_height="200dp"
android:layout_gravity="center"
android:layout_margin="16dp"
android:background="#dddddd" />
<TextView
android:id="@+id/textResult"
android:layout_width="match_parent"
android:layout_height="wrap_content" />
<LinearLayout
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:gravity="center_horizontal"
android:orientation="horizontal" >
<Button
android:id="@+id/buttonClear"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="@string/clear" />
<Button
android:id="@+id/buttonDetect"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginLeft="8dp"
android:text="@string/detect"
android:visibility="invisible" />
</LinearLayout>
</LinearLayout>
\ No newline at end of file
<?xml version="1.0" encoding="utf-8"?>
<resources>
<string name="app_name">mnist</string>
<string name="toast">请写上一个数字</string>
<string name="text_pref">数字是: </string>
<string name="clear">清除</string>
<string name="detect">识别</string>
</resources>
\ No newline at end of file
// Top-level build file where you can add configuration options common to all sub-projects/modules.
plugins {
id 'com.android.application' version '8.3.2' apply false
}
\ No newline at end of file
# Project-wide Gradle settings.
# IDE (e.g. Android Studio) users:
# Gradle settings configured through the IDE *will override*
# any settings specified in this file.
# For more details on how to configure your build environment visit
# http://www.gradle.org/docs/current/userguide/build_environment.html
# Specifies the JVM arguments used for the daemon process.
# The setting is particularly useful for tweaking memory settings.
org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
# When configured, Gradle will run in incubating parallel mode.
# This option should only be used with decoupled projects. More details, visit
# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
# org.gradle.parallel=true
# AndroidX package structure to make it clearer which packages are bundled with the
# Android operating system, and which are packaged with your app's APK
# https://developer.android.com/topic/libraries/support-library/androidx-rn
android.useAndroidX=true
# Enables namespacing of each library's R class so that its R class includes only the
# resources declared in the library itself and none from the library's dependencies,
# thereby reducing the size of the R class for that library
android.nonTransitiveRClass=true
\ No newline at end of file
No preview for this file type
#Mon Jan 08 16:37:10 CST 2024
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.4-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
#!/usr/bin/env sh
#
# Copyright 2015 the original author or authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
##############################################################################
##
## Gradle start up script for UN*X
##
##############################################################################
# Attempt to set APP_HOME
# Resolve links: $0 may be a link
PRG="$0"
# Need this for relative symlinks.
while [ -h "$PRG" ] ; do
ls=`ls -ld "$PRG"`
link=`expr "$ls" : '.*-> \(.*\)$'`
if expr "$link" : '/.*' > /dev/null; then
PRG="$link"
else
PRG=`dirname "$PRG"`"/$link"
fi
done
SAVED="`pwd`"
cd "`dirname \"$PRG\"`/" >/dev/null
APP_HOME="`pwd -P`"
cd "$SAVED" >/dev/null
APP_NAME="Gradle"
APP_BASE_NAME=`basename "$0"`
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD="maximum"
warn () {
echo "$*"
}
die () {
echo
echo "$*"
echo
exit 1
}
# OS specific support (must be 'true' or 'false').
cygwin=false
msys=false
darwin=false
nonstop=false
case "`uname`" in
CYGWIN* )
cygwin=true
;;
Darwin* )
darwin=true
;;
MINGW* )
msys=true
;;
NONSTOP* )
nonstop=true
;;
esac
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
# Determine the Java command to use to start the JVM.
if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
# IBM's JDK on AIX uses strange locations for the executables
JAVACMD="$JAVA_HOME/jre/sh/java"
else
JAVACMD="$JAVA_HOME/bin/java"
fi
if [ ! -x "$JAVACMD" ] ; then
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
else
JAVACMD="java"
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
# Increase the maximum file descriptors if we can.
if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
MAX_FD_LIMIT=`ulimit -H -n`
if [ $? -eq 0 ] ; then
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
MAX_FD="$MAX_FD_LIMIT"
fi
ulimit -n $MAX_FD
if [ $? -ne 0 ] ; then
warn "Could not set maximum file descriptor limit: $MAX_FD"
fi
else
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
fi
fi
# For Darwin, add options to specify how the application appears in the dock
if $darwin; then
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
fi
# For Cygwin or MSYS, switch paths to Windows format before running java
if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
JAVACMD=`cygpath --unix "$JAVACMD"`
# We build the pattern for arguments to be converted via cygpath
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
SEP=""
for dir in $ROOTDIRSRAW ; do
ROOTDIRS="$ROOTDIRS$SEP$dir"
SEP="|"
done
OURCYGPATTERN="(^($ROOTDIRS))"
# Add a user-defined pattern to the cygpath arguments
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
fi
# Now convert the arguments - kludge to limit ourselves to /bin/sh
i=0
for arg in "$@" ; do
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
else
eval `echo args$i`="\"$arg\""
fi
i=`expr $i + 1`
done
case $i in
0) set -- ;;
1) set -- "$args0" ;;
2) set -- "$args0" "$args1" ;;
3) set -- "$args0" "$args1" "$args2" ;;
4) set -- "$args0" "$args1" "$args2" "$args3" ;;
5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
esac
fi
# Escape application args
save () {
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
echo " "
}
APP_ARGS=`save "$@"`
# Collect all arguments for the java command, following the shell quoting and substitution rules
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
exec "$JAVACMD" "$@"
@rem
@rem Copyright 2015 the original author or authors.
@rem
@rem Licensed under the Apache License, Version 2.0 (the "License");
@rem you may not use this file except in compliance with the License.
@rem You may obtain a copy of the License at
@rem
@rem https://www.apache.org/licenses/LICENSE-2.0
@rem
@rem Unless required by applicable law or agreed to in writing, software
@rem distributed under the License is distributed on an "AS IS" BASIS,
@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@rem See the License for the specific language governing permissions and
@rem limitations under the License.
@rem
@if "%DEBUG%" == "" @echo off
@rem ##########################################################################
@rem
@rem Gradle startup script for Windows
@rem
@rem ##########################################################################
@rem Set local scope for the variables with windows NT shell
if "%OS%"=="Windows_NT" setlocal
set DIRNAME=%~dp0
if "%DIRNAME%" == "" set DIRNAME=.
set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME%
@rem Resolve any "." and ".." in APP_HOME to make it shorter.
for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
@rem Find java.exe
if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1
if "%ERRORLEVEL%" == "0" goto execute
echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:findJavaFromJavaHome
set JAVA_HOME=%JAVA_HOME:"=%
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
if exist "%JAVA_EXE%" goto execute
echo.
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:execute
@rem Setup the command line
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
@rem Execute Gradle
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
:end
@rem End local scope for the variables with windows NT shell
if "%ERRORLEVEL%"=="0" goto mainEnd
:fail
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
rem the _cmd.exe /c_ return code!
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
exit /b 1
:mainEnd
if "%OS%"=="Windows_NT" endlocal
:omega
# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import math
print(tf.__version__)
# Helper function to display digit images
def show_sample(images, labels, sample_count=25):
# Create a square with can fit {sample_count} images
grid_count = math.ceil(math.ceil(math.sqrt(sample_count)))
grid_count = min(grid_count, len(images), len(labels))
plt.figure(figsize=(2*grid_count, 2*grid_count))
for i in range(sample_count):
plt.subplot(grid_count, grid_count, i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(images[i], cmap=plt.cm.gray)
plt.xlabel(labels[i])
plt.show()
# Download MNIST dataset.
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# If you can't download the MNIST dataset from Keras, please try again with an alternative method below
# path = keras.utils.get_file('mnist.npz',
# origin='https://s3.amazonaws.com/img-datasets/mnist.npz',
# file_hash='8a61469f7ea1b51cbae51d4f78837e45')
# with np.load(path, allow_pickle=True) as f:
# train_images, train_labels = f['x_train'], f['y_train']
# test_images, test_labels = f['x_test'], f['y_test']
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
# Show the first 25 images in the training dataset.
show_sample(train_images,
['Label: %s' % label for label in train_labels])
# Define the model architecture
model = keras.Sequential([
# keras.layers.Flatten(input_shape=(28, 28)),
# keras.layers.Dense(128, activation=tf.nn.relu),
# Optional: You can replace the dense layer above with the convolution layers below to get higher accuracy.
keras.layers.Reshape(target_shape=(28, 28, 1)),
keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation=tf.nn.relu),
keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation=tf.nn.relu),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Dropout(0.25),
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation=tf.nn.relu),
keras.layers.Dropout(0.5),
keras.layers.Dense(10)
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# Train the digit classification model
model.fit(train_images, train_labels, epochs=5)
# Evaluate the model using test dataset.
test_loss, test_acc = model.evaluate(test_images, test_labels)
print('Test accuracy:', test_acc)
# Predict the labels of digit images in our test dataset.
predictions = model.predict(test_images)
# Then plot the first 25 test images and their predicted labels.
show_sample(test_images,
['Predicted: %d' % np.argmax(result) for result in predictions])
# Convert Keras model to TF Lite format.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
# Save the TF Lite model as file
f = open('mnist.tflite', "wb")
f.write(tflite_model)
f.close()
# Download the digit classification model if you're using Colab,
# or print the model's local path if you're not using Colab.
try:
from google.colab import files
files.download('mnist.tflite')
except ImportError:
import os
print('TF Lite model:', os.path.join(os.getcwd(), 'mnist.tflite'))
# Download a test image
zero_img_path = keras.utils.get_file(
'zero.png',
'https://storage.googleapis.com/khanhlvg-public.appspot.com/digit-classifier/zero.png'
)
image = keras.preprocessing.image.load_img(
zero_img_path,
color_mode = 'grayscale',
target_size=(28, 28),
interpolation='bilinear'
)
# Pre-process the image: Adding batch dimension and normalize the pixel value to [0..1]
# In training, we feed images in a batch to the model to improve training speed, making the model input shape to be (BATCH_SIZE, 28, 28).
# For inference, we still need to match the input shape with training, so we expand the input dimensions to (1, 28, 28) using np.expand_dims
input_image = np.expand_dims(np.array(image, dtype=np.float32) / 255.0, 0)
# Show the pre-processed input image
show_sample(input_image, ['Input Image'], 1)
# Run inference with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
interpreter.set_tensor(interpreter.get_input_details()[0]["index"], input_image)
interpreter.invoke()
output = interpreter.tensor(interpreter.get_output_details()[0]["index"])()[0]
# Print the model's classification result
digit = np.argmax(output)
print('Predicted Digit: %d\nConfidence: %f' % (digit, output[digit]))
pluginManagement {
repositories {
google()
mavenCentral()
gradlePluginPortal()
}
}
dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories {
google()
mavenCentral()
}
}
rootProject.name = "tensorflow-lite-keras-mnist-android"
include ':app'
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!