diff --git a/segment_anything/predictor.py b/segment_anything/predictor.py index 6661d2a..8a6e6d8 100644 --- a/segment_anything/predictor.py +++ b/segment_anything/predictor.py @@ -160,10 +160,10 @@ class SamPredictor: return_logits=return_logits, ) - masks = masks[0].detach().cpu().numpy() - iou_predictions = iou_predictions[0].detach().cpu().numpy() - low_res_masks = low_res_masks[0].detach().cpu().numpy() - return masks, iou_predictions, low_res_masks + masks_np = masks[0].detach().cpu().numpy() + iou_predictions_np = iou_predictions[0].detach().cpu().numpy() + low_res_masks_np = low_res_masks[0].detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np @torch.no_grad() def predict_torch(