Add mini web demo
This commit is contained in:
		
							
								
								
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -35,3 +35,8 @@ model_ts*.txt
 | 
				
			|||||||
.idea
 | 
					.idea
 | 
				
			||||||
.vscode
 | 
					.vscode
 | 
				
			||||||
_darcs
 | 
					_darcs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# demo 
 | 
				
			||||||
 | 
					**/node_modules
 | 
				
			||||||
 | 
					yarn.lock
 | 
				
			||||||
 | 
					package-lock.json
 | 
				
			||||||
							
								
								
									
										
											BIN
										
									
								
								assets/minidemo.gif
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								assets/minidemo.gif
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 1.9 MiB  | 
							
								
								
									
										95
									
								
								demo/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										95
									
								
								demo/README.md
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,95 @@
 | 
				
			|||||||
 | 
					## Segment Anything Simple Web demo
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This **front-end only** demo shows how to load a fixed image and `.npy` file of the SAM image embedding, and run the SAM ONNX model in the browser using Web Assembly with mulithreading enabled by `SharedArrayBuffer`, Web Worker, and SIMD128.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<img src="https://github.com/facebookresearch/segment-anything/raw/main/assets/minidemo.gif" width="500"/>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Run the app
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					yarn && yarn start
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Navigate to [`http://localhost:8081/`](http://localhost:8081/)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Move your cursor around to see the mask prediction update in real time.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Change the image, embedding and ONNX model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					In the [ONNX Model Example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb) upload the image of your choice and generate and save corresponding embedding.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Initialize the predictor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					checkpoint = "sam_vit_h_4b8939.pth"
 | 
				
			||||||
 | 
					model_type = "vit_h"
 | 
				
			||||||
 | 
					sam = sam_model_registry[model_type](checkpoint=checkpoint)
 | 
				
			||||||
 | 
					sam.to(device='cuda')
 | 
				
			||||||
 | 
					predictor = SamPredictor(sam)
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Set the new image and export the embedding
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					image = cv2.imread('src/assets/dogs.jpg')
 | 
				
			||||||
 | 
					predictor.set_image(image)
 | 
				
			||||||
 | 
					image_embedding = predictor.get_image_embedding().cpu().numpy()
 | 
				
			||||||
 | 
					np.save("dogs_embedding.npy", image_embedding)
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Save the new image and embedding in `/assets/data`and update the following paths to the files at the top of`App.tsx`:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```py
 | 
				
			||||||
 | 
					const IMAGE_PATH = "/assets/data/dogs.jpg";
 | 
				
			||||||
 | 
					const IMAGE_EMBEDDING = "/assets/data/dogs_embedding.npy";
 | 
				
			||||||
 | 
					const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx";
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Optionally you can also export the ONNX model. Currently the example ONNX model from the notebook is saved at `/model/sam_onnx_quantized_example.onnx`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**NOTE: if you change the ONNX model by using a new checkpoint you need to also re-export the embedding.**
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## ONNX multithreading with SharedArrayBuffer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					To use multithreading, the appropriate headers need to be set to create a cross origin isolation state which will enable use of `SharedArrayBuffer` (see this [blog post](https://cloudblogs.microsoft.com/opensource/2021/09/02/onnx-runtime-web-running-your-machine-learning-model-in-browser/) for more details)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The headers below are set in `configs/webpack/dev.js`:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```js
 | 
				
			||||||
 | 
					headers: {
 | 
				
			||||||
 | 
					    "Cross-Origin-Opener-Policy": "same-origin",
 | 
				
			||||||
 | 
					    "Cross-Origin-Embedder-Policy": "credentialless",
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Structure of the app
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**`App.tsx`**
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- Initializes ONNX model
 | 
				
			||||||
 | 
					- Loads image embedding and image
 | 
				
			||||||
 | 
					- Runs the ONNX model based on input prompts
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**`Stage.tsx`**
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- Handles mouse move interaction to update the ONNX model prompt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**`Tool.tsx`**
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- Renders the image and the mask prediction
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**`helpers/maskUtils.tsx`**
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- Conversion of ONNX model output from array to an HTMLImageElement
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**`helpers/onnxModelAPI.tsx`**
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- Formats the inputs for the ONNX model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**`helpers/scaleHelper.tsx`**
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- Handles image scaling logic for SAM (longest size 1024)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**`hooks/`**
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- Handle shared state for the app
 | 
				
			||||||
							
								
								
									
										78
									
								
								demo/configs/webpack/common.js
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								demo/configs/webpack/common.js
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,78 @@
 | 
				
			|||||||
 | 
					const { resolve } = require("path");
 | 
				
			||||||
 | 
					const HtmlWebpackPlugin = require("html-webpack-plugin");
 | 
				
			||||||
 | 
					const FriendlyErrorsWebpackPlugin = require("friendly-errors-webpack-plugin");
 | 
				
			||||||
 | 
					const CopyPlugin = require("copy-webpack-plugin");
 | 
				
			||||||
 | 
					const webpack = require("webpack");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					module.exports = {
 | 
				
			||||||
 | 
					  entry: "./src/index.tsx",
 | 
				
			||||||
 | 
					  resolve: {
 | 
				
			||||||
 | 
					    extensions: [".js", ".jsx", ".ts", ".tsx"],
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  output: {
 | 
				
			||||||
 | 
					    path: resolve(__dirname, "dist"),
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  module: {
 | 
				
			||||||
 | 
					    rules: [
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					        test: /\.mjs$/,
 | 
				
			||||||
 | 
					        include: /node_modules/,
 | 
				
			||||||
 | 
					        type: "javascript/auto",
 | 
				
			||||||
 | 
					        resolve: {
 | 
				
			||||||
 | 
					          fullySpecified: false,
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					        test: [/\.jsx?$/, /\.tsx?$/],
 | 
				
			||||||
 | 
					        use: ["ts-loader"],
 | 
				
			||||||
 | 
					        exclude: /node_modules/,
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					        test: /\.css$/,
 | 
				
			||||||
 | 
					        use: ["style-loader", "css-loader"],
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					        test: /\.(scss|sass)$/,
 | 
				
			||||||
 | 
					        use: ["style-loader", "css-loader", "postcss-loader"],
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					        test: /\.(jpe?g|png|gif|svg)$/i,
 | 
				
			||||||
 | 
					        use: [
 | 
				
			||||||
 | 
					          "file-loader?hash=sha512&digest=hex&name=img/[contenthash].[ext]",
 | 
				
			||||||
 | 
					          "image-webpack-loader?bypassOnDebug&optipng.optimizationLevel=7&gifsicle.interlaced=false",
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					        test: /\.(woff|woff2|ttf)$/,
 | 
				
			||||||
 | 
					        use: {
 | 
				
			||||||
 | 
					          loader: "url-loader",
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  plugins: [
 | 
				
			||||||
 | 
					    new CopyPlugin({
 | 
				
			||||||
 | 
					      patterns: [
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					          from: "node_modules/onnxruntime-web/dist/*.wasm",
 | 
				
			||||||
 | 
					          to: "[name][ext]",
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					          from: "model",
 | 
				
			||||||
 | 
					          to: "model",
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					          from: "src/assets",
 | 
				
			||||||
 | 
					          to: "assets",
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					      ],
 | 
				
			||||||
 | 
					    }),
 | 
				
			||||||
 | 
					    new HtmlWebpackPlugin({
 | 
				
			||||||
 | 
					      template: "./src/assets/index.html",
 | 
				
			||||||
 | 
					    }),
 | 
				
			||||||
 | 
					    new FriendlyErrorsWebpackPlugin(),
 | 
				
			||||||
 | 
					    new webpack.ProvidePlugin({
 | 
				
			||||||
 | 
					      process: "process/browser",
 | 
				
			||||||
 | 
					    }),
 | 
				
			||||||
 | 
					  ],
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
							
								
								
									
										19
									
								
								demo/configs/webpack/dev.js
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								demo/configs/webpack/dev.js
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,19 @@
 | 
				
			|||||||
 | 
					// development config
 | 
				
			||||||
 | 
					const { merge } = require("webpack-merge");
 | 
				
			||||||
 | 
					const commonConfig = require("./common");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					module.exports = merge(commonConfig, {
 | 
				
			||||||
 | 
					  mode: "development",
 | 
				
			||||||
 | 
					  devServer: {
 | 
				
			||||||
 | 
					    hot: true, // enable HMR on the server
 | 
				
			||||||
 | 
					    open: true,
 | 
				
			||||||
 | 
					    // These headers enable the cross origin isolation state
 | 
				
			||||||
 | 
					    // needed to enable use of SharedArrayBuffer for ONNX 
 | 
				
			||||||
 | 
					    // multithreading. 
 | 
				
			||||||
 | 
					    headers: {
 | 
				
			||||||
 | 
					      "Cross-Origin-Opener-Policy": "same-origin",
 | 
				
			||||||
 | 
					      "Cross-Origin-Embedder-Policy": "credentialless",
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  devtool: "cheap-module-source-map",
 | 
				
			||||||
 | 
					});
 | 
				
			||||||
							
								
								
									
										16
									
								
								demo/configs/webpack/prod.js
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								demo/configs/webpack/prod.js
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,16 @@
 | 
				
			|||||||
 | 
					// production config
 | 
				
			||||||
 | 
					const { merge } = require("webpack-merge");
 | 
				
			||||||
 | 
					const { resolve } = require("path");
 | 
				
			||||||
 | 
					const Dotenv = require("dotenv-webpack");
 | 
				
			||||||
 | 
					const commonConfig = require("./common");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					module.exports = merge(commonConfig, {
 | 
				
			||||||
 | 
					  mode: "production",
 | 
				
			||||||
 | 
					  output: {
 | 
				
			||||||
 | 
					    filename: "js/bundle.[contenthash].min.js",
 | 
				
			||||||
 | 
					    path: resolve(__dirname, "../../dist"),
 | 
				
			||||||
 | 
					    publicPath: "/",
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  devtool: "source-map",
 | 
				
			||||||
 | 
					  plugins: [new Dotenv()],
 | 
				
			||||||
 | 
					});
 | 
				
			||||||
							
								
								
									
										64
									
								
								demo/package.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								demo/package.json
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,64 @@
 | 
				
			|||||||
 | 
					{
 | 
				
			||||||
 | 
					  "name": "se-demo",
 | 
				
			||||||
 | 
					  "version": "0.1.0",
 | 
				
			||||||
 | 
					  "license": "MIT",
 | 
				
			||||||
 | 
					  "scripts": {
 | 
				
			||||||
 | 
					    "build": "yarn run clean-dist && webpack --config=configs/webpack/prod.js && mv dist/*.wasm dist/js && cp -R dataset dist",
 | 
				
			||||||
 | 
					    "clean-dist": "rimraf dist/*",
 | 
				
			||||||
 | 
					    "lint": "eslint './src/**/*.{js,ts,tsx}' --quiet",
 | 
				
			||||||
 | 
					    "start": "yarn run start-dev",
 | 
				
			||||||
 | 
					    "test": "yarn run start-model-test",
 | 
				
			||||||
 | 
					    "start-dev": "webpack serve --config=configs/webpack/dev.js"
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  "devDependencies": {
 | 
				
			||||||
 | 
					    "@babel/core": "^7.18.13",
 | 
				
			||||||
 | 
					    "@babel/preset-env": "^7.18.10",
 | 
				
			||||||
 | 
					    "@babel/preset-react": "^7.18.6",
 | 
				
			||||||
 | 
					    "@babel/preset-typescript": "^7.18.6",
 | 
				
			||||||
 | 
					    "@pmmmwh/react-refresh-webpack-plugin": "^0.5.7",
 | 
				
			||||||
 | 
					    "@testing-library/react": "^13.3.0",
 | 
				
			||||||
 | 
					    "@types/node": "^18.7.13",
 | 
				
			||||||
 | 
					    "@types/react": "^18.0.17",
 | 
				
			||||||
 | 
					    "@types/react-dom": "^18.0.6",
 | 
				
			||||||
 | 
					    "@types/underscore": "^1.11.4",
 | 
				
			||||||
 | 
					    "@typescript-eslint/eslint-plugin": "^5.35.1",
 | 
				
			||||||
 | 
					    "@typescript-eslint/parser": "^5.35.1",
 | 
				
			||||||
 | 
					    "babel-loader": "^8.2.5",
 | 
				
			||||||
 | 
					    "copy-webpack-plugin": "^11.0.0",
 | 
				
			||||||
 | 
					    "css-loader": "^6.7.1",
 | 
				
			||||||
 | 
					    "dotenv": "^16.0.2",
 | 
				
			||||||
 | 
					    "dotenv-webpack": "^8.0.1",
 | 
				
			||||||
 | 
					    "eslint": "^8.22.0",
 | 
				
			||||||
 | 
					    "eslint-plugin-react": "^7.31.0",
 | 
				
			||||||
 | 
					    "file-loader": "^6.2.0",
 | 
				
			||||||
 | 
					    "fork-ts-checker-webpack-plugin": "^7.2.13",
 | 
				
			||||||
 | 
					    "friendly-errors-webpack-plugin": "^1.7.0",
 | 
				
			||||||
 | 
					    "html-webpack-plugin": "^5.5.0",
 | 
				
			||||||
 | 
					    "image-webpack-loader": "^8.1.0",
 | 
				
			||||||
 | 
					    "postcss-loader": "^7.0.1",
 | 
				
			||||||
 | 
					    "postcss-preset-env": "^7.8.0",
 | 
				
			||||||
 | 
					    "process": "^0.11.10",
 | 
				
			||||||
 | 
					    "rimraf": "^3.0.2",
 | 
				
			||||||
 | 
					    "sass": "^1.54.5",
 | 
				
			||||||
 | 
					    "sass-loader": "^13.0.2",
 | 
				
			||||||
 | 
					    "style-loader": "^3.3.1",
 | 
				
			||||||
 | 
					    "tailwindcss": "^3.1.8",
 | 
				
			||||||
 | 
					    "ts-loader": "^9.3.1",
 | 
				
			||||||
 | 
					    "typescript": "^4.8.2",
 | 
				
			||||||
 | 
					    "webpack": "^5.74.0",
 | 
				
			||||||
 | 
					    "webpack-cli": "^4.10.0",
 | 
				
			||||||
 | 
					    "webpack-dev-server": "^4.10.0",
 | 
				
			||||||
 | 
					    "webpack-dotenv-plugin": "^2.1.0",
 | 
				
			||||||
 | 
					    "webpack-merge": "^5.8.0"
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  "dependencies": {
 | 
				
			||||||
 | 
					    "konva": "^8.3.12",
 | 
				
			||||||
 | 
					    "npyjs": "^0.4.0",
 | 
				
			||||||
 | 
					    "onnxruntime-web": "^1.14.0",
 | 
				
			||||||
 | 
					    "react": "^18.2.0",
 | 
				
			||||||
 | 
					    "react-dom": "^18.2.0",
 | 
				
			||||||
 | 
					    "react-konva": "^18.2.1",
 | 
				
			||||||
 | 
					    "underscore": "^1.13.6",
 | 
				
			||||||
 | 
					    "react-refresh": "^0.14.0"
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										4
									
								
								demo/postcss.config.js
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								demo/postcss.config.js
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,4 @@
 | 
				
			|||||||
 | 
					const tailwindcss = require("tailwindcss");
 | 
				
			||||||
 | 
					module.exports = {
 | 
				
			||||||
 | 
					  plugins: ["postcss-preset-env", 'tailwindcss/nesting', tailwindcss],
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
							
								
								
									
										124
									
								
								demo/src/App.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										124
									
								
								demo/src/App.tsx
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,124 @@
 | 
				
			|||||||
 | 
					import { InferenceSession, Tensor } from "onnxruntime-web";
 | 
				
			||||||
 | 
					import React, { useContext, useEffect, useState } from "react";
 | 
				
			||||||
 | 
					import "./assets/scss/App.scss";
 | 
				
			||||||
 | 
					import { handleImageScale } from "./components/helpers/scaleHelper";
 | 
				
			||||||
 | 
					import { modelScaleProps } from "./components/helpers/Interfaces";
 | 
				
			||||||
 | 
					import { onnxMaskToImage } from "./components/helpers/maskUtils";
 | 
				
			||||||
 | 
					import { modelData } from "./components/helpers/onnxModelAPI";
 | 
				
			||||||
 | 
					import Stage from "./components/Stage";
 | 
				
			||||||
 | 
					import AppContext from "./components/hooks/createContext";
 | 
				
			||||||
 | 
					const ort = require("onnxruntime-web");
 | 
				
			||||||
 | 
					/* @ts-ignore */
 | 
				
			||||||
 | 
					import npyjs from "npyjs";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Define image, embedding and model paths
 | 
				
			||||||
 | 
					const IMAGE_PATH = "/assets/data/dogs.jpg";
 | 
				
			||||||
 | 
					const IMAGE_EMBEDDING = "/assets/data/dogs_embedding.npy";
 | 
				
			||||||
 | 
					const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const App = () => {
 | 
				
			||||||
 | 
					  const {
 | 
				
			||||||
 | 
					    clicks: [clicks],
 | 
				
			||||||
 | 
					    image: [, setImage],
 | 
				
			||||||
 | 
					    maskImg: [, setMaskImg],
 | 
				
			||||||
 | 
					  } = useContext(AppContext)!;
 | 
				
			||||||
 | 
					  const [model, setModel] = useState<InferenceSession | null>(null); // ONNX model
 | 
				
			||||||
 | 
					  const [tensor, setTensor] = useState<Tensor | null>(null); // Image embedding tensor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // The ONNX model expects the input to be rescaled to 1024. 
 | 
				
			||||||
 | 
					  // The modelScale state variable keeps track of the scale values.
 | 
				
			||||||
 | 
					  const [modelScale, setModelScale] = useState<modelScaleProps | null>(null);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Initialize the ONNX model. load the image, and load the SAM
 | 
				
			||||||
 | 
					  // pre-computed image embedding
 | 
				
			||||||
 | 
					  useEffect(() => {
 | 
				
			||||||
 | 
					    // Initialize the ONNX model
 | 
				
			||||||
 | 
					    const initModel = async () => {
 | 
				
			||||||
 | 
					      try {
 | 
				
			||||||
 | 
					        if (MODEL_DIR === undefined) return;
 | 
				
			||||||
 | 
					        const URL: string = MODEL_DIR;
 | 
				
			||||||
 | 
					        const model = await InferenceSession.create(URL);
 | 
				
			||||||
 | 
					        setModel(model);
 | 
				
			||||||
 | 
					      } catch (e) {
 | 
				
			||||||
 | 
					        console.log(e);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					    initModel();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Load the image
 | 
				
			||||||
 | 
					    const url = new URL(IMAGE_PATH, location.origin);
 | 
				
			||||||
 | 
					    loadImage(url);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Load the Segment Anything pre-computed embedding
 | 
				
			||||||
 | 
					    Promise.resolve(loadNpyTensor(IMAGE_EMBEDDING, "float32")).then(
 | 
				
			||||||
 | 
					      (embedding) => setTensor(embedding)
 | 
				
			||||||
 | 
					    );
 | 
				
			||||||
 | 
					  }, []);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const loadImage = async (url: URL) => {
 | 
				
			||||||
 | 
					    try {
 | 
				
			||||||
 | 
					      const img = new Image();
 | 
				
			||||||
 | 
					      img.src = url.href;
 | 
				
			||||||
 | 
					      img.onload = () => {
 | 
				
			||||||
 | 
					        const { height, width, samScale } = handleImageScale(img);
 | 
				
			||||||
 | 
					        setModelScale({
 | 
				
			||||||
 | 
					          height: height,  // original image height
 | 
				
			||||||
 | 
					          width: width,  // original image width
 | 
				
			||||||
 | 
					          samScale: samScale, // scaling factor for image which has been resized to longest side 1024
 | 
				
			||||||
 | 
					        });
 | 
				
			||||||
 | 
					        img.width = width; 
 | 
				
			||||||
 | 
					        img.height = height; 
 | 
				
			||||||
 | 
					        setImage(img);
 | 
				
			||||||
 | 
					      };
 | 
				
			||||||
 | 
					    } catch (error) {
 | 
				
			||||||
 | 
					      console.log(error);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Decode a Numpy file into a tensor. 
 | 
				
			||||||
 | 
					  const loadNpyTensor = async (tensorFile: string, dType: string) => {
 | 
				
			||||||
 | 
					    let npLoader = new npyjs();
 | 
				
			||||||
 | 
					    const npArray = await npLoader.load(tensorFile);
 | 
				
			||||||
 | 
					    const tensor = new ort.Tensor(dType, npArray.data, npArray.shape);
 | 
				
			||||||
 | 
					    return tensor;
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Run the ONNX model every time clicks has changed
 | 
				
			||||||
 | 
					  useEffect(() => {
 | 
				
			||||||
 | 
					    runONNX();
 | 
				
			||||||
 | 
					  }, [clicks]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const runONNX = async () => {
 | 
				
			||||||
 | 
					    try {
 | 
				
			||||||
 | 
					      if (
 | 
				
			||||||
 | 
					        model === null ||
 | 
				
			||||||
 | 
					        clicks === null ||
 | 
				
			||||||
 | 
					        tensor === null ||
 | 
				
			||||||
 | 
					        modelScale === null
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					        return;
 | 
				
			||||||
 | 
					      else {
 | 
				
			||||||
 | 
					        // Preapre the model input in the correct format for SAM. 
 | 
				
			||||||
 | 
					        // The modelData function is from onnxModelAPI.tsx.
 | 
				
			||||||
 | 
					        const feeds = modelData({
 | 
				
			||||||
 | 
					          clicks,
 | 
				
			||||||
 | 
					          tensor,
 | 
				
			||||||
 | 
					          modelScale,
 | 
				
			||||||
 | 
					        });
 | 
				
			||||||
 | 
					        if (feeds === undefined) return;
 | 
				
			||||||
 | 
					        // Run the SAM ONNX model with the feeds returned from modelData()
 | 
				
			||||||
 | 
					        const results = await model.run(feeds);
 | 
				
			||||||
 | 
					        const output = results[model.outputNames[0]];
 | 
				
			||||||
 | 
					        // The predicted mask returned from the ONNX model is an array which is 
 | 
				
			||||||
 | 
					        // rendered as an HTML image using onnxMaskToImage() from maskUtils.tsx.
 | 
				
			||||||
 | 
					        setMaskImg(onnxMaskToImage(output.data, output.dims[2], output.dims[3]));
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    } catch (e) {
 | 
				
			||||||
 | 
					      console.log(e);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return <Stage />;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export default App;
 | 
				
			||||||
							
								
								
									
										
											BIN
										
									
								
								demo/src/assets/data/dogs.jpg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								demo/src/assets/data/dogs.jpg
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 438 KiB  | 
							
								
								
									
										18
									
								
								demo/src/assets/index.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								demo/src/assets/index.html
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,18 @@
 | 
				
			|||||||
 | 
					<!DOCTYPE html>
 | 
				
			||||||
 | 
					<html lang="en" dir="ltr" prefix="og: https://ogp.me/ns#" class="w-full h-full">
 | 
				
			||||||
 | 
					  <head>
 | 
				
			||||||
 | 
					    <meta charset="utf-8" />
 | 
				
			||||||
 | 
					    <meta
 | 
				
			||||||
 | 
					      name="viewport"
 | 
				
			||||||
 | 
					      content="width=device-width, initial-scale=1, shrink-to-fit=no"
 | 
				
			||||||
 | 
					    />
 | 
				
			||||||
 | 
					    <title>Segment Anything Demo</title>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    <!--  Meta Tags -->
 | 
				
			||||||
 | 
					    <meta property="og:type" content="website" />
 | 
				
			||||||
 | 
					    <meta property="og:title" content="Segment Anything Demo" />
 | 
				
			||||||
 | 
					  </head>
 | 
				
			||||||
 | 
					  <body class="w-full h-full">
 | 
				
			||||||
 | 
					    <div id="root" class="w-full h-full"></div>
 | 
				
			||||||
 | 
					  </body>
 | 
				
			||||||
 | 
					</html>
 | 
				
			||||||
							
								
								
									
										3
									
								
								demo/src/assets/scss/App.scss
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								demo/src/assets/scss/App.scss
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
				
			|||||||
 | 
					@tailwind base;
 | 
				
			||||||
 | 
					@tailwind components;
 | 
				
			||||||
 | 
					@tailwind utilities;
 | 
				
			||||||
							
								
								
									
										43
									
								
								demo/src/components/Stage.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								demo/src/components/Stage.tsx
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,43 @@
 | 
				
			|||||||
 | 
					import React, { useContext } from "react";
 | 
				
			||||||
 | 
					import * as _ from "underscore";
 | 
				
			||||||
 | 
					import Tool from "./Tool";
 | 
				
			||||||
 | 
					import { modelInputProps } from "./helpers/Interfaces";
 | 
				
			||||||
 | 
					import AppContext from "./hooks/createContext";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const Stage = () => {
 | 
				
			||||||
 | 
					  const {
 | 
				
			||||||
 | 
					    clicks: [, setClicks],
 | 
				
			||||||
 | 
					    image: [image],
 | 
				
			||||||
 | 
					  } = useContext(AppContext)!;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const getClick = (x: number, y: number): modelInputProps => {
 | 
				
			||||||
 | 
					    const clickType = 1;
 | 
				
			||||||
 | 
					    return { x, y, clickType };
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Get mouse position and scale the (x, y) coordinates back to the natural
 | 
				
			||||||
 | 
					  // scale of the image. Update the state of clicks with setClicks to trigger
 | 
				
			||||||
 | 
					  // the ONNX model to run and generate a new mask via a useEffect in App.tsx
 | 
				
			||||||
 | 
					  const handleMouseMove = _.throttle((e: any) => {
 | 
				
			||||||
 | 
					    let el = e.nativeEvent.target;
 | 
				
			||||||
 | 
					    const rect = el.getBoundingClientRect();
 | 
				
			||||||
 | 
					    let x = e.clientX - rect.left;
 | 
				
			||||||
 | 
					    let y = e.clientY - rect.top;
 | 
				
			||||||
 | 
					    const imageScale = image ? image.width / el.offsetWidth : 1;
 | 
				
			||||||
 | 
					    x *= imageScale;
 | 
				
			||||||
 | 
					    y *= imageScale;
 | 
				
			||||||
 | 
					    const click = getClick(x, y);
 | 
				
			||||||
 | 
					    if (click) setClicks([click]);
 | 
				
			||||||
 | 
					  }, 15);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const flexCenterClasses = "flex items-center justify-center";
 | 
				
			||||||
 | 
					  return (
 | 
				
			||||||
 | 
					    <div className={`${flexCenterClasses} w-full h-full`}>
 | 
				
			||||||
 | 
					      <div className={`${flexCenterClasses} relative w-[90%] h-[90%]`}>
 | 
				
			||||||
 | 
					        <Tool handleMouseMove={handleMouseMove} />
 | 
				
			||||||
 | 
					      </div>
 | 
				
			||||||
 | 
					    </div>
 | 
				
			||||||
 | 
					  );
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export default Stage;
 | 
				
			||||||
							
								
								
									
										67
									
								
								demo/src/components/Tool.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								demo/src/components/Tool.tsx
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,67 @@
 | 
				
			|||||||
 | 
					import React, { useContext, useEffect, useState } from "react";
 | 
				
			||||||
 | 
					import AppContext from "./hooks/createContext";
 | 
				
			||||||
 | 
					import { ToolProps } from "./helpers/Interfaces";
 | 
				
			||||||
 | 
					import * as _ from "underscore";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const Tool = ({ handleMouseMove }: ToolProps) => {
 | 
				
			||||||
 | 
					  const {
 | 
				
			||||||
 | 
					    image: [image],
 | 
				
			||||||
 | 
					    maskImg: [maskImg, setMaskImg],
 | 
				
			||||||
 | 
					  } = useContext(AppContext)!;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Determine if we should shrink or grow the images to match the
 | 
				
			||||||
 | 
					  // width or the height of the page and setup a ResizeObserver to
 | 
				
			||||||
 | 
					  // monitor changes in the size of the page
 | 
				
			||||||
 | 
					  const [shouldFitToWidth, setShouldFitToWidth] = useState(true);
 | 
				
			||||||
 | 
					  const bodyEl = document.body;
 | 
				
			||||||
 | 
					  const fitToPage = () => {
 | 
				
			||||||
 | 
					    if (!image) return;
 | 
				
			||||||
 | 
					    const imageAspectRatio = image.width / image.height;
 | 
				
			||||||
 | 
					    const screenAspectRatio = window.innerWidth / window.innerHeight;
 | 
				
			||||||
 | 
					    setShouldFitToWidth(imageAspectRatio > screenAspectRatio);
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					  const resizeObserver = new ResizeObserver((entries) => {
 | 
				
			||||||
 | 
					    for (const entry of entries) {
 | 
				
			||||||
 | 
					      if (entry.target === bodyEl) {
 | 
				
			||||||
 | 
					        fitToPage();
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  });
 | 
				
			||||||
 | 
					  useEffect(() => {
 | 
				
			||||||
 | 
					    fitToPage();
 | 
				
			||||||
 | 
					    resizeObserver.observe(bodyEl);
 | 
				
			||||||
 | 
					    return () => {
 | 
				
			||||||
 | 
					      resizeObserver.unobserve(bodyEl);
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					  }, [image]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const imageClasses = "";
 | 
				
			||||||
 | 
					  const maskImageClasses = `absolute opacity-40 pointer-events-none`;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Render the image and the predicted mask image on top
 | 
				
			||||||
 | 
					  return (
 | 
				
			||||||
 | 
					    <>
 | 
				
			||||||
 | 
					      {image && (
 | 
				
			||||||
 | 
					        <img
 | 
				
			||||||
 | 
					          onMouseMove={handleMouseMove}
 | 
				
			||||||
 | 
					          onMouseOut={() => _.defer(() => setMaskImg(null))}
 | 
				
			||||||
 | 
					          onTouchStart={handleMouseMove}
 | 
				
			||||||
 | 
					          src={image.src}
 | 
				
			||||||
 | 
					          className={`${
 | 
				
			||||||
 | 
					            shouldFitToWidth ? "w-full" : "h-full"
 | 
				
			||||||
 | 
					          } ${imageClasses}`}
 | 
				
			||||||
 | 
					        ></img>
 | 
				
			||||||
 | 
					      )}
 | 
				
			||||||
 | 
					      {maskImg && (
 | 
				
			||||||
 | 
					        <img
 | 
				
			||||||
 | 
					          src={maskImg.src}
 | 
				
			||||||
 | 
					          className={`${
 | 
				
			||||||
 | 
					            shouldFitToWidth ? "w-full" : "h-full"
 | 
				
			||||||
 | 
					          } ${maskImageClasses}`}
 | 
				
			||||||
 | 
					        ></img>
 | 
				
			||||||
 | 
					      )}
 | 
				
			||||||
 | 
					    </>
 | 
				
			||||||
 | 
					  );
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export default Tool;
 | 
				
			||||||
							
								
								
									
										23
									
								
								demo/src/components/helpers/Interfaces.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								demo/src/components/helpers/Interfaces.tsx
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,23 @@
 | 
				
			|||||||
 | 
					import { Tensor } from "onnxruntime-web";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export interface modelScaleProps {
 | 
				
			||||||
 | 
					  samScale: number;
 | 
				
			||||||
 | 
					  height: number;
 | 
				
			||||||
 | 
					  width: number;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export interface modelInputProps {
 | 
				
			||||||
 | 
					  x: number;
 | 
				
			||||||
 | 
					  y: number;
 | 
				
			||||||
 | 
					  clickType: number;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export interface modeDataProps {
 | 
				
			||||||
 | 
					  clicks?: Array<modelInputProps>;
 | 
				
			||||||
 | 
					  tensor: Tensor;
 | 
				
			||||||
 | 
					  modelScale: modelScaleProps;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export interface ToolProps {
 | 
				
			||||||
 | 
					  handleMouseMove: (e: any) => void;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										43
									
								
								demo/src/components/helpers/maskUtils.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								demo/src/components/helpers/maskUtils.tsx
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,43 @@
 | 
				
			|||||||
 | 
					// Functions for handling mask output from the ONNX model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Convert the onnx model mask prediction to ImageData
 | 
				
			||||||
 | 
					function arrayToImageData(input: any, width: number, height: number) {
 | 
				
			||||||
 | 
					  const [r, g, b, a] = [0, 114, 189, 255]; // the masks's blue color
 | 
				
			||||||
 | 
					  const arr = new Uint8ClampedArray(4 * width * height).fill(0);
 | 
				
			||||||
 | 
					  for (let i = 0; i < input.length; i++) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Threshold the onnx model mask prediction at 0.0
 | 
				
			||||||
 | 
					    // This is equivalent to thresholding the mask using predictor.model.mask_threshold
 | 
				
			||||||
 | 
					    // in python
 | 
				
			||||||
 | 
					    if (input[i] > 0.0) {
 | 
				
			||||||
 | 
					      arr[4 * i + 0] = r;
 | 
				
			||||||
 | 
					      arr[4 * i + 1] = g;
 | 
				
			||||||
 | 
					      arr[4 * i + 2] = b;
 | 
				
			||||||
 | 
					      arr[4 * i + 3] = a;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return new ImageData(arr, height, width);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Use a Canvas element to produce an image from ImageData
 | 
				
			||||||
 | 
					function imageDataToImage(imageData: ImageData) {
 | 
				
			||||||
 | 
					  const canvas = imageDataToCanvas(imageData);
 | 
				
			||||||
 | 
					  const image = new Image();
 | 
				
			||||||
 | 
					  image.src = canvas.toDataURL();
 | 
				
			||||||
 | 
					  return image;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Canvas elements can be created from ImageData
 | 
				
			||||||
 | 
					function imageDataToCanvas(imageData: ImageData) {
 | 
				
			||||||
 | 
					  const canvas = document.createElement("canvas");
 | 
				
			||||||
 | 
					  const ctx = canvas.getContext("2d");
 | 
				
			||||||
 | 
					  canvas.width = imageData.width;
 | 
				
			||||||
 | 
					  canvas.height = imageData.height;
 | 
				
			||||||
 | 
					  ctx?.putImageData(imageData, 0, 0);
 | 
				
			||||||
 | 
					  return canvas;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Convert the onnx model mask output to an HTMLImageElement
 | 
				
			||||||
 | 
					export function onnxMaskToImage(input: any, width: number, height: number) {
 | 
				
			||||||
 | 
					  return imageDataToImage(arrayToImageData(input, width, height));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										65
									
								
								demo/src/components/helpers/onnxModelAPI.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								demo/src/components/helpers/onnxModelAPI.tsx
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,65 @@
 | 
				
			|||||||
 | 
					import { Tensor } from "onnxruntime-web";
 | 
				
			||||||
 | 
					import { modeDataProps } from "./Interfaces";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const modelData = ({ clicks, tensor, modelScale }: modeDataProps) => {
 | 
				
			||||||
 | 
					  const imageEmbedding = tensor;
 | 
				
			||||||
 | 
					  let pointCoords;
 | 
				
			||||||
 | 
					  let pointLabels;
 | 
				
			||||||
 | 
					  let pointCoordsTensor;
 | 
				
			||||||
 | 
					  let pointLabelsTensor;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Check there are input click prompts
 | 
				
			||||||
 | 
					  if (clicks) {
 | 
				
			||||||
 | 
					    let n = clicks.length;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // If there is no box input, a single padding point with 
 | 
				
			||||||
 | 
					    // label -1 and coordinates (0.0, 0.0) should be concatenated
 | 
				
			||||||
 | 
					    // so initialize the array to support (n + 1) points.
 | 
				
			||||||
 | 
					    pointCoords = new Float32Array(2 * (n + 1));
 | 
				
			||||||
 | 
					    pointLabels = new Float32Array(n + 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Add clicks and scale to what SAM expects
 | 
				
			||||||
 | 
					    for (let i = 0; i < n; i++) {
 | 
				
			||||||
 | 
					      pointCoords[2 * i] = clicks[i].x * modelScale.samScale;
 | 
				
			||||||
 | 
					      pointCoords[2 * i + 1] = clicks[i].y * modelScale.samScale;
 | 
				
			||||||
 | 
					      pointLabels[i] = clicks[i].clickType;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Add in the extra point/label when only clicks and no box
 | 
				
			||||||
 | 
					    // The extra point is at (0, 0) with label -1
 | 
				
			||||||
 | 
					    pointCoords[2 * n] = 0.0;
 | 
				
			||||||
 | 
					    pointCoords[2 * n + 1] = 0.0;
 | 
				
			||||||
 | 
					    pointLabels[n] = -1.0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Create the tensor
 | 
				
			||||||
 | 
					    pointCoordsTensor = new Tensor("float32", pointCoords, [1, n + 1, 2]);
 | 
				
			||||||
 | 
					    pointLabelsTensor = new Tensor("float32", pointLabels, [1, n + 1]);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  const imageSizeTensor = new Tensor("float32", [
 | 
				
			||||||
 | 
					    modelScale.height,
 | 
				
			||||||
 | 
					    modelScale.width,
 | 
				
			||||||
 | 
					  ]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (pointCoordsTensor === undefined || pointLabelsTensor === undefined)
 | 
				
			||||||
 | 
					    return;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // There is no previous mask, so default to an empty tensor
 | 
				
			||||||
 | 
					  const maskInput = new Tensor(
 | 
				
			||||||
 | 
					    "float32",
 | 
				
			||||||
 | 
					    new Float32Array(256 * 256),
 | 
				
			||||||
 | 
					    [1, 1, 256, 256]
 | 
				
			||||||
 | 
					  );
 | 
				
			||||||
 | 
					  // There is no previous mask, so default to 0
 | 
				
			||||||
 | 
					  const hasMaskInput = new Tensor("float32", [0]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return {
 | 
				
			||||||
 | 
					    image_embeddings: imageEmbedding,
 | 
				
			||||||
 | 
					    point_coords: pointCoordsTensor,
 | 
				
			||||||
 | 
					    point_labels: pointLabelsTensor,
 | 
				
			||||||
 | 
					    orig_im_size: imageSizeTensor,
 | 
				
			||||||
 | 
					    mask_input: maskInput,
 | 
				
			||||||
 | 
					    has_mask_input: hasMaskInput,
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export { modelData };
 | 
				
			||||||
							
								
								
									
										12
									
								
								demo/src/components/helpers/scaleHelper.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								demo/src/components/helpers/scaleHelper.tsx
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,12 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					// Helper function for handling image scaling needed for SAM
 | 
				
			||||||
 | 
					const handleImageScale = (image: HTMLImageElement) => {
 | 
				
			||||||
 | 
					  // Input images to SAM must be resized so the longest side is 1024
 | 
				
			||||||
 | 
					  const LONG_SIDE_LENGTH = 1024;
 | 
				
			||||||
 | 
					  let w = image.naturalWidth;
 | 
				
			||||||
 | 
					  let h = image.naturalHeight;
 | 
				
			||||||
 | 
					  const samScale = LONG_SIDE_LENGTH / Math.max(h, w);
 | 
				
			||||||
 | 
					  return { height: h, width: w, samScale };
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export { handleImageScale };
 | 
				
			||||||
							
								
								
									
										25
									
								
								demo/src/components/hooks/context.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								demo/src/components/hooks/context.tsx
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,25 @@
 | 
				
			|||||||
 | 
					import React, { useState } from "react";
 | 
				
			||||||
 | 
					import { modelInputProps } from "../helpers/Interfaces";
 | 
				
			||||||
 | 
					import AppContext from "./createContext";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const AppContextProvider = (props: {
 | 
				
			||||||
 | 
					  children: React.ReactElement<any, string | React.JSXElementConstructor<any>>;
 | 
				
			||||||
 | 
					}) => {
 | 
				
			||||||
 | 
					  const [clicks, setClicks] = useState<Array<modelInputProps> | null>(null);
 | 
				
			||||||
 | 
					  const [image, setImage] = useState<HTMLImageElement | null>(null);
 | 
				
			||||||
 | 
					  const [maskImg, setMaskImg] = useState<HTMLImageElement | null>(null);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return (
 | 
				
			||||||
 | 
					    <AppContext.Provider
 | 
				
			||||||
 | 
					      value={{
 | 
				
			||||||
 | 
					        clicks: [clicks, setClicks],
 | 
				
			||||||
 | 
					        image: [image, setImage],
 | 
				
			||||||
 | 
					        maskImg: [maskImg, setMaskImg],
 | 
				
			||||||
 | 
					      }}
 | 
				
			||||||
 | 
					    >
 | 
				
			||||||
 | 
					      {props.children}
 | 
				
			||||||
 | 
					    </AppContext.Provider>
 | 
				
			||||||
 | 
					  );
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export default AppContextProvider;
 | 
				
			||||||
							
								
								
									
										21
									
								
								demo/src/components/hooks/createContext.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								demo/src/components/hooks/createContext.tsx
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,21 @@
 | 
				
			|||||||
 | 
					import { createContext } from "react";
 | 
				
			||||||
 | 
					import { modelInputProps } from "../helpers/Interfaces";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					interface contextProps {
 | 
				
			||||||
 | 
					  clicks: [
 | 
				
			||||||
 | 
					    clicks: modelInputProps[] | null,
 | 
				
			||||||
 | 
					    setClicks: (e: modelInputProps[] | null) => void
 | 
				
			||||||
 | 
					  ];
 | 
				
			||||||
 | 
					  image: [
 | 
				
			||||||
 | 
					    image: HTMLImageElement | null,
 | 
				
			||||||
 | 
					    setImage: (e: HTMLImageElement | null) => void
 | 
				
			||||||
 | 
					  ];
 | 
				
			||||||
 | 
					  maskImg: [
 | 
				
			||||||
 | 
					    maskImg: HTMLImageElement | null,
 | 
				
			||||||
 | 
					    setMaskImg: (e: HTMLImageElement | null) => void
 | 
				
			||||||
 | 
					  ];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const AppContext = createContext<contextProps | null>(null);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export default AppContext;
 | 
				
			||||||
							
								
								
									
										11
									
								
								demo/src/index.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								demo/src/index.tsx
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,11 @@
 | 
				
			|||||||
 | 
					import * as React from "react";
 | 
				
			||||||
 | 
					import { createRoot } from "react-dom/client";
 | 
				
			||||||
 | 
					import AppContextProvider from "./components/hooks/context";
 | 
				
			||||||
 | 
					import App from "./App";
 | 
				
			||||||
 | 
					const container = document.getElementById("root");
 | 
				
			||||||
 | 
					const root = createRoot(container!);
 | 
				
			||||||
 | 
					root.render(
 | 
				
			||||||
 | 
					  <AppContextProvider>
 | 
				
			||||||
 | 
					    <App/>
 | 
				
			||||||
 | 
					  </AppContextProvider>
 | 
				
			||||||
 | 
					);
 | 
				
			||||||
							
								
								
									
										6
									
								
								demo/tailwind.config.js
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								demo/tailwind.config.js
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,6 @@
 | 
				
			|||||||
 | 
					/** @type {import('tailwindcss').Config} */
 | 
				
			||||||
 | 
					module.exports = {
 | 
				
			||||||
 | 
					  content: ["./src/**/*.{html,js,tsx}"],
 | 
				
			||||||
 | 
					  theme: {},
 | 
				
			||||||
 | 
					  plugins: [],
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
							
								
								
									
										24
									
								
								demo/tsconfig.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								demo/tsconfig.json
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,24 @@
 | 
				
			|||||||
 | 
					{
 | 
				
			||||||
 | 
					  "compilerOptions": {
 | 
				
			||||||
 | 
					    "lib": ["dom", "dom.iterable", "esnext"],
 | 
				
			||||||
 | 
					    "allowJs": true,
 | 
				
			||||||
 | 
					    "skipLibCheck": true,
 | 
				
			||||||
 | 
					    "strict": true,
 | 
				
			||||||
 | 
					    "forceConsistentCasingInFileNames": true,
 | 
				
			||||||
 | 
					    "noEmit": false,
 | 
				
			||||||
 | 
					    "esModuleInterop": true,
 | 
				
			||||||
 | 
					    "module": "esnext",
 | 
				
			||||||
 | 
					    "moduleResolution": "node",
 | 
				
			||||||
 | 
					    "resolveJsonModule": true,
 | 
				
			||||||
 | 
					    "isolatedModules": true,
 | 
				
			||||||
 | 
					    "jsx": "react",
 | 
				
			||||||
 | 
					    "incremental": true,
 | 
				
			||||||
 | 
					    "target": "ESNext",
 | 
				
			||||||
 | 
					    "useDefineForClassFields": true,
 | 
				
			||||||
 | 
					    "allowSyntheticDefaultImports": true,
 | 
				
			||||||
 | 
					    "outDir": "./dist/",
 | 
				
			||||||
 | 
					    "sourceMap": true
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", "src"],
 | 
				
			||||||
 | 
					  "exclude": ["node_modules"]
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Reference in New Issue
	
	Block a user