Fix mypy thinking output of SamPredictor.predict is a tensor.
This commit is contained in:
parent
d4ecc68dea
commit
dec4c12940
@ -160,10 +160,10 @@ class SamPredictor:
|
|||||||
return_logits=return_logits,
|
return_logits=return_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
masks = masks[0].detach().cpu().numpy()
|
masks_np = masks[0].detach().cpu().numpy()
|
||||||
iou_predictions = iou_predictions[0].detach().cpu().numpy()
|
iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
|
||||||
low_res_masks = low_res_masks[0].detach().cpu().numpy()
|
low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
|
||||||
return masks, iou_predictions, low_res_masks
|
return masks_np, iou_predictions_np, low_res_masks_np
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def predict_torch(
|
def predict_torch(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user