Move batched NMS indices to correct device (closes #17)
This commit is contained in:
parent
2780a301de
commit
6325eb8512
@ -214,7 +214,7 @@ class SamAutomaticMaskGenerator:
|
||||
keep_by_nms = batched_nms(
|
||||
data["boxes"].float(),
|
||||
scores,
|
||||
torch.zeros(len(data["boxes"])), # categories
|
||||
torch.zeros_like(data["boxes"][:,0]), # categories
|
||||
iou_threshold=self.crop_nms_thresh,
|
||||
)
|
||||
data.filter(keep_by_nms)
|
||||
@ -251,7 +251,7 @@ class SamAutomaticMaskGenerator:
|
||||
keep_by_nms = batched_nms(
|
||||
data["boxes"].float(),
|
||||
data["iou_preds"],
|
||||
torch.zeros(len(data["boxes"])), # categories
|
||||
torch.zeros_like(data["boxes"][:,0]), # categories
|
||||
iou_threshold=self.box_nms_thresh,
|
||||
)
|
||||
data.filter(keep_by_nms)
|
||||
|
Loading…
x
Reference in New Issue
Block a user