I would like to create a shiny app where a user can use keras to train his data.
However, such training can be long and sometimes depending on settings it can happen that fitting does not converge. So, I would like that to be able to smoothly interrupt training on user demand (e.g. click a button) without having to close the app.
It is possible to stop the training thanks to callback (Training Callbacks • keras)
However, my problem, is that once training is started, Abort button can not be observed until it training is ended
So how to detect that a user has clicked Abort button ?
One workaround:
- I know it is also possible to run some javascript during callback (Live ploting training/validation history in shiny app · Issue #978 · rstudio/keras · GitHub).
Is there a way to inpect DOM so as to see if Abort button has been clicked ?
For instance, to look at abort button class '.clicked'
Another solution:
- I also investigated the use of future and promises
But I did not find a working solution. (got, error "'what' must be a function or character string", tensorflow - Parallelizing keras models in R using doParallel - Stack Overflow)
Besides, keras is already multi-threaded and I fear that running inside parallel will lead to unexpected behaviour (seed, ...). Maybe with a gpu backend keras, but I don't want to rely on this because not all users may have a gpu.
library(keras)
library(tensorflow)
library(shiny)
batch_size <- 128
num_classes <- 10
epochs <- 5
# adapted from https://keras.rstudio.com/articles/examples/mnist_cnn.html
# Input image dimensions
img_rows <- 28
img_cols <- 28
# The data, shuffled and split between train and test sets
mnist <- dataset_mnist()
x_train <- mnist$train$x
y_train <- mnist$train$y
x_test <- mnist$test$x
y_test <- mnist$test$y
# Redefine dimension of train/test inputs
x_train <- array_reshape(x_train, c(nrow(x_train), img_rows, img_cols, 1))
x_test <- array_reshape(x_test, c(nrow(x_test), img_rows, img_cols, 1))
input_shape <- c(img_rows, img_cols, 1)
# Transform RGB values into [0,1] range
x_train <- x_train / 255
x_test <- x_test / 255
# Convert class vectors to binary class matrices
y_train <- to_categorical(y_train, num_classes)
y_test <- to_categorical(y_test, num_classes)
model <- keras_model_sequential() %>%
layer_conv_2d(filters = 8, kernel_size = c(3,3), activation = 'relu',
input_shape = input_shape) %>%
layer_conv_2d(filters = 8, kernel_size = c(3,3), activation = 'relu') %>%
layer_max_pooling_2d(pool_size = c(2, 2)) %>%
layer_dropout(rate = 0.5) %>%
layer_flatten() %>%
layer_dense(units = 16, activation = 'relu') %>%
layer_dropout(rate = 0.5) %>%
layer_dense(units = num_classes, activation = 'softmax')
model %>% compile(
loss = loss_categorical_crossentropy,
optimizer = optimizer_adadelta(),
metrics = c('accuracy')
)
ui <- fluidPage(
tags$div(actionButton("start", "Start training"),
actionButton("abort", "Abort training", onclick="this.classList.add('clicked')")),
)
server <- function(input, output, session) {
observeEvent(input$abort, { # this observer can not be trigger while model is training
str(input$abort)
})
observeEvent(input$start, {
keras::k_clear_session()
# adapted from https://keras.rstudio.com/articles/training_callbacks.html
LossHistory <- R6::R6Class("LossHistory",
inherit = KerasCallback,
public = list(
losses = NULL,
on_epoch_begin = function(epoch, logs = list()) {
cat("callback epoch:", epoch + 1, "\n")
if(epoch >= 2) {
self$model$stop_training <- TRUE # this allows to stop_training before all epochs have been done
cat("training has been interrupted by user\nplease wait till current epoch stops\n")
}
}
))
mycallback <- LossHistory$new()
model %>%
fit(
x_train, y_train,
batch_size = batch_size,
epochs = epochs,
validation_split = 0.2,
callbacks = list(mycallback),
verbose = 1
)
})
}
shinyApp(ui = ui, server = server)