やること
Suzuです|・ω・)ノ(大事なことなので2回言いました)
洗濯槽内に混入したティッシュを機械学習で検出するシリーズ、前回はティッシュ画像の撮影とアノテーションを行いました。
今回は学習を行います。
学習
データセットのzipとdataset.yamlをGoogle Colaboratoryの左側のところにアップロードします。
アップしたデータセットを展開します。
!unzip datasets.zip
YOLOv8はPythonライブラリ「ultralytics」に入っているのでインストールします。PyTorchも必要ですがColabにプリインされています。
!pip install ultralytics
学習はこれだけです、かんたん!
from ultralytics import YOLO
#モデル指定
model = YOLO('yolov8n.pt')
#データセットを学習
results = model.train(data='dataset.yaml', epochs=100)
・
・
・
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
100/100 0G 0.3883 0.3886 0.8251 3 640: 100%|██████████| 4/4 [00:25<00:00, 6.41s/it]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 1/1 [00:02<00:00, 2.87s/it] all 23 23 0.995 1 0.995 0.801
100 epochs completed in 0.806 hours.
Optimizer stripped from runs/detect/train43/weights/last.pt, 6.2MB
Optimizer stripped from runs/detect/train43/weights/best.pt, 6.2MB
Validating runs/detect/train43/weights/best.pt...
Ultralytics YOLOv8.1.18 🚀 Python-3.10.12 torch-2.1.0+cu121 CPU (AMD EPYC 7B12)
Model summary (fused): 168 layers, 3005843 parameters, 0 gradients, 8.1 GFLOPs
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 1/1 [00:02<00:00, 2.71s/it]
all 23 23 0.99 1 0.995 0.841
Speed: 1.0ms preprocess, 99.9ms inference, 0.0ms loss, 0.3ms postprocess per image
Results saved to runs/detect/train43
出力に書かれているように、
runs/detect/train43/weights/last.pt
runs/detect/train43/weights/best.pt
に学習済みモデルが自動保存されるのでどちらかをローカルにダウンロードしておきます。「last.pt」は最後のモデル、「best.pt」は学習中にもっとも精度が高かったモデルです。
検証
このまま検証(バリデーション)も実行してみます。
#モデルの精度を検証
results = model.val()
Ultralytics YOLOv8.1.18 🚀 Python-3.10.12 torch-2.1.0+cu121 CPU (AMD EPYC 7B12)
Model summary (fused): 168 layers, 3005843 parameters, 0 gradients, 8.1 GFLOPs
val: Scanning /content/datasets/suzu/labels/val.cache... 23 images, 0 backgrounds, 0 corrupt: 100%|██████████| 23/23 [00:00<?, ?it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 2/2 [00:03<00:00, 1.66s/it]
all 23 23 0.99 1 0.995 0.841
Speed: 1.0ms preprocess, 117.2ms inference, 0.0ms loss, 0.3ms postprocess per image
Results saved to runs/detect/train432
valデータは23枚でこれらは学習には使われていません。Precision=0.99、Recall=1.0、mAP50=0.901、mAP50-95=0.597という精度でした。精度の指標についてはChatGPTに聞いてください。
判定を試してみる
新しい写真を1枚アップして検出してみます。
#新しい画像で検出
results = model.predict('test_tissue.jpg')
検出できました!
続く
③リアルタイム検出編に続きます!