4/14(日) 足・靴・木型研究会「第2回研究集会」を開催します☆彡

1-11. 洗濯槽内のティッシュを画像処理で検出してみた(②学習編)

やること

Mayoです|・ω・)ノ(大事なことなので2回言いました)

洗濯槽内に混入したティッシュを機械学習で検出するシリーズ、前回はティッシュ画像の撮影とアノテーションを行いました。

今回は学習を行います。

学習

データセットのzipとdataset.yamlをGoogle Colaboratoryの左側のところにアップロードします。

アップしたデータセットを展開します。

!unzip datasets.zip

YOLOv8はPythonライブラリ「ultralytics」に入っているのでインストールします。PyTorchも必要ですがColabにプリインされています。

!pip install ultralytics

学習はこれだけです、かんたん!

from ultralytics import YOLO

#モデル指定
model = YOLO('yolov8n.pt')

#データセットを学習
results = model.train(data='dataset.yaml', epochs=100)
・
・
・
      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
    100/100         0G     0.3883     0.3886     0.8251          3        640: 100%|██████████| 4/4 [00:25<00:00,  6.41s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 1/1 [00:02<00:00,  2.87s/it]                   all         23         23      0.995          1      0.995      0.801

100 epochs completed in 0.806 hours.
Optimizer stripped from runs/detect/train43/weights/last.pt, 6.2MB
Optimizer stripped from runs/detect/train43/weights/best.pt, 6.2MB

Validating runs/detect/train43/weights/best.pt...
Ultralytics YOLOv8.1.18 🚀 Python-3.10.12 torch-2.1.0+cu121 CPU (AMD EPYC 7B12)
Model summary (fused): 168 layers, 3005843 parameters, 0 gradients, 8.1 GFLOPs
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 1/1 [00:02<00:00,  2.71s/it]
                   all         23         23       0.99          1      0.995      0.841
Speed: 1.0ms preprocess, 99.9ms inference, 0.0ms loss, 0.3ms postprocess per image
Results saved to runs/detect/train43

出力に書かれているように、

runs/detect/train43/weights/last.pt
runs/detect/train43/weights/best.pt

に学習済みモデルが自動保存されるのでどちらかをローカルにダウンロードしておきます。「last.pt」は最後のモデル、「best.pt」は学習中にもっとも精度が高かったモデルです。

検証

このまま検証(バリデーション)も実行してみます。

#モデルの精度を検証
results = model.val()
Ultralytics YOLOv8.1.18 🚀 Python-3.10.12 torch-2.1.0+cu121 CPU (AMD EPYC 7B12)
Model summary (fused): 168 layers, 3005843 parameters, 0 gradients, 8.1 GFLOPs
val: Scanning /content/datasets/suzukamayo/labels/val.cache... 23 images, 0 backgrounds, 0 corrupt: 100%|██████████| 23/23 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 2/2 [00:03<00:00,  1.66s/it]
                   all         23         23       0.99          1      0.995      0.841
Speed: 1.0ms preprocess, 117.2ms inference, 0.0ms loss, 0.3ms postprocess per image
Results saved to runs/detect/train432

valデータは23枚でこれらは学習には使われていません。Precision=0.99、Recall=1.0、mAP50=0.901、mAP50-95=0.597という精度でした。精度の指標についてはChatGPTに聞いてください。

判定を試してみる

新しい写真を1枚アップして検出してみます。

#新しい画像で検出
results = model.predict('test_tissue.jpg')

検出できました!

続く

③リアルタイム検出編に続きます!

タイトルとURLをコピーしました