This commit is contained in:
Eric Mintun 2023-04-10 12:02:02 -07:00
parent b028d54358
commit 7fa17d78c4

View File

@ -214,7 +214,7 @@ class SamAutomaticMaskGenerator:
keep_by_nms = batched_nms(
data["boxes"].float(),
scores,
torch.zeros_like(data["boxes"][:,0]), # categories
torch.zeros_like(data["boxes"][:, 0]), # categories
iou_threshold=self.crop_nms_thresh,
)
data.filter(keep_by_nms)
@ -251,7 +251,7 @@ class SamAutomaticMaskGenerator:
keep_by_nms = batched_nms(
data["boxes"].float(),
data["iou_preds"],
torch.zeros_like(data["boxes"][:,0]), # categories
torch.zeros_like(data["boxes"][:, 0]), # categories
iou_threshold=self.box_nms_thresh,
)
data.filter(keep_by_nms)
@ -357,7 +357,7 @@ class SamAutomaticMaskGenerator:
keep_by_nms = batched_nms(
boxes.float(),
torch.as_tensor(scores),
torch.zeros_like(boxes[:,0]), # categories
torch.zeros_like(boxes[:, 0]), # categories
iou_threshold=nms_thresh,
)