Deploying PyTorch model to Posit Connect with Vetiver

I'm trying to deploy a PyTorch model to Connect as an API with vetiver and getting stuck. I think the issue is that my deployed API doesn't 'have' the class from which the model was created. This is all being done with the Bank Marketing Dataset from Kaggle.

The model itself is created from a class that inherits from torch.nn.Module, which seems to be a common way to do this.

from torch import nn


class BankModel(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.linear_stack = nn.Sequential(
            nn.Linear(input_size, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.linear_stack(x)
        return x

I then save the trained model to disk with torch.save. In the deployment script I read the model from disk, pin it to Connect, and then deploy an API as per the vetiver docs. This is the relevant section of the code:

bank_model = BankModel(input_size=input_size).to("cpu")
bank_model.load_state_dict(torch.load(model_filepath))
v = vetiver.VetiverModel(
    bank_model,
    model_name=model_name,
    versioned=True
)
vetiver.vetiver_pin_write(board, v)
latest_version = sorted(
    board.pin_versions(model_name)["version"].to_list())[0]
app_id = os.getenv("APP_ID")
app_id = None if app_id is None else int(app_id)
vetiver.deploy_rsconnect(
    connect_server=connect_server,
    board=board,
    pin_name=model_name,
    version=latest_version,
    extra_files=[req_filepath, cert_filepath],
    new=False,
    app_id=app_id,
    title=api_title
)

Everything deploys but when I try to access the API I get an error: Unexpected error while running Python API: Can't get attribute 'BankModel' on <module '__main__' from '/opt/rstudio-connect/python/connect_fastapi_runtime.py'>.

I'm much less familiar with Python than R so it's entirely possible that I'm making an obvious mistake, but as far as I can tell I'm sending the model up to Connect without the custom class that tells it how to use the model object. I can't figure out how to provide that class though.

Here is the full script that I'm using to train the model:

import click
from dotenv import find_dotenv, load_dotenv
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset


class BankDataset(Dataset):
    def __init__(self, data):
        self.all = torch.as_tensor(data)
        self.features = self.all[:, :-1]
        self.target = self.all[:, -1].reshape(-1, 1)

    def __len__(self):
        return len(self.target)

    def __getitem__(self, idx):
        x = self.features[idx]
        y = self.target[idx]
        return x, y


class BankModel(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.linear_stack = nn.Sequential(
            nn.Linear(input_size, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.linear_stack(x)
        return x


def train(dataloader, model, loss_fn, optimizer, device):
    model.train()
    train_loss = 0.0

    for X, y in dataloader:
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss.item() * y.size(0)

    train_loss /= len(dataloader.dataset)
    return train_loss


def val(dataloader, model, loss_fn, device):
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            val_loss += loss_fn(pred, y).item() * y.size(0)
    val_loss /= len(dataloader.dataset)
    return val_loss


@click.command()
@click.argument('dtrain_path', type=click.Path(exists=True))
@click.argument('dval_path', type=click.Path(exists=True))
@click.argument('model_path', type=click.Path())
def main(dtrain_path, dval_path, model_path):
    """
    Train a PyTorch model using the training data and validation data that were
    already processed earlier. These are saved to dtrain_path and dval_path as
    numpy arrays, where the rightmost column is the outcome variable. The
    outcome variable is a binary variable.

    The PyTorch model has one hidden layer with 16 nodes. During training the
    model performance is evaluated using the validation data, based on
    the log-loss metric.

    The function writes out the best-performing model out to disk.
    """
    # Convert data to PyTorch Datasets
    dtrain = BankDataset(np.load(dtrain_path, allow_pickle=True))
    dval = BankDataset(np.load(dval_path, allow_pickle=True))

    # Create DataLoader for training and validation datasets
    batch_size = 64
    torch.manual_seed(1)
    train_loader = DataLoader(dtrain, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(dval, batch_size=batch_size)

    # Initialize the model
    input_size = dtrain.features.shape[1]
    # Get cpu, gpu or mps device for training.
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )
    print(f"Using {device} device")
    model = BankModel(input_size).to(device)

    # Define the loss function and optimizer
    learning_rate = 0.001
    loss_fn = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Training loop from the book
    best_val_loss = float('inf')
    epochs = 10
    log_epochs = 10
    for t in range(epochs):
        train_loss = train(train_loader, model, loss_fn, optimizer, device)
        val_loss = val(val_loader, model, loss_fn, device)
        if (t + 1) % log_epochs == 0:
            print(f"Epoch {t+1}\n-------------------------------")
            print(f"Train loss: {train_loss:.5f}\n" +
                  f"Validation loss: {val_loss:.5f}")
        # Save model if it's the best performing
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), model_path)


if __name__ == '__main__':
    # find .env automagically by walking up directories until it's found, then
    # load up the .env entries as environment variables
    load_dotenv(find_dotenv())

    main()

And here is the full deployment script:

import click
import torch
import numpy as np
from torch import nn
from torch.utils.data import DataLoader, Dataset
from dotenv import find_dotenv, load_dotenv
import os
import pins
import vetiver
import json
import rsconnect.api
from rsconnect.api import RSConnectServer


class BankDataset(Dataset):
    def __init__(self, data):
        self.all = torch.as_tensor(data)
        self.features = self.all[:, :-1]
        self.target = self.all[:, -1].reshape(-1, 1)

    def __len__(self):
        return len(self.target)

    def __getitem__(self, idx):
        x = self.features[idx]
        y = self.target[idx]
        return x, y


class BankModel(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.linear_stack = nn.Sequential(
            nn.Linear(input_size, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.linear_stack(x)
        return x


@click.command()
@click.argument('model_filepath', type=click.Path(exists=True))
@click.argument('dtrain_filepath', type=click.Path(exists=True))
@click.argument('req_filepath', type=click.Path(exists=True))
@click.argument('output_filepath', type=click.Path())
def main(model_filepath, dtrain_filepath, req_filepath, output_filepath):
    """ Pin the trained model to RStudio Connect, use that pin to deploy an
        API, and capture the info for the deployed model as YAML.
    """
    rsc_url = os.getenv("CONNECT_SERVER")
    api_key = os.getenv("CONNECT_API_KEY")
    model_name = os.getenv("MODEL_NAME")
    api_title = model_name + "_vetiver"
    os.environ["REQUESTS_CA_BUNDLE"] = cert_filepath
    connect_server = RSConnectServer(url=rsc_url, api_key=api_key)
    board = pins.board_connect(
        server_url=rsc_url,
        api_key=api_key,
        allow_pickle_read=True
    )
    dtrain = BankDataset(np.load(dtrain_filepath, allow_pickle=True))
    input_size = dtrain.features.shape[1]
    bank_model = BankModel(input_size=input_size).to("cpu")
    bank_model.load_state_dict(torch.load(model_filepath))
    v = vetiver.VetiverModel(
        bank_model,
        model_name=model_name,
        versioned=True
    )
    vetiver.vetiver_pin_write(board, v)
    latest_version = sorted(
        board.pin_versions(model_name)["version"].to_list())[0]
    app_id = os.getenv("APP_ID")
    app_id = None if app_id is None else int(app_id)
    vetiver.deploy_rsconnect(
        connect_server=connect_server,
        board=board,
        pin_name=model_name,
        version=latest_version,
        extra_files=[req_filepath],
        new=False,
        app_id=app_id,
        title=api_title
    )
    if app_id is None:
        all_apps = rsconnect.api.retrieve_matching_apps(connect_server)
        possible_ids = [x["id"] for x in all_apps if x["title"] == api_title]
        app_id = sorted(possible_ids)[-1]
    app_info = rsconnect.api.get_app_info(connect_server, app_id)
    with open(output_filepath, "w", encoding="utf-8") as file:
        json.dump(app_info, file)


if __name__ == '__main__':
    # find .env automagically by walking up directories until it's found, then
    # load up the .env entries as environment variables
    load_dotenv(find_dotenv())

    main()

Hello there and welcome @hamedbh! This is a great question--thank you for asking! You are exactly correct in your assumption of "as far as I can tell I'm sending the model up to Connect without the custom class that tells it how to use the model object." This is a tricky space-- vetiver is pretty clever when it comes to serializing/deserializing models, but when you add custom components without the source code, it unfortunately will not have the context to recreate your BankModel object. There's a few main ways to get around this:

  1. Use the linear_stack as your model in the VetiverModel and then add the forward method as a POST endpoint on your API using VetiverAPI.vetiver_post().
  2. Create a small Python package for your BankModel, which will allow your script to load the source code necessary to remake your custom class.

Let me know if either of these make sense for your use case!