Fix mypy thinking output of SamPredictor.predict is a tensor.
This commit is contained in:
parent
d4ecc68dea
commit
dec4c12940
@ -160,10 +160,10 @@ class SamPredictor:
|
||||
return_logits=return_logits,
|
||||
)
|
||||
|
||||
masks = masks[0].detach().cpu().numpy()
|
||||
iou_predictions = iou_predictions[0].detach().cpu().numpy()
|
||||
low_res_masks = low_res_masks[0].detach().cpu().numpy()
|
||||
return masks, iou_predictions, low_res_masks
|
||||
masks_np = masks[0].detach().cpu().numpy()
|
||||
iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
|
||||
low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
|
||||
return masks_np, iou_predictions_np, low_res_masks_np
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_torch(
|
||||
|
Loading…
x
Reference in New Issue
Block a user