[Tensorflow & Android] .h5 -> .pb -> .tflite로 변환 & 안드로이드에 적용(python & java)
from tensorflow import keras
model = keras.models.load_model('./model/my_model.h5', compile=False)
export_path = './pb'
model.save(export_path, save_format="tf")
model = keras.models.load_model('#######', compile=False)
→ ######에는 .h5파일이 있는곳의 directory를 쓴다.
export_path = '@@@'
→ @@@에는 변환된 .pb파일이 있게될 directory를 쓴다. (.pb파일을 어디에 저장할것인가)
import tensorflow as tf
saved_model_dir = './pb'
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
open('./tf/converted_model.tflite', 'wb').write(tflite_model)
saved_model_dir = '######'
→ ######에는 .pb파일이 있는곳의 directory를 쓴다.
open('@@@/converted_model.tflite', 'wb').write(tflite_model)
→ @@@에는 변환된 .tflite파일이 있게될 directory를 쓴다. (.tflite파일을 어디에 저장할것인가)
ps. @@@/~~~~.tflite (~~~~ : 자기가 원하는 파일제목 쓰기)
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
aaptOptions {
noCompress "tflite"
}
→.tflite의 압축을 막기 위한 코드 (밑에 사진과 같이 android{~}부분에 올바르게 코드를 작성해야한다.)
implementation 'org.tensorflow:tensorflow-lite:+'
→ 위의 코드도 동일하게 dependencies{~~~} 부분에 올바르게 작성해야한다.
→ 위와 같은 경로로 들어가서 assets라는 Directory를 생성해준다.
→ assets Directory가 생성됨을 확인
→ assets Directory에 .tflite를 붙여넣기 한다.
① MainActivity 내부에 tflite관련 코드 (작성 필수)
private Interpreter getTfliteInterpreter(String modelPath) {
try {
return new Interpreter(loadModelFile(MainActivity.this, modelPath));
}
catch (Exception e) {
e.printStackTrace();
}
return null;
}
public MappedByteBuffer loadModelFile(Activity activity, String modelPath) throws IOException {
AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(modelPath);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
② OnCreate함수 코드(주석 보면서, 자신의 파일 이름과 맞게 수정하기)
private static final int FROM_ALBUM = 1; // onActivityResult 식별자
private static final int FROM_CAMERA = 2; // 카메라는 사용 안함
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
// 인텐트의 결과는 onActivityResult 함수에서 수신.
// 여러 개의 인텐트를 동시에 사용하기 때문에 숫자를 통해 결과 식별(FROM_ALBUM 등등)
findViewById(R.id.button_1).setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View view) {
Intent intent = new Intent();
intent.setType("image/*"); // 이미지만
intent.setAction(Intent.ACTION_GET_CONTENT); // 카메라(ACTION_IMAGE_CAPTURE)
startActivityForResult(intent, FROM_ALBUM);
}
});
}
// 사진첩에서 사진 파일 불러와 버섯종류분류하는 코드
@Override
protected void onActivityResult(int requestCode, int resultCode, Intent data) {
// 카메라를 다루지 않기 때문에 앨범 상수에 대해서 성공한 경우에 대해서만 처리
super.onActivityResult(requestCode, resultCode, data);
if (requestCode != FROM_ALBUM || resultCode != RESULT_OK)
return;
//각 모델에 따른 input , output shape 각자 맞게 변환
// mobilenetcheck.h5 일시 224 * 224 * 3
float[][][][] input = new float[1][224][224][3];
float[][] output = new float[1][5]; //tflite에 버섯 종류 5개라서 (내기준)
try {
int batchNum = 0;
InputStream buf = getContentResolver().openInputStream(data.getData());
Bitmap bitmap = BitmapFactory.decodeStream(buf);
buf.close();
//이미지 뷰에 선택한 사진 띄우기
ImageView iv = findViewById(R.id.image);
iv.setScaleType(ImageView.ScaleType.FIT_XY);
iv.setImageBitmap(bitmap);
// x,y 최댓값 사진 크기에 따라 달라짐 (조절 해줘야함)
for (int x = 0; x < 224; x++) {
for (int y = 0; y < 224; y++) {
int pixel = bitmap.getPixel(x, y);
input[batchNum][x][y][0] = Color.red(pixel) / 1.0f;
input[batchNum][x][y][1] = Color.green(pixel) / 1.0f;
input[batchNum][x][y][2] = Color.blue(pixel) / 1.0f;
}
}
// 자신의 tflite 이름 써주기
Interpreter lite = getTfliteInterpreter("converted_model_mobileNetCheck.tflite");
lite.run(input, output);
} catch (IOException e) {
e.printStackTrace();
}
TextView tv_output = findViewById(R.id.tv_output);
int i;
// 텍스트뷰에 무슨 버섯인지 띄우기 but error남 ㅜㅜ 붉은 사슴뿔만 주구장창
for (i = 0; i < 5; i++) {
if (output[0][i] * 100 > 90) {
if (i == 0) {
tv_output.setText(String.format("개나리 광대버섯 %d %.5f", i, output[0][0] * 100));
} else if (i == 1) {
tv_output.setText(String.format("붉은사슴뿔버섯,%d %.5f", i, output[0][1] * 100));
} else if (i == 2) {
tv_output.setText(String.format("새송이버섯,%d, %.5f", i, output[0][2] * 100));
} else if (i == 3) {
tv_output.setText(String.format("표고버섯, %d, %.5f", i, output[0][3] * 100));
} else {
tv_output.setText(String.format("화경버섯, %d, %.5f", i, output[0][4] * 100));
}
} else
continue;
}
}
③ activity_main.xml
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical" >
<ImageView
android:id="@+id/image"
android:layout_width="match_parent"
android:layout_height="438dp">
</ImageView>
<Button
android:id="@+id/button_1"
android:layout_width="match_parent"
android:layout_height="128dp"
android:text="fds"
></Button>
<TextView
android:id="@+id/tv_output"
android:layout_width="match_parent"
android:layout_height="237dp"
android:text="Hello World!" />
</LinearLayout>
→ 앱 실행시 어떤사진이든 붉은사슴뿔버섯이라고 뜸 (옆에 퍼센트는 사진마다 바뀜)
→ visual code로 py파일 실행시 제대로된 출력이 나옴
[Firebase + Android + Tensorflow] tflite모델을 이용해 Camera로 찍은 사진을 분류, Firebase Cloud Storage에 Upload하기(JAVA) (2) | 2020.08.14 |
---|---|
[Firebase + Android] Android App으로 Firebase Cloud Storage 이용하기(JAVA) (0) | 2020.08.09 |
[SplashActivity] How to Create a Splash Screen(Java) (0) | 2020.07.31 |
[Image] Image resize & convert(python) (0) | 2020.07.30 |
[Web crawling] Image web crawling (python) (0) | 2020.07.30 |