Fix mypy thinking output of SamPredictor.predict is a tensor.

This commit is contained in:
Eric Mintun 2023-04-12 09:48:07 -07:00
parent d4ecc68dea
commit dec4c12940

View File

@ -160,10 +160,10 @@ class SamPredictor:
return_logits=return_logits, return_logits=return_logits,
) )
masks = masks[0].detach().cpu().numpy() masks_np = masks[0].detach().cpu().numpy()
iou_predictions = iou_predictions[0].detach().cpu().numpy() iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
low_res_masks = low_res_masks[0].detach().cpu().numpy() low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
return masks, iou_predictions, low_res_masks return masks_np, iou_predictions_np, low_res_masks_np
@torch.no_grad() @torch.no_grad()
def predict_torch( def predict_torch(