Exploring 7 Core Functions of torch.utils.data in PyTorch

Exploring 7 Core Functions of torch.utils.data in PyTorch
This article is approximately 1800 words long and is recommended to be read in 5 minutes. This article will deeply introduce the 7 core functions of the torch.utils.data module in PyTorch, which can help you better manage and manipulate data.
In machine learning and deep learning projects, data processing is a crucial part. PyTorch, as a powerful deep learning framework, provides various flexible and efficient data processing tools. This article will deeply introduce the 7 core functions of the torch.utils.data module in PyTorch, which can help you better manage and manipulate data. We will explain each function in detail and provide code examples to demonstrate their usage.
Exploring 7 Core Functions of torch.utils.data in PyTorch

1. Dataset Class

The Dataset class is the foundation of data processing in PyTorch. By inheriting this class, you can create custom datasets that adapt to various types of data, such as images, text, or time-series data.
To create a custom dataset, you need to implement two key methods:
  • __len__ method: Returns the size of the dataset
  • __getitem__ method: Retrieves a sample based on the given index
This flexibility allows the Dataset class to handle various data formats and sources.
Code Example:
 import torch from torch.utils.data import Dataset
 class CustomDataset(Dataset):     def __init__(self, data, labels):         self.data = data         self.labels = labels
     def __len__(self):         return len(self.data)
     def __getitem__(self, idx):         return self.data[idx], self.labels[idx]
 # Create a simple dataset data = torch.randn(100, 5)  # 100 samples, each with 5 features labels = torch.randint(0, 2, (100,))  # Binary classification labels
 dataset = CustomDataset(data, labels) print(f"Dataset size: {len(dataset)}") print(f"First sample: {dataset[0]}")

2. DataLoader

The DataLoader is an extremely important tool that wraps the dataset and provides an iterable object. It simplifies operations such as batch loading, data shuffling, and parallel data processing, making it key for efficiently inputting data during model training and evaluation.
The main functions of DataLoader include:
  • Batch loading data
  • Automatic shuffling of data
  • Multi-process data loading for improved efficiency
  • Custom data sampling strategies
Code Example:
 from torch.utils.data import DataLoader
 # Using the previously created dataset batch_size = 16 dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
 for batch_data, batch_labels in dataloader:     print(f"Batch data shape: {batch_data.shape}")     print(f"Batch labels shape: {batch_labels.shape}")     break  # Only print the first batch

3. Subset

The Subset allows you to create a smaller, specific subset from a larger dataset. This is particularly useful in the following scenarios:
  • Experimenting with a data subset
  • Splitting the dataset into training, validation, and test sets
By specifying indices, you can easily create the desired data subset.
Code Example:
 from torch.utils.data import Subset import numpy as np
 # Create a subset containing the first 20% of the original dataset dataset_size = len(dataset) subset_size = int(0.2 * dataset_size) subset_indices = np.random.choice(dataset_size, subset_size, replace=False)
 subset = Subset(dataset, subset_indices) print(f"Subset size: {len(subset)}")
 # Create a new DataLoader using the subset subset_loader = DataLoader(subset, batch_size=8, shuffle=True)

4. ConcatDataset

The ConcatDataset is used to combine multiple datasets into a single dataset. This tool is very useful when you have multiple datasets that need to be used together. It can:
  • Merge data from different sources
  • Create larger and more diverse training sets
Code Example:
 from torch.utils.data import ConcatDataset
 # Create two simple datasets dataset1 = CustomDataset(torch.randn(50, 5), torch.randint(0, 2, (50,))) dataset2 = CustomDataset(torch.randn(30, 5), torch.randint(0, 2, (30,)))
 # Merge datasets combined_dataset = ConcatDataset([dataset1, dataset2]) print(f"Combined dataset size: {len(combined_dataset)}")
 # Create DataLoader using the combined dataset combined_loader = DataLoader(combined_dataset, batch_size=16, shuffle=True)

5. TensorDataset

The TensorDataset is very useful when the data already exists in tensor form. It wraps tensors into a dataset object, making it simple to handle pre-processed features and labels.
The main advantages of TensorDataset are:
  • Directly using tensor data
  • Simplifying the usage process of already pre-processed data
Code Example:
 from torch.utils.data import TensorDataset
 # Create feature and label tensors features = torch.randn(1000, 10)  # 1000 samples, each with 10 features labels = torch.randint(0, 5, (1000,))  # 5-class problem
 # Create TensorDataset tensor_dataset = TensorDataset(features, labels)
 # Create DataLoader using TensorDataset tensor_loader = DataLoader(tensor_dataset, batch_size=32, shuffle=True)
 for batch_features, batch_labels in tensor_loader:     print(f"Feature shape: {batch_features.shape}, Label shape: {batch_labels.shape}")     break

6. RandomSampler

The RandomSampler is used to randomly sample elements from a dataset. This tool is especially important when using training methods that require random sampling, such as Stochastic Gradient Descent (SGD). It can help:
  • Increase randomness in training
  • Reduce the risk of model overfitting
Code Example:
 from torch.utils.data import RandomSampler
 # Using the previously created dataset random_sampler = RandomSampler(dataset, replacement=True, num_samples=50)
 # Create DataLoader using RandomSampler random_loader = DataLoader(dataset, batch_size=10, sampler=random_sampler)
 for batch_data, batch_labels in random_loader:     print(f"Random sampled batch size: {batch_data.shape[0]}")     break

7. WeightedRandomSampler

The WeightedRandomSampler samples with replacement based on specified probabilities (weights). This is particularly useful when dealing with imbalanced datasets, as it can:
  • Sample minority classes more frequently
  • Balance class distribution and improve the model’s sensitivity to minority classes

Code Example:

 from torch.utils.data import WeightedRandomSampler import torch.nn.functional as F  # Assume we have an imbalanced dataset imbalanced_labels = torch.tensor([0, 0, 0, 0, 1, 1, 2]) class_sample_count = torch.tensor([(imbalanced_labels == t).sum() for t in torch.unique(imbalanced_labels, sorted=True)]) weight = 1. / class_sample_count.float() samples_weight = torch.tensor([weight[t] for t in imbalanced_labels])  # Create WeightedRandomSampler weighted_sampler = WeightedRandomSampler(samples_weight, len(samples_weight))  # Create a simple dataset imbalanced_dataset = TensorDataset(torch.randn(7, 5), imbalanced_labels)  # Create DataLoader using WeightedRandomSampler weighted_loader = DataLoader(imbalanced_dataset, batch_size=3, sampler=weighted_sampler)  for batch_data, batch_labels in weighted_loader:     print(f"Sampled labels: {batch_labels}")     break

Conclusion

The torch.utils.data module in PyTorch provides these powerful and flexible tools that make data processing simple and efficient. By skillfully using these tools, you can better manage data flow, thus building more powerful and efficient machine learning models.
Editor: Wang Jing

About Us

Data派THU, as a data science public account, backed by Tsinghua University Big Data Research Center, shares cutting-edge data science and big data technology innovation research dynamics, continuously disseminating data science knowledge, striving to build a platform for gathering data talents, and creating the strongest group of data in China.

Exploring 7 Core Functions of torch.utils.data in PyTorch

Sina Weibo: @数据派THU

WeChat Video Account: 数据派THU

Today’s Headlines: 数据派THU

Leave a Comment