Extending Parsnip

Hi All,

Thanking everyone in advance for any available help.

I have been trying to figure out how to add a new model to {parsnip}. {neuralnet} felt like a pretty straight forward model to build so I figured to give it a go.

Fitting and predicting with {neuralnet} is very straight forward, yay.


fit <- neuralnet(Species ~., data = iris, hidden = 3)

predict(fit, newdata = iris)

After running the script below, I managed to fit my new "neural_net" {parsnip} model with no problems but running predict generates an "Error: $ operator is invalid for atomic vectors" error. I tried messing around with the encoding and prediction sections but can't seem to find a way to predict on new data. Does anyone have pointers as to what I am doing wrong?

Any help will be greatly appreciated.


# Register model and arguments
parsnip::set_model_mode(model = "neural_net", mode = "classification")
parsnip::set_model_engine("neural_net", mode = "classification", eng = "neuralnet")
parsnip::set_dependency("neuralnet", eng = "neuralnet", pkg = "neuralnet")

# Arguments = hidden
   model = "neural_net",
   eng = "neuralnet",
   parsnip = "hidden",
   original = "hidden",
   func = list(pkg = "dials", fun = "hidden_units"),
   has_submodel = FALSE

# Encoding
   model = "neural_net",
   eng = "neuralnet",
   mode = "classification",
   options = list(
      predictor_indicators = "traditional",
      compute_intercept = TRUE,
      remove_intercept = TRUE,
      allow_sparse_x = FALSE

# Fit
   model = "neural_net",
   eng = "neuralnet",
   mode = "classification",
   value = list(
      interface = "formula",
      protect = c("formula", "data"),
      func = c(pkg = "neuralnet", fun = "neuralnet"),
      defaults = list()

# Predict
   model = "neural_net",
   eng = "neuralnet",
   mode = "classification",
   type = "class",
   value = list(
      pre = NULL,
      post = NULL,
      func = c(fun = "predict"),
      args = list(
         object = quote(object$fit),
         newdata = quote(new_data)
# Model function
neural_net <- function(mode = "classification", hidden = 1) {
   # Check correct mode
   if(mode != "classification") {
      stop("`mode` should be 'classification'", call. = FALSE)
   # Capture arguments
   args <- list(
      hidden = rlang::enquo(hidden)
   # Model specs / slots
      args = args,
      mode = mode,
      eng_args = NULL,
      method = NULL,
      engine = NULL

#Try it out
nn_spec <- neural_net(hidden = 3) %>% 

nn_fit <- nn_spec %>% 
   fit(Species ~ ., data = iris)

predict(nn_fit, new_data = iris)

Running traceback() after the error gives:

> traceback()

5: factor(as.character(res$values), levels = object$lvl)

4: predict_class.model_fit(object = object, new_data = new_data,


3: predict_class(object = object, new_data = new_data, ...)

2: predict.model_fit(nn_fit, new_data = iris[, -5])

1: predict(nn_fit, new_data = iris[, -5])

Looking at ?predict.nn,it says that it always returns a matrix of predictions (class probabilities, in this case). You need to convert those to the predicted levels (perhaps using something like apply(pred, 1, which.max). Maybe see this code). This should happen in a function that is called by the post argument to parsnip::set_pred(). See "Step 4" in the documentation.

Some other comments though:

  • parsnip::mlp() is the canonical function for single-layer, feed forward neural networks. You might just want to add an engine for that.

  • We have standardized argument names in tidymodels. You might want to use hidden_units instead of hidden when you call parsnip::set_model_arg().

Thanks, Max. Greatly appreciate your response. I eventually figured out that the issue was with the kind of object that is returned but didn't know how to approach resolving. Your suggestions are appreciated.

I didn't truly overlook the two extra points that you raised. I am doing this as a self-improvement tutorial so was looking for something fairly straight forward to start off with. I kept separate names to have mental markers to focus on.

Thanks again for the help and, hopefully, I will have some for-real parsnip extensions to contribute in the not too distant future.

1 Like

This topic was automatically closed 21 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.