Hi Edgar,
This is possible if you are able to jit_trace
your model. Then your model is serialized as a ScriptModule
and can be loaded in Python without any additional dependency. SInce torchserver use python in the backend this should work.
The main caveat here is that you wouldn't be able to use models that are not traceable ie. models that need to use the --model-file
parameter in torchserve
as R nn_modules
wouldn't be readable by python. In this case you'd better using plumber or something that can load an R process and execute R logic.
Also, for now, you need to create a wrapper function to your model and detach all parameters, before tracing, but we will implement support for jit_tracing modules in the near future.
I did this quick example and it worked as expected. First I jit_trace
d an R torch model to to save it as a ScriptModule
. THe model i traced is a pretrained alexnet from torchvision
:
library(torch)
net <- torchvision::model_alexnet(pretrained = TRUE)
# currently we need to detach all parameters in order to
# JIT compile. We need to support modules to avoid that.
for (p in net$parameters) {
p$detach_()
}
# currently we can only JIT functions, not nn_modules, so we wrap
# the model into a function.
# this will be implemented soon
fn <- function(x) {
net(x)
}
input <- torch_randn(100, 3, 224, 224)
out <- fn(input)
tr_fn <- jit_trace(fn, input)
jit_save(tr_fn, "models/model.pt")
Now, after installing torch-model-archiver
with pip install torch-model-archiver
I 'archived' that model to the .mar
file with:
torch-model-archiver --model-name mynet \
--version 1.0 \
--serialized-file models/model.pt \
--export-path model-store \
--handler image_classifier \
--force
Next I started torchserve
using their docker image:
docker run \
--rm --shm-size=1g \
--ulimit memlock=-1 \
--ulimit stack=67108864 \
-p8080:8080 \
-p8081:8081 \
-p8082:8082 \
-p7070:7070 \
-p7071:7071 \
--mount type=bind,source=/home/dfalbel/torchserve/model-store/,target=/tmp/models \
pytorch/torchserve:latest \
torchserve --model-store=/tmp/models --models mynet.mar
Finally I could run the predictions using:
curl -O https://raw.githubusercontent.com/pytorch/serve/master/docs/images/kitten_small.jpg
curl http://127.0.0.1:8080/predictions/mynet -T kitten_small.jpg
{
"281": 0.5944757461547852,
"285": 0.3166409432888031,
"287": 0.052945494651794434,
"282": 0.028301434591412544,
"286": 0.004156926181167364
}
Hope this helps! And let me know if you have additional questions.