CV_AG/Test_logic_track.py

80 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
from ultralytics import YOLO
# 输入和输出视频路径
# video_path = r'D:\AIM\pecan\OneDrive_2_2024-11-7\G5 Flex 01 8-5-2024, 1.20.12pm EDT - 8-5-2024, 1.23.22pm EDT.mp4'
# video_path_out = r'D:\AIM\pecan\G5 Flex 01 8-5-2024_out.mp4'
video_path = r'D:\AIM\pecan\GH014359.mp4'
video_path_out = r'D:\AIM\pecan\GH014359_out.mp4'
# 加载 YOLO 模型
model = YOLO(r"D:\AIM\pecan\runs\detect\train2\weights\best.pt") # 加载自定义模型
# 初始化 VideoWriter 用于保存输出视频
cap = cv2.VideoCapture(video_path)
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(video_path_out, fourcc, fps, (width, height))
# 字典,用于跟踪每个核桃的状态
walnut_states = {} # 格式: {ID: "状态"}
# 定义类别标签
class_labels = {
0: "circumferential",
1: "cracked",
2: "crushed",
3: "longitudinal",
4: "open",
5: "uncracked"
}
# 需要分配 ID 的类别
id_tracked_classes = ["cracked", "uncracked"]
# 使用 BoT-SORT 进行目标跟踪
results = model.track(source=video_path, conf=0.5, tracker='botsort.yaml', show=False)
for result in results:
frame = result.orig_img # 当前帧
detections = result.boxes # 检测框信息
# 处理每个检测框
for box in detections:
x1, y1, x2, y2 = map(int, box.xyxy[0]) # 检测框坐标
obj_id = int(box.id) if box.id is not None else -1 # 跟踪目标ID
class_id = int(box.cls) # 类别ID
score = box.conf # 置信度
# 获取检测框对应的类别标签
label = class_labels.get(class_id, "unknown")
# 仅对需要分配ID的类别更新核桃状态
if label in id_tracked_classes:
if obj_id not in walnut_states:
walnut_states[obj_id] = label
else:
# 一旦检测到“cracked”状态保持为“cracked”
if walnut_states[obj_id] != "cracked":
walnut_states[obj_id] = label
display_text = f"ID {obj_id} | {walnut_states[obj_id]}"
else:
# 非分配ID的类别仅显示类别标签
display_text = label
# 绘制检测框和标签
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(frame, display_text, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
# 将处理好的帧写入输出视频
out.write(frame)
# 释放资源
cap.release()
out.release()
print("视频处理完成,结果已保存。")