상세 컨텐츠

본문 제목

[Tensorflow & Android] .h5 -> .pb -> .tflite로 변환 & 안드로이드에 적용(python & java)

IT Convergence Engineering/AI 버섯 어플

by Soo_buglosschestnut 2020. 7. 28. 01:32

본문

.h5 -> .pb -> .tflite로 변환 & 안드로이드에 적용(python & java)


1. .h5 -> .pb -> .tflite로 변환

 

  • .h5 파일 -> .pb 파일로 변환하기 (코드)
from tensorflow import keras
model = keras.models.load_model('./model/my_model.h5', compile=False)

export_path = './pb'
model.save(export_path, save_format="tf")

.h5파일이 있는 directory  → multi 폴더                                                 .pb파일이 있는 directory → pb 폴더

 

model = keras.models.load_model('#######', compile=False)

######에는 .h5파일이 있는곳의 directory를 쓴다.

 

export_path = '@@@'

 @@@에는 변환된 .pb파일이 있게될 directory를 쓴다. (.pb파일을 어디에 저장할것인가)

 

 

  • .pb 파일 -> .tflite 파일로 변환하기 (코드)
import tensorflow as tf

saved_model_dir = './pb'
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
                                       tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
open('./tf/converted_model.tflite', 'wb').write(tflite_model)

            .pb 파일의 위치 확인                                           .tflite파일이 tf 폴더에 있는지 확인                                                                       

saved_model_dir = '######'

 ######에는 .pb파일이 있는곳의 directory를 쓴다.

 

open('@@@/converted_model.tflite', 'wb').write(tflite_model)

 @@@에는 변환된 .tflite파일이 있게될 directory를 쓴다. (.tflite파일을 어디에 저장할것인가)

ps. @@@/~~~~.tflite  (~~~~ : 자기가 원하는 파일제목 쓰기)

 


2. .tflite 파일 안드로이드에 적용하기

① 안드로이드 프로젝트 생성 후, 권한 설정하기

  • 외부 저장소 이용시 추가(갤러리앱 사용시 등등)

외부저장소 이용시 추가

<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>

 

  • .tflite 사용시 build.gradle(Module: app) 업데이트 해주기 (중요!!)
aaptOptions {
        noCompress "tflite"
    }

→.tflite의 압축을 막기 위한 코드 (밑에 사진과 같이 android{~}부분에 올바르게 코드를 작성해야한다.)

 

코드를 위 사진과 같이 android 부분에 넣어야함.

 

 

implementation 'org.tensorflow:tensorflow-lite:+'

→ 위의 코드도 동일하게 dependencies{~~~} 부분에 올바르게 작성해야한다.

 

dependecies 부분에 코드를 추가

→ 코드 작성 후 Sync Now를 필히 해줘야한다!!!

 

  • .tflite가 있을 assets Directory 만들기

→ 위와 같은 경로로 들어가서 assets라는 Directory를 생성해준다.

 

 

→ assets Directory가 생성됨을 확인

 

→ assets Directory에 .tflite를 붙여넣기 한다.

 

 

  • MainActivity.java 코드

① MainActivity 내부에 tflite관련 코드 (작성 필수)

 private Interpreter getTfliteInterpreter(String modelPath) {
        try {
            return new Interpreter(loadModelFile(MainActivity.this, modelPath));
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }




    public MappedByteBuffer loadModelFile(Activity activity, String modelPath) throws IOException {
        AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(modelPath);
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }

OnCreate함수 코드(주석 보면서, 자신의 파일 이름과 맞게 수정하기)

  • 버튼 클릭시 외부저장소로 폴더로 넘어감
 private static final int FROM_ALBUM = 1;    // onActivityResult 식별자
 private static final int FROM_CAMERA = 2;   // 카메라는 사용 안함

@Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        // 인텐트의 결과는 onActivityResult 함수에서 수신.
        // 여러 개의 인텐트를 동시에 사용하기 때문에 숫자를 통해 결과 식별(FROM_ALBUM 등등)
        findViewById(R.id.button_1).setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View view) {
                Intent intent = new Intent();
                intent.setType("image/*");                      // 이미지만
                intent.setAction(Intent.ACTION_GET_CONTENT);    // 카메라(ACTION_IMAGE_CAPTURE)
                startActivityForResult(intent, FROM_ALBUM);
            }
        });
    }


    // 사진첩에서 사진 파일 불러와 버섯종류분류하는 코드
    @Override
    protected void onActivityResult(int requestCode, int resultCode, Intent data) {
        // 카메라를 다루지 않기 때문에 앨범 상수에 대해서 성공한 경우에 대해서만 처리
        super.onActivityResult(requestCode, resultCode, data);
        if (requestCode != FROM_ALBUM || resultCode != RESULT_OK)
            return;

        //각 모델에 따른 input , output shape 각자 맞게 변환
        // mobilenetcheck.h5 일시 224 * 224 * 3
        float[][][][] input = new float[1][224][224][3];
        float[][] output = new float[1][5]; //tflite에 버섯 종류 5개라서 (내기준)

        try {
            int batchNum = 0;
            InputStream buf = getContentResolver().openInputStream(data.getData());
            Bitmap bitmap = BitmapFactory.decodeStream(buf);
            buf.close();

            //이미지 뷰에 선택한 사진 띄우기
            ImageView iv = findViewById(R.id.image);
            iv.setScaleType(ImageView.ScaleType.FIT_XY);
            iv.setImageBitmap(bitmap);

            // x,y 최댓값 사진 크기에 따라 달라짐 (조절 해줘야함)
            for (int x = 0; x < 224; x++) {
                for (int y = 0; y < 224; y++) {
                    int pixel = bitmap.getPixel(x, y);
                    input[batchNum][x][y][0] = Color.red(pixel) / 1.0f;
                    input[batchNum][x][y][1] = Color.green(pixel) / 1.0f;
                    input[batchNum][x][y][2] = Color.blue(pixel) / 1.0f;
                }
            }

            // 자신의 tflite 이름 써주기
            Interpreter lite = getTfliteInterpreter("converted_model_mobileNetCheck.tflite");
            lite.run(input, output);
        } catch (IOException e) {
            e.printStackTrace();
        }

        TextView tv_output = findViewById(R.id.tv_output);
        int i;

        // 텍스트뷰에 무슨 버섯인지 띄우기 but error남 ㅜㅜ 붉은 사슴뿔만 주구장창
        for (i = 0; i < 5; i++) {
            if (output[0][i] * 100 > 90) {
                if (i == 0) {
                    tv_output.setText(String.format("개나리 광대버섯  %d %.5f", i, output[0][0] * 100));
                } else if (i == 1) {
                    tv_output.setText(String.format("붉은사슴뿔버섯,%d  %.5f", i, output[0][1] * 100));
                } else if (i == 2) {
                    tv_output.setText(String.format("새송이버섯,%d, %.5f", i, output[0][2] * 100));
                } else if (i == 3) {
                    tv_output.setText(String.format("표고버섯, %d, %.5f", i, output[0][3] * 100));
                } else {
                    tv_output.setText(String.format("화경버섯, %d, %.5f", i, output[0][4] * 100));
                }
            } else
                continue;
        }
    }

③ activity_main.xml

<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:orientation="vertical" >

    <ImageView
        android:id="@+id/image"
        android:layout_width="match_parent"
        android:layout_height="438dp">

    </ImageView>

    <Button
        android:id="@+id/button_1"
        android:layout_width="match_parent"
        android:layout_height="128dp"
        android:text="fds"
        ></Button>


    <TextView
        android:id="@+id/tv_output"
        android:layout_width="match_parent"
        android:layout_height="237dp"
        android:text="Hello World!" />

</LinearLayout>


 

  • 앱 실행결과

앱 실행시 어떤사진이든 붉은사슴뿔버섯이라고 뜸 (옆에 퍼센트는 사진마다 바뀜) 

visual code로 py파일 실행시 제대로된 출력이 나옴

 


더보기
더보기

youtu.be/JnhW5tQ_7Vo

- TensorFlowLite for Android(2:58~)

관련글 더보기