13 Essential Features of PyTorch You Must Know

Click on the aboveBeginner Learning Vision”, select to add “Starred” or “Pinned

Essential content delivered promptly

Compiled by | ronghuaiyang
Source | Frontiers of Artificial Intelligence

PyTorch has gained a lot of attention in both academic and industrial research applications. It is a deep learning framework with great flexibility, utilizing a wealth of practical tools and functions to speed up workflows. The learning curve of PyTorch is not very steep, but achieving efficient and clean code within it can be tricky. After using it for over 2 years, here are my favorite PyTorch features that I wish I had known when I first started learning it.

1 DatasetFolder

One of the first things people do when learning PyTorch is to implement some kind of custom Dataset. This is a rookie mistake, and it’s unnecessary to waste time writing such things. Typically, datasets are either a list of data (or numpy arrays) or files on disk. Therefore, organizing your data on disk is better than writing a custom Dataset to load some weird format.

One of the most common data formats for classifiers is having a directory with subfolders, where subfolders represent classes, and files within subfolders represent samples, as shown below.

folder/class_0/file1.txt
folder/class_0/file2.txt
folder/class_0/...

folder/class_1/file3.txt
folder/class_1/file4.txt

folder/class_2/file5.txt
folder/class_2/...

There is a built-in way to load such datasets, regardless of whether your data is images, text files, or anything else, simply by using ‘DatasetFolder’. Surprisingly, this class is part of the torchvision package, not core PyTorch. This class is very comprehensive, allowing you to filter files from folders, load them using custom code, and dynamically transform raw files. Example:

from torchvision.datasets import DatasetFolder
from pathlib import Path
# I have text files in this folder
ds = DatasetFolder("/Users/marcin/Dev/tmp/my_text_dataset", 
    loader=lambda path: Path(path).read_text(),
    extensions=(".txt",), #only load .txt files
    transform=lambda text: text[:100], # only take first 100 characters
)

# Everything you need is already there
len(ds), ds.classes, ds.class_to_idx
(20, ['novels', 'thrillers'], {'novels': 0, 'thrillers': 1})

If you are dealing with images, there is also a torchvision.datasets.ImageFolder class, which is based on DatasetLoader and is pre-configured to load images.

2 Minimize the Use of .to(device), Use zeros_like/ones_like Instead

I have read a lot of PyTorch code from GitHub repositories. What annoys me the most is that almost every repo has numerous *.to(device) lines that transfer data from CPU or GPU to elsewhere. Such statements often appear in many repos or beginner tutorials. I strongly recommend minimizing the implementation of such operations and relying on built-in PyTorch functionalities to handle them automatically. Using .to(device) everywhere often leads to performance degradation and can cause exceptions:

Expected object of device type cuda but got device type cpu

Obviously, in some cases, you cannot avoid it, but most cases (if not all) can be handled here. One such case is initializing a tensor of all zeros or all ones, which often occurs when computing losses in deep neural networks, where the model’s output is already on cuda, and you need another tensor to be on cuda as well; in this case, you can use *_like operators:

my_output # on any device, if it's cuda then my_zeros will also be on cuda
my_zeros = torch.zeros_like(my_output_from_model)

Internally, what PyTorch does is call the following operation:

my_zeros = torch.zeros(my_output.size(), dtype=my_output.dtype, layout=my_output.layout, device=my_output.device)

So all settings are correct, reducing the chance of errors in the code. Similar operations include:

torch.zeros_like()
torch.ones_like()
torch.rand_like()
torch.randn_like()
torch.randint_like()
torch.empty_like()
torch.full_like()

3 Register Buffer ( nn.Module.register_buffer)

This will be my next advice for people to avoid using .to(device) everywhere. Sometimes, your model or loss function requires pre-set parameters that are used during the forward call, such as a “weight” parameter that scales the loss or some fixed tensor that does not change but is used every time. For this situation, use the nn.Module.register_buffer method, which tells PyTorch to store the value passed to it within the module and move these values along with the module. If you initialize your module and then move it to GPU, these values will also move automatically. Additionally, if you save the module’s state, buffers will also be saved!

Once registered, these values can be accessed in the forward function just like other module attributes.

from torch import nn
import torch

class ModuleWithCustomValues(nn.Module):
    def __init__(self, weights, alpha):
        super().__init__()
        self.register_buffer("weights", torch.tensor(weights))
        self.register_buffer("alpha", torch.tensor(alpha))
    
    def forward(self, x):
        return x * self.weights + self.alpha

m = ModuleWithCustomValues(
    weights=[1.0, 2.0], alpha=1e-4
)
m(torch.tensor([1.23, 4.56]))
tensor([1.2301, 9.1201])

4 Built-in Identity()

Sometimes, when you use transfer learning, you need to replace some layers with a 1:1 mapping, which can be achieved using nn.Module, simply returning the input value. PyTorch has this class built-in.

For example, if you want to get image representations from a pre-trained ResNet50 before the classification layer, here’s how to do it:

from torchvision.models import resnet50
model = resnet50(pretrained=True)
model.fc = nn.Identity()
last_layer_output = model(torch.rand((1, 3, 224, 224)))
last_layer_output.shape
torch.Size([1, 2048])

5 Pairwise distances: torch.cdist

Next time you encounter the problem of calculating the Euclidean distance (or generally the p-norm) between two tensors, remember torch.cdist. It does exactly that and automatically uses matrix multiplication when using Euclidean distance to improve performance.

points1 = torch.tensor([[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]])
points2 = torch.tensor([[0.0, 0.0], [-1.0, -1.0], [-2.0, -2.0], [-3.0, -3.0]]) # batches don't have to be equal
torch.cdist(points1, points2, p=2.0)
tensor([[0.0000, 1.4142, 2.8284, 4.2426],
        [1.4142, 2.8284, 4.2426, 5.6569],
        [2.8284, 4.2426, 5.6569, 7.0711]])

No matrix multiplication or with matrix multiplication performance, using mm on my machine was more than twice as fast.

%%timeit
points1 = torch.rand((512, 2))
points2 = torch.rand((512, 2))
torch.cdist(points1, points2, p=2.0, compute_mode="donot_use_mm_for_euclid_dist")

867µs±142µs per loop (mean±std. dev. of 7 run, 1000 loop each)

%%timeit
points1 = torch.rand((512, 2))
points2 = torch.rand((512, 2))
torch.cdist(points1, points2, p=2.0)

417µs±52.9µs per loop (mean±std. dev. of 7 run, 1000 loop each)

6 Cosine similarity: F.cosine_similarity

Similar to the previous point, calculating Euclidean distance is not always what you need. When dealing with vectors, cosine similarity is often the metric of choice. PyTorch also has a built-in implementation for cosine similarity.

import torch.nn.functional as F
vector1 = torch.tensor([0.0, 1.0])
vectory2 = torch.tensor([0.05, 1.0])
print(F.cosine_similarity(vector1, vector2, dim=0))
vectory3 = torch.tensor([0.0, -1.0])
print(F.cosine_similarity(vector1, vector3, dim=0))
tensor(0.9988)
tensor(-1.)

Batch Cosine Similarity in PyTorch

import torch.nn.functional as F
batch_of_vectors = torch.rand((4, 64))
similarity_matrix = F.cosine_similarity(batch_of_vectors.unsqueeze(1), batch_of_vectors.unsqueeze(0), dim=2)
similarity_matrix
tensor([[1.0000, 0.6922, 0.6480, 0.6789],
        [0.6922, 1.0000, 0.7143, 0.7172],
        [0.6480, 0.7143, 1.0000, 0.7312],
        [0.6789, 0.7172, 0.7312, 1.0000]])

7 Normalizing Vectors: F.normalize

The last point is still loosely related to vectors and distances, which is normalization: typically done to improve computational stability by changing the scale of the vector. The most common normalization is L2, which can be applied in PyTorch as follows:

vector = torch.tensor([99.0, -512.0, 123.0, 0.1, 6.66])
normalized_vector = F.normalize(vector, p=2.0, dim=0)
normalized_vector
tensor([ 1.8476e-01, -9.5552e-01,  2.2955e-01,  1.8662e-04,  1.2429e-02])

The old way of performing normalization in PyTorch was:

vector = torch.tensor([99.0, -512.0, 123.0, 0.1, 6.66])
normalized_vector = vector / torch.norm(vector, p=2.0)
normalized_vector
tensor([ 1.8476e-01, -9.5552e-01,  2.2955e-01,  1.8662e-04,  1.2429e-02])

Batch L2 Normalization in PyTorch

batch_of_vectors = torch.rand((4, 64))
normalized_batch_of_vectors = F.normalize(batch_of_vectors, p=2.0, dim=1)
normalized_batch_of_vectors.shape, torch.norm(normalized_batch_of_vectors, dim=1) # all vectors will have length of 1.0
(torch.Size([4, 64]), tensor([1.0000, 1.0000, 1.0000, 1.0000]))

8 Linear Layer + Chunking Trick (torch.chunk)

This is a creative trick I recently discovered. Suppose you want to map your input to N different linear projections. You can do this by creating N nn.Linear layers. Alternatively, you can create a single linear layer, perform a forward pass, and then split the output into N chunks. This method typically yields higher performance, so it’s a trick worth remembering.

d = 1024
batch = torch.rand((8, d))
layers = nn.Linear(d, 128, bias=False), nn.Linear(d, 128, bias=False), nn.Linear(d, 128, bias=False)
one_layer = nn.Linear(d, 128 * 3, bias=False)
%%timeit
o1 = layers[0](batch)
o2 = layers[1](batch)
o3 = layers[2](batch)

289 µs ± 30.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
o1, o2, o3 = torch.chunk(one_layer(batch), 3, dim=1)

202 µs ± 8.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

9 Masked Select (torch.masked_select)

Sometimes you only need to perform calculations on a portion of the input tensor. For example, you want to calculate the loss only on tensors that meet certain conditions. To do this, you can use torch.masked_select, noting that this operation can also be used when gradients are needed.

data = torch.rand((3, 3)).requires_grad_()
print(data)
mask = data > data.mean()
print(mask)
torch.masked_select(data, mask)
tensor([[0.0582, 0.7170, 0.7713],
        [0.9458, 0.2597, 0.6711],
        [0.2828, 0.2232, 0.1981]], requires_grad=True)
tensor([[False,  True,  True],
        [ True, False,  True],
        [False, False, False]])
tensor([0.7170, 0.7713, 0.9458, 0.6711], grad_fn=<maskedselectbackward>)
</maskedselectbackward>

Directly Apply Mask on Tensor

Similar behavior can be achieved by using the mask as an “indexer” for the input tensor.

data[mask]
tensor([0.7170, 0.7713, 0.9458, 0.6711], grad_fn=<indexbackward>)
</indexbackward>

Sometimes, an ideal solution is to fill all False values in the mask with zeros, which can be done as follows:

data * mask
tensor([[0.0000, 0.7170, 0.7713],
        [0.9458, 0.0000, 0.6711],
        [0.0000, 0.0000, 0.0000]], grad_fn=<mulbackward0>)
</mulbackward0>

10 Using torch.where to Apply Conditions on Tensors

This function is useful when you want to combine two tensors based on a condition; if the condition is true, take elements from the first tensor, and if false, take from the second tensor.

x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True)
y = -x
condition_or_mask = x <= 3.0
torch.where(condition_or_mask, x, y)
tensor([ 1.,  2.,  3., -4., -5.], grad_fn=<swherebackward>)
</swherebackward>

11 Filling Tensor Values at Given Positions (Tensor.scatter)

The use case for this function is as follows: you want to fill a tensor with values from another tensor at given positions. One-dimensional tensors are easier to understand, so I will show that first, then move to more advanced examples.

data = torch.tensor([1, 2, 3, 4, 5])
index = torch.tensor([0, 1])
values = torch.tensor([-1, -2, -3, -4, -5])
data.scatter(0, index, values)
tensor([-1, -2,  3,  4,  5])

The above example is simple, but now let’s see what happens if we change the index to index = torch.tensor([0, 1, 4]):

data = torch.tensor([1, 2, 3, 4, 5])
index = torch.tensor([0, 1, 4])
values = torch.tensor([-1, -2, -3, -4, -5])
data.scatter(0, index, values)
tensor([-1, -2,  3,  4, -3])

Why is the last value -3? This is counterintuitive, right? This is the core idea of the PyTorch scatter function. The index variable indicates where the i-th value of the data tensor should be placed in the values tensor. I hope the following simple Python version of this operation helps clarify:

data_orig = torch.tensor([1, 2, 3, 4, 5])
index = torch.tensor([0, 1, 4])
values = torch.tensor([-1, -2, -3, -4, -5])
scattered = data_orig.scatter(0, index, values)

data = data_orig.clone()
for idx_in_values, where_to_put_the_value in enumerate(index):
    what_value_to_put = values[idx_in_values]
    data[where_to_put_the_value] = what_value_to_put
data, scattered
(tensor([-1, -2,  3,  4, -3]), tensor([-1, -2,  3,  4, -3]))

PyTorch Scatter Example for 2D Data

Always remember that the shape of the index relates to the shape of the values, and the values in the index correspond to positions in the data.

data = torch.zeros((4, 4)).float()
index = torch.tensor([
    [0, 1],
    [2, 3],
    [0, 3],
    [1, 2]
])
values = torch.arange(1, 9).float().view(4, 2)
values, data.scatter(1, index, values)
(tensor([[1., 2.],
        [3., 4.],
        [5., 6.],
        [7., 8.]]),
 tensor([[1., 2., 0., 0.],
        [0., 0., 3., 4.],
        [5., 0., 0., 6.],
        [0., 7., 8., 0.]]))

12 Performing Image Interpolation in Networks (F.interpolate)

When I learned PyTorch, I was surprised to find that it is indeed possible to resize images (or any intermediate tensor) during the forward pass while maintaining the gradient flow. This method is particularly useful when using CNNs and GANs.

# image from https://commons.wikimedia.org/wiki/File:A_female_British_Shorthair_at_the_age_of_20_months.jpg
img = Image.open("./cat.jpg")
img
13 Essential Features of PyTorch You Must Know
to_pil_image(
    F.interpolate(to_tensor(img).unsqueeze(0),  # batch of size 1
                  mode="bilinear", 
                  scale_factor=2.0, 
                  align_corners=False).squeeze(0) # remove batch dimension
)
13 Essential Features of PyTorch You Must Know

Check how the gradient flow is preserved:

F.interpolate(to_tensor(img).unsqueeze(0).requires_grad_(),
                  mode="bicubic", 
                  scale_factor=2.0, 
                  align_corners=False)
tensor([[[[0.9216, 0.9216, 0.9216,  ..., 0.8361, 0.8272, 0.8219],
    [0.9214, 0.9214, 0.9214,  ..., 0.8361, 0.8272, 0.8219],
    [0.9212, 0.9212, 0.9212,  ..., 0.8361, 0.8272, 0.8219],
    ...,
    [0.9098, 0.9098, 0.9098,  ..., 0.3592, 0.3486, 0.3421],
    [0.9098, 0.9098, 0.9098,  ..., 0.3566, 0.3463, 0.3400],
    [0.9098, 0.9098, 0.9098,  ..., 0.3550, 0.3449, 0.3387]],

    [[0.6627, 0.6627, 0.6627,  ..., 0.5380, 0.5292, 0.5238],
    [0.6626, 0.6626, 0.6626,  ..., 0.5380, 0.5292, 0.5238],
    [0.6623, 0.6623, 0.6623,  ..., 0.5380, 0.5292, 0.5238],
    ...,
    [0.6196, 0.6196, 0.6196,  ..., 0.3631, 0.3525, 0.3461],
    [0.6196, 0.6196, 0.6196,  ..., 0.3605, 0.3502, 0.3439],
    [0.6196, 0.6196, 0.6196,  ..., 0.3589, 0.3488, 0.3426]],

    [[0.4353, 0.4353, 0.4353,  ..., 0.1913, 0.1835, 0.1787],
    [0.4352, 0.4352, 0.4352,  ..., 0.1913, 0.1835, 0.1787],
    [0.4349, 0.4349, 0.4349,  ..., 0.1913, 0.1835, 0.1787],
    ...,
    [0.3333, 0.3333, 0.3333,  ..., 0.3827, 0.3721, 0.3657],
    [0.3333, 0.3333, 0.3333,  ..., 0.3801, 0.3698, 0.3635],
    [0.3333, 0.3333, 0.3333,  ..., 0.3785, 0.3684, 0.3622]]]], grad_fn=<upsamplebicubic2dbackward1>)
</upsamplebicubic2dbackward1>

13 Creating Image Grids (torchvision.utils.make_grid)

When using PyTorch and torchvision, there’s no need to use matplotlib or some external libraries to copy and paste code to display image grids. Just use torchvision.utils.make_grid.

from torchvision.utils import make_grid
from torchvision.transforms.functional import to_tensor, to_pil_image
from PIL import Image
img = Image.open("./cat.jpg")
to_pil_image(
    make_grid(
        [to_tensor(i) for i in [img, img, img]],
         nrow=2, # number of images in a single row
         padding=5 # "frame" size
     )
)
13 Essential Features of PyTorch You Must Know

Original link: https://zablo.net/blog/post/pytorch-13-features-you-should-know/

Good News!

Beginner Learning Vision Knowledge Planet

Is now open to the public👇👇👇







Download 1: OpenCV-Contrib Extension Module Chinese Version Tutorial
Reply "Extension Module Chinese Tutorial" in the "Beginner Learning Vision" public account backend to download the first Chinese version of the OpenCV extension module tutorial available online, covering more than twenty chapters including extension module installation, SFM algorithms, stereo vision, object tracking, biological vision, super-resolution processing, etc.

Download 2: Python Vision Practical Project 52 Lectures
Reply "Python Vision Practical Project" in the "Beginner Learning Vision" public account backend to download 31 practical vision projects including image segmentation, mask detection, lane line detection, vehicle counting, eyeliner addition, license plate recognition, character recognition, emotion detection, text content extraction, face recognition, etc., to help quickly learn computer vision.

Download 3: OpenCV Practical Projects 20 Lectures
Reply "OpenCV Practical Projects 20 Lectures" in the "Beginner Learning Vision" public account backend to download 20 practical projects based on OpenCV, achieving advanced learning of OpenCV.

Group Chat

Welcome to join the public account reader group to communicate with peers. Currently, there are WeChat groups for SLAM, 3D vision, sensors, autonomous driving, computational photography, detection, segmentation, recognition, medical imaging, GAN, algorithm competitions, etc. (which will gradually be subdivided). Please scan the WeChat number below to join the group, and note: "Nickname + School/Company + Research Direction", for example: "Zhang San + Shanghai Jiao Tong University + Vision SLAM". Please follow the format for notes; otherwise, you will not be approved. After successful addition, you will be invited into the relevant WeChat group based on your research direction. Please do not send advertisements in the group; otherwise, you will be removed from the group. Thank you for your understanding~



Leave a Comment