From 06bd20da89b5ec4b0f0256b3197596d4cd5c2f8a Mon Sep 17 00:00:00 2001 From: Eric Mintun Date: Mon, 10 Apr 2023 09:21:25 -0700 Subject: [PATCH] Fix lint. --- segment_anything/modeling/image_encoder.py | 8 ++++---- segment_anything/utils/onnx.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/segment_anything/modeling/image_encoder.py b/segment_anything/modeling/image_encoder.py index 626623a..755ff4f 100644 --- a/segment_anything/modeling/image_encoder.py +++ b/segment_anything/modeling/image_encoder.py @@ -144,8 +144,8 @@ class Block(nn.Module): rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. If it equals 0, then use global attention. - input_size (tuple(int, int) or None): Input resolution for calculating the relative positional - parameter size. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. """ super().__init__() self.norm1 = norm_layer(dim) @@ -201,8 +201,8 @@ class Attention(nn.Module): qkv_bias (bool): If True, add a learnable bias to query, key, value. rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. - input_size (tuple(int, int) or None): Input resolution for calculating the relative positional - parameter size. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. """ super().__init__() self.num_heads = num_heads diff --git a/segment_anything/utils/onnx.py b/segment_anything/utils/onnx.py index 493950a..3196bdf 100644 --- a/segment_anything/utils/onnx.py +++ b/segment_anything/utils/onnx.py @@ -82,7 +82,7 @@ class SamOnnxModel(nn.Module): ) prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) - masks = masks[..., : prepadded_size[0], : prepadded_size[1]] + masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore orig_im_size = orig_im_size.to(torch.int64) h, w = orig_im_size[0], orig_im_size[1]