Deploy a model built with the PyTorch framework.
PyTorch models are handled in Baseten just like scikit-learn and TensorFlow models, but require additional file(s) defining the model class.
Baseten officially supports torch version 1.9.0 or higher. Especially if you're using an online notebook environment like Google Colab or a bundle of packages like Anaconda, ensure that the version you are using is supported. If it's not, use the --upgrade flag and pip will install the most recent version.
If your model class MyModel is defined in the file, add the following keyword argument to the baseten.deploy call:
import baseten
baseten.login("*** INSERT API KEY ***") #
my_model, # e.g: a PyTorch model MyModel
model_name='My pytorch model'
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
baseten.deploy will deploy your model to Baseten and print out a URL. Go there on your browser to see its deployment status and other useful information.