From dec4c129404d8f2a78cfe5c0768920669381f25e Mon Sep 17 00:00:00 2001 From: Eric Mintun Date: Wed, 12 Apr 2023 09:48:07 -0700 Subject: [PATCH] Fix mypy thinking output of SamPredictor.predict is a tensor. --- segment_anything/predictor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/segment_anything/predictor.py b/segment_anything/predictor.py index 6661d2a..8a6e6d8 100644 --- a/segment_anything/predictor.py +++ b/segment_anything/predictor.py @@ -160,10 +160,10 @@ class SamPredictor: return_logits=return_logits, ) - masks = masks[0].detach().cpu().numpy() - iou_predictions = iou_predictions[0].detach().cpu().numpy() - low_res_masks = low_res_masks[0].detach().cpu().numpy() - return masks, iou_predictions, low_res_masks + masks_np = masks[0].detach().cpu().numpy() + iou_predictions_np = iou_predictions[0].detach().cpu().numpy() + low_res_masks_np = low_res_masks[0].detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np @torch.no_grad() def predict_torch(