Merge pull request #57 from Elm-Forest/main

Fixed some typos in comments
This commit is contained in:
ericmintun 2023-04-10 08:49:32 -07:00 committed by GitHub
commit 322eebc06f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 11 additions and 11 deletions

View File

@ -144,7 +144,7 @@ def run_export(
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=UserWarning)
with open(output, "wb") as f: with open(output, "wb") as f:
print(f"Exporing onnx model to {output}...") print(f"Exporting onnx model to {output}...")
torch.onnx.export( torch.onnx.export(
onnx_model, onnx_model,
tuple(dummy_inputs.values()), tuple(dummy_inputs.values()),

View File

@ -73,10 +73,10 @@ class SamAutomaticMaskGenerator:
calculated the stability score. calculated the stability score.
box_nms_thresh (float): The box IoU cutoff used by non-maximal box_nms_thresh (float): The box IoU cutoff used by non-maximal
suppression to filter duplicate masks. suppression to filter duplicate masks.
crops_n_layers (int): If >0, mask prediction will be run again on crop_n_layers (int): If >0, mask prediction will be run again on
crops of the image. Sets the number of layers to run, where each crops of the image. Sets the number of layers to run, where each
layer has 2**i_layer number of image crops. layer has 2**i_layer number of image crops.
crops_nms_thresh (float): The box IoU cutoff used by non-maximal crop_nms_thresh (float): The box IoU cutoff used by non-maximal
suppression to filter duplicate masks between different crops. suppression to filter duplicate masks between different crops.
crop_overlap_ratio (float): Sets the degree to which crops overlap. crop_overlap_ratio (float): Sets the degree to which crops overlap.
In the first crop layer, crops will overlap by this fraction of In the first crop layer, crops will overlap by this fraction of

View File

@ -198,7 +198,7 @@ class Attention(nn.Module):
Args: Args:
dim (int): Number of input channels. dim (int): Number of input channels.
num_heads (int): Number of attention heads. num_heads (int): Number of attention heads.
qkv_bias (bool: If True, add a learnable bias to query, key, value. qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (int or None): Input resolution for calculating the relative positional input_size (int or None): Input resolution for calculating the relative positional
@ -270,7 +270,7 @@ def window_unpartition(
""" """
Window unpartition into original sequences and removing padding. Window unpartition into original sequences and removing padding.
Args: Args:
x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size. window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp). pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding. hw (Tuple): original height and width (H, W) before padding.

View File

@ -85,8 +85,8 @@ class Sam(nn.Module):
(list(dict)): A list over input images, where each element is (list(dict)): A list over input images, where each element is
as dictionary with the following keys. as dictionary with the following keys.
'masks': (torch.Tensor) Batched binary mask predictions, 'masks': (torch.Tensor) Batched binary mask predictions,
with shape BxCxHxW, where B is the number of input promts, with shape BxCxHxW, where B is the number of input prompts,
C is determiend by multimask_output, and (H, W) is the C is determined by multimask_output, and (H, W) is the
original size of the image. original size of the image.
'iou_predictions': (torch.Tensor) The model's predictions 'iou_predictions': (torch.Tensor) The model's predictions
of mask quality, in shape BxC. of mask quality, in shape BxC.

View File

@ -96,7 +96,7 @@ class TwoWayTransformer(nn.Module):
key_pe=image_pe, key_pe=image_pe,
) )
# Apply the final attenion layer from the points to the image # Apply the final attention layer from the points to the image
q = queries + point_embedding q = queries + point_embedding
k = keys + image_pe k = keys + image_pe
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)

View File

@ -186,7 +186,7 @@ class SamPredictor:
point_labels (torch.Tensor or None): A BxN array of labels for the point_labels (torch.Tensor or None): A BxN array of labels for the
point prompts. 1 indicates a foreground point and 0 indicates a point prompts. 1 indicates a foreground point and 0 indicates a
background point. background point.
box (np.ndarray or None): A Bx4 array given a box prompt to the boxes (np.ndarray or None): A Bx4 array given a box prompt to the
model, in XYXY format. model, in XYXY format.
mask_input (np.ndarray): A low resolution mask input to the model, typically mask_input (np.ndarray): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form Bx1xHxW, where coming from a previous prediction iteration. Has form Bx1xHxW, where

View File

@ -162,7 +162,7 @@ def calculate_stability_score(
the predicted mask logits at high and low values. the predicted mask logits at high and low values.
""" """
# One mask is always contained inside the other. # One mask is always contained inside the other.
# Save memory by preventing unnecesary cast to torch.int64 # Save memory by preventing unnecessary cast to torch.int64
intersections = ( intersections = (
(masks > (mask_threshold + threshold_offset)) (masks > (mask_threshold + threshold_offset))
.sum(-1, dtype=torch.int16) .sum(-1, dtype=torch.int16)

View File

@ -15,7 +15,7 @@ from typing import Tuple
class ResizeLongestSide: class ResizeLongestSide:
""" """
Resizes images to longest side 'target_length', as well as provides Resizes images to the longest side 'target_length', as well as provides
methods for resizing coordinates and boxes. Provides methods for methods for resizing coordinates and boxes. Provides methods for
transforming both numpy array and batched torch tensors. transforming both numpy array and batched torch tensors.
""" """