80 lines
2.6 KiB
Python
80 lines
2.6 KiB
Python
import os
|
||
import cv2
|
||
from ultralytics import YOLO
|
||
|
||
# 输入和输出视频路径
|
||
# video_path = r'D:\AIM\pecan\OneDrive_2_2024-11-7\G5 Flex 01 8-5-2024, 1.20.12pm EDT - 8-5-2024, 1.23.22pm EDT.mp4'
|
||
# video_path_out = r'D:\AIM\pecan\G5 Flex 01 8-5-2024_out.mp4'
|
||
video_path = r'D:\AIM\pecan\GH014359.mp4'
|
||
video_path_out = r'D:\AIM\pecan\GH014359_out.mp4'
|
||
|
||
# 加载 YOLO 模型
|
||
model = YOLO(r"D:\AIM\pecan\runs\detect\train2\weights\best.pt") # 加载自定义模型
|
||
|
||
# 初始化 VideoWriter 用于保存输出视频
|
||
cap = cv2.VideoCapture(video_path)
|
||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||
out = cv2.VideoWriter(video_path_out, fourcc, fps, (width, height))
|
||
|
||
# 字典,用于跟踪每个核桃的状态
|
||
walnut_states = {} # 格式: {ID: "状态"}
|
||
|
||
# 定义类别标签
|
||
class_labels = {
|
||
0: "circumferential",
|
||
1: "cracked",
|
||
2: "crushed",
|
||
3: "longitudinal",
|
||
4: "open",
|
||
5: "uncracked"
|
||
}
|
||
|
||
# 需要分配 ID 的类别
|
||
id_tracked_classes = ["cracked", "uncracked"]
|
||
|
||
# 使用 BoT-SORT 进行目标跟踪
|
||
results = model.track(source=video_path, conf=0.5, tracker='botsort.yaml', show=False)
|
||
|
||
for result in results:
|
||
frame = result.orig_img # 当前帧
|
||
detections = result.boxes # 检测框信息
|
||
|
||
# 处理每个检测框
|
||
for box in detections:
|
||
x1, y1, x2, y2 = map(int, box.xyxy[0]) # 检测框坐标
|
||
obj_id = int(box.id) if box.id is not None else -1 # 跟踪目标ID
|
||
class_id = int(box.cls) # 类别ID
|
||
score = box.conf # 置信度
|
||
|
||
# 获取检测框对应的类别标签
|
||
label = class_labels.get(class_id, "unknown")
|
||
|
||
# 仅对需要分配ID的类别更新核桃状态
|
||
if label in id_tracked_classes:
|
||
if obj_id not in walnut_states:
|
||
walnut_states[obj_id] = label
|
||
else:
|
||
# 一旦检测到“cracked”,状态保持为“cracked”
|
||
if walnut_states[obj_id] != "cracked":
|
||
walnut_states[obj_id] = label
|
||
|
||
display_text = f"ID {obj_id} | {walnut_states[obj_id]}"
|
||
else:
|
||
# 非分配ID的类别,仅显示类别标签
|
||
display_text = label
|
||
|
||
# 绘制检测框和标签
|
||
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||
cv2.putText(frame, display_text, (x1, y1 - 10),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
||
|
||
# 将处理好的帧写入输出视频
|
||
out.write(frame)
|
||
|
||
# 释放资源
|
||
cap.release()
|
||
out.release()
|
||
print("视频处理完成,结果已保存。") |