Initial commit
This commit is contained in:
238
scripts/amg.py
Normal file
238
scripts/amg.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import cv2 # type: ignore
|
||||
|
||||
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Runs automatic mask generation on an input image or directory of images, "
|
||||
"and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
|
||||
"as well as pycocotools if saving in RLE format."
|
||||
)
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to either a single input image or folder of images.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"Path to the directory where masks will be output. Output will be either a folder "
|
||||
"of PNGs per image or a single json with COCO-style masks."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
default="default",
|
||||
help="The type of model to load, in ['default', 'vit_l', 'vit_b']",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The path to the SAM checkpoint to use for mask generation.",
|
||||
)
|
||||
|
||||
parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.")
|
||||
|
||||
parser.add_argument(
|
||||
"--convert-to-rle",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Save masks as COCO RLEs in a single json instead of as a folder of PNGs. "
|
||||
"Requires pycocotools."
|
||||
),
|
||||
)
|
||||
|
||||
amg_settings = parser.add_argument_group("AMG Settings")
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--points-per-side",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Generate masks by sampling a grid over the image with this many points to a side.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--points-per-batch",
|
||||
type=int,
|
||||
default=None,
|
||||
help="How many input points to process simultaneously in one batch.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--pred-iou-thresh",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Exclude masks with a predicted score from the model that is lower than this threshold.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--stability-score-thresh",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Exclude masks with a stability score lower than this threshold.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--stability-score-offset",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Larger values perturb the mask more when measuring stability score.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--box-nms-thresh",
|
||||
type=float,
|
||||
default=None,
|
||||
help="The overlap threshold for excluding a duplicate mask.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--crop-n-layers",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"If >0, mask generation is run on smaller crops of the image to generate more masks. "
|
||||
"The value sets how many different scales to crop at."
|
||||
),
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--crop-nms-thresh",
|
||||
type=float,
|
||||
default=None,
|
||||
help="The overlap threshold for excluding duplicate masks across different crops.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--crop-overlap-ratio",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Larger numbers mean image crops will overlap more.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--crop-n-points-downscale-factor",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The number of points-per-side in each layer of crop is reduced by this factor.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--min-mask-region-area",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"Disconnected mask regions or holes with area smaller than this value "
|
||||
"in pixels are removed by postprocessing."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None:
|
||||
header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa
|
||||
metadata = [header]
|
||||
for i, mask_data in enumerate(masks):
|
||||
mask = mask_data["segmentation"]
|
||||
filename = f"{i}.png"
|
||||
cv2.imwrite(os.path.join(path, filename), mask * 255)
|
||||
mask_metadata = [
|
||||
str(i),
|
||||
str(mask_data["area"]),
|
||||
*[str(x) for x in mask_data["bbox"]],
|
||||
*[str(x) for x in mask_data["point_coords"][0]],
|
||||
str(mask_data["predicted_iou"]),
|
||||
str(mask_data["stability_score"]),
|
||||
*[str(x) for x in mask_data["crop_box"]],
|
||||
]
|
||||
row = ",".join(mask_metadata)
|
||||
metadata.append(row)
|
||||
metadata_path = os.path.join(path, "metadata.csv")
|
||||
with open(metadata_path, "w") as f:
|
||||
f.write("\n".join(metadata))
|
||||
|
||||
return
|
||||
|
||||
|
||||
def get_amg_kwargs(args):
|
||||
amg_kwargs = {
|
||||
"points_per_side": args.points_per_side,
|
||||
"points_per_batch": args.points_per_batch,
|
||||
"pred_iou_thresh": args.pred_iou_thresh,
|
||||
"stability_score_thresh": args.stability_score_thresh,
|
||||
"stability_score_offset": args.stability_score_offset,
|
||||
"box_nms_thresh": args.box_nms_thresh,
|
||||
"crop_n_layers": args.crop_n_layers,
|
||||
"crop_nms_thresh": args.crop_nms_thresh,
|
||||
"crop_overlap_ratio": args.crop_overlap_ratio,
|
||||
"crop_n_points_downscale_factor": args.crop_n_points_downscale_factor,
|
||||
"min_mask_region_area": args.min_mask_region_area,
|
||||
}
|
||||
amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
|
||||
return amg_kwargs
|
||||
|
||||
|
||||
def main(args: argparse.Namespace) -> None:
|
||||
print("Loading model...")
|
||||
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
|
||||
_ = sam.to(device=args.device)
|
||||
output_mode = "coco_rle" if args.convert_to_rle else "binary_mask"
|
||||
amg_kwargs = get_amg_kwargs(args)
|
||||
generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs)
|
||||
|
||||
if not os.path.isdir(args.input):
|
||||
targets = [args.input]
|
||||
else:
|
||||
targets = [
|
||||
f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f))
|
||||
]
|
||||
targets = [os.path.join(args.input, f) for f in targets]
|
||||
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
for t in targets:
|
||||
print(f"Processing '{t}'...")
|
||||
image = cv2.imread(t)
|
||||
if image is None:
|
||||
print(f"Could not load '{t}' as an image, skipping...")
|
||||
continue
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
masks = generator.generate(image)
|
||||
|
||||
base = os.path.basename(t)
|
||||
base = os.path.splitext(base)[0]
|
||||
save_base = os.path.join(args.output, base)
|
||||
if output_mode == "binary_mask":
|
||||
os.makedirs(save_base, exist_ok=False)
|
||||
write_masks_to_folder(masks, save_base)
|
||||
else:
|
||||
save_file = save_base + ".json"
|
||||
with open(save_file, "w") as f:
|
||||
json.dump(masks, f)
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
main(args)
|
204
scripts/export_onnx_model.py
Normal file
204
scripts/export_onnx_model.py
Normal file
@@ -0,0 +1,204 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from segment_anything import build_sam, build_sam_vit_b, build_sam_vit_l
|
||||
from segment_anything.utils.onnx import SamOnnxModel
|
||||
|
||||
import argparse
|
||||
import warnings
|
||||
|
||||
try:
|
||||
import onnxruntime # type: ignore
|
||||
|
||||
onnxruntime_exists = True
|
||||
except ImportError:
|
||||
onnxruntime_exists = False
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Export the SAM prompt encoder and mask decoder to an ONNX model."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output", type=str, required=True, help="The filename to save the ONNX model to."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
default="default",
|
||||
help="In ['default', 'vit_b', 'vit_l']. Which type of SAM model to export.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--return-single-mask",
|
||||
action="store_true",
|
||||
help=(
|
||||
"If true, the exported ONNX model will only return the best mask, "
|
||||
"instead of returning multiple masks. For high resolution images "
|
||||
"this can improve runtime when upscaling masks is expensive."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--opset",
|
||||
type=int,
|
||||
default=17,
|
||||
help="The ONNX opset version to use. Must be >=11",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--quantize-out",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"If set, will quantize the model and save it with this name. "
|
||||
"Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--gelu-approximate",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Replace GELU operations with approximations using tanh. Useful "
|
||||
"for some runtimes that have slow or unimplemented erf ops, used in GELU."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-stability-score",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Replaces the model's predicted mask quality score with the stability "
|
||||
"score calculated on the low resolution masks using an offset of 1.0. "
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--return-extra-metrics",
|
||||
action="store_true",
|
||||
help=(
|
||||
"The model will return five results: (masks, scores, stability_scores, "
|
||||
"areas, low_res_logits) instead of the usual three. This can be "
|
||||
"significantly slower for high resolution outputs."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def run_export(
|
||||
model_type: str,
|
||||
checkpoint: str,
|
||||
output: str,
|
||||
opset: int,
|
||||
return_single_mask: bool,
|
||||
gelu_approximate: bool = False,
|
||||
use_stability_score: bool = False,
|
||||
return_extra_metrics=False,
|
||||
):
|
||||
print("Loading model...")
|
||||
if model_type == "vit_b":
|
||||
sam = build_sam_vit_b(checkpoint)
|
||||
elif model_type == "vit_l":
|
||||
sam = build_sam_vit_l(checkpoint)
|
||||
else:
|
||||
sam = build_sam(checkpoint)
|
||||
|
||||
onnx_model = SamOnnxModel(
|
||||
model=sam,
|
||||
return_single_mask=return_single_mask,
|
||||
use_stability_score=use_stability_score,
|
||||
return_extra_metrics=return_extra_metrics,
|
||||
)
|
||||
|
||||
if gelu_approximate:
|
||||
for n, m in onnx_model.named_modules():
|
||||
if isinstance(m, torch.nn.GELU):
|
||||
m.approximate = "tanh"
|
||||
|
||||
dynamic_axes = {
|
||||
"point_coords": {1: "num_points"},
|
||||
"point_labels": {1: "num_points"},
|
||||
}
|
||||
|
||||
embed_dim = sam.prompt_encoder.embed_dim
|
||||
embed_size = sam.prompt_encoder.image_embedding_size
|
||||
mask_input_size = [4 * x for x in embed_size]
|
||||
dummy_inputs = {
|
||||
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
|
||||
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
|
||||
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
|
||||
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
|
||||
"has_mask_input": torch.tensor([1], dtype=torch.float),
|
||||
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
|
||||
}
|
||||
|
||||
_ = onnx_model(**dummy_inputs)
|
||||
|
||||
output_names = ["masks", "iou_predictions", "low_res_masks"]
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
with open(output, "wb") as f:
|
||||
print(f"Exporing onnx model to {output}...")
|
||||
torch.onnx.export(
|
||||
onnx_model,
|
||||
tuple(dummy_inputs.values()),
|
||||
f,
|
||||
export_params=True,
|
||||
verbose=False,
|
||||
opset_version=opset,
|
||||
do_constant_folding=True,
|
||||
input_names=list(dummy_inputs.keys()),
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
)
|
||||
|
||||
if onnxruntime_exists:
|
||||
ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
|
||||
ort_session = onnxruntime.InferenceSession(output)
|
||||
_ = ort_session.run(None, ort_inputs)
|
||||
print("Model has successfully been run with ONNXRuntime.")
|
||||
|
||||
|
||||
def to_numpy(tensor):
|
||||
return tensor.cpu().numpy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
run_export(
|
||||
model_type=args.model_type,
|
||||
checkpoint=args.checkpoint,
|
||||
output=args.output,
|
||||
opset=args.opset,
|
||||
return_single_mask=args.return_single_mask,
|
||||
gelu_approximate=args.gelu_approximate,
|
||||
use_stability_score=args.use_stability_score,
|
||||
return_extra_metrics=args.return_extra_metrics,
|
||||
)
|
||||
|
||||
if args.quantize_out is not None:
|
||||
assert onnxruntime_exists, "onnxruntime is required to quantize the model."
|
||||
from onnxruntime.quantization import QuantType # type: ignore
|
||||
from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore
|
||||
|
||||
print(f"Quantizing model and writing to {args.quantize_out}...")
|
||||
quantize_dynamic(
|
||||
model_input=args.output,
|
||||
model_output=args.quantize_out,
|
||||
optimize_model=True,
|
||||
per_channel=False,
|
||||
reduce_range=False,
|
||||
weight_type=QuantType.QUInt8,
|
||||
)
|
||||
print("Done!")
|
Reference in New Issue
Block a user