update notebooks to specify model_type explicitly

This commit is contained in:
Hanzi Mao 2023-04-06 21:15:43 -07:00
parent c3b8a88a7b
commit 9e1eb9fdbc
3 changed files with 14 additions and 29 deletions

View File

@ -214,19 +214,6 @@
"To run automatic mask generation, provide a SAM model to the `SamAutomaticMaskGenerator` class. Set the path below to the SAM checkpoint. Running on CUDA and with the default model is recommended."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "17ade22d",
"metadata": {},
"outputs": [],
"source": [
"sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n",
"\n",
"device = \"cuda\"\n",
"model_type = \"default\""
]
},
{
"cell_type": "code",
"execution_count": 10,
@ -238,6 +225,11 @@
"sys.path.append(\"..\")\n",
"from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor\n",
"\n",
"sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n",
"model_type = \"vit_h\"\n",
"\n",
"device = \"cuda\"\n",
"\n",
"sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n",
"sam.to(device=device)\n",
"\n",
@ -446,7 +438,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
"version": "3.8.0"
}
},
"nbformat": 4,

View File

@ -192,7 +192,7 @@
"outputs": [],
"source": [
"checkpoint = \"sam_vit_h_4b8939.pth\"\n",
"model_type = \"default\""
"model_type = \"vit_h\""
]
},
{
@ -766,7 +766,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
"version": "3.8.0"
}
},
"nbformat": 4,

View File

@ -229,18 +229,6 @@
"First, load the SAM model and predictor. Change the path below to point to the SAM checkpoint. Running on CUDA and using the default model are recommended for best results."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "17ccff22",
"metadata": {},
"outputs": [],
"source": [
"sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n",
"device = \"cuda\"\n",
"model_type = \"default\""
]
},
{
"cell_type": "code",
"execution_count": 10,
@ -252,6 +240,11 @@
"sys.path.append(\"..\")\n",
"from segment_anything import sam_model_registry, SamPredictor\n",
"\n",
"sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n",
"model_type = \"vit_h\"\n",
"\n",
"device = \"cuda\"\n",
"\n",
"sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n",
"sam.to(device=device)\n",
"\n",
@ -1015,7 +1008,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
"version": "3.8.0"
}
},
"nbformat": 4,