Using the torch JIT compiler

I'm trying to understand how the jit_trace and jit_compile functions work in torch.

I want to be able to do the equivalent of pnorm and qnorm on the GPU. There is a build-in erf and erfinv, so I define my versions as follows:

cpt_pnorm <- function (x)
torch_pnorm <- jit_trace(cpt_pnorm,torch_tensor(c(-.67,0,.67)))

cpt_qnorm <- function (x)
torch_qnorm <- jit_trace(cpt_qnorm,torch_tensor(c(.25,.5,.75)))

This seems to work, but it is the next part I get confused by:
When I look at the graph, I get:

> torch_qnorm$graph
graph(%0 : Float(3, strides=[1], requires_grad=0, device=cpu)):
  %1 : float = prim::Constant[value=2.]()
  %2 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::mul(%0, %1)
  %3 : float = prim::Constant[value=1.]()
  %4 : int = prim::Constant[value=1]()
  %5 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::sub_(%2, %3, %4)
  %6 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::erfinv_(%5)
  %7 : float = prim::Constant[value=1.4142135623730951]()
  %8 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::mul_(%6, %7)
  return (%8)

So here are my questions:

  1. It says requires_grad=0, but I might want to use this in a context in which I need a grad. Is this OK?

  2. It says device=cpu, will it still run on the gpu?

  3. Finally, I want to put these functions in the package. Do I need to save the compiled version and figure how how to load it when the package loads, or will it autoload with the package?

If there is a better forum for questions about torch, please point me to it.


  1. Yes, you should be able to backpropagate through traced functions, even though they were traced without setting requires_grad = TRUE.

  2. The execution will depend on the device of the inputs. So you can assume they will run on the GPU if you call them with tensors that are on the GPU. (evenn if you traced them with cpu tensors)

  3. For packkages, I'd delay their creation to the first usage, so I'd do something like this in the package code:

.jit_functions <- environment()

cpt_pnorm <- function (x) {
torch_pnorm <- function(x) {
  if (!exists("pnorm", envir = .jit_functions)) {
    .jit_functions[["pnorm"]] <- jit_trace(cpt_pnorm,torch_tensor(c(-.67,0,.67)))