From 6325eb851263b93dfa7698c154922478dd27902b Mon Sep 17 00:00:00 2001 From: Louis Maddox Date: Thu, 6 Apr 2023 11:07:51 +0100 Subject: [PATCH] Move batched NMS indices to correct device (closes #17) --- segment_anything/automatic_mask_generator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/segment_anything/automatic_mask_generator.py b/segment_anything/automatic_mask_generator.py index da944ed..9c04d76 100644 --- a/segment_anything/automatic_mask_generator.py +++ b/segment_anything/automatic_mask_generator.py @@ -214,7 +214,7 @@ class SamAutomaticMaskGenerator: keep_by_nms = batched_nms( data["boxes"].float(), scores, - torch.zeros(len(data["boxes"])), # categories + torch.zeros_like(data["boxes"][:,0]), # categories iou_threshold=self.crop_nms_thresh, ) data.filter(keep_by_nms) @@ -251,7 +251,7 @@ class SamAutomaticMaskGenerator: keep_by_nms = batched_nms( data["boxes"].float(), data["iou_preds"], - torch.zeros(len(data["boxes"])), # categories + torch.zeros_like(data["boxes"][:,0]), # categories iou_threshold=self.box_nms_thresh, ) data.filter(keep_by_nms)