update notebooks to specify model_type explicitly

This commit is contained in:
Hanzi Mao 2023-04-06 21:15:43 -07:00
parent c3b8a88a7b
commit 9e1eb9fdbc
3 changed files with 14 additions and 29 deletions

View File

@ -214,19 +214,6 @@
"To run automatic mask generation, provide a SAM model to the `SamAutomaticMaskGenerator` class. Set the path below to the SAM checkpoint. Running on CUDA and with the default model is recommended." "To run automatic mask generation, provide a SAM model to the `SamAutomaticMaskGenerator` class. Set the path below to the SAM checkpoint. Running on CUDA and with the default model is recommended."
] ]
}, },
{
"cell_type": "code",
"execution_count": 9,
"id": "17ade22d",
"metadata": {},
"outputs": [],
"source": [
"sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n",
"\n",
"device = \"cuda\"\n",
"model_type = \"default\""
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 10,
@ -238,6 +225,11 @@
"sys.path.append(\"..\")\n", "sys.path.append(\"..\")\n",
"from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor\n", "from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor\n",
"\n", "\n",
"sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n",
"model_type = \"vit_h\"\n",
"\n",
"device = \"cuda\"\n",
"\n",
"sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n", "sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n",
"sam.to(device=device)\n", "sam.to(device=device)\n",
"\n", "\n",
@ -446,7 +438,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.10" "version": "3.8.0"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -192,7 +192,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"checkpoint = \"sam_vit_h_4b8939.pth\"\n", "checkpoint = \"sam_vit_h_4b8939.pth\"\n",
"model_type = \"default\"" "model_type = \"vit_h\""
] ]
}, },
{ {
@ -766,7 +766,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.10" "version": "3.8.0"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -229,18 +229,6 @@
"First, load the SAM model and predictor. Change the path below to point to the SAM checkpoint. Running on CUDA and using the default model are recommended for best results." "First, load the SAM model and predictor. Change the path below to point to the SAM checkpoint. Running on CUDA and using the default model are recommended for best results."
] ]
}, },
{
"cell_type": "code",
"execution_count": 9,
"id": "17ccff22",
"metadata": {},
"outputs": [],
"source": [
"sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n",
"device = \"cuda\"\n",
"model_type = \"default\""
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 10,
@ -252,6 +240,11 @@
"sys.path.append(\"..\")\n", "sys.path.append(\"..\")\n",
"from segment_anything import sam_model_registry, SamPredictor\n", "from segment_anything import sam_model_registry, SamPredictor\n",
"\n", "\n",
"sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n",
"model_type = \"vit_h\"\n",
"\n",
"device = \"cuda\"\n",
"\n",
"sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n", "sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n",
"sam.to(device=device)\n", "sam.to(device=device)\n",
"\n", "\n",
@ -1015,7 +1008,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.10" "version": "3.8.0"
} }
}, },
"nbformat": 4, "nbformat": 4,