I am trying to replicate this tutorial but instead of a linear regression I am using random forest (RF). I can make the prediction but now I would like to compute and extract the residuals of the regression (i.e., observed - predicted values). Then, I would like to cbind
the residuals with the coordinates of the data.frame
, like so:
resids_df <- cbind(original_df[, 1:2], rf_resids) # where the original_df[, 1:2] contains the coordinates and the rf_resids are the residuals of the RF regression
The issue is that when after making the predictions, the output has 54 more values than my original data set (3792 vs 3738). This causes an issue because I can't cbind
the residuals due to the difference in number of rows.
How can I resolve this issue and get exactly the same number of residuals (observations) as my original data set?
P.S. my data set does not contain NA values.
In the example below, I used asubset of my dataset but again, you can see that the there is 1 more value in the predictions compared to the number of rows in the data set.
library(tidymodels)
library(spatialsample)
library(sf)
wd <- "path/"
proj_ref_sys <- "EPSG:7760"
drought <- read.csv(paste0(wd, "block.data.csv"))
nrow(drought)
# [1] 60 !!!!!!!!!!!!!
drought_sf <- st_as_sf(drought, coords = c("x", "y"), crs = proj_ref_sys)
set.seed(123)
folds <- spatial_block_cv(drought_sf, v = 3)
drought_res <-
workflow(ntl ~ pop + agbh + nir,
rand_forest(mode = "regression", mtry = 2, trees = 100) %>%
set_engine("randomForest")) %>%
fit_resamples(folds, control = control_resamples(save_pred = TRUE))
drought_res
collect_predictions(drought_res)
# A tibble: **61** × 5 !!!!!!!!!!!!!!!!!!!
id .pred .row ntl .config
<chr> <dbl> <int> <dbl> <chr>
1 Fold1 28.7 18 29.2 Preprocessor1_Model1
2 Fold1 27.9 19 32.8 Preprocessor1_Model1
3 Fold1 17.2 20 29.6 Preprocessor1_Model1
4 Fold1 19.6 21 28.6 Preprocessor1_Model1
5 Fold1 34.3 22 36.5 Preprocessor1_Model1
6 Fold1 48.7 28 34.8 Preprocessor1_Model1
7 Fold1 45.9 29 32.2 Preprocessor1_Model1
8 Fold1 40.1 30 28.3 Preprocessor1_Model1
9 Fold1 14.6 31 22.5 Preprocessor1_Model1
10 Fold1 9.96 32 17.1 Preprocessor1_Model1
# ℹ 51 more rows
# ℹ Use `print(n = ...)` to see more rows
The data.frame
I'm using:
structure(list(x = c(995494.2549, 995924.2549, 996354.2549, 996784.2549,
997214.2549, 997644.2549, 998074.2549, 998504.2549, 998934.2549,
999364.2549, 999794.2549, 1000224.2549, 1000654.2549, 1001084.2549,
1001514.2549, 1001944.2549, 1002374.2549, 1002804.2549, 1003234.2549,
1003664.2549, 1004094.2549, 1004524.2549, 1004954.2549, 1005384.2549,
1005814.2549, 1006244.2549, 1006674.2549, 1007104.2549, 1007534.2549,
1007964.2549, 1008394.2549, 1008824.2549, 1009254.2549, 1009684.2549,
1010114.2549, 1010544.2549, 1010974.2549, 1011404.2549, 1011834.2549,
1012264.2549, 1012694.2549, 1013124.2549, 1013554.2549, 1013984.2549,
1014414.2549, 1014844.2549, 1015274.2549, 1015704.2549, 1016134.2549,
1016564.2549, 1016994.2549, 1017424.2549, 1017854.2549, 1018284.2549,
1018714.2549, 995494.2549, 995924.2549, 996354.2549, 996784.2549,
997214.2549), y = c(1019851.5842, 1019851.5842, 1019851.5842,
1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842,
1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842,
1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842,
1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842,
1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842,
1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842,
1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842,
1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842,
1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842,
1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842,
1019851.5842, 1019851.5842, 1019421.5842, 1019421.5842, 1019421.5842,
1019421.5842, 1019421.5842), ntl = c(9.14866638183594, 15.3856477737427,
16.3302040100098, 12.454291343689, 10.4823837280273, 11.394606590271,
8.1963529586792, 4.50725030899048, 3.95374751091003, 5.73203563690186,
14.3955335617065, 17.0745468139648, 14.2944135665894, 10.333722114563,
9.80743503570557, 12.5352020263672, 19.8813304901123, 29.2410221099854,
32.8321876525879, 29.575023651123, 28.5894374847412, 36.4911346435547,
49.4252128601074, 61.3118171691895, 58.6104736328125, 43.0437355041504,
28.096061706543, 34.8003845214844, 32.1936340332031, 28.3407783508301,
22.5178966522217, 17.0638084411621, 20.7549228668213, 18.3547439575195,
10.2983675003052, 7.3524694442749, 7.17788362503052, 7.06999540328979,
8.03957176208496, 12.6783542633057, 18.7537479400635, 26.1656856536865,
36.539493560791, 41.0569839477539, 25.5366401672363, 15.7820110321045,
9.87918758392334, 7.65169858932495, 6.96318626403809, 8.69833087921143,
12.1393032073975, 15.151198387146, 14.5944147109985, 9.46016979217529,
4.53868055343628, 12.8388118743896, 21.1265335083008, 19.3046970367432,
10.5719947814941, 8.08844661712646), pop = c(31.2753772735596,
55.8289375305176, 56.4003105163574, 33.795223236084, 31.0511913299561,
30.5730743408203, 13.667106628418, 7.08161020278931, 6.89333772659302,
13.9001550674438, 35.5272178649902, 42.4625587463379, 32.9688529968262,
21.4302787780762, 12.6151924133301, 17.4939270019531, 38.1474113464355,
60.8120536804199, 65.3665008544922, 53.8765907287598, 46.2705993652344,
61.42333984375, 70.8307113647461, 53.3152236938477, 31.4083557128906,
24.9810562133789, 38.3716621398926, 56.114860534668, 67.1656036376953,
60.8404235839844, 33.7796592712402, 29.8311328887939, 44.3309173583984,
31.9606342315674, 16.7053775787354, 10.1427822113037, 11.4020376205444,
10.7794933319092, 18.2773151397705, 34.2912216186523, 50.6655197143555,
52.1081962585449, 53.0502471923828, 59.4989013671875, 48.5897750854492,
41.188159942627, 27.0699615478516, 11.5318984985352, 9.09538650512695,
14.2379903793335, 24.8153190612793, 29.3468627929688, 30.5861835479736,
15.3130531311035, 9.47307205200195, 37.2332077026367, 94.2268676757812,
73.2485733032227, 26.8748569488525, 26.8519401550293), agbh = c(0.124395661056042,
0.543155550956726, 0.930405616760254, 0.176615670323372, 0.122252210974693,
1.86410081386566, 0.201039269566536, 0.00215102708898485, 0.00524011626839638,
0.0221506990492344, 1.75632297992706, 0.954743504524231, 0.373224049806595,
0.0127956680953503, 0.0007417316082865, 0.0123716788366437, 0.279229581356049,
2.30779552459717, 2.58910322189331, 1.23243260383606, 0.819948613643646,
1.74025285243988, 4.03071403503418, 2.78268098831177, 2.00978517532349,
0.700970351696014, 0.196071043610573, 2.19463133811951, 4.83159875869751,
2.20620393753052, 0.321354597806931, 0.00308413081802428, 1.737912774086,
0.468539208173752, 0.0156131321564317, 0.00116395147051662, 0.0145542966201901,
0.000892410753294826, 0.0419198162853718, 2.84171080589294, 3.22121715545654,
2.73401832580566, 2.47091150283813, 2.10038590431213, 1.15651941299438,
0.490403175354004, 0.0419915802776814, 0.101970501244068, 0.00181114906445146,
0.0132269319146872, 0.212756171822548, 0.111757233738899, 1.2169703245163,
0.129767879843712, 0, 0.582266986370087, 2.96843385696411, 1.16728830337524,
0.0494964420795441, 0.0664984136819839), nir = c(0.261590600013733,
0.250058531761169, 0.238313049077988, 0.246726274490356, 0.241509333252907,
0.215491861104965, 0.25552836060524, 0.26755028963089, 0.283316373825073,
0.2645283639431, 0.2347122579813, 0.250579416751862, 0.272739976644516,
0.26601967215538, 0.260071456432343, 0.283827364444733, 0.270996034145355,
0.229571804404259, 0.228905484080315, 0.240774929523468, 0.22843000292778,
0.201068416237831, 0.174168020486832, 0.187955036759377, 0.235188364982605,
0.226306527853012, 0.197943985462189, 0.192345812916756, 0.18694880604744,
0.203041225671768, 0.24348683655262, 0.264572501182556, 0.234625786542892,
0.252681404352188, 0.252072751522064, 0.241365790367126, 0.228045880794525,
0.252986639738083, 0.261032313108444, 0.233464851975441, 0.235829710960388,
0.235184907913208, 0.212146639823914, 0.204127430915833, 0.216947212815285,
0.225598230957985, 0.231632620096207, 0.224976778030396, 0.219116434454918,
0.255260914564133, 0.241265594959259, 0.237798929214478, 0.241482153534889,
0.240964710712433, 0.252938002347946, 0.258243441581726, 0.211435839533806,
0.217503502964973, 0.237074509263039, 0.237700119614601)), row.names = c(NA,
60L), class = "data.frame")