I'm trying to follow along with the book "Deep Learning with PyTorch". I am using the new R packages torch
and torchvision
.
On page 173, section 7.2.1 I'm just not sure how to filter this dataset to include only labels 1 and 3 (corresponding to the 0 and 2 in the book).
This is my code, and I'd like to know how to filter transformed_cifar10
as per the code in the book. Meaning filter it so that the transformed_cifar10$y
labels only include 1 and 3. and then remap {1,3} to {1,2}.
library(dplyr)
library(torch)
library(torchvision)
data_path <- "./ch7/data" # need to change this?
train_transforms <- function (img) {
img %>%
transform_to_tensor() %>%
transform_normalize(mean = c(0.4915, 0.4823, 0.4468),
std = c(0.2470, 0.2435, 0.2616))
}
transformed_cifar10 <- cifar10_dataset(data_path,
train = TRUE,
download = TRUE,
transform = train_transforms)
This is the python code in the book:
# In[5]:
label_map = {0: 0, 2: 1}
class_names = ['airplane', 'bird']
cifar2 = [(img, label_map[label])
for img, label in cifar10
if label in [0, 2]]
First I thought of trying something like this but clearly it doesn't work... Any ideas?
tensor_cifar10[tensor_cifar10$y == 1]