模型描述:模型基于yolov8的基础预训练版本 数据:1000张安全衣与普通衣服标注数据
模型预测代码
import cv2
from yolo_net import YOLO
##加载训练保存好的最优权重的模型
model = YOLO("runs/detect/train10/weights/last.pt")
input_video = cv2.VideoCapture('predict_img_video/test_video.mp4')
output_video = cv2.VideoWriter('output_video.mp4', cv2.VideoWriter_fourcc(*'mp4v'), input_video.get(cv2.CAP_PROP_FPS), (int(input_video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(input_video.get(cv2.CAP_PROP_FRAME_HEIGHT))))
while True:
ret, frame = input_video.read()
if not ret:
break
## 对帧进行预测
predictions = model.predict(frame)
# 将预测结果绘制到帧上并保存到输出视频对象中
output_frame = frame.copy()
for detection in predictions:
output_frame = detection.plot()
output_video.write(output_frame)
cv2.imshow('Output Video', output_frame)
if cv2.waitKey(1) & 0xFF == ord('q'): # 按 'q' 键退出处理过程
break
input_video.release() # 释放输入视频对象资源
output_video.release() # 释放输出视频对象资源
cv2.destroyAllWindows() # 关闭所有OpenCV窗口
评论