Hello World Model
In this tutorial, you will create a "Hello World" model. The model will take a string as input and return a string as output. You will also learn how to export a model as TorchScript model that can be loaded with the PlayTorch SDK for on-device inference.
Create PyTorch Model
Let's begin by creating a PyTorch model. Here, we are going to create a simple
"Hello World" model using torch.nn.Module
to represent a neural network (hence
the namespace nn
).
The model defines a forward
function with one argument name
. The function
"performs" the computation, e.g., in later tutorials, it will perform inference
on an image.
The model constructor has one argument prefix
, which will be used in the
forward
function to prefix the name
argument.
More details on PyTorch modules at https://pytorch.org/docs/stable/notes/modules.html
import torch
from torch import nn
class Model(nn.Module):
def __init__(self, prefix: str):
super().__init__()
self.prefix = prefix
def forward(self, name: str) -> str:
return f"{self.prefix} {name}!"
Create an instance of the model
Next, let's create a instance of the model and perform a computation.
model = Model("Hello")
model("Roman")
Hello Roman!
Export Model for Mobile
Now that we have a model, let's export the model to use on mobile. To do that, we need to script the model (i.e., create a TorchScript representation) as follows:
scripted_model = torch.jit.script(model)
scripted_model("Lindsay")
Hello Lindsay!
The torch.jit.script
is the recommended way to create a TorchScript
model
because it can capture control flow,
but it might fail in some cases. If that happens, we recommend consulting the PyTorch
TorchScript documentation for solutions.
PyTorch offers the optimize_for_mobile
utility function to run a list of
optimizations on the model (e.g., Conv2D + BatchNorm fusion, dropout removal).
It's recommended to optimize the model with this utility before exporting it for
mobile.
More details on the optimize_for_mobile
utility at: https://pytorch.org/docs/stable/mobile_optimizer.html
from torch.utils.mobile_optimizer import optimize_for_mobile
optimized_model = optimize_for_mobile(scripted_model)
optimized_model("Kodo")
Hello Kodo!
Great! Now, let's export the model for mobile. This is done by saving the model
for the lite interpreter. The _save_for_lite_interpreter
function will create
a hello_world.ptl
file, which we will be able to load with the PlayTorch SDK.
optimized_model._save_for_lite_interpreter("hello_world.ptl")
More details on the lite interpreter at: https://pytorch.org/tutorials/prototype/lite_interpreter.html
Create Mobile UI and Load Model on Mobile
Next, let's create a PlayTorch Snack by following the link
http://snack.playtorch.dev/. Then, drag and drop the hello_world.ptl
file onto
the just created PlayTorch Snack--this will import the model into the Snack.
Replace the source code in the App.js
with the React Native source code below.
The source code below will create a user interface with a text input, a button,
and a text element. When pressing the button, it will load the hello_world.ptl
model and call the model forward function with the text input value as argument.
The returned model output will then be displayed below the button.
import * as React from 'react';
import { useState } from 'react';
import {
Button,
SafeAreaView,
StyleSheet,
Text,
TextInput,
View,
} from 'react-native';
import { torch, MobileModel } from 'react-native-pytorch-core';
export default function App() {
const [modelInput, setModelInput] = useState('');
const [modelOutput, setModelOutput] = useState('');
async function handleModelInput() {
const filePath = await MobileModel.download(require('./hello_world.ptl'));
const model = await torch.jit._loadForMobile(filePath);
const output = await model.forward(modelInput);
setModelOutput(output);
}
return (
<SafeAreaView style={StyleSheet.absoluteFill}>
<View style={styles.container}>
<TextInput
value={modelInput}
onChangeText={setModelInput}
placeholder="Write your name"
/>
<Button title="Let's Go" onPress={handleModelInput} />
<Text>{modelOutput}</Text>
</View>
</SafeAreaView>
);
}
const styles = StyleSheet.create({
container: {
flex: 1,
justifyContent: 'center',
backgroundColor: '#fff',
padding: 20,
},
});