I'm trying to deploy a torch model using vetiver to host an API endpoint on our posit connect server. Everything runs smooth when I try to debug locally and the deployment to the server also goes fine without any errors.
The all endpoints of the API works except predict that returns "Internal Server Error" when trying with the dummy data, how should I proceed debugging and finding the cause for this?
This is my deployment script
import pandas as pd
import torch
import os
import vetiver
import pins
from mlp_model import MLPModel
from rsconnect.api import RSConnectServer
# Silence model_card warning
vetiver.utils.modelcard_options.quiet = True
# Path to state dict file
state_dict_path = r"mlp_model_state_dict.pt"
# Rebuild model and load weights
input_length, feats, hidden_sizes, output_size = 378, 1, (64, 8), 1
mlp = MLPModel(input_length, feats, hidden_sizes, output_size)
mlp.load_state_dict(torch.load(state_dict_path , map_location="cpu"))
mlp.eval()
# Define dummy data
INPUT_LEN = 378
FEATURES = [f"x{i}" for i in range(INPUT_LEN)]
X_proto = pd.DataFrame([[0.0]*INPUT_LEN], columns=FEATURES).astype("float32")
pin_name = "mlp_model"
# Vetiver model
v = vetiver.VetiverModel(
model=mlp,
model_name=pin_name,
prototype_data=X_proto,
handler_predict=vetiver.TorchHandler(
model=mlp,
prototype_data=X_proto
),
versioned=True,
description="Test of Torch model"
)
# Verify model, this runs OK
# y_vetiver = v.handler_predict(X_proto, check_prototype = True)
# print("OK vetiver:", y_vetiver)
# Connect board and pin
connect_server = RSConnectServer(
url= os.environ["CONNECT_SERVER"],
api_key= os.environ["CONNECT_API_KEY"]
)
board = pins.board_connect(
server_url=os.environ["CONNECT_SERVER"],
api_key=os.environ["CONNECT_API_KEY"],
allow_pickle_read=True
)
# Write pin
vetiver.vetiver_pin_write(board=board, model=v)
# Deploy
vetiver.deploy_connect(
connect_server=connect_server,
board=board,
pin_name=pin_name,
title="VetiverTorchTest",
extra_files=['requirements.txt','mlp_model.py'],
new=True,
app_id=None,
python=None,
force_generate=False,
log_callback=None
)