Add mini web demo
This commit is contained in:
124
demo/src/App.tsx
Normal file
124
demo/src/App.tsx
Normal file
@@ -0,0 +1,124 @@
|
||||
import { InferenceSession, Tensor } from "onnxruntime-web";
|
||||
import React, { useContext, useEffect, useState } from "react";
|
||||
import "./assets/scss/App.scss";
|
||||
import { handleImageScale } from "./components/helpers/scaleHelper";
|
||||
import { modelScaleProps } from "./components/helpers/Interfaces";
|
||||
import { onnxMaskToImage } from "./components/helpers/maskUtils";
|
||||
import { modelData } from "./components/helpers/onnxModelAPI";
|
||||
import Stage from "./components/Stage";
|
||||
import AppContext from "./components/hooks/createContext";
|
||||
const ort = require("onnxruntime-web");
|
||||
/* @ts-ignore */
|
||||
import npyjs from "npyjs";
|
||||
|
||||
// Define image, embedding and model paths
|
||||
const IMAGE_PATH = "/assets/data/dogs.jpg";
|
||||
const IMAGE_EMBEDDING = "/assets/data/dogs_embedding.npy";
|
||||
const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx";
|
||||
|
||||
const App = () => {
|
||||
const {
|
||||
clicks: [clicks],
|
||||
image: [, setImage],
|
||||
maskImg: [, setMaskImg],
|
||||
} = useContext(AppContext)!;
|
||||
const [model, setModel] = useState<InferenceSession | null>(null); // ONNX model
|
||||
const [tensor, setTensor] = useState<Tensor | null>(null); // Image embedding tensor
|
||||
|
||||
// The ONNX model expects the input to be rescaled to 1024.
|
||||
// The modelScale state variable keeps track of the scale values.
|
||||
const [modelScale, setModelScale] = useState<modelScaleProps | null>(null);
|
||||
|
||||
// Initialize the ONNX model. load the image, and load the SAM
|
||||
// pre-computed image embedding
|
||||
useEffect(() => {
|
||||
// Initialize the ONNX model
|
||||
const initModel = async () => {
|
||||
try {
|
||||
if (MODEL_DIR === undefined) return;
|
||||
const URL: string = MODEL_DIR;
|
||||
const model = await InferenceSession.create(URL);
|
||||
setModel(model);
|
||||
} catch (e) {
|
||||
console.log(e);
|
||||
}
|
||||
};
|
||||
initModel();
|
||||
|
||||
// Load the image
|
||||
const url = new URL(IMAGE_PATH, location.origin);
|
||||
loadImage(url);
|
||||
|
||||
// Load the Segment Anything pre-computed embedding
|
||||
Promise.resolve(loadNpyTensor(IMAGE_EMBEDDING, "float32")).then(
|
||||
(embedding) => setTensor(embedding)
|
||||
);
|
||||
}, []);
|
||||
|
||||
const loadImage = async (url: URL) => {
|
||||
try {
|
||||
const img = new Image();
|
||||
img.src = url.href;
|
||||
img.onload = () => {
|
||||
const { height, width, samScale } = handleImageScale(img);
|
||||
setModelScale({
|
||||
height: height, // original image height
|
||||
width: width, // original image width
|
||||
samScale: samScale, // scaling factor for image which has been resized to longest side 1024
|
||||
});
|
||||
img.width = width;
|
||||
img.height = height;
|
||||
setImage(img);
|
||||
};
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
}
|
||||
};
|
||||
|
||||
// Decode a Numpy file into a tensor.
|
||||
const loadNpyTensor = async (tensorFile: string, dType: string) => {
|
||||
let npLoader = new npyjs();
|
||||
const npArray = await npLoader.load(tensorFile);
|
||||
const tensor = new ort.Tensor(dType, npArray.data, npArray.shape);
|
||||
return tensor;
|
||||
};
|
||||
|
||||
// Run the ONNX model every time clicks has changed
|
||||
useEffect(() => {
|
||||
runONNX();
|
||||
}, [clicks]);
|
||||
|
||||
const runONNX = async () => {
|
||||
try {
|
||||
if (
|
||||
model === null ||
|
||||
clicks === null ||
|
||||
tensor === null ||
|
||||
modelScale === null
|
||||
)
|
||||
return;
|
||||
else {
|
||||
// Preapre the model input in the correct format for SAM.
|
||||
// The modelData function is from onnxModelAPI.tsx.
|
||||
const feeds = modelData({
|
||||
clicks,
|
||||
tensor,
|
||||
modelScale,
|
||||
});
|
||||
if (feeds === undefined) return;
|
||||
// Run the SAM ONNX model with the feeds returned from modelData()
|
||||
const results = await model.run(feeds);
|
||||
const output = results[model.outputNames[0]];
|
||||
// The predicted mask returned from the ONNX model is an array which is
|
||||
// rendered as an HTML image using onnxMaskToImage() from maskUtils.tsx.
|
||||
setMaskImg(onnxMaskToImage(output.data, output.dims[2], output.dims[3]));
|
||||
}
|
||||
} catch (e) {
|
||||
console.log(e);
|
||||
}
|
||||
};
|
||||
|
||||
return <Stage />;
|
||||
};
|
||||
|
||||
export default App;
|
BIN
demo/src/assets/data/dogs.jpg
Normal file
BIN
demo/src/assets/data/dogs.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 438 KiB |
18
demo/src/assets/index.html
Normal file
18
demo/src/assets/index.html
Normal file
@@ -0,0 +1,18 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en" dir="ltr" prefix="og: https://ogp.me/ns#" class="w-full h-full">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta
|
||||
name="viewport"
|
||||
content="width=device-width, initial-scale=1, shrink-to-fit=no"
|
||||
/>
|
||||
<title>Segment Anything Demo</title>
|
||||
|
||||
<!-- Meta Tags -->
|
||||
<meta property="og:type" content="website" />
|
||||
<meta property="og:title" content="Segment Anything Demo" />
|
||||
</head>
|
||||
<body class="w-full h-full">
|
||||
<div id="root" class="w-full h-full"></div>
|
||||
</body>
|
||||
</html>
|
3
demo/src/assets/scss/App.scss
Normal file
3
demo/src/assets/scss/App.scss
Normal file
@@ -0,0 +1,3 @@
|
||||
@tailwind base;
|
||||
@tailwind components;
|
||||
@tailwind utilities;
|
43
demo/src/components/Stage.tsx
Normal file
43
demo/src/components/Stage.tsx
Normal file
@@ -0,0 +1,43 @@
|
||||
import React, { useContext } from "react";
|
||||
import * as _ from "underscore";
|
||||
import Tool from "./Tool";
|
||||
import { modelInputProps } from "./helpers/Interfaces";
|
||||
import AppContext from "./hooks/createContext";
|
||||
|
||||
const Stage = () => {
|
||||
const {
|
||||
clicks: [, setClicks],
|
||||
image: [image],
|
||||
} = useContext(AppContext)!;
|
||||
|
||||
const getClick = (x: number, y: number): modelInputProps => {
|
||||
const clickType = 1;
|
||||
return { x, y, clickType };
|
||||
};
|
||||
|
||||
// Get mouse position and scale the (x, y) coordinates back to the natural
|
||||
// scale of the image. Update the state of clicks with setClicks to trigger
|
||||
// the ONNX model to run and generate a new mask via a useEffect in App.tsx
|
||||
const handleMouseMove = _.throttle((e: any) => {
|
||||
let el = e.nativeEvent.target;
|
||||
const rect = el.getBoundingClientRect();
|
||||
let x = e.clientX - rect.left;
|
||||
let y = e.clientY - rect.top;
|
||||
const imageScale = image ? image.width / el.offsetWidth : 1;
|
||||
x *= imageScale;
|
||||
y *= imageScale;
|
||||
const click = getClick(x, y);
|
||||
if (click) setClicks([click]);
|
||||
}, 15);
|
||||
|
||||
const flexCenterClasses = "flex items-center justify-center";
|
||||
return (
|
||||
<div className={`${flexCenterClasses} w-full h-full`}>
|
||||
<div className={`${flexCenterClasses} relative w-[90%] h-[90%]`}>
|
||||
<Tool handleMouseMove={handleMouseMove} />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default Stage;
|
67
demo/src/components/Tool.tsx
Normal file
67
demo/src/components/Tool.tsx
Normal file
@@ -0,0 +1,67 @@
|
||||
import React, { useContext, useEffect, useState } from "react";
|
||||
import AppContext from "./hooks/createContext";
|
||||
import { ToolProps } from "./helpers/Interfaces";
|
||||
import * as _ from "underscore";
|
||||
|
||||
const Tool = ({ handleMouseMove }: ToolProps) => {
|
||||
const {
|
||||
image: [image],
|
||||
maskImg: [maskImg, setMaskImg],
|
||||
} = useContext(AppContext)!;
|
||||
|
||||
// Determine if we should shrink or grow the images to match the
|
||||
// width or the height of the page and setup a ResizeObserver to
|
||||
// monitor changes in the size of the page
|
||||
const [shouldFitToWidth, setShouldFitToWidth] = useState(true);
|
||||
const bodyEl = document.body;
|
||||
const fitToPage = () => {
|
||||
if (!image) return;
|
||||
const imageAspectRatio = image.width / image.height;
|
||||
const screenAspectRatio = window.innerWidth / window.innerHeight;
|
||||
setShouldFitToWidth(imageAspectRatio > screenAspectRatio);
|
||||
};
|
||||
const resizeObserver = new ResizeObserver((entries) => {
|
||||
for (const entry of entries) {
|
||||
if (entry.target === bodyEl) {
|
||||
fitToPage();
|
||||
}
|
||||
}
|
||||
});
|
||||
useEffect(() => {
|
||||
fitToPage();
|
||||
resizeObserver.observe(bodyEl);
|
||||
return () => {
|
||||
resizeObserver.unobserve(bodyEl);
|
||||
};
|
||||
}, [image]);
|
||||
|
||||
const imageClasses = "";
|
||||
const maskImageClasses = `absolute opacity-40 pointer-events-none`;
|
||||
|
||||
// Render the image and the predicted mask image on top
|
||||
return (
|
||||
<>
|
||||
{image && (
|
||||
<img
|
||||
onMouseMove={handleMouseMove}
|
||||
onMouseOut={() => _.defer(() => setMaskImg(null))}
|
||||
onTouchStart={handleMouseMove}
|
||||
src={image.src}
|
||||
className={`${
|
||||
shouldFitToWidth ? "w-full" : "h-full"
|
||||
} ${imageClasses}`}
|
||||
></img>
|
||||
)}
|
||||
{maskImg && (
|
||||
<img
|
||||
src={maskImg.src}
|
||||
className={`${
|
||||
shouldFitToWidth ? "w-full" : "h-full"
|
||||
} ${maskImageClasses}`}
|
||||
></img>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default Tool;
|
23
demo/src/components/helpers/Interfaces.tsx
Normal file
23
demo/src/components/helpers/Interfaces.tsx
Normal file
@@ -0,0 +1,23 @@
|
||||
import { Tensor } from "onnxruntime-web";
|
||||
|
||||
export interface modelScaleProps {
|
||||
samScale: number;
|
||||
height: number;
|
||||
width: number;
|
||||
}
|
||||
|
||||
export interface modelInputProps {
|
||||
x: number;
|
||||
y: number;
|
||||
clickType: number;
|
||||
}
|
||||
|
||||
export interface modeDataProps {
|
||||
clicks?: Array<modelInputProps>;
|
||||
tensor: Tensor;
|
||||
modelScale: modelScaleProps;
|
||||
}
|
||||
|
||||
export interface ToolProps {
|
||||
handleMouseMove: (e: any) => void;
|
||||
}
|
43
demo/src/components/helpers/maskUtils.tsx
Normal file
43
demo/src/components/helpers/maskUtils.tsx
Normal file
@@ -0,0 +1,43 @@
|
||||
// Functions for handling mask output from the ONNX model
|
||||
|
||||
// Convert the onnx model mask prediction to ImageData
|
||||
function arrayToImageData(input: any, width: number, height: number) {
|
||||
const [r, g, b, a] = [0, 114, 189, 255]; // the masks's blue color
|
||||
const arr = new Uint8ClampedArray(4 * width * height).fill(0);
|
||||
for (let i = 0; i < input.length; i++) {
|
||||
|
||||
// Threshold the onnx model mask prediction at 0.0
|
||||
// This is equivalent to thresholding the mask using predictor.model.mask_threshold
|
||||
// in python
|
||||
if (input[i] > 0.0) {
|
||||
arr[4 * i + 0] = r;
|
||||
arr[4 * i + 1] = g;
|
||||
arr[4 * i + 2] = b;
|
||||
arr[4 * i + 3] = a;
|
||||
}
|
||||
}
|
||||
return new ImageData(arr, height, width);
|
||||
}
|
||||
|
||||
// Use a Canvas element to produce an image from ImageData
|
||||
function imageDataToImage(imageData: ImageData) {
|
||||
const canvas = imageDataToCanvas(imageData);
|
||||
const image = new Image();
|
||||
image.src = canvas.toDataURL();
|
||||
return image;
|
||||
}
|
||||
|
||||
// Canvas elements can be created from ImageData
|
||||
function imageDataToCanvas(imageData: ImageData) {
|
||||
const canvas = document.createElement("canvas");
|
||||
const ctx = canvas.getContext("2d");
|
||||
canvas.width = imageData.width;
|
||||
canvas.height = imageData.height;
|
||||
ctx?.putImageData(imageData, 0, 0);
|
||||
return canvas;
|
||||
}
|
||||
|
||||
// Convert the onnx model mask output to an HTMLImageElement
|
||||
export function onnxMaskToImage(input: any, width: number, height: number) {
|
||||
return imageDataToImage(arrayToImageData(input, width, height));
|
||||
}
|
65
demo/src/components/helpers/onnxModelAPI.tsx
Normal file
65
demo/src/components/helpers/onnxModelAPI.tsx
Normal file
@@ -0,0 +1,65 @@
|
||||
import { Tensor } from "onnxruntime-web";
|
||||
import { modeDataProps } from "./Interfaces";
|
||||
|
||||
const modelData = ({ clicks, tensor, modelScale }: modeDataProps) => {
|
||||
const imageEmbedding = tensor;
|
||||
let pointCoords;
|
||||
let pointLabels;
|
||||
let pointCoordsTensor;
|
||||
let pointLabelsTensor;
|
||||
|
||||
// Check there are input click prompts
|
||||
if (clicks) {
|
||||
let n = clicks.length;
|
||||
|
||||
// If there is no box input, a single padding point with
|
||||
// label -1 and coordinates (0.0, 0.0) should be concatenated
|
||||
// so initialize the array to support (n + 1) points.
|
||||
pointCoords = new Float32Array(2 * (n + 1));
|
||||
pointLabels = new Float32Array(n + 1);
|
||||
|
||||
// Add clicks and scale to what SAM expects
|
||||
for (let i = 0; i < n; i++) {
|
||||
pointCoords[2 * i] = clicks[i].x * modelScale.samScale;
|
||||
pointCoords[2 * i + 1] = clicks[i].y * modelScale.samScale;
|
||||
pointLabels[i] = clicks[i].clickType;
|
||||
}
|
||||
|
||||
// Add in the extra point/label when only clicks and no box
|
||||
// The extra point is at (0, 0) with label -1
|
||||
pointCoords[2 * n] = 0.0;
|
||||
pointCoords[2 * n + 1] = 0.0;
|
||||
pointLabels[n] = -1.0;
|
||||
|
||||
// Create the tensor
|
||||
pointCoordsTensor = new Tensor("float32", pointCoords, [1, n + 1, 2]);
|
||||
pointLabelsTensor = new Tensor("float32", pointLabels, [1, n + 1]);
|
||||
}
|
||||
const imageSizeTensor = new Tensor("float32", [
|
||||
modelScale.height,
|
||||
modelScale.width,
|
||||
]);
|
||||
|
||||
if (pointCoordsTensor === undefined || pointLabelsTensor === undefined)
|
||||
return;
|
||||
|
||||
// There is no previous mask, so default to an empty tensor
|
||||
const maskInput = new Tensor(
|
||||
"float32",
|
||||
new Float32Array(256 * 256),
|
||||
[1, 1, 256, 256]
|
||||
);
|
||||
// There is no previous mask, so default to 0
|
||||
const hasMaskInput = new Tensor("float32", [0]);
|
||||
|
||||
return {
|
||||
image_embeddings: imageEmbedding,
|
||||
point_coords: pointCoordsTensor,
|
||||
point_labels: pointLabelsTensor,
|
||||
orig_im_size: imageSizeTensor,
|
||||
mask_input: maskInput,
|
||||
has_mask_input: hasMaskInput,
|
||||
};
|
||||
};
|
||||
|
||||
export { modelData };
|
12
demo/src/components/helpers/scaleHelper.tsx
Normal file
12
demo/src/components/helpers/scaleHelper.tsx
Normal file
@@ -0,0 +1,12 @@
|
||||
|
||||
// Helper function for handling image scaling needed for SAM
|
||||
const handleImageScale = (image: HTMLImageElement) => {
|
||||
// Input images to SAM must be resized so the longest side is 1024
|
||||
const LONG_SIDE_LENGTH = 1024;
|
||||
let w = image.naturalWidth;
|
||||
let h = image.naturalHeight;
|
||||
const samScale = LONG_SIDE_LENGTH / Math.max(h, w);
|
||||
return { height: h, width: w, samScale };
|
||||
};
|
||||
|
||||
export { handleImageScale };
|
25
demo/src/components/hooks/context.tsx
Normal file
25
demo/src/components/hooks/context.tsx
Normal file
@@ -0,0 +1,25 @@
|
||||
import React, { useState } from "react";
|
||||
import { modelInputProps } from "../helpers/Interfaces";
|
||||
import AppContext from "./createContext";
|
||||
|
||||
const AppContextProvider = (props: {
|
||||
children: React.ReactElement<any, string | React.JSXElementConstructor<any>>;
|
||||
}) => {
|
||||
const [clicks, setClicks] = useState<Array<modelInputProps> | null>(null);
|
||||
const [image, setImage] = useState<HTMLImageElement | null>(null);
|
||||
const [maskImg, setMaskImg] = useState<HTMLImageElement | null>(null);
|
||||
|
||||
return (
|
||||
<AppContext.Provider
|
||||
value={{
|
||||
clicks: [clicks, setClicks],
|
||||
image: [image, setImage],
|
||||
maskImg: [maskImg, setMaskImg],
|
||||
}}
|
||||
>
|
||||
{props.children}
|
||||
</AppContext.Provider>
|
||||
);
|
||||
};
|
||||
|
||||
export default AppContextProvider;
|
21
demo/src/components/hooks/createContext.tsx
Normal file
21
demo/src/components/hooks/createContext.tsx
Normal file
@@ -0,0 +1,21 @@
|
||||
import { createContext } from "react";
|
||||
import { modelInputProps } from "../helpers/Interfaces";
|
||||
|
||||
interface contextProps {
|
||||
clicks: [
|
||||
clicks: modelInputProps[] | null,
|
||||
setClicks: (e: modelInputProps[] | null) => void
|
||||
];
|
||||
image: [
|
||||
image: HTMLImageElement | null,
|
||||
setImage: (e: HTMLImageElement | null) => void
|
||||
];
|
||||
maskImg: [
|
||||
maskImg: HTMLImageElement | null,
|
||||
setMaskImg: (e: HTMLImageElement | null) => void
|
||||
];
|
||||
}
|
||||
|
||||
const AppContext = createContext<contextProps | null>(null);
|
||||
|
||||
export default AppContext;
|
11
demo/src/index.tsx
Normal file
11
demo/src/index.tsx
Normal file
@@ -0,0 +1,11 @@
|
||||
import * as React from "react";
|
||||
import { createRoot } from "react-dom/client";
|
||||
import AppContextProvider from "./components/hooks/context";
|
||||
import App from "./App";
|
||||
const container = document.getElementById("root");
|
||||
const root = createRoot(container!);
|
||||
root.render(
|
||||
<AppContextProvider>
|
||||
<App/>
|
||||
</AppContextProvider>
|
||||
);
|
Reference in New Issue
Block a user