Brief intro
I apologize if this is not the correct forum for this question, but any help would be greatly appreciated!
I'm trying to build a Wasserstein GAN with gradient penalty, following this paper. Essentially, I'm trying to recreate their python code in R. All was going well until I tried to implement their gradient penalty. In the original code, they've defined it as a class, and they use K.gradients to calculate the gradients:
class GradientPenalty(Layer):
def __init__(self, **kwargs):
super(GradientPenalty, self).__init__(**kwargs)
def call(self, inputs):
target, wrt = inputs
grad = K.gradients(target, wrt)[0]
return K.sqrt(K.sum(K.batch_flatten(K.square(grad)), axis=1, keepdims=True))-1
def compute_output_shape(self, input_shapes):
return (input_shapes[1][0], 1)
I generally try to avoid classes in R, and write things as functions instead. I think the rest will be better explained with what I'm hoping can work as a minimal reproducible example.
Session info
R version 4.0.2 (2020-06-22)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 18.04.6 LTS
Matrix products: default
BLAS: /usr/lib/x86_64-linux-gnu/atlas/libblas.so.3.10.3
LAPACK: /usr/lib/x86_64-linux-gnu/atlas/liblapack.so.3.10.3
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C LC_TIME=nb_NO.UTF-8 LC_COLLATE=en_US.UTF-8
[5] LC_MONETARY=nb_NO.UTF-8 LC_MESSAGES=en_US.UTF-8 LC_PAPER=nb_NO.UTF-8 LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C LC_MEASUREMENT=nb_NO.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] keras_2.9.0 tensorflow_2.9.0
loaded via a namespace (and not attached):
[1] Rcpp_1.0.8 here_0.1 lattice_0.20-41 png_0.1-7 rprojroot_1.3-2 zeallot_0.1.0 rappdirs_0.3.3
[8] grid_4.0.2 R6_2.5.1 backports_1.1.10 jsonlite_1.8.0 magrittr_2.0.2 rlang_0.4.11 tfruns_1.4
[15] whisker_0.4 Matrix_1.2-18 reticulate_1.25 generics_0.1.2 tools_4.0.2 compiler_4.0.2 base64enc_0.1-3
Reproducible example
library(tensorflow)
library(keras)
tf$executing_eagerly() # should be true, because eager execution is needed in a different part of the code
# these are not initialized here in the real thing, but the resulting tensors are the same:
disc_out_avg <- layer_input(shape = list(1))
disc_in_avg <- layer_input(shape = list(NULL, NULL, 1))
# my first attempt at translating gradient penalty:
gradient_penalty <- function(inputs){
c(target, wrt) %<-% inputs
grad <- k_gradients(loss = target, variables = wrt)[1]
return(k_sqrt(k_sum(k_batch_flatten(k_square(grad)), axis = 1, keepdims = TRUE)) -1)
}
disc_gp <- gradient_penalty(list(disc_out_avg, disc_in_avg))
This produces the error
Error in py_call_impl(callable, dots$args, dots$keywords) :
RuntimeError: tf.gradients is not supported when eager execution is enabled. Use tf.GradientTape instead.
I won't try to hide the fact that I don't quite understand what GradientTape is, which might explain why I can't get the following to work. However, I've been googling it for weeks and I'm not getting any wiser!
library(tensorflow)
library(keras)
tf$executing_eagerly() # should be true, because eager execution is needed in a different part of the code
# these are not initialized here in the real thing, but the resulting tensors are the same:
disc_out_avg <- layer_input(shape = list(1))
disc_in_avg <- layer_input(shape = list(NULL, NULL, 1))
# my second attempt at translating gradient penalty:
gradient_penalty <- function(inputs){
c(target, wrt) %<-% inputs
with(tf$GradientTape() %as% tape, {
# tape$watch(wrt)
})
grad <- tape$gradient(target, wrt)
return(k_sqrt(k_sum(k_batch_flatten(k_square(grad)), axis = 1, keepdims = TRUE)) -1)
}
disc_gp <- gradient_penalty(list(disc_out_avg, disc_in_avg))
Error in py_call_impl(callable, dots$args, dots$keywords) :
AttributeError: 'KerasTensor' object has no attribute '_id'
Called from: py_call_impl(callable, dots$args, dots$keywords)
If I try "watching" one variable or the other, I get:
Error in py_call_impl(callable, dots$args, dots$keywords) :
ValueError: Passed in object of type <class 'keras.engine.keras_tensor.KerasTensor'>, not tf.Tensor
Called from: py_call_impl(callable, dots$args, dots$keywords)
Any clues as to what I might be doing wrong?
Edit: have also posted my issue on github.