I am learning to perform image segmentation in RStudio instead of a python IDE. I am using a script adapted from this article: https://blogs.rstudio.com/ai/posts/2019-08-23-unet/
My code works until I actually try to fit the model, when I get this error:
Error in py_call_impl(callable, dots$args, dots$keywords) :
ValueError: in user code:
/Users/mayasamuels-fair/Library/r-miniconda/envs/r-reticulate/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:1478 predict_function *
return step_function(self, iterator)
/Users/mayasamuels-fair/Library/r-miniconda/envs/r-reticulate/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:1468 step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
/Users/mayasamuels-fair/Library/r-miniconda/envs/r-reticulate/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py:1259 run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/Users/mayasamuels-fair/Library/r-miniconda/envs/r-reticulate/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py:2730 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
/Users/mayasamuels-fair/Library/r-miniconda/envs/r-reticula
Here is my code:
data <- tibble(
img = list.files(here::here("/Volumes/Maya\'sDrive/CurrentProjects/OstracodSegmentation/train_png"), full.names = TRUE),
mask = list.files(here::here("/Volumes/Maya\'sDrive/CurrentProjects/OstracodSegmentation/train_masks"), full.names = TRUE)
)
data <- initial_split(data, prop = 0.8)
training_dataset <- training(data) %>%
tensor_slices_dataset() %>%
dataset_map(~.x %>% list_modify(
img = tf$image$decode_png(tf$io$read_file(.x$img), channels = 3),
mask = tf$image$decode_png(tf$io$read_file(.x$mask), channels = 3)
))
training_dataset <- training_dataset %>%
dataset_map(~.x %>% list_modify(
img = tf$image$convert_image_dtype(.x$img, dtype = tf$float32),
mask = tf$image$convert_image_dtype(.x$mask, dtype = tf$float32)
))
training_dataset <- training_dataset %>%
dataset_map(~.x %>% list_modify(
img = tf$image$resize(.x$img, size = shape(128, 128)),
mask = tf$image$resize(.x$mask, size = shape(128, 128))
))
random_bsh <- function(img) {
img %>%
tf$image$random_brightness(max_delta = 0.3) %>%
tf$image$random_contrast(lower = 0.5, upper = 0.7) %>%
tf$image$random_saturation(lower = 0.5, upper = 0.7) %>%
tf$clip_by_value(0, 1) # clip the values into [0,1] range.
}
training_dataset <- training_dataset %>%
dataset_map(~.x %>% list_modify(
img = random_bsh(.x$img)
))
create_dataset <- function(data, train, batch_size = 30) {
dataset <- data %>%
tensor_slices_dataset() %>%
dataset_map(~.x %>% list_modify(
img = tf$image$decode_png(tf$io$read_file(.x$img), channels=3),
mask = tf$image$decode_png(tf$io$read_file(.x$mask), channels=3)
)) %>%
dataset_map(~.x %>% list_modify(
img = tf$image$convert_image_dtype(.x$img, dtype = tf$float32),
mask = tf$image$convert_image_dtype(.x$mask, dtype = tf$float32)
)) %>%
dataset_map(~.x %>% list_modify(
img = tf$image$resize(.x$img, size = shape(128, 128)),
mask = tf$image$resize(.x$mask, size = shape(128, 128))
))
if (train) {
dataset <- dataset %>%
dataset_map(~.x %>% list_modify(
img = random_bsh(.x$img)
))
}
if (train) {
dataset <- dataset %>%
dataset_shuffle(buffer_size = batch_size*128)
}
dataset <- dataset %>%
dataset_batch(batch_size)
dataset %>%
dataset_map(unname) # Keras needs an unnamed output.
}
training_dataset <- create_dataset(training(data), train = TRUE)
validation_dataset <- create_dataset(testing(data), train = FALSE)
model <- unet(input_shape = c(128, 128, 3))
dice <- custom_metric("dice", function(y_true, y_pred, smooth = 1.0) {
y_true_f <- k_flatten(y_true)
y_pred_f <- k_flatten(y_pred)
intersection <- k_sum(y_true_f * y_pred_f)
(2 * intersection + smooth) / (k_sum(y_true_f) + k_sum(y_pred_f) + smooth)
})
model %>% compile(optimizer = optimizer_rmsprop(lr = 1e-5), loss = "binary_crossentropy",
metrics = list(dice, metric_binary_accuracy))
model %>% fit(training_dataset, epochs = 50, validation_data = validation_dataset)
Does anyone know what the problem is? I can't make sense of the error message.