I'm trying to understand how the jit_trace
and jit_compile
functions work in torch
.
I want to be able to do the equivalent of pnorm
and qnorm
on the GPU. There is a build-in erf
and erfinv
, so I define my versions as follows:
cpt_pnorm <- function (x)
torch_div(x,sqrt(2))$erf_()$add_(1)$div_(2)
torch_pnorm <- jit_trace(cpt_pnorm,torch_tensor(c(-.67,0,.67)))
cpt_qnorm <- function (x)
torch_mul(x,2)$sub_(1)$erfinv_()$mul_(sqrt(2))
torch_qnorm <- jit_trace(cpt_qnorm,torch_tensor(c(.25,.5,.75)))
This seems to work, but it is the next part I get confused by:
When I look at the graph, I get:
> torch_qnorm$graph
graph(%0 : Float(3, strides=[1], requires_grad=0, device=cpu)):
%1 : float = prim::Constant[value=2.]()
%2 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::mul(%0, %1)
%3 : float = prim::Constant[value=1.]()
%4 : int = prim::Constant[value=1]()
%5 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::sub_(%2, %3, %4)
%6 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::erfinv_(%5)
%7 : float = prim::Constant[value=1.4142135623730951]()
%8 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::mul_(%6, %7)
return (%8)
So here are my questions:
-
It says
requires_grad=0
, but I might want to use this in a context in which I need a grad. Is this OK? -
It says
device=cpu
, will it still run on the gpu? -
Finally, I want to put these functions in the package. Do I need to save the compiled version and figure how how to load it when the package loads, or will it autoload with the package?
If there is a better forum for questions about torch, please point me to it.