diff --git a/segment_anything/utils/transforms.py b/segment_anything/utils/transforms.py index 97a682a..c08ba1e 100644 --- a/segment_anything/utils/transforms.py +++ b/segment_anything/utils/transforms.py @@ -59,7 +59,7 @@ class ResizeLongestSide: the transformation expected by the model. """ # Expects an image in BCHW format. May not exactly match apply_image. - target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) return F.interpolate( image, target_size, mode="bilinear", align_corners=False, antialias=True )