<麻雀牌検出器作成記②>モデルのトレーニング【オブジェクト検出】

AI・機械学習

はじめに

昨年、麻雀点数計算アプリをリリースした。

麻雀サポーター - Google Play のアプリ
麻雀の手牌をリアルタイムカメラから検出し、点数計算・多面待ち検出・受入最多の打牌検出する、初心者向け麻雀サポートアプリ

リアルタイムカメラに手牌をかざすと、手牌の麻雀牌をAIが認識し、自動で点数計算する等の便利な機能満載のアプリである。

本アプリ開発の最初のステップとして、麻雀牌検出器の生成があった。

今回は麻雀牌検出器生成記録の第二弾として、モデルをトレーニングするフェーズについて記載する。

前回の記事はこちらである。

トレーニングの準備

麻雀牌検出器を生成するため、モデルをバンバントレーニングしていく。
学習プログラムについては、これもRoboflowがGoogle Colabに最強の学習環境を用意してくれている。

Google Colaboratory

このソースコードやconfigについて、必要なところのみカスタムし、後はパラメータを調整しながら学習を繰り返すのみである。

筆者は時間節約のために、Google ColabのGPU契約をした。

カスタムしたところ

ベースモデルの選定

以下のサイトから、転移学習に使用するベースとするモデルを選定する。

https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md

転移学習とは、モデルを0から学習するのではなく、別の用途で生成された学習済みモデルをベースにして学習させていくやり方である。
学習精度が向上するといわれている。

今回、麻雀牌検出器に使用したモデルは「SSD MobileNet V2 FPNLite 640×640」である。
最初はできるだけ軽量なモデルにしようと思い、「SSD MobileNet V2 FPNLite 320×320」にしていたが、認識精度をどうしても上げることができず、サイズを大きくした。

この変更に伴い、「Configure Custom TensorFlow2 Object Detection Training Configuration」章のMODELS_CONFIGを以下のようにした。

##change chosen model to deploy different models available in the TF2 object detection zoo
MODELS_CONFIG = {
    'efficientdet-d0': {
        'model_name': 'efficientdet_d0_coco17_tpu-32',
        'base_pipeline_file': 'ssd_efficientdet_d0_512x512_coco17_tpu-8.config',
        'pretrained_checkpoint': 'efficientdet_d0_coco17_tpu-32.tar.gz',
        'batch_size': 16
    },
    'efficientdet-d1': {
        'model_name': 'efficientdet_d1_coco17_tpu-32',
        'base_pipeline_file': 'ssd_efficientdet_d1_640x640_coco17_tpu-8.config',
        'pretrained_checkpoint': 'efficientdet_d1_coco17_tpu-32.tar.gz',
        'batch_size': 16
    },
    'efficientdet-d2': {
        'model_name': 'efficientdet_d2_coco17_tpu-32',
        'base_pipeline_file': 'ssd_efficientdet_d2_768x768_coco17_tpu-8.config',
        'pretrained_checkpoint': 'efficientdet_d2_coco17_tpu-32.tar.gz',
        'batch_size': 16
    },
    'efficientdet-d3': {
        'model_name': 'efficientdet_d3_coco17_tpu-32',
        'base_pipeline_file': 'ssd_efficientdet_d3_896x896_coco17_tpu-32.config',
        'pretrained_checkpoint': 'efficientdet_d3_coco17_tpu-32.tar.gz',
        'batch_size': 16
    },
    # ★★★以下を追加★★★
    'mobilenet-v2': {
        'model_name': 'ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8',
        'base_pipeline_file': 'ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8.config',
        'pretrained_checkpoint': 'ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8.tar.gz',
        'batch_size': 64
    },
}

#in this tutorial we implement the lightweight, smallest state of the art efficientdet model
#if you want to scale up tot larger efficientdet models you will likely need more compute!
chosen_model = 'mobilenet-v2'

パイプラインコンフィグの変更

選択したモデルの、パイプラインコンフィグがダウンロードされる。
今回は「ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8.config」である。

こちらのコンフィグファイルの中身を、「#download base training configuration file」のスクリプトで取得した後、以下のように変更する。

# SSD with Mobilenet v2 FPN-lite (go/fpn-lite) feature extractor, shared box
# predictor and focal loss (a mobile version of Retinanet).
# Retinanet: see Lin et al, https://arxiv.org/abs/1708.02002
# Trained on COCO, initialized from Imagenet classification checkpoint
# Train on TPU-8
#
# Achieves 28.2 mAP on COCO17 Val

model {
  ssd {
    inplace_batchnorm_update: true
    freeze_batchnorm: false
    num_classes: 34 # ★★★ここをラベル数に変更★★★
    box_coder {
      faster_rcnn_box_coder {
        y_scale: 10.0
        x_scale: 10.0
        height_scale: 5.0
        width_scale: 5.0
      }
    }
    matcher {
      argmax_matcher {
        matched_threshold: 0.5
        unmatched_threshold: 0.5
        ignore_thresholds: false
        negatives_lower_than_unmatched: true
        force_match_for_each_row: true
        use_matmul_gather: true
      }
    }
    similarity_calculator {
      iou_similarity {
      }
    }
    encode_background_as_zeros: true
    anchor_generator {
      multiscale_anchor_generator {
        min_level: 3
        max_level: 7
        anchor_scale: 4.0
        aspect_ratios: [1.0, 2.0, 0.5]
        scales_per_octave: 2
      }
    }
    image_resizer {
      fixed_shape_resizer {
        height: 640
        width: 640
      }
    }
    box_predictor {
      weight_shared_convolutional_box_predictor {
        depth: 128
        class_prediction_bias_init: -4.6
        conv_hyperparams {
          activation: RELU_6,
          regularizer {
            l2_regularizer {
              weight: 0.00004
            }
          }
          initializer {
            random_normal_initializer {
              stddev: 0.01
              mean: 0.0
            }
          }
          batch_norm {
            scale: true,
            decay: 0.997,
            epsilon: 0.001,
          }
        }
        num_layers_before_predictor: 4
        share_prediction_tower: true
        use_depthwise: true
        kernel_size: 3
      }
    }
    feature_extractor {
      type: 'ssd_mobilenet_v2_fpn_keras'
      use_depthwise: true
      fpn {
        min_level: 3
        max_level: 7
        additional_layer_depth: 128
      }
      min_depth: 16
      depth_multiplier: 1.0
      conv_hyperparams {
        activation: RELU_6,
        regularizer {
          l2_regularizer {
            weight: 0.00004
          }
        }
        initializer {
          random_normal_initializer {
            stddev: 0.01
            mean: 0.0
          }
        }
        batch_norm {
          scale: true,
          decay: 0.997,
          epsilon: 0.001,
        }
      }
      override_base_feature_extractor_hyperparams: true
    }
    loss {
      classification_loss {
        weighted_sigmoid_focal {
          alpha: 0.25
          gamma: 2.0
        }
      }
      localization_loss {
        weighted_smooth_l1 {
        }
      }
      classification_weight: 1.0
      localization_weight: 1.0
    }
    normalize_loss_by_num_matches: true
    normalize_loc_loss_by_codesize: true
    post_processing {
      batch_non_max_suppression {
        score_threshold: 1e-8
        iou_threshold: 0.6
        max_detections_per_class: 100
        max_total_detections: 100
      }
      score_converter: SIGMOID
    }
  }
}

train_config: {
  fine_tune_checkpoint_version: V2
  # ★★★以下をcheckpointのパスに変更★★★
  fine_tune_checkpoint: "/content/models/research/deploy/ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8/checkpoint/ckpt-0"
  fine_tune_checkpoint_type: "detection"
  batch_size: 16 # ★★★バッチサイズパラメータ調整★★★
  sync_replicas: true
  startup_delay_steps: 0
  replicas_to_aggregate: 8
  num_steps: 80000 # ★★★学習ステップ数パラメータ調整★★★
  # ★★★以下4つ、データ拡張オプション設定★★★
  data_augmentation_options {
    random_rotation90 {
    }
  }
  data_augmentation_options {
    random_adjust_brightness {
      max_delta: 0.1
    }
  }
  data_augmentation_options {
    random_adjust_contrast {
      min_delta: 0.9
      max_delta: 1.1
    }
  }
  data_augmentation_options {
    random_image_scale {
      min_scale_ratio: 0.5
      max_scale_ratio: 2.0
    }
  }
  optimizer {
    momentum_optimizer: {
      # ★★★学習率パラメータ調整★★★
      learning_rate: {
        cosine_decay_learning_rate {
          learning_rate_base: .02
          total_steps: 90000
          warmup_learning_rate: .0066666
          warmup_steps: 2000
        }
      }
      momentum_optimizer_value: 0.9
    }
    use_moving_average: false
  }
  max_number_of_boxes: 100
  unpad_groundtruth_tensors: false
}

train_input_reader: {
  # ★★★以下をlabel_mapのパスに変更★★★
  label_map_path: "/content/train/mahjong-tile_label_map.pbtxt"
  tf_record_input_reader {
    # ★★★以下を学習用tfrecordファイルのパスに変更★★★
    input_path: "/content/train/mahjong-tile.tfrecord"
  }
}

eval_config: {
  metrics_set: "coco_detection_metrics"
  use_moving_averages: false
}

eval_input_reader: {
  # ★★★以下をlabel_mapのパスに変更★★★
  label_map_path: "/content/train/mahjong-tile_label_map.pbtxt"
  shuffle: false
  num_epochs: 1
  tf_record_input_reader {
    # ★★★以下を評価用tfrecordファイルのパスに変更★★★
    input_path: "/content/valid/mahjong-tile.tfrecord"
  }
}

データ拡張オプションの変更

Tensorflow2 Object Detectionでは、上述したパイプラインコンフィグにて、データ拡張オプションをしてすることができる。
今回は以下のようにした。

  #★★★90°回転★★★
  data_augmentation_options {
    random_rotation90 {
    }
  }
  #★★★照度変更★★★
  data_augmentation_options {
    random_adjust_brightness {
      max_delta: 0.1
    }
  }
  #★★★コントラスト変更★★★
  data_augmentation_options {
    random_adjust_contrast {
      min_delta: 0.9
      max_delta: 1.1
    }
  }
  #★★★画像拡大/縮小★★★
  data_augmentation_options {
    random_image_scale {
      min_scale_ratio: 0.5
      max_scale_ratio: 2.0
    }
  }

パラメータの調整

学習のパラメータを調整する。
上述したパイプラインコンフィグの内容を変更する。
これは最初からベストな値は分からない。
何回も学習してみながら、ベストな値を模索・調整していく。

今回、最終的に用いたパラメータ値は以下である。

  • num_steps(総ステップ数) : 80000
  • batch_size(バッチサイズ) : 16
  • learning_rate(学習率)
    • learning_rate_base : .02
    • total_steps : 90000
    • warmup_learning_rate : .0066666
    • warmup_steps : 2000

感触としては、バッチサイズの値に合わせて学習率の値を調整してやる必要があった。
バッチサイズはメモリの関係上、今回は16が限界であった。
この16に合う学習率を模索した結果、上記のようになった。

トレーニングの終了

モデルの評価

学習を進めていくうちに、Lossが指数関数的に下がっていくことが確認できると思う。
Lossが指数関数的に下がっていない場合は学習が上手くいっていない、Lossが途中から大きく上がりだした場合は過学習となっている可能性がある。
この場合、パラメータ値や教師データを見直す必要がある。
まずはパラメータを調整して、再Tryしてみるのが良い。

学習の指標としては、評価用データの結果をチェックポイントごとに確認してみるのが良い。
ただし、学習スクリプトと並行して動かせないため、工夫する必要がある。
例えばチェックポイントをGoogle Driveに出力し、共有し、評価スクリプトはローカル環境で動かす、等。
APが指数関数的に増加してるか等、評価の指標となるものはたくさんある。
学習が上手くいってると、ほんとにきれいな曲線になるため、そうなっていない場合はやはりパラメータなどを見直す必要がある。

そして良さそうなモデルができたら、学習に使用していないデータで結果を確認してみる。
自分が納得いくモデルになるまで、根気強く頑張ろう!

モデルの出力

今回、モデルはスマートフォンアプリ用に「tflite」モデルとして出力する。
そこで、「Exporting a Trained Inference Graph」章の「#run conversion script」スクリプトを以下のように変更する必要がある。

#run conversion script
import re
import numpy as np

output_directory = '/content/fine_tuned_model'

#place the model weights you would like to export here
last_model_path = '/content/training/'
print(last_model_path)
# ★★★以下を変更★★★
!python /content/models/research/object_detection/export_tflite_graph_tf2.py \
    --trained_checkpoint_dir {last_model_path} \
    --output_directory {output_directory} \
    --pipeline_config_path {pipeline_file} \
    --max_detections 40

max_detectionsの指定がないと、デフォルトで1度に10個しかオブジェクト検出してくれない。

これでできるのが「saved_model」であるが、これを実際にtfliteモデルに変換するため、以下のスクリプトを追加する。

!pip install -q tflite_support
_TFLITE_MODEL_PATH = "/content/tf_lite/model.tflite"

# convert to tflite
converter = tf.lite.TFLiteConverter.from_saved_model('/content/fine_tuned_model/saved_model')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
with open(_TFLITE_MODEL_PATH, 'wb') as f:
  f.write(tflite_model)

# create label_map
_ODT_LABEL_MAP_PATH = label_map_pbtxt_fname
_TFLITE_LABEL_PATH = '/content/tf_lite/label_map.pbtxt'

category_index = label_map_util.create_category_index_from_labelmap(
    _ODT_LABEL_MAP_PATH)
f = open(_TFLITE_LABEL_PATH, 'w')
for class_id in range(1, 91):
  if class_id not in category_index:
    f.write('???\n')
    continue
  name = category_index[class_id]['name']
  f.write(name+'\n')
f.close()

# create tflite metadeta
from tflite_support.metadata_writers import object_detector
from tflite_support.metadata_writers import writer_utils
_TFLITE_MODEL_WITH_METADATA_PATH = "/content/tf_lite/model_with_metadata.tflite"

writer = object_detector.MetadataWriter.create_for_inference(
    writer_utils.load_file(_TFLITE_MODEL_PATH), input_norm_mean=[127.5],
    input_norm_std=[127.5], label_file_paths=[_TFLITE_LABEL_PATH])
writer_utils.save_file(writer.populate(), _TFLITE_MODEL_WITH_METADATA_PATH)

これで、Flutterに組み込める麻雀牌認識モデルが出来上がった。

最後に

今回は、麻雀牌検出器が生成できた。
次回は麻雀サポーター生成記録の次のステップとして、麻雀牌検出カメラの作成を記載していく。
お楽しみに!

コメント

タイトルとURLをコピーしました