How can I fix this loop to correcttly track and classify objects?

  Kiến thức lập trình

I have this loop where what I want to do is:

  • detect the objects nad their bb
  • track them to assign a unique id to each box
  • check if an object stops
  • if an object stops, do the classification on this bb and save the frame with the informations of the bounding box.

So I have used a yolo detector fine-tuned, a classifier, and I’m using the centroidtracker that I got on github. So I have this while loop:

stationary_objects = {}  # Dictionary to store stationary objects
previous_centroids = {}

# Set to track saved images for each object id
saved_images_set = set()

time_threshold = 2.0
stationary_threshold = 5.0

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    results = detector(frame)
    detections = results[0].boxes
    det = []

    for box in detections:
        x1, y1, x2, y2 = box.xyxy[0].tolist()
        bbox = [x1, y1, x2, y2]
        det.append(bbox)

    objects = tracker.update(det)  # Use the tracker on the bbox list detected
    current_time = time.time()
    video_time = cap.get(cv2.CAP_PROP_POS_MSEC) / 1000.0
    timestamp = str(timedelta(seconds=int(video_time)))

    for (objectID, centroid) in objects.items():

        if objectID in previous_centroids:
            previous_centroid = previous_centroids[objectID]
            distance = euclidean_distance(centroid, previous_centroid)
        else:
            distance = float('inf')

        previous_centroids[objectID] = centroid

        # Check if the object is stationary
        if distance < stationary_threshold:

            if objectID not in stationary_objects:
                stationary_objects[objectID] = {
                    "stationary_start_time": current_time,
                    "classified": False,
                    "classification": None
                }
            elif (current_time - stationary_objects[objectID]["stationary_start_time"]) >= time_threshold:
                if not stationary_objects[objectID]["classified"]:
                    if objectID < len(det):
                        x1, y1, x2, y2 = det[objectID]
                        roi = frame[int(y1):int(y2), int(x1):int(x2)]
                        roi = Image.fromarray(roi)
                        image = transform(roi)
                        image = torch.unsqueeze(image, 0).to(DEVICE)

                        with torch.no_grad():
                            outputs = model(image)

                        pred_idx = torch.argmax(outputs, dim=1).item()
                        pred_class_name = class_names[pred_idx]
                        prob = torch.softmax(outputs, dim=1)[0, pred_idx].item() * 100
                        label = f'{pred_class_name}, id:{objectID}'

                        cv2.putText(frame, label, (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
                        cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)

                        # Mark the object as classified
                        stationary_objects[objectID]["classified"] = True
                        stationary_objects[objectID]["classification"] = pred_class_name

                        if objectID not in saved_images_set:
                            cv2.putText(frame, timestamp, (width - 150, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
                            image_name = f"{name_video.strip('/').replace('.mp4', '')}_obj_{timestamp}.jpg"
                            cv2.imwrite(os.path.join(result_img_path, image_name), frame)
                            saved_images_set.add(objectID)
        else:
            if objectID in stationary_objects:
                del stationary_objects[objectID]

    out.write(frame)  # Write the frame to the output video

the only problem is that the code saves the frames, but actually seems like it saves the first frames in which every bounding boxes appears… or better, the check on the stationarity is not done correctly, so it saves the frames even if the objects in the video is still moving. I think this is a logic problem but I don’t understand where is it.

The tracker is the following: tracker

Theme wordpress giá rẻ Theme wordpress giá rẻ Thiết kế website

LEAVE A COMMENT