Move batched NMS indices to correct device (closes #17)

This commit is contained in:
Louis Maddox 2023-04-06 11:07:51 +01:00 committed by ericmintun
parent 2780a301de
commit 6325eb8512

View File

@ -214,7 +214,7 @@ class SamAutomaticMaskGenerator:
keep_by_nms = batched_nms( keep_by_nms = batched_nms(
data["boxes"].float(), data["boxes"].float(),
scores, scores,
torch.zeros(len(data["boxes"])), # categories torch.zeros_like(data["boxes"][:,0]), # categories
iou_threshold=self.crop_nms_thresh, iou_threshold=self.crop_nms_thresh,
) )
data.filter(keep_by_nms) data.filter(keep_by_nms)
@ -251,7 +251,7 @@ class SamAutomaticMaskGenerator:
keep_by_nms = batched_nms( keep_by_nms = batched_nms(
data["boxes"].float(), data["boxes"].float(),
data["iou_preds"], data["iou_preds"],
torch.zeros(len(data["boxes"])), # categories torch.zeros_like(data["boxes"][:,0]), # categories
iou_threshold=self.box_nms_thresh, iou_threshold=self.box_nms_thresh,
) )
data.filter(keep_by_nms) data.filter(keep_by_nms)