Merge branch 'facebookresearch:main' into main
This commit is contained in:
@@ -41,8 +41,8 @@ parser.add_argument(
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
default="default",
|
||||
help="The type of model to load, in ['default', 'vit_l', 'vit_b']",
|
||||
required=True,
|
||||
help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
@@ -6,7 +6,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
from segment_anything import build_sam, build_sam_vit_b, build_sam_vit_l
|
||||
from segment_anything import sam_model_registry
|
||||
from segment_anything.utils.onnx import SamOnnxModel
|
||||
|
||||
import argparse
|
||||
@@ -34,8 +34,8 @@ parser.add_argument(
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
default="default",
|
||||
help="In ['default', 'vit_b', 'vit_l']. Which type of SAM model to export.",
|
||||
required=True,
|
||||
help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -105,12 +105,7 @@ def run_export(
|
||||
return_extra_metrics=False,
|
||||
):
|
||||
print("Loading model...")
|
||||
if model_type == "vit_b":
|
||||
sam = build_sam_vit_b(checkpoint)
|
||||
elif model_type == "vit_l":
|
||||
sam = build_sam_vit_l(checkpoint)
|
||||
else:
|
||||
sam = build_sam(checkpoint)
|
||||
sam = sam_model_registry[model_type](checkpoint=checkpoint)
|
||||
|
||||
onnx_model = SamOnnxModel(
|
||||
model=sam,
|
||||
|
Reference in New Issue
Block a user