下面是整理后的代码,将各个步骤清晰地分段处理,并更新了模型权重的路径为本地文件路径:
# Import required libraries
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
import cv2
# Load the image
im = cv2.imread("./tmp/input.jpg")
# Configure the model
cfg = get_cfg()
cfg.merge_from_file("configs/COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # Set the threshold for object detection
cfg.MODEL.WEIGHTS = "models/model_final_f6e8b1.pkl" # Path to the local model file
# Create predictor
predictor = DefaultPredictor(cfg)
# Make prediction
outputs = predictor(im)
# Visualize the output (optional)
# Uncomment the following lines if you want to visualize and save the result
v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
# cv2.imshow("Predictions", out.get_image()[:, :, ::-1])
# cv2.waitKey(0) # Display until a key is pressed
cv2.imwrite("./tmp/output.jpg", out.get_image()[:, :, ::-1]) # Save the output
说明
-
cfg.MODEL.WEIGHTS 已更新为
models/model_final_f6e8b1.pkl
,请确保此路径和文件名与本地文件匹配。 - 如果需要可视化预测结果,可取消最后几行的注释。
- 确保在代码运行时已安装并配置好
detectron2
环境。
识别效果: