diff --git a/README.md b/README.md index 6faa840..410ee9a 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [Alexander Kirillov](https://alexander-kirillov.github.io/), [Eric Mintun](https://ericmintun.github.io/), [Nikhila Ravi](https://nikhilaravi.com/), [Hanzi Mao](https://hanzimao.me/), Chloe Rolland, Laura Gustafson, [Tete Xiao](https://tetexiao.com), [Spencer Whitehead](https://www.spencerwhitehead.com/), Alex Berg, Wan-Yen Lo, [Piotr Dollar](https://pdollar.github.io/), [Ross Girshick](https://www.rossgirshick.info/) -[[`Paper`](https://ai.facebook.com/research/publications/segment-anything/)] [[`Project`](https://segment-anything.com/)] [[`Demo`](https://segment-anything.com/demo)] [[`Dataset`](https://segment-anything.com/dataset/index.html)] [[`Blog`](https://ai.facebook.com/blog/segment-anything-foundation-model-image-segmentation/)] +[[`Paper`](https://ai.facebook.com/research/publications/segment-anything/)] [[`Project`](https://segment-anything.com/)] [[`Demo`](https://segment-anything.com/demo)] [[`Dataset`](https://segment-anything.com/dataset/index.html)] [[`Blog`](https://ai.facebook.com/blog/segment-anything-foundation-model-image-segmentation/)] [[`BibTeX`](#citing-segment-anything)] ![SAM design](assets/model_diagram.png?raw=true) @@ -43,8 +43,9 @@ pip install opencv-python pycocotools matplotlib onnxruntime onnx First download a [model checkpoint](#model-checkpoints). Then the model can be used in just a few lines to get masks from a given prompt: ``` -from segment_anything import build_sam, SamPredictor -predictor = SamPredictor(build_sam(checkpoint="")) +from segment_anything import SamPredictor, sam_model_registry +sam = sam_model_registry[""](checkpoint="") +predictor = SamPredictor(sam) predictor.set_image() masks, _, _ = predictor.predict() ``` @@ -52,15 +53,16 @@ masks, _, _ = predictor.predict() or generate masks for an entire image: ``` -from segment_anything import build_sam, SamAutomaticMaskGenerator -mask_generator = SamAutomaticMaskGenerator(build_sam(checkpoint="")) +from segment_anything import SamAutomaticMaskGenerator, sam_model_registry +sam = sam_model_registry[""](checkpoint="") +mask_generator = SamAutomaticMaskGenerator(sam) masks = mask_generator.generate() ``` Additionally, masks can be generated for images from the command line: ``` -python scripts/amg.py --checkpoint --input --output +python scripts/amg.py --checkpoint --model-type --input --output ``` See the examples notebooks on [using SAM with prompts](/notebooks/predictor_example.ipynb) and [automatically generating masks](/notebooks/automatic_mask_generator_example.ipynb) for more details. @@ -75,7 +77,7 @@ See the examples notebooks on [using SAM with prompts](/notebooks/predictor_exam SAM's lightweight mask decoder can be exported to ONNX format so that it can be run in any environment that supports ONNX runtime, such as in-browser as showcased in the [demo](https://segment-anything.com/demo). Export the model with ``` -python scripts/export_onnx_model.py --checkpoint --output +python scripts/export_onnx_model.py --checkpoint --model-type --output ``` See the [example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb) for details on how to combine image preprocessing via SAM's backbone with mask prediction using the ONNX model. It is recommended to use the latest stable version of PyTorch for ONNX export. @@ -85,14 +87,55 @@ See the [example notebook](https://github.com/facebookresearch/segment-anything/ Three model versions of the model are available with different backbone sizes. These models can be instantiated by running ``` from segment_anything import sam_model_registry -sam = sam_model_registry[""](checkpoint="") +sam = sam_model_registry[""](checkpoint="") ``` -Click the links below to download the checkpoint for the corresponding model name. The default model in bold can also be instantiated with `build_sam`, as in the examples in [Getting Started](#getting-started). +Click the links below to download the checkpoint for the corresponding model type. * **`default` or `vit_h`: [ViT-H SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth)** * `vit_l`: [ViT-L SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth) * `vit_b`: [ViT-B SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) +## Dataset +See [here](https://ai.facebook.com/datasets/segment-anything/) for an overview of the datastet. The dataset can be downloaded [here](https://ai.facebook.com/datasets/segment-anything-downloads/). By downloading the datasets you agree that you have read and accepted the terms of the SA-1B Dataset Research License. + +We save masks per image as a json file. It can be loaded as a dictionary in python in the below format. + + +```python +{ + "image" : image_info, + "annotations" : [annotation], +} + +image_info { + "image_id" : int, # Image id + "width" : int, # Image width + "height" : int, # Image height + "file_name" : str, # Image filename +} + +annotation { + "id" : int, # Annotation id + "segmentation" : dict, # Mask saved in COCO RLE format. + "bbox" : [x, y, w, h], # The box around the mask, in XYWH format + "area" : int, # The area in pixels of the mask + "predicted_iou" : float, # The model's own prediction of the mask's quality + "stability_score" : float, # A measure of the mask's quality + "crop_box" : [x, y, w, h], # The crop of the image used to generate the mask, in XYWH format + "point_coords" : [[x, y]], # The point coordinates input to the model to generate the mask +} +``` + +Image ids can be found in sa_images_ids.txt which can be downloaded using the above [link](https://ai.facebook.com/datasets/segment-anything-downloads/) as well. + +To decode a mask in COCO RLE format into binary: +``` +from pycocotools import mask as mask_utils +mask = mask_utils.decode(annotation["segmentation"]) +``` +See [here](https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/mask.py) for more instructions to manipulate masks stored in RLE format. + + ## License The model is licensed under the [Apache 2.0 license](LICENSE). @@ -105,3 +148,16 @@ See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md The Segment Anything project was made possible with the help of many contributors (alphabetical): Aaron Adcock, Vaibhav Aggarwal, Morteza Behrooz, Cheng-Yang Fu, Ashley Gabriel, Ahuva Goldstand, Allen Goodman, Sumanth Gurram, Jiabo Hu, Somya Jain, Devansh Kukreja, Robert Kuo, Joshua Lane, Yanghao Li, Lilian Luong, Jitendra Malik, Mallika Malhotra, William Ngan, Omkar Parkhi, Nikhil Raina, Dirk Rowe, Neil Sejoor, Vanessa Stark, Bala Varadarajan, Bram Wasti, Zachary Winstrom + +## Citing Segment Anything + +If you use SAM or SA-1B in your research, please use the following BibTeX entry. + +``` +@article{kirillov2023segany, + title={Segment Anything}, + author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross}, + journal={arXiv:2304.02643}, + year={2023} +} +``` diff --git a/notebooks/automatic_mask_generator_example.ipynb b/notebooks/automatic_mask_generator_example.ipynb index 261323d..946d45e 100644 --- a/notebooks/automatic_mask_generator_example.ipynb +++ b/notebooks/automatic_mask_generator_example.ipynb @@ -214,19 +214,6 @@ "To run automatic mask generation, provide a SAM model to the `SamAutomaticMaskGenerator` class. Set the path below to the SAM checkpoint. Running on CUDA and with the default model is recommended." ] }, - { - "cell_type": "code", - "execution_count": 9, - "id": "17ade22d", - "metadata": {}, - "outputs": [], - "source": [ - "sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n", - "\n", - "device = \"cuda\"\n", - "model_type = \"default\"" - ] - }, { "cell_type": "code", "execution_count": 10, @@ -238,6 +225,11 @@ "sys.path.append(\"..\")\n", "from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor\n", "\n", + "sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n", + "model_type = \"vit_h\"\n", + "\n", + "device = \"cuda\"\n", + "\n", "sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n", "sam.to(device=device)\n", "\n", @@ -446,7 +438,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.8.0" } }, "nbformat": 4, diff --git a/notebooks/onnx_model_example.ipynb b/notebooks/onnx_model_example.ipynb index 155dd27..b76b4f4 100644 --- a/notebooks/onnx_model_example.ipynb +++ b/notebooks/onnx_model_example.ipynb @@ -192,7 +192,7 @@ "outputs": [], "source": [ "checkpoint = \"sam_vit_h_4b8939.pth\"\n", - "model_type = \"default\"" + "model_type = \"vit_h\"" ] }, { @@ -766,7 +766,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.8.0" } }, "nbformat": 4, diff --git a/notebooks/predictor_example.ipynb b/notebooks/predictor_example.ipynb index 8374c4d..239d336 100644 --- a/notebooks/predictor_example.ipynb +++ b/notebooks/predictor_example.ipynb @@ -229,18 +229,6 @@ "First, load the SAM model and predictor. Change the path below to point to the SAM checkpoint. Running on CUDA and using the default model are recommended for best results." ] }, - { - "cell_type": "code", - "execution_count": 9, - "id": "17ccff22", - "metadata": {}, - "outputs": [], - "source": [ - "sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n", - "device = \"cuda\"\n", - "model_type = \"default\"" - ] - }, { "cell_type": "code", "execution_count": 10, @@ -252,6 +240,11 @@ "sys.path.append(\"..\")\n", "from segment_anything import sam_model_registry, SamPredictor\n", "\n", + "sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n", + "model_type = \"vit_h\"\n", + "\n", + "device = \"cuda\"\n", + "\n", "sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n", "sam.to(device=device)\n", "\n", @@ -1015,7 +1008,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.8.0" } }, "nbformat": 4, diff --git a/scripts/amg.py b/scripts/amg.py index 3cae6ff..f2dbf67 100644 --- a/scripts/amg.py +++ b/scripts/amg.py @@ -41,8 +41,8 @@ parser.add_argument( parser.add_argument( "--model-type", type=str, - default="default", - help="The type of model to load, in ['default', 'vit_l', 'vit_b']", + required=True, + help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']", ) parser.add_argument( diff --git a/scripts/export_onnx_model.py b/scripts/export_onnx_model.py index 15d51f0..a109722 100644 --- a/scripts/export_onnx_model.py +++ b/scripts/export_onnx_model.py @@ -6,7 +6,7 @@ import torch -from segment_anything import build_sam, build_sam_vit_b, build_sam_vit_l +from segment_anything import sam_model_registry from segment_anything.utils.onnx import SamOnnxModel import argparse @@ -34,8 +34,8 @@ parser.add_argument( parser.add_argument( "--model-type", type=str, - default="default", - help="In ['default', 'vit_b', 'vit_l']. Which type of SAM model to export.", + required=True, + help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.", ) parser.add_argument( @@ -105,12 +105,7 @@ def run_export( return_extra_metrics=False, ): print("Loading model...") - if model_type == "vit_b": - sam = build_sam_vit_b(checkpoint) - elif model_type == "vit_l": - sam = build_sam_vit_l(checkpoint) - else: - sam = build_sam(checkpoint) + sam = sam_model_registry[model_type](checkpoint=checkpoint) onnx_model = SamOnnxModel( model=sam, diff --git a/segment_anything/build_sam.py b/segment_anything/build_sam.py index 07abfca..37cd245 100644 --- a/segment_anything/build_sam.py +++ b/segment_anything/build_sam.py @@ -45,8 +45,8 @@ def build_sam_vit_b(checkpoint=None): sam_model_registry = { - "default": build_sam, - "vit_h": build_sam, + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, "vit_l": build_sam_vit_l, "vit_b": build_sam_vit_b, }