Merge branch 'facebookresearch:main' into main

This commit is contained in:
Elm Forest
2023-04-09 19:16:56 +08:00
committed by GitHub
7 changed files with 87 additions and 51 deletions

View File

@@ -41,8 +41,8 @@ parser.add_argument(
parser.add_argument(
"--model-type",
type=str,
default="default",
help="The type of model to load, in ['default', 'vit_l', 'vit_b']",
required=True,
help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
)
parser.add_argument(

View File

@@ -6,7 +6,7 @@
import torch
from segment_anything import build_sam, build_sam_vit_b, build_sam_vit_l
from segment_anything import sam_model_registry
from segment_anything.utils.onnx import SamOnnxModel
import argparse
@@ -34,8 +34,8 @@ parser.add_argument(
parser.add_argument(
"--model-type",
type=str,
default="default",
help="In ['default', 'vit_b', 'vit_l']. Which type of SAM model to export.",
required=True,
help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.",
)
parser.add_argument(
@@ -105,12 +105,7 @@ def run_export(
return_extra_metrics=False,
):
print("Loading model...")
if model_type == "vit_b":
sam = build_sam_vit_b(checkpoint)
elif model_type == "vit_l":
sam = build_sam_vit_l(checkpoint)
else:
sam = build_sam(checkpoint)
sam = sam_model_registry[model_type](checkpoint=checkpoint)
onnx_model = SamOnnxModel(
model=sam,