Update onnx.py

Fix break in computation graph due to typecast to python `int` instead of torch type.
This commit is contained in:
William Woof
2023-04-06 18:13:24 +01:00
committed by GitHub
parent f2557f7780
commit e9f2d58094

View File

@@ -81,8 +81,8 @@ class SamOnnxModel(nn.Module):
align_corners=False,
)
prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size)
masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])]
prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
masks = masks[..., : prepadded_size[0], : prepadded_size[1]]
orig_im_size = orig_im_size.to(torch.int64)
h, w = orig_im_size[0], orig_im_size[1]