MNIST Digit Classification
In this tutorial we will use a model trained on the MNIST dataset of handwritten digits to predict the number that the user draws.
There are several pieces to this tutorial, so please follow each step carefully. If you get lost, completed examples of each step can be found here.
If you haven't installed the PyTorch Live CLI yet, please follow this tutorial to get started.
Create a new React Native project
We will start by creating a new React Native project with the PyTorch Live (PTL) template using the CLI. Run the following command:
npx torchlive-cli init MNISTClassifier
Once that is done, let's go into a our newly created project and run it!
cd MNISTClassifier
- Android
- iOS (Simulator)
npx torchlive-cli run-android
npx torchlive-cli run-ios
Adding Basic UI
The aim of this tutorial is to help you become more familiar with PTL core components, so we will not spend time on how to style UI, but rather provide the layout and styles from the start.
Go ahead and start by copying the following code into the file src/demos/MyDemos.tsx
:
import React, {useState} from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Canvas, CanvasRenderingContext2D} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
const [canvasSize, setCanvasSize] = useState<number>(0);
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
return (
<View
style={styles.container}
onLayout={event => {
const {layout} = event.nativeEvent;
setCanvasSize(Math.min(layout?.width || 0, layout?.height || 0));
}}>
<View style={[styles.instruction, {marginTop: insets.top}]}>
<Text style={styles.label}>Write a number</Text>
<Text style={styles.label}>Let's test the MNIST model</Text>
</View>
<Canvas
style={{
height: canvasSize,
width: canvasSize,
}}
onContext2D={setCtx}
/>
<View style={[styles.resultView]} pointerEvents="none">
<Text style={[styles.label, styles.secondary]}>
Highest confidence will go here
</Text>
<Text style={[styles.label, styles.secondary]}>
Second highest will go here
</Text>
</View>
</View>
);
}
const styles = StyleSheet.create({
container: {
height: '100%',
width: '100%',
backgroundColor: '#180b3b',
justifyContent: 'center',
alignItems: 'center',
},
resultView: {
position: 'absolute',
bottom: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
instruction: {
position: 'absolute',
top: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
label: {
fontSize: 16,
color: '#ffffff',
},
secondary: {
color: '#ffffff99',
},
});
Now you should see UI that looks exactly like the screenshot below.
- Android
- iOS (Simulator)
npx torchlive-cli run-android
npx torchlive-cli run-ios
Before we add more code, let's take a second to discuss some of what the above code does.
The PyTorch Live Canvas Component
We'll be using the PTL canvas in this tutorial to let the user draw numbers that we will try to classify.
Just like the name suggests, a canvas is a surface that we can programmatically draw on.
In order to draw things on a canvas, we use what is called the canvas context, the ctx
state variable in this case.
Note that we haven't used the context to draw anything yet, so our canvas is essentially invisible.
...
export default function MNISTDemo() {
...
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
...
<Canvas
style={{
height: canvasSize,
width: canvasSize,
}}
onContext2D={setCtx}
/>
...
The onLayout
Prop
In our code, we use the onLayout
prop on the container view to get the dimensions of the screen space we are working with.
Once we have the dimensions of the screen, we find which is smaller between the screen width and height and then we use that to size our canvas.
This makes sure that our canvas is square and fits within the bounds of our screen in both portrait and landscape.
...
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
const [canvasSize, setCanvasSize] = useState<number>(0);
...
return (
<View
style={styles.container}
onLayout={event => {
const {layout} = event.nativeEvent;
setCanvasSize(Math.min(layout?.width || 0, layout?.height || 0));
}}>
...
Results placeholders
Note that for now we just have placeholder text where we will put our model results. Later on, after we run the model, we will update the text there to display the results.
...
<View style={[styles.resultView]} pointerEvents="none">
<Text style={[styles.label, styles.secondary]}>
Highest confidence will go here
</Text>
<Text style={[styles.label, styles.secondary]}>
Second highest will go here
</Text>
</View>
...
Filling the Canvas
Like we mentioned in the previous section, our canvas is currently completely blank.
Let's change that and make a clear surface for users to draw on.
Here's a short summary of the changes we're introducing:
Import
useCallback
anduseEffect
from React.Define a color for our canvas background (
COLOR_CANVAS_BACKGROUND
). We'll use a lighter purple color to distinguish from the rest of the screen.Create a
draw
function that will fill in our background. We create it withuseCallback
to make it so the function updates every time the context or size of the canvas change.Check to make sure context is not null so we have something to draw with.
Set the context's fill style to our canvas background purple (essentially choosing which marker to work with).
Fill in a rectangle that starts at the origin coordinate (0,0) on our canvas (the top left corner) and ends in the bottom right corner of our canvas so it covers the whole thing.
Call the
invalidate
function on our canvas context to let the screen know that we have drawn new things for it to show.
Trigger the
draw
anytime it changes with theuseEffect
block. Remember thatdraw
changes every time the canvas context or size changes, so essentially thisuseEffect
runs every time the canvas changes.
The useCallback
and useEffect
that we imported as well as the useState
function we already had imported are examples of React Hooks. Hooks allow React function components, like our MNISTDemo
function component, to remember things.
You'll notice at the end of useCallback
and useEffect
we have a list []
. This list is the list of "dependencies" for that hook. This just means that the hook will hold onto the value we give it until one of the "dependencies" changes, in which case it will update the value it remembers.
For more information on React Hooks, head over to the React docs where you can read or watch explanations.
- Changes
- Entire File
@@ -1,8 +1,10 @@
-import React, {useState} from 'react';
+import React, {useCallback, useEffect, useState} from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Canvas, CanvasRenderingContext2D} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
+const COLOR_CANVAS_BACKGROUND = '#4F25C6';
+
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
@@ -10,6 +12,20 @@
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
+ const draw = useCallback(() => {
+ if (ctx != null) {
+ // fill background by drawing a rect
+ ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
+ ctx.fillRect(0, 0, canvasSize, canvasSize);
+
+ ctx.invalidate();
+ }
+ }, [ctx, canvasSize]);
+
+ useEffect(() => {
+ draw();
+ }, [draw]);
+
return (
<View
style={styles.container}
import React, {useCallback, useEffect, useState} from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Canvas, CanvasRenderingContext2D} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
const COLOR_CANVAS_BACKGROUND = '#4F25C6';
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
const [canvasSize, setCanvasSize] = useState<number>(0);
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
const draw = useCallback(() => {
if (ctx != null) {
// fill background by drawing a rect
ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
ctx.fillRect(0, 0, canvasSize, canvasSize);
ctx.invalidate();
}
}, [ctx, canvasSize]);
useEffect(() => {
draw();
}, [draw]);
return (
<View
style={styles.container}
onLayout={event => {
const {layout} = event.nativeEvent;
setCanvasSize(Math.min(layout?.width || 0, layout?.height || 0));
}}>
<View style={[styles.instruction, {marginTop: insets.top}]}>
<Text style={styles.label}>Write a number</Text>
<Text style={styles.label}>Let's test the MNIST model</Text>
</View>
<Canvas
style={{
height: canvasSize,
width: canvasSize,
}}
onContext2D={setCtx}
/>
<View style={[styles.resultView]} pointerEvents="none">
<Text style={[styles.label, styles.secondary]}>
Highest confidence will go here
</Text>
<Text style={[styles.label, styles.secondary]}>
Second highest will go here
</Text>
</View>
</View>
);
}
const styles = StyleSheet.create({
container: {
height: '100%',
width: '100%',
backgroundColor: '#180b3b',
justifyContent: 'center',
alignItems: 'center',
},
resultView: {
position: 'absolute',
bottom: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
instruction: {
position: 'absolute',
top: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
label: {
fontSize: 16,
color: '#ffffff',
},
secondary: {
color: '#ffffff99',
},
});
Once you run your app, the My Demos screen should now look like this.
- Android
- iOS (Simulator)
npx torchlive-cli run-android
npx torchlive-cli run-ios
I know that was a lot of new stuff to simply paint our canvas light purple, but it provides us with a good foundation for when we draw more on our canvas.
Drawing with Touch Input
Now that we have a clear area for the user to draw on, let's make it so they can draw!
Let's go over what we will change to make drawing possible:
Import
useRef
from React.Define a color for the trail of the users touch (
COLOR_TRAIL_STROKE
). We'll use white to make it stand out.Define a
TrailPoint
type to keep our data safe, error free, and easy to use.Create a ref to a list of
TrailPoints
calledtrailRef
and set it to an empty list.Keep track of if the user has finished drawing with the
drawingDone
state variable and initialize it tofalse
.Add support for drawing the trail to our draw function:
- Create a variable called
trail
and set it to the current value of ourtrailRef
. This is purely so we don't have to writetrailRef.current
every time we need the trail. - Check to make sure the trail isn't null.
- Draw our background to cover anything previously drawn.
- Check to make sure our trail has at least 1 point.
- Set the context's
strokeColor
- you can think of it as picking the marker color we'll draw lines with. - Set the context's line drawing style parameters (
lineWidth
,lineJoin
,lineCap
, andmiterLimit
). - Tell the context to start a line at the first point in the trail.
- Loop through points of the trail to add them to the line we are drawing.
- Tell the context via the
stroke
method to actually draw the line that we constructed. - Use the
invalidate
method to tell the screen we have updates ready to draw.
- Create a variable called
Create functions for handling when a user touches the canvas (
handleStart
,handleTouch
, andhandleEnd
).The
handleStart
is called when the user first touches the canvas. It is a simple function that does the following:- Set the
drawingDone
variable tofalse
. - Reset the trailRef to an
emptyList
.
- Set the
The
handleMove
function is called each time the device detects that the touch has changed positions since the starting touch.- Get the coordinates of the new touch location and store them in the
position
variable. - If there are already points in the
trail
, only add the new position if it's 5 pixels away from the last position (avoids keeping unnecessary points that slow down the app). - If there are no points in the
trail
, add the new position. - Trigger the
draw
function to display the newly updatedtrail
.
- Get the coordinates of the new touch location and store them in the
The
handleEnd
function is called when the user's touch is no longer detected on the screen.- Simply set the
drawingDone
state variable totrue
.
- Simply set the
Set the
onTouchStart
,onTouchMove
, andonTouchEnd
props on our<Canvas />
component tohandleStart
,handleMove
, andhandleEnd
respectively.
- Changes
- Entire File
@@ -1,26 +1,88 @@
-import React, {useCallback, useEffect, useState} from 'react';
+import React, {useCallback, useEffect, useState, useRef} from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Canvas, CanvasRenderingContext2D} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
const COLOR_CANVAS_BACKGROUND = '#4F25C6';
+const COLOR_TRAIL_STROKE = '#FFFFFF';
+
+type TrailPoint = {
+ x: number;
+ y: number;
+};
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
const [canvasSize, setCanvasSize] = useState<number>(0);
+
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
+ const trailRef = useRef<TrailPoint[]>([]);
+ const [drawingDone, setDrawingDone] = useState(false);
+
const draw = useCallback(() => {
if (ctx != null) {
- // fill background by drawing a rect
- ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
- ctx.fillRect(0, 0, canvasSize, canvasSize);
+ const trail = trailRef.current;
+ if (trail != null) {
+ // fill background by drawing a rect
+ ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
+ ctx.fillRect(0, 0, canvasSize, canvasSize);
+
+ // Draw the trail
+
+ if (trail.length > 0) {
+ ctx.strokeStyle = COLOR_TRAIL_STROKE;
+ ctx.lineWidth = 25;
+ ctx.lineJoin = 'round';
+ ctx.lineCap = 'round';
+ ctx.miterLimit = 1;
+ ctx.beginPath();
+ ctx.moveTo(trail[0].x, trail[0].y);
+ for (let i = 1; i < trail.length; i++) {
+ ctx.lineTo(trail[i].x, trail[i].y);
+ }
+ ctx.stroke();
+ }
- ctx.invalidate();
+ ctx.invalidate();
+ }
}
- }, [ctx, canvasSize]);
+ }, [ctx, canvasSize, trailRef]);
+
+ // handlers for touch events
+ const handleMove = useCallback(
+ async event => {
+ const position: TrailPoint = {
+ x: event.nativeEvent.locationX,
+ y: event.nativeEvent.locationY,
+ };
+ const trail = trailRef.current;
+ if (trail.length > 0) {
+ const lastPosition = trail[trail.length - 1];
+ const dx = position.x - lastPosition.x;
+ const dy = position.y - lastPosition.y;
+ // add a point to trail if distance from last point > 5
+ if (dx * dx + dy * dy > 25) {
+ trail.push(position);
+ }
+ } else {
+ trail.push(position);
+ }
+ draw();
+ },
+ [trailRef, draw],
+ );
+
+ const handleStart = useCallback(() => {
+ setDrawingDone(false);
+ trailRef.current = [];
+ }, [trailRef, setDrawingDone]);
+
+ const handleEnd = useCallback(() => {
+ setDrawingDone(true);
+ }, [setDrawingDone]);
useEffect(() => {
draw();
@@ -35,7 +97,9 @@
}}>
<View style={[styles.instruction, {marginTop: insets.top}]}>
<Text style={styles.label}>Write a number</Text>
- <Text style={styles.label}>Let's test the MNIST model</Text>
+ <Text style={styles.label}>
+ Let's see if the AI model will get it right
+ </Text>
</View>
<Canvas
style={{
@@ -43,15 +107,20 @@
width: canvasSize,
}}
onContext2D={setCtx}
+ onTouchMove={handleMove}
+ onTouchStart={handleStart}
+ onTouchEnd={handleEnd}
/>
- <View style={[styles.resultView]} pointerEvents="none">
- <Text style={[styles.label, styles.secondary]}>
- Highest confidence will go here
- </Text>
- <Text style={[styles.label, styles.secondary]}>
- Second highest will go here
- </Text>
- </View>
+ {drawingDone && (
+ <View style={[styles.resultView]} pointerEvents="none">
+ <Text style={[styles.label, styles.secondary]}>
+ Highest confidence will go here
+ </Text>
+ <Text style={[styles.label, styles.secondary]}>
+ Second highest will go here
+ </Text>
+ </View>
+ )}
</View>
);
}
import React, {useCallback, useEffect, useState, useRef} from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Canvas, CanvasRenderingContext2D} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
const COLOR_CANVAS_BACKGROUND = '#4F25C6';
const COLOR_TRAIL_STROKE = '#FFFFFF';
type TrailPoint = {
x: number;
y: number;
};
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
const [canvasSize, setCanvasSize] = useState<number>(0);
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
const trailRef = useRef<TrailPoint[]>([]);
const [drawingDone, setDrawingDone] = useState(false);
const draw = useCallback(() => {
if (ctx != null) {
const trail = trailRef.current;
if (trail != null) {
// fill background by drawing a rect
ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
ctx.fillRect(0, 0, canvasSize, canvasSize);
// Draw the trail
if (trail.length > 0) {
ctx.strokeStyle = COLOR_TRAIL_STROKE;
ctx.lineWidth = 25;
ctx.lineJoin = 'round';
ctx.lineCap = 'round';
ctx.miterLimit = 1;
ctx.beginPath();
ctx.moveTo(trail[0].x, trail[0].y);
for (let i = 1; i < trail.length; i++) {
ctx.lineTo(trail[i].x, trail[i].y);
}
ctx.stroke();
}
ctx.invalidate();
}
}
}, [ctx, canvasSize, trailRef]);
// handlers for touch events
const handleMove = useCallback(
async event => {
const position: TrailPoint = {
x: event.nativeEvent.locationX,
y: event.nativeEvent.locationY,
};
const trail = trailRef.current;
if (trail.length > 0) {
const lastPosition = trail[trail.length - 1];
const dx = position.x - lastPosition.x;
const dy = position.y - lastPosition.y;
// add a point to trail if distance from last point > 5
if (dx * dx + dy * dy > 25) {
trail.push(position);
}
} else {
trail.push(position);
}
draw();
},
[trailRef, draw],
);
const handleStart = useCallback(() => {
setDrawingDone(false);
trailRef.current = [];
}, [trailRef, setDrawingDone]);
const handleEnd = useCallback(() => {
setDrawingDone(true);
}, [setDrawingDone]);
useEffect(() => {
draw();
}, [draw]);
return (
<View
style={styles.container}
onLayout={event => {
const {layout} = event.nativeEvent;
setCanvasSize(Math.min(layout?.width || 0, layout?.height || 0));
}}>
<View style={[styles.instruction, {marginTop: insets.top}]}>
<Text style={styles.label}>Write a number</Text>
<Text style={styles.label}>
Let's see if the AI model will get it right
</Text>
</View>
<Canvas
style={{
height: canvasSize,
width: canvasSize,
}}
onContext2D={setCtx}
onTouchMove={handleMove}
onTouchStart={handleStart}
onTouchEnd={handleEnd}
/>
{drawingDone && (
<View style={[styles.resultView]} pointerEvents="none">
<Text style={[styles.label, styles.secondary]}>
Highest confidence will go here
</Text>
<Text style={[styles.label, styles.secondary]}>
Second highest will go here
</Text>
</View>
)}
</View>
);
}
const styles = StyleSheet.create({
container: {
height: '100%',
width: '100%',
backgroundColor: '#180b3b',
justifyContent: 'center',
alignItems: 'center',
},
resultView: {
position: 'absolute',
bottom: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
instruction: {
position: 'absolute',
top: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
label: {
fontSize: 16,
color: '#ffffff',
},
secondary: {
color: '#ffffff99',
},
});
Run this code and we should now be able to do some drawing like you can see in the video below.
As you will notice, the drawing seems to glitch out at times, especially as the trail gets longer and longer. Let's fix that next.
React Refs
Refs in React are a variable like state, but they don't cause the component to re-render when they are changed.
You can get or set the value of a ref via the .current
property.
In our code, we access the trail with trailRef.current
. We set the trail in our handleStart
function to an empty list with trailRef.current = []
.
Avoiding Excessive Re-rendering
The glitchiness we see in our code as it stands is because we are asking the screen to refresh before it is ready.
Mobile screens typically refresh 60 times per second (though some new phones refresh twice as often). When we display things with React, it takes care of matching our device's refresh rate.
While we are using React to render our <Canvas />
, what we draw on our canvas we handle ourselves. Lucky for us, there is a simple way to make sure we don't render too often.
To address this, we will make a few updates to our code, mainly in the draw
function:
Create a ref called
animationHandleRef
that can be anumber
ornull
and set it tonull
. We will use this ref to check if rendering is currently in process or not.Use the
animationHandleRef
in thedraw
function to control how often we rerender:- Start the function by checking if the
animationHandleRef
is set to a non-null value. If it is, we want to end early, because we know the device is already working on rendering. - Wrap our code that does drawing in an inline function that we pass to
requestAnimationFrame
and set theanimationHandleRef
's value to what it returns. (Read more about this function in the note following the code.) - After telling our canvas we are ready for a rerender with
ctx.invalidate()
, clear theanimationHandleRef
by setting its value to null. - Add
animationHandleRef
to thedraw
function's callback dependencies list.
- Start the function by checking if the
- Changes
- Entire File
@@ -21,35 +21,40 @@
const trailRef = useRef<TrailPoint[]>([]);
const [drawingDone, setDrawingDone] = useState(false);
+ const animationHandleRef = useRef<number | null>(null);
const draw = useCallback(() => {
+ if (animationHandleRef.current != null) return;
if (ctx != null) {
- const trail = trailRef.current;
- if (trail != null) {
- // fill background by drawing a rect
- ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
- ctx.fillRect(0, 0, canvasSize, canvasSize);
-
- // Draw the trail
+ animationHandleRef.current = requestAnimationFrame(() => {
+ const trail = trailRef.current;
+ if (trail != null) {
+ // fill background by drawing a rect
+ ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
+ ctx.fillRect(0, 0, canvasSize, canvasSize);
- if (trail.length > 0) {
+ // Draw the trail
ctx.strokeStyle = COLOR_TRAIL_STROKE;
ctx.lineWidth = 25;
ctx.lineJoin = 'round';
ctx.lineCap = 'round';
ctx.miterLimit = 1;
- ctx.beginPath();
- ctx.moveTo(trail[0].x, trail[0].y);
- for (let i = 1; i < trail.length; i++) {
- ctx.lineTo(trail[i].x, trail[i].y);
+
+ if (trail.length > 0) {
+ ctx.beginPath();
+ ctx.moveTo(trail[0].x, trail[0].y);
+ for (let i = 1; i < trail.length; i++) {
+ ctx.lineTo(trail[i].x, trail[i].y);
+ }
}
ctx.stroke();
+ // Need to include this at the end, for now.
+ ctx.invalidate();
+ animationHandleRef.current = null;
}
-
- ctx.invalidate();
- }
+ });
}
- }, [ctx, canvasSize, trailRef]);
+ }, [animationHandleRef, ctx, canvasSize, trailRef]);
// handlers for touch events
const handleMove = useCallback(
import React, {useCallback, useEffect, useState, useRef} from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Canvas, CanvasRenderingContext2D} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
const COLOR_CANVAS_BACKGROUND = '#4F25C6';
const COLOR_TRAIL_STROKE = '#FFFFFF';
type TrailPoint = {
x: number;
y: number;
};
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
const [canvasSize, setCanvasSize] = useState<number>(0);
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
const trailRef = useRef<TrailPoint[]>([]);
const [drawingDone, setDrawingDone] = useState(false);
const animationHandleRef = useRef<number | null>(null);
const draw = useCallback(() => {
if (animationHandleRef.current != null) return;
if (ctx != null) {
animationHandleRef.current = requestAnimationFrame(() => {
const trail = trailRef.current;
if (trail != null) {
// fill background by drawing a rect
ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
ctx.fillRect(0, 0, canvasSize, canvasSize);
// Draw the trail
ctx.strokeStyle = COLOR_TRAIL_STROKE;
ctx.lineWidth = 25;
ctx.lineJoin = 'round';
ctx.lineCap = 'round';
ctx.miterLimit = 1;
if (trail.length > 0) {
ctx.beginPath();
ctx.moveTo(trail[0].x, trail[0].y);
for (let i = 1; i < trail.length; i++) {
ctx.lineTo(trail[i].x, trail[i].y);
}
}
ctx.stroke();
// Need to include this at the end, for now.
ctx.invalidate();
animationHandleRef.current = null;
}
});
}
}, [animationHandleRef, ctx, canvasSize, trailRef]);
// handlers for touch events
const handleMove = useCallback(
async event => {
const position: TrailPoint = {
x: event.nativeEvent.locationX,
y: event.nativeEvent.locationY,
};
const trail = trailRef.current;
if (trail.length > 0) {
const lastPosition = trail[trail.length - 1];
const dx = position.x - lastPosition.x;
const dy = position.y - lastPosition.y;
// add a point to trail if distance from last point > 5
if (dx * dx + dy * dy > 25) {
trail.push(position);
}
} else {
trail.push(position);
}
draw();
},
[trailRef, draw],
);
const handleStart = useCallback(() => {
setDrawingDone(false);
trailRef.current = [];
}, [trailRef, setDrawingDone]);
const handleEnd = useCallback(() => {
setDrawingDone(true);
}, [setDrawingDone]);
useEffect(() => {
draw();
}, [draw]);
return (
<View
style={styles.container}
onLayout={event => {
const {layout} = event.nativeEvent;
setCanvasSize(Math.min(layout?.width || 0, layout?.height || 0));
}}>
<View style={[styles.instruction, {marginTop: insets.top}]}>
<Text style={styles.label}>Write a number</Text>
<Text style={styles.label}>
Let's see if the AI model will get it right
</Text>
</View>
<Canvas
style={{
height: canvasSize,
width: canvasSize,
}}
onContext2D={setCtx}
onTouchMove={handleMove}
onTouchStart={handleStart}
onTouchEnd={handleEnd}
/>
{drawingDone && (
<View style={[styles.resultView]} pointerEvents="none">
<Text style={[styles.label, styles.secondary]}>
Highest confidence will go here
</Text>
<Text style={[styles.label, styles.secondary]}>
Second highest will go here
</Text>
</View>
)}
</View>
);
}
const styles = StyleSheet.create({
container: {
height: '100%',
width: '100%',
backgroundColor: '#180b3b',
justifyContent: 'center',
alignItems: 'center',
},
resultView: {
position: 'absolute',
bottom: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
instruction: {
position: 'absolute',
top: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
label: {
fontSize: 16,
color: '#ffffff',
},
secondary: {
color: '#ffffff99',
},
});
What does requestAnimationFrame
do?
requestAnimationFrame
is a utility function that helps us run code when the screen is ready for the next rerender.
Input: a callback function as a parameter and then runs that function when the screen next refreshes.
Output: a number that functions as an ID for the callback. You can use that number to cancel the callback if you later decide you don't want to run the code. (We don't need that feature for this)
Once you have those changes in your code, go ahead and refresh the app and see how much smoother drawing is.
- Android
- iOS (Simulator)
npx torchlive-cli run-android
npx torchlive-cli run-ios
With silky smooth drawing in place, we are now ready to start working with the MNIST model.
Running the Model
We'll start by creating a React hook that provides a function for running inference on an input image. We'll follow React hooks naming conventions and call ours useMNISTModel
.
Let's summarize the changes we're making:
- Import
Image
andMobileModel
fromreact-native-pytorch-core
. - Load the model file with the
require
function and call itmnistModel
. - Create a type called
MNISTResult
with the following properties:num
- a digit from 0 to 9.score
- the confidence the model has in the input image being the givennum
.
- Define a function called
useMNISTModel
that does the following:- Creates a React callback async function called
processImage
that takes inImage
as a parameter and does the following.- Uses the
MobileModel
api to execute themnistModel
we loaded with a set of parameters that tell the model how much of the image to use and what the foreground and background colors are. - Transform the raw scores into
MNISTResult
objects. - Sort the results by
score
. - return the sorted results.
- Uses the
- Returns an object containing the
processImage
function we just created.
- Creates a React callback async function called
- Changes
- Entire File
@@ -1,6 +1,11 @@
import React, {useCallback, useEffect, useState, useRef} from 'react';
import {StyleSheet, Text, View} from 'react-native';
-import {Canvas, CanvasRenderingContext2D} from 'react-native-pytorch-core';
+import {
+ Canvas,
+ CanvasRenderingContext2D,
+ Image,
+ MobileModel,
+} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
const COLOR_CANVAS_BACKGROUND = '#4F25C6';
@@ -11,6 +16,44 @@
y: number;
};
+// This is the custom model you have trained. See the tutorial for more on preparing a PyTorch model for mobile.
+const mnistModel = require('../../models/mnist.ptl');
+
+type MNISTResult = {
+ num: number;
+ score: number;
+};
+
+/**
+ * The React hook provides MNIST model inference on an input image.
+ */
+function useMNISTModel() {
+ const processImage = useCallback(async (image: Image) => {
+ // Runs model inference on input image
+ const {
+ result: {scores},
+ } = await MobileModel.execute<{scores: number[]}>(mnistModel, {
+ image,
+ crop_width: 1,
+ crop_height: 1,
+ scale_width: 28,
+ scale_height: 28,
+ colorBackground: COLOR_CANVAS_BACKGROUND,
+ colorForeground: COLOR_TRAIL_STROKE,
+ });
+
+ // Get the score of each number (index), and sort the array by the most likely first.
+ const sortedScore: MNISTResult[] = scores
+ .map((score, index) => ({score: score, num: index}))
+ .sort((a, b) => b.score - a.score);
+ return sortedScore;
+ }, []);
+
+ return {
+ processImage,
+ };
+}
+
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
import React, {useCallback, useEffect, useState, useRef} from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {
Canvas,
CanvasRenderingContext2D,
Image,
MobileModel,
} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
const COLOR_CANVAS_BACKGROUND = '#4F25C6';
const COLOR_TRAIL_STROKE = '#FFFFFF';
type TrailPoint = {
x: number;
y: number;
};
// This is the custom model you have trained. See the tutorial for more on preparing a PyTorch model for mobile.
const mnistModel = require('../../models/mnist.ptl');
type MNISTResult = {
num: number;
score: number;
};
/**
* The React hook provides MNIST model inference on an input image.
*/
function useMNISTModel() {
const processImage = useCallback(async (image: Image) => {
// Runs model inference on input image
const {
result: {scores},
} = await MobileModel.execute<{scores: number[]}>(mnistModel, {
image,
crop_width: 1,
crop_height: 1,
scale_width: 28,
scale_height: 28,
colorBackground: COLOR_CANVAS_BACKGROUND,
colorForeground: COLOR_TRAIL_STROKE,
});
// Get the score of each number (index), and sort the array by the most likely first.
const sortedScore: MNISTResult[] = scores
.map((score, index) => ({score: score, num: index}))
.sort((a, b) => b.score - a.score);
return sortedScore;
}, []);
return {
processImage,
};
}
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
const [canvasSize, setCanvasSize] = useState<number>(0);
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
const trailRef = useRef<TrailPoint[]>([]);
const [drawingDone, setDrawingDone] = useState(false);
const animationHandleRef = useRef<number | null>(null);
const draw = useCallback(() => {
if (animationHandleRef.current != null) return;
if (ctx != null) {
animationHandleRef.current = requestAnimationFrame(() => {
const trail = trailRef.current;
if (trail != null) {
// fill background by drawing a rect
ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
ctx.fillRect(0, 0, canvasSize, canvasSize);
// Draw the trail
ctx.strokeStyle = COLOR_TRAIL_STROKE;
ctx.lineWidth = 25;
ctx.lineJoin = 'round';
ctx.lineCap = 'round';
ctx.miterLimit = 1;
if (trail.length > 0) {
ctx.beginPath();
ctx.moveTo(trail[0].x, trail[0].y);
for (let i = 1; i < trail.length; i++) {
ctx.lineTo(trail[i].x, trail[i].y);
}
}
ctx.stroke();
// Need to include this at the end, for now.
ctx.invalidate();
animationHandleRef.current = null;
}
});
}
}, [animationHandleRef, ctx, canvasSize, trailRef]);
// handlers for touch events
const handleMove = useCallback(
async event => {
const position: TrailPoint = {
x: event.nativeEvent.locationX,
y: event.nativeEvent.locationY,
};
const trail = trailRef.current;
if (trail.length > 0) {
const lastPosition = trail[trail.length - 1];
const dx = position.x - lastPosition.x;
const dy = position.y - lastPosition.y;
// add a point to trail if distance from last point > 5
if (dx * dx + dy * dy > 25) {
trail.push(position);
}
} else {
trail.push(position);
}
draw();
},
[trailRef, draw],
);
const handleStart = useCallback(() => {
setDrawingDone(false);
trailRef.current = [];
}, [trailRef, setDrawingDone]);
const handleEnd = useCallback(() => {
setDrawingDone(true);
}, [setDrawingDone]);
useEffect(() => {
draw();
}, [draw]);
return (
<View
style={styles.container}
onLayout={event => {
const {layout} = event.nativeEvent;
setCanvasSize(Math.min(layout?.width || 0, layout?.height || 0));
}}>
<View style={[styles.instruction, {marginTop: insets.top}]}>
<Text style={styles.label}>Write a number</Text>
<Text style={styles.label}>
Let's see if the AI model will get it right
</Text>
</View>
<Canvas
style={{
height: canvasSize,
width: canvasSize,
}}
onContext2D={setCtx}
onTouchMove={handleMove}
onTouchStart={handleStart}
onTouchEnd={handleEnd}
/>
{drawingDone && (
<View style={[styles.resultView]} pointerEvents="none">
<Text style={[styles.label, styles.secondary]}>
Highest confidence will go here
</Text>
<Text style={[styles.label, styles.secondary]}>
Second highest will go here
</Text>
</View>
)}
</View>
);
}
const styles = StyleSheet.create({
container: {
height: '100%',
width: '100%',
backgroundColor: '#180b3b',
justifyContent: 'center',
alignItems: 'center',
},
resultView: {
position: 'absolute',
bottom: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
instruction: {
position: 'absolute',
top: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
label: {
fontSize: 16,
color: '#ffffff',
},
secondary: {
color: '#ffffff99',
},
});
An even shorter summary: it takes in an Image
and gives back a list of sorted results.
But, we don't have Image
s. We just have a trail on a canvas.
In the next section, we'll learn how to create an Image
from the contents of our canvas that we can pass to the model.
Creating an Image from our Canvas
We are going to create another hook called useMNISTCanvasInference
that uses the hook we just created (useMNISTModel
).
This hook will take in the canvasSize
and give us back two things:
result
- a state variable that holds the sorted list ofMNISTResult
s from the model.classify
- a function that takes in thecanvas
context, extracts an image from it, processes the image, and then updates theresult
state variable.
In our classify
callback, we use some of the PTL core components, including the newly imported ImageUtil
object.
The ImageUtil
object allows us to take the imageData
we pull from the canvas and turn it into an Image
that can be used by our model.
You'll also see that we call the release
function on both our imageData
and our image
variables as soon as we are done using them. This is a vital step to make sure we don't run out of memory on images we no longer need.
- Changes
- Entire File
@@ -4,6 +4,7 @@
Canvas,
CanvasRenderingContext2D,
Image,
+ ImageUtil,
MobileModel,
} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
@@ -54,6 +55,48 @@
};
}
+/**
+ * The React hook provides MNIST inference using the image data extracted from
+ * a canvas.
+ *
+ * @param canvasSize The size of the square canvas
+ */
+function useMNISTCanvasInference(canvasSize: number) {
+ const [result, setResult] = useState<MNISTResult[]>();
+ const {processImage} = useMNISTModel();
+ const classify = useCallback(
+ async (ctx: CanvasRenderingContext2D) => {
+ // Return immediately if canvas is size 0
+ if (canvasSize === 0) {
+ return null;
+ }
+
+ // Get image data center crop
+ const imageData = await ctx.getImageData(0, 0, canvasSize, canvasSize);
+
+ // Convert image data to image.
+ const image: Image = await ImageUtil.fromImageData(imageData);
+
+ // Release image data to free memory
+ imageData.release();
+
+ // Run MNIST inference on the image
+ const result = await processImage(image);
+
+ // Release image to free memory
+ image.release();
+
+ // Set result state to force re-render of component that uses this hook
+ setResult(result);
+ },
+ [canvasSize, processImage, setResult],
+ );
+ return {
+ result,
+ classify,
+ };
+}
+
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
import React, {useCallback, useEffect, useState, useRef} from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {
Canvas,
CanvasRenderingContext2D,
Image,
ImageUtil,
MobileModel,
} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
const COLOR_CANVAS_BACKGROUND = '#4F25C6';
const COLOR_TRAIL_STROKE = '#FFFFFF';
type TrailPoint = {
x: number;
y: number;
};
// This is the custom model you have trained. See the tutorial for more on preparing a PyTorch model for mobile.
const mnistModel = require('../../models/mnist.ptl');
type MNISTResult = {
num: number;
score: number;
};
/**
* The React hook provides MNIST model inference on an input image.
*/
function useMNISTModel() {
const processImage = useCallback(async (image: Image) => {
// Runs model inference on input image
const {
result: {scores},
} = await MobileModel.execute<{scores: number[]}>(mnistModel, {
image,
crop_width: 1,
crop_height: 1,
scale_width: 28,
scale_height: 28,
colorBackground: COLOR_CANVAS_BACKGROUND,
colorForeground: COLOR_TRAIL_STROKE,
});
// Get the score of each number (index), and sort the array by the most likely first.
const sortedScore: MNISTResult[] = scores
.map((score, index) => ({score: score, num: index}))
.sort((a, b) => b.score - a.score);
return sortedScore;
}, []);
return {
processImage,
};
}
/**
* The React hook provides MNIST inference using the image data extracted from
* a canvas.
*
* @param canvasSize The size of the square canvas
*/
function useMNISTCanvasInference(canvasSize: number) {
const [result, setResult] = useState<MNISTResult[]>();
const {processImage} = useMNISTModel();
const classify = useCallback(
async (ctx: CanvasRenderingContext2D) => {
// Return immediately if canvas is size 0
if (canvasSize === 0) {
return null;
}
// Get image data center crop
const imageData = await ctx.getImageData(0, 0, canvasSize, canvasSize);
// Convert image data to image.
const image: Image = await ImageUtil.fromImageData(imageData);
// Release image data to free memory
imageData.release();
// Run MNIST inference on the image
const result = await processImage(image);
// Release image to free memory
image.release();
// Set result state to force re-render of component that uses this hook
setResult(result);
},
[canvasSize, processImage, setResult],
);
return {
result,
classify,
};
}
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
const [canvasSize, setCanvasSize] = useState<number>(0);
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
const trailRef = useRef<TrailPoint[]>([]);
const [drawingDone, setDrawingDone] = useState(false);
const animationHandleRef = useRef<number | null>(null);
const draw = useCallback(() => {
if (animationHandleRef.current != null) return;
if (ctx != null) {
animationHandleRef.current = requestAnimationFrame(() => {
const trail = trailRef.current;
if (trail != null) {
// fill background by drawing a rect
ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
ctx.fillRect(0, 0, canvasSize, canvasSize);
// Draw the trail
ctx.strokeStyle = COLOR_TRAIL_STROKE;
ctx.lineWidth = 25;
ctx.lineJoin = 'round';
ctx.lineCap = 'round';
ctx.miterLimit = 1;
if (trail.length > 0) {
ctx.beginPath();
ctx.moveTo(trail[0].x, trail[0].y);
for (let i = 1; i < trail.length; i++) {
ctx.lineTo(trail[i].x, trail[i].y);
}
}
ctx.stroke();
// Need to include this at the end, for now.
ctx.invalidate();
animationHandleRef.current = null;
}
});
}
}, [animationHandleRef, ctx, canvasSize, trailRef]);
// handlers for touch events
const handleMove = useCallback(
async event => {
const position: TrailPoint = {
x: event.nativeEvent.locationX,
y: event.nativeEvent.locationY,
};
const trail = trailRef.current;
if (trail.length > 0) {
const lastPosition = trail[trail.length - 1];
const dx = position.x - lastPosition.x;
const dy = position.y - lastPosition.y;
// add a point to trail if distance from last point > 5
if (dx * dx + dy * dy > 25) {
trail.push(position);
}
} else {
trail.push(position);
}
draw();
},
[trailRef, draw],
);
const handleStart = useCallback(() => {
setDrawingDone(false);
trailRef.current = [];
}, [trailRef, setDrawingDone]);
const handleEnd = useCallback(() => {
setDrawingDone(true);
}, [setDrawingDone]);
useEffect(() => {
draw();
}, [draw]);
return (
<View
style={styles.container}
onLayout={event => {
const {layout} = event.nativeEvent;
setCanvasSize(Math.min(layout?.width || 0, layout?.height || 0));
}}>
<View style={[styles.instruction, {marginTop: insets.top}]}>
<Text style={styles.label}>Write a number</Text>
<Text style={styles.label}>
Let's see if the AI model will get it right
</Text>
</View>
<Canvas
style={{
height: canvasSize,
width: canvasSize,
}}
onContext2D={setCtx}
onTouchMove={handleMove}
onTouchStart={handleStart}
onTouchEnd={handleEnd}
/>
{drawingDone && (
<View style={[styles.resultView]} pointerEvents="none">
<Text style={[styles.label, styles.secondary]}>
Highest confidence will go here
</Text>
<Text style={[styles.label, styles.secondary]}>
Second highest will go here
</Text>
</View>
)}
</View>
);
}
const styles = StyleSheet.create({
container: {
height: '100%',
width: '100%',
backgroundColor: '#180b3b',
justifyContent: 'center',
alignItems: 'center',
},
resultView: {
position: 'absolute',
bottom: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
instruction: {
position: 'absolute',
top: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
label: {
fontSize: 16,
color: '#ffffff',
},
secondary: {
color: '#ffffff99',
},
});
With this second hook, we are ready to run our model with the user created drawings. Let's hook it up in the next section.
Running the Model & Displaying Results
While we add a decent amount of lines in this section, they are all simple changes.
Let's cut to the summary:
- Create a type called
NumberLabelSet
so we know what kind of data we have access to about a number. - Create a list of
NumberLabelSet
objects and call itnumLabels
. - Get the
classify
method andresult
state variable by callinguseMNISTCanvasInference
from within our demo component. - Update the
handleEnd
function to check for a canvas context and then trigger the model. - Add
classify
as a dependency to thehandleEnd
callback function. - Change the text in the results section to reflect the numbers from the model output.
- Changes
- Entire File
@@ -97,6 +97,54 @@
};
}
+type NumberLabelSet = {
+ english: string;
+ asciiSymbol: string;
+};
+
+const numLabels: NumberLabelSet[] = [
+ {
+ english: 'zero',
+ asciiSymbol: '🄌',
+ },
+ {
+ english: 'one',
+ asciiSymbol: '➊',
+ },
+ {
+ english: 'two',
+ asciiSymbol: '➋',
+ },
+ {
+ english: 'three',
+ asciiSymbol: '➌',
+ },
+ {
+ english: 'four',
+ asciiSymbol: '➍',
+ },
+ {
+ english: 'five',
+ asciiSymbol: '➎',
+ },
+ {
+ english: 'six',
+ asciiSymbol: '➏',
+ },
+ {
+ english: 'seven',
+ asciiSymbol: '➐',
+ },
+ {
+ english: 'eight',
+ asciiSymbol: '➑',
+ },
+ {
+ english: 'nine',
+ asciiSymbol: '➒',
+ },
+];
+
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
@@ -105,6 +153,8 @@
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
+ const {classify, result} = useMNISTCanvasInference(canvasSize);
+
const trailRef = useRef<TrailPoint[]>([]);
const [drawingDone, setDrawingDone] = useState(false);
const animationHandleRef = useRef<number | null>(null);
@@ -173,7 +223,8 @@
const handleEnd = useCallback(() => {
setDrawingDone(true);
- }, [setDrawingDone]);
+ if (ctx != null) classify(ctx);
+ }, [setDrawingDone, classify, ctx]);
useEffect(() => {
draw();
@@ -205,10 +256,16 @@
{drawingDone && (
<View style={[styles.resultView]} pointerEvents="none">
<Text style={[styles.label, styles.secondary]}>
- Highest confidence will go here
+ {result &&
+ `${numLabels[result[0].num].asciiSymbol} it looks like ${
+ numLabels[result[0].num].english
+ }`}
</Text>
<Text style={[styles.label, styles.secondary]}>
- Second highest will go here
+ {result &&
+ `${numLabels[result[1].num].asciiSymbol} or it might be ${
+ numLabels[result[1].num].english
+ }`}
</Text>
</View>
)}
import React, {useCallback, useEffect, useState, useRef} from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {
Canvas,
CanvasRenderingContext2D,
Image,
ImageUtil,
MobileModel,
} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
const COLOR_CANVAS_BACKGROUND = '#4F25C6';
const COLOR_TRAIL_STROKE = '#FFFFFF';
type TrailPoint = {
x: number;
y: number;
};
// This is the custom model you have trained. See the tutorial for more on preparing a PyTorch model for mobile.
const mnistModel = require('../../models/mnist.ptl');
type MNISTResult = {
num: number;
score: number;
};
/**
* The React hook provides MNIST model inference on an input image.
*/
function useMNISTModel() {
const processImage = useCallback(async (image: Image) => {
// Runs model inference on input image
const {
result: {scores},
} = await MobileModel.execute<{scores: number[]}>(mnistModel, {
image,
crop_width: 1,
crop_height: 1,
scale_width: 28,
scale_height: 28,
colorBackground: COLOR_CANVAS_BACKGROUND,
colorForeground: COLOR_TRAIL_STROKE,
});
// Get the score of each number (index), and sort the array by the most likely first.
const sortedScore: MNISTResult[] = scores
.map((score, index) => ({score: score, num: index}))
.sort((a, b) => b.score - a.score);
return sortedScore;
}, []);
return {
processImage,
};
}
/**
* The React hook provides MNIST inference using the image data extracted from
* a canvas.
*
* @param canvasSize The size of the square canvas
*/
function useMNISTCanvasInference(canvasSize: number) {
const [result, setResult] = useState<MNISTResult[]>();
const {processImage} = useMNISTModel();
const classify = useCallback(
async (ctx: CanvasRenderingContext2D) => {
// Return immediately if canvas is size 0
if (canvasSize === 0) {
return null;
}
// Get image data center crop
const imageData = await ctx.getImageData(0, 0, canvasSize, canvasSize);
// Convert image data to image.
const image: Image = await ImageUtil.fromImageData(imageData);
// Release image data to free memory
imageData.release();
// Run MNIST inference on the image
const result = await processImage(image);
// Release image to free memory
image.release();
// Set result state to force re-render of component that uses this hook
setResult(result);
},
[canvasSize, processImage, setResult],
);
return {
result,
classify,
};
}
type NumberLabelSet = {
english: string;
asciiSymbol: string;
};
const numLabels: NumberLabelSet[] = [
{
english: 'zero',
asciiSymbol: '🄌',
},
{
english: 'one',
asciiSymbol: '➊',
},
{
english: 'two',
asciiSymbol: '➋',
},
{
english: 'three',
asciiSymbol: '➌',
},
{
english: 'four',
asciiSymbol: '➍',
},
{
english: 'five',
asciiSymbol: '➎',
},
{
english: 'six',
asciiSymbol: '➏',
},
{
english: 'seven',
asciiSymbol: '➐',
},
{
english: 'eight',
asciiSymbol: '➑',
},
{
english: 'nine',
asciiSymbol: '➒',
},
];
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
const [canvasSize, setCanvasSize] = useState<number>(0);
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
const {classify, result} = useMNISTCanvasInference(canvasSize);
const trailRef = useRef<TrailPoint[]>([]);
const [drawingDone, setDrawingDone] = useState(false);
const animationHandleRef = useRef<number | null>(null);
const draw = useCallback(() => {
if (animationHandleRef.current != null) return;
if (ctx != null) {
animationHandleRef.current = requestAnimationFrame(() => {
const trail = trailRef.current;
if (trail != null) {
// fill background by drawing a rect
ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
ctx.fillRect(0, 0, canvasSize, canvasSize);
// Draw the trail
ctx.strokeStyle = COLOR_TRAIL_STROKE;
ctx.lineWidth = 25;
ctx.lineJoin = 'round';
ctx.lineCap = 'round';
ctx.miterLimit = 1;
if (trail.length > 0) {
ctx.beginPath();
ctx.moveTo(trail[0].x, trail[0].y);
for (let i = 1; i < trail.length; i++) {
ctx.lineTo(trail[i].x, trail[i].y);
}
}
ctx.stroke();
// Need to include this at the end, for now.
ctx.invalidate();
animationHandleRef.current = null;
}
});
}
}, [animationHandleRef, ctx, canvasSize, trailRef]);
// handlers for touch events
const handleMove = useCallback(
async event => {
const position: TrailPoint = {
x: event.nativeEvent.locationX,
y: event.nativeEvent.locationY,
};
const trail = trailRef.current;
if (trail.length > 0) {
const lastPosition = trail[trail.length - 1];
const dx = position.x - lastPosition.x;
const dy = position.y - lastPosition.y;
// add a point to trail if distance from last point > 5
if (dx * dx + dy * dy > 25) {
trail.push(position);
}
} else {
trail.push(position);
}
draw();
},
[trailRef, draw],
);
const handleStart = useCallback(() => {
setDrawingDone(false);
trailRef.current = [];
}, [trailRef, setDrawingDone]);
const handleEnd = useCallback(() => {
setDrawingDone(true);
if (ctx != null) classify(ctx);
}, [setDrawingDone, classify, ctx]);
useEffect(() => {
draw();
}, [draw]);
return (
<View
style={styles.container}
onLayout={event => {
const {layout} = event.nativeEvent;
setCanvasSize(Math.min(layout?.width || 0, layout?.height || 0));
}}>
<View style={[styles.instruction, {marginTop: insets.top}]}>
<Text style={styles.label}>Write a number</Text>
<Text style={styles.label}>
Let's see if the AI model will get it right
</Text>
</View>
<Canvas
style={{
height: canvasSize,
width: canvasSize,
}}
onContext2D={setCtx}
onTouchMove={handleMove}
onTouchStart={handleStart}
onTouchEnd={handleEnd}
/>
{drawingDone && (
<View style={[styles.resultView]} pointerEvents="none">
<Text style={[styles.label, styles.secondary]}>
{result &&
`${numLabels[result[0].num].asciiSymbol} it looks like ${
numLabels[result[0].num].english
}`}
</Text>
<Text style={[styles.label, styles.secondary]}>
{result &&
`${numLabels[result[1].num].asciiSymbol} or it might be ${
numLabels[result[1].num].english
}`}
</Text>
</View>
)}
</View>
);
}
const styles = StyleSheet.create({
container: {
height: '100%',
width: '100%',
backgroundColor: '#180b3b',
justifyContent: 'center',
alignItems: 'center',
},
resultView: {
position: 'absolute',
bottom: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
instruction: {
position: 'absolute',
top: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
label: {
fontSize: 16,
color: '#ffffff',
},
secondary: {
color: '#ffffff99',
},
});
When you run the code, you should see it display results properly in the bottom left corner like the screen recording below.
- Android
- iOS (Simulator)
npx torchlive-cli run-android
npx torchlive-cli run-ios
And with that we have a working MNIST classifier!