物体検出プログラム(SSD)で小さい物体を検出するプログラムを作ったのでソースコードを解説します
物体検出プログラム
先日の記事で小さな物体を検出する方法を紹介しました。
今日はそのソースコードを解説します。
なお、プログラム作成にあたり以下の本を参考にしました。
import os
import cv2
import sys
import glob
import torch
import numpy as np
import torch.nn as nn
import torch.backends.cudnn as cudnn
from PIL import Image
from natsort import natsorted
from matplotlib import pyplot as plt
# GPUの接続
torch.cuda.is_available()
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
sys.path.append(module_path)
if torch.cuda.is_available():
torch.set_default_tensor_type('torch.cuda.FloatTensor')
from ssd import build_ssd
##### パラメータ設定 #####
cn = 9 # クラス数
bbox = 8 # バインティングボックスの色指定(5:赤, 8:青, 10:ライトグリーン, 15:緑, 21:黄色, 30:薄いオレンジ, 40:オレンジ)
# バックグラウンドを0としてプラス8種類で合計クラス数は9
labels = ( # always index 0
'odairisama', 'ohinasama', 'kanjo_hisage', 'kanjo_sanpou', 'kanjo_nagae', 'taiko', 'tuzumi', 'fue'
)
# 引数testでネットワークを定義し、学習済みモデルをロード
# ネットワークの定義
# 引数が'test'だと、推論結果に対してクラスDetectで後処理を実行
ssd_net = build_ssd('test', 300, cn) # initialize SSD
# GPUの場合、deviceに'cuda'を設定
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# GPUへの転送
net = ssd_net.to(device)
# 学習済みモデルのロード
net.load_weights('./weights/ohina_200211_3000_1.pth')
# 物体検出させたい画像
image = cv2.imread('./blog/018_4480px.jpg')
### 画像分割セクション ###
h, w = image.shape[:2] # 画像の縦と横のサイズを取得
n = 10 # 画像分割数
y0 = h // n
x0 = w // n
print("y0=", y0)
print("x0=", x0)
# 分割した画像をリスト化
c = [image[x0*x:x0*(x+1), y0*y:y0*(y+1)] for x in range(n) for y in range(n)]
# enumerate関数でインデックス番号と画像情報を取得
# インデックス番号はjpgのファイル番号に使う
# 画像はsepaフォルダに保存する
for i, img in enumerate(c):
cv2.imwrite(os.path.join('./blog/sepa', '{}.jpg'.format(i)), img)
### ここまでが画像分割セクション ###
n0 = n*n
f = 0
for i in range(n0):
rgb_image = cv2.cvtColor(c[i], cv2.COLOR_BGR2RGB)
# 畳み込み計算できるよう前処理を実行
# 画像のサイズを300×300に変更
x = cv2.resize(c[i], (300, 300)).astype(np.float32)
# 平均のRGBを引く(色情報の標準化)
x -= (104.0, 117.0, 123.0)
x = x.astype(np.float32)
# cv2のチャンネルの順番はBGR(青、緑、赤)なので、RGB(赤、緑、青)に入れ替える
x = x[:, :, ::-1].copy()
# HWCの形状[300, 300, 3]をCHWの形状[3, 300, 300]に変更
x = torch.from_numpy(x).permute(2, 0, 1)
# SSDネットワークの順伝播計算を実行
# 0次元目にバッチサイズの次元を追加
# [3, 300, 300] → [1, 3, 300, 300]
xx = x.unsqueeze(0)
# GPUへの転送
xx = xx.to(device)
# dropoutを実行しない
net.eval()
# 計算グラフを作成しない
with torch.no_grad():
# 順伝播を実行し、推論結果を出力
y = net(xx)
# バウンディングボックスの出力
#from data import VOC_CLASSES as labels
plt.figure(figsize=(10,10))
colors = plt.cm.hsv(np.linspace(0, 1, bbox)).tolist()
plt.axis('off')
plt.imshow(rgb_image)
currentAxis = plt.gca()
#plt.show()
# 推論結果をdetectionsに格納
detections = y.data
scale = torch.Tensor(rgb_image.shape[1::-1]).repeat(2)
for i in range(detections.size(1)):
j = 0
# 確信度confが0.6以上のボックスを表示
# jは確信度上位200件のボックスのインデックス
while detections[0,i,j,0] >= 0.6:
score = detections[0,i,j,0]
label_name = labels[i-1]
display_txt = '%s: %.2f'%(label_name, score)
pt = (detections[0,i,j,1:]*scale).cpu().numpy()
coords = (pt[0], pt[1]), pt[2]-pt[0]+1, pt[3]-pt[1]+1
color = colors[i]
# fillをFalseにすると枠線だけ書く
currentAxis.add_patch(plt.Rectangle(*coords, fill=False, edgecolor=color, linewidth=2))
currentAxis.text(pt[0], pt[1], display_txt, bbox={'facecolor':color, 'alpha':0.5})
j+=1
plt.imshow(rgb_image)
# detectフォルダに検出枠が画像に上書きされた図として保存される
plt.savefig('./blog/detect/' + '{}.jpg'.format(f))
# cv2.imwriteでも保存は出来るが検出枠が上書きされない
#cv2.imwrite('./blog/detect/' + '{}.jpg'.format(f), rgb_image)
f += 1
files = glob.glob("./blog/detect/" + "*.jpg") # detectフォルダから画像を読み込む
d = []
# files からnatsortedでファイル番号の若い順(0.jpg, 1.jpg, ・・・)に読み込む
for i in natsorted(files):
img = Image.open(i)
img = np.asarray(img)
img = img[120:890, 128:898] # savefigで保存した画像は余白が邪魔なのでトリミング
d.append(img)
# 4480px * 4480px の画像を縦横10分割したので計100枚を合体させる
# 先にnp.hstack で横方向を10個合体させる
# 次にnp.vstack で縦方向を10個合体させる
img_x = np.vstack((np.hstack(d[0:10]),
np.hstack(d[10:20]),
np.hstack(d[20:30]),
np.hstack(d[30:40]),
np.hstack(d[40:50]),
np.hstack(d[50:60]),
np.hstack(d[60:70]),
np.hstack(d[70:80]),
np.hstack(d[80:90]),
np.hstack(d[90:100])
))
img_x = cv2.cvtColor(img_x, cv2.COLOR_BGR2RGB) # 色をBGRからRGBに変換
cv2.imshow('img_1', img_x)
cv2.imwrite('./blog/result.jpg', img_x)
cv2.waitKey(0)
cv2.destroyAllWindows()
ソースコード解説
ソースコード番号順にポイントとなる部分を開設して行きます。
まず、コード1~12で必要なライブラリを読み込みます。
48行目で学習済みモデルをロードします。このとき、pthファイルを読み込んでいますが、これは検知させたい物体を事前に学習させ自作する必要があります。pthファイルの作り方はまた別の機会にお話しします。
51行目で cv2.imread() で 物体検出させたい画像を読み込みます。
今回は、4480×4480ピクセルサイズのjpgファイルを読み込ませています。検証のためにかなり大きな画像を使用しましたが、もっと小さな画像でもプログラムの動作原理は同じです。
リスト内包表記
最初のポイントは64行目
分割した画像をリスト化していますが内包表記を使っています。
ご存じの方も多いと思いますが、for ループを2回使うため内包表記を使わないと処理速度がかなり遅くなります。理由は計算量がO(n^2)(Oはランダウの記号)になるためです。計算量がデータ量nの2乗で多くなるため、その分処理に時間がかかります。
ここをリスト内包表記を使えば数倍~数十倍処理が速くなります。
計算量と処理時間の関係についてはこちらの記事をどうぞ。
64行目でリスト化したものを、69-70行目で分割画像に番号を付けてフォルダに保存します。
enumerate 関数
enumerate() 関数を使うことで、インデックス番号( i )と画像情報(img)を同時に取得できます。
インデックス番号は 0 番からはじまるので、最初の分割画像のファイル名は 0.jpg になります。
74行目以降で分割画像を1枚づつ推論して行きます。
77行目は、先ほど作ったリストcから分割画像を1枚読み込んで、色情報をBGRからRGBに変換します。
さらに80行目では、先ほど作ったリストcから分割画像を1枚読み込んでリサイズし、推論用の画像を用意します。
これは、cv2.imread()で読み込んだ画像は色の順番がBGRの順番になっているからです。
推論するためには、RGBに変換しておく必要があるため cv2.cvtColor(c[i], cv2.COLOR_BGR2RGB)で色の順番を変更します。
色情報の標準化
なにげに大事なのが84行目です。ここで色情報の標準化をおこなっています。
RGBの各色の平均値(ここでは(104.0, 117.0, 123.0))を引くことで、物体特有の色情報(特徴)が明確になります。
以下の画像を見てください。
左がオリジナル画像です。右が色情報を標準化した画像です。
右の方が物体の特徴がはっきりします。
次に88行目で色の順番を逆にします。[:, :, : :-1] でBGR が RGBに代わります。numpy ではおなじみの操作です。
102行目で net を評価モードにして107行目で推論します。
なお推論結果はいったん y に代入し、119行目で推論結果を detections に格納します。
127行目以降、確信度が0.6以上である限り while 文でループを回し続けます。
もし0.6以上の確率で推論されれば、129行目の label_nameに odairisama などの名前が代入されます。
代入される名前は、32行目でタプル(タプルとはカッコでくくられたもののこと)であらかじめ指定された名前の何れかです。
130行目のdisplay_txt には、枠の左上の方に出ている名前と推論確率(0.99とか)が代入されます。
131行目の pt は枠の位置情報です。pt[0] はx方向、pt[1] はy方向です。
136行目の plt.Rectangle は グラフに四角の枠を表示する関数です。
137行目はそのグラフに130行目の display_txt の内容を表示させる指定です。
画像保存はplt.savefig
140行目でいったん plt.imshow() で rgb_image の画像を表示させ、143行目で plt.savefig() 関数で画像を保存します。この画像には四角の枠、名前、推論確率が上書きされた状態になっています。
ちなみに、画像を保存するだけなら cv2.imwrite() 関数でもできますが、四角の枠、名前、推論確率は上書きされません。
150行目で glob.glob() 関数で detect フォルダに先ほど保存した画像を読み込みます。
また、” *.jpg ” このダブルクォーテーションで囲んだ部分は、「*」がワイルドカードというもので画像ファイルの拡張子が jpg のものをすべて読み込みなさいという指示です。
読み込んだ画像は、files という「リスト」にいったん入ります。
natsorted で昇順読み込み
155行目で前回記事でも説明しましたが、natsorted() 関数を使うことで files リストからファイル番号の若い順(昇順:0.jpg、1.jpg、・・・)に読み込みます。
画像をNumPyで扱うにはndarray化が必須
156行目の Image.open() 関数で読み込まれた画像は ndarray 形式ではないので、157行目の np.asarray() 関数で ndarray 形式に変換します。ndarray 形式にすることで numpy で扱えるようになります。
162行目で savefigで保存した画像の余白をトリミングします。余白を小さくする方法はいろいろありますが、私はこの方法がシンプルなので好きです。
168行目で画像を結合して行きます。
画像結合には np.hstack , np.vstack
始めに、np.hstack() 関数で水平方向(x方向)の画像を結合します。
結合する画像枚数は、d[0:10] の部分で指定しています。この場合は、0.jpg から 9.jpg まで10枚の画像を結合することになります。
同様に、d[10:20] の部分では 10.jpg から19.jpg まで10枚の画像を結合し、以下同様に水平方向の画像を結合して行きます。
それが終わると、np.vstack() 関数で垂直方向(y方向)の画像を結合して行きます。
全部の画像結合が終わると、img_x に代入されます。
180行目で 色をBGRからRGBに変換します。
182行目で cv2.imshow() 関数で img_x 画像を表示します。
183行目で img_x 画像を blog フォルダに result.jpg というファイル名で保存します。
画像を表示させ続けたいならcv2.waitKey(0)を忘れずに
ちなみに、185行目に cv2.waitKey(0) を指定しておかないと img_x の画像表示が一瞬で消えます。
185行目はなにかキーを押すまで img_x が消えないようにしておくための指示です。
186行目は画面をクリアするための指示です。
つかれました。簡単ですが解説を終わります。
まとめ
- 大量のデータを処理する際はリスト内包表記を使えば数倍~数十倍処理が速くなる。
- enumerate() 関数を使えばインデックス番号( i )と画像情報(img)を同時に取得できる。
- 画像を学習させる際は色情報の標準化が大事。
- 画像保存はplt.savefig() 関数を使う。
- 画像をNumPyで扱うにはndarray化が必須。
- 画像結合には np.hstack , np.vstackを使う。
- 画像を表示させ続けたいならcv2.waitKey(0)を忘れずに書くこと。
ディスカッション
コメント一覧
まだ、コメントがありません