Transformer model in R using package torch

I'm working on replicating the transformer model for a time series forecasting task, but the model's performance is quite poor. Can anyone offer suggestions on how to improve it? This could include changes to the model structure, adding additional layers, or other enhancements.

Here my code for structure of model:

#### Transform input----------------------------------------------------------------
# Convert to tensor format
x_data <- torch_tensor(as.matrix(supervised_data[, 1:n_lags]), dtype = torch_float())  # Features (lags)
y_data <- torch_tensor(as.matrix(supervised_data[, n_lags +1]), dtype = torch_float())    # Target

# Reshape x_data to match (batch_size, seq_leng, feature_size)
x_data <- x_data$view(c(nrow(x_data), n_lags, 1))  # (batch_size, seq_leng, feature_size)

# Split the data into training and testing sets (80% for training, 20% for testing)
train_size <- round(0.8 * nrow(supervised_data))

x_train <- x_data[1:train_size, , drop = FALSE]  # Ensure it's a tensor with the correct shape
y_train <- y_data[1:train_size]

x_test <- x_data[(train_size + 1):nrow(supervised_data), , drop = FALSE]
y_test <- y_data[(train_size + 1):nrow(supervised_data)]

#### Build components of model----------------------------------------------------------------
### Positional encoding:
positional_encoding <- function(seq_leng, d_model, n = 10000) {
  if (missing(seq_leng) || missing(d_model)) {
    stop("'seq_leng' and 'd_model' must be provided.")
  }
  
  P <- matrix(0, nrow = seq_leng, ncol = d_model)  
  
  for (k in 1:seq_leng) {
    for (i in 0:(d_model / 2 - 1)) {
      denominator <- n^(2 * i / d_model)
      P[k, 2 * i + 1] <- sin(k / denominator)
      P[k, 2 * i + 2] <- cos(k / denominator)
    }
  }
  
  return(P)
}

pos_enc <- positional_encoding(seq_leng = n_lags, d_model = 64)

### Mask function:
gen_trg_mask <- function(length, device) {
  mask <- torch_tril(torch_ones(length, length, device = device)) == 1
  mask <- mask$to(dtype = torch_float())
  mask[mask == 0] <- -Inf
  mask[mask == 1] <- 0.0
  return(mask)
}

device <- torch_device("cpu")

mask_self_attention <- nn_module(
  initialize = function(embed_dim, num_heads) {
    self$embed_dim <- embed_dim
    self$num_heads <- num_heads
    self$head_dim <- embed_dim / num_heads
    
    if (self$head_dim %% 1 != 0) {
      stop("embed_dim must be divisible by num_heads")
    }
    
    # Linear layers for Q, K, V
    self$query <- nn_linear(embed_dim, embed_dim, bias = FALSE)
    self$key <- nn_linear(embed_dim, embed_dim, bias = FALSE)
    self$value <- nn_linear(embed_dim, embed_dim, bias = FALSE)
    
    # Final linear layer after concatenating heads
    self$out <- nn_linear(embed_dim, embed_dim, bias = FALSE)
    self$mask <- gen_trg_mask(n_lags,device)
  },
  
  forward = function(x, mask = NULL) {
    batch_size <- x$size(1)
    seq_leng <- x$size(2)
    
    # Linear projections for Q, K, V
    Q <- self$query(x)  # (batch_size, seq_leng, embed_dim)
    K <- self$key(x)
    V <- self$value(x)
    
    # Reshape to separate heads: (batch_size, num_heads, seq_leng, head_dim)
    Q <- Q$view(c(batch_size, seq_leng, self$num_heads, self$head_dim))$transpose(2, 3)
    K <- K$view(c(batch_size, seq_leng, self$num_heads, self$head_dim))$transpose(2, 3)
    V <- V$view(c(batch_size, seq_leng, self$num_heads, self$head_dim))$transpose(2, 3)
    
    # Compute attention scores
    d_k <- self$head_dim
    attention_scores <- torch_matmul(Q, torch_transpose(K, -1, -2)) / sqrt(d_k)
    
    # Apply mask if provided
    if (!is.null(mask)) {
      attention_scores <- attention_scores + mask
    }
    
    # Compute attention weights
    weights <- nnf_softmax(attention_scores, dim = -1)
    
    # Apply weights to V
    attn_output <- torch_matmul(weights, V)  # (batch_size, num_heads, seq_leng, head_dim)
    
    
    attn_output <- attn_output$transpose(2, 3)$contiguous()$view(c(batch_size, seq_leng, self$embed_dim))
    
    
    output <- self$out(attn_output)
    return(output)
  }
)


### Encoder block:
encoder_layer <- nn_module(
  "TransformerEncoderLayer",
  
  initialize = function(d_model, num_heads, d_ff, dropout = 0.1) {
    # Multi-Head Attention
    self$multihead_attention <- nn_multihead_attention(embed_dim = d_model, num_heads = num_heads)
    
    # Feedforward Network (Fully Connected)
    self$feed_forward <- nn_sequential(
      nn_linear(d_model, d_ff),
      nn_relu(),
      nn_linear(d_ff, d_model)
    )
    
    # Dropout for regularization
    self$dropout <- nn_dropout(dropout)
    
    # Layer Normalization
    self$layer_norm_1 <- nn_layer_norm(d_model)
    self$layer_norm_2 <- nn_layer_norm(d_model)
  },
  
  forward = function(x) {
    attn_output <- self$multihead_attention(x, x, x)  # (Q, K, V)
    x <- x + self$dropout(attn_output[[1]])  # Use the first element of the tuple (the output)
    x <- self$layer_norm_1(x)  
    
    # Feedforward network with residual connection
    ff_output <- self$feed_forward(x)
    x <- x + self$dropout(ff_output)  
    x <- self$layer_norm_2(x)  
    
    return(x)
  }
)

### Decoder Layer
decoder_layer <- nn_module(
  "TransformerDecoderLayer",
  
  initialize = function(d_model, num_heads, d_ff, dropout = 0.1) {
    self$mask_self_attention <- mask_self_attention(embed_dim = d_model, num_heads = num_heads)
    self$multihead_attention <- nn_multihead_attention(embed_dim = d_model, num_heads = num_heads)
    self$feed_forward <- nn_sequential(
      nn_linear(d_model, d_ff),
      nn_relu(),
      nn_linear(d_ff, d_model)
    )
    self$dropout <- nn_dropout(dropout)
    self$layer_norm_1 <- nn_layer_norm(d_model)
    self$layer_norm_2 <- nn_layer_norm(d_model)
    self$layer_norm_3 <- nn_layer_norm(d_model)
  },
  
  forward = function(x, encoder_output, mask = NULL) {
    # Masked Self-Attention
    attn_output <- self$mask_self_attention(x, mask)
    x <- x + self$dropout(attn_output)
    x <- self$layer_norm_1(x)
    
    # Encoder-Decoder Multi-Head Attention
    attn_output <- self$multihead_attention(x, encoder_output, encoder_output)
    x <- x + self$dropout(attn_output[[1]])
    x <- self$layer_norm_2(x)
    
    # Feedforward Network
    ff_output <- self$feed_forward(x)
    x <- x + self$dropout(ff_output)
    x <- self$layer_norm_3(x)
    
    return(x)
  }
)


### Final transformer model: 
transformer <- nn_module(
  "Transformer",
  
  initialize = function(d_model, seq_leng, num_heads, d_ff, num_encoder_layers, num_decoder_layers, dropout = 0.1, pos_enc) {
    self$d_model <- d_model
    self$num_heads <- num_heads
    self$d_ff <- d_ff
    self$num_encoder_layers <- num_encoder_layers
    self$num_decoder_layers <- num_decoder_layers
    self$seq_leng <- seq_leng
    self$dropout <- dropout
    self$pos_enc <- pos_enc  
    
    # Encoder layers
    self$encoder_layers <- nn_module_list(
      lapply(1:num_encoder_layers, function(i) {
        encoder_layer(d_model, num_heads, d_ff, dropout)
      })
    )
    
    # Decoder layers
    self$decoder_layers <- nn_module_list(
      lapply(1:num_decoder_layers, function(i) {
        decoder_layer(d_model, num_heads, d_ff, dropout)
      })
    )
    
    self$conv1d <- nn_conv1d(in_channels = seq_leng, 
                             out_channels = 1, 
                             kernel_size = 3, 
                             padding = 1)
    
    # Final output layer
    self$output_layer <- nn_linear(d_model, 1)  # Output layer to predict a single value
 
  },
  
  forward = function(src, trg) {
    # Add positional encoding to the input sequences
    src <- src + self$pos_enc[1:src$size(2), , drop = FALSE]  # Add positional encoding to src
    trg <- trg + self$pos_enc[1:trg$size(2), , drop = FALSE]  # Add positional encoding to trg
    
    # Encoder forward pass
    encoder_output <- src
    for (i in 1:self$num_encoder_layers) {
      encoder_output <- self$encoder_layers[[i]](encoder_output)
    }
    
    # Decoder forward pass
    decoder_output <- trg
    for (i in 1:self$num_decoder_layers) {
      decoder_output <- self$decoder_layers[[i]](decoder_output, encoder_output)
    }
    
    # Apply global average pooling to reduce sequence length to 1
    decoder_output <- self$conv1d(decoder_output)
    
    # Apply final output layer
    output <- self$output_layer(decoder_output)
    
    return(output)
  }
)

#### Training----------------------------------------------------------------
model <- transformer(
  d_model = 64,         # Embedding dimension
  seq_leng = n_lags,        # Sequence length
  num_heads = 8,        # Number of heads
  d_ff = n_lags,           # Dimension of the feedforward layer
  num_encoder_layers = 6, 
  num_decoder_layers = 6, 
  dropout = 0.3,
  pos_enc = pos_enc  # Pass positional encoding matrix to model
)


#### Training----------------------------------------------------------------
epochs <- 200
loss_fn <- nn_mse_loss()
optimizer <- optim_adam(model$parameters, lr = 1e-3)

for (epoch in 1:epochs) {
  model$train()
  
  optimizer$zero_grad()
  
  # Forward pass
  y_pred <- model(x_train, x_train)  # In time series, input and output are often the same during training
  
  # Compute the loss
  loss <- loss_fn(y_pred[, , 1], y_train)  # Use the first element of the output tensor
  
  # Backpropagation and optimization
  loss$backward()
  optimizer$step()
  
  if (epoch %% 10 == 0) {
    cat("Epoch: ", epoch, " Loss: ", loss$item(), "\n")
  }
}