Unet++ Implementation in PyTorch

Programmers transitioning to AI are paying attention to this accountπŸ‘‡πŸ‘‡πŸ‘‡

Unet++ Network

Dense Connection

Unet++ inherits the structure of Unet while also drawing on the dense connection method of DenseNet (various branches in Figure 1).

Unet++ Implementation in PyTorch

The author connects the layers through dense connections, just like DenseNet, where each module interacts with each other, allowing every module to see one another. This mutual familiarity improves the segmentation results.

In practical segmentation, repeated downsampling naturally loses some detail features. In Unet, skip connections are used to recover these details, but can it be done better? Unet++ provides the answer. This dense connection method aims to preserve as much detail and global information as possible at each layer, building bridges for communication between layers, ultimately sharing with the last layer to achieve the retention and reconstruction of global and local information.

Deep Supervision

Of course, simply connecting various modules can achieve good results. We can also observe that a Unet++ is actually many different depths of Unet++ stacked together. Therefore, can each depth of Unet++ output a loss? The answer is naturally yes.

Thus, the author proposes deep supervision, which supervises the output of each depth of Unet++, combining losses in a certain way (for example, through weighting). This results in a weighted loss from Unet++ at depths 1, 2, 3, and 4 (Figure 2 shows the fusion of different depth Unet++).

Unet++ Implementation in PyTorch

So, what is the use of deep supervision? – Pruning

Since Unet++ is composed of multiple different depths of Unet++, removing any layer will not change the forward propagation gradient. However, if you find that the output of the third Unet++ is similar to that of the fourth, you can confidently remove the fourth depth of Unet++. For example, directly removing the brown part in Figure 3 achieves pruning. This results in a more lightweight network.

Unet++ Implementation in PyTorch

Model Reproduction

Unet++

To make it more intuitive, I’ve matched all the symbols in the code with the corresponding ones in the network structure.

Unet++ Implementation in PyTorch

Unet++ Implementation in PyTorch

Unet++ Implementation in PyTorch

Unet++ Implementation in PyTorch

Unet++ Implementation in PyTorch

Unet++ Implementation in PyTorch

Unet++ Implementation in PyTorch

Unet++ Implementation in PyTorch

Dataset Preparation

The dataset uses the Camvid dataset, and you can refer to the construction method in the CamVid dataset creation and usage – PyTorch.

https://blog.csdn.net/yumaomi/article/details/124786867

Unet++ Implementation in PyTorch

Unet++ Implementation in PyTorch

Unet++ Implementation in PyTorch

Unet++ Implementation in PyTorch

Unet++ Implementation in PyTorch

Unet++ Implementation in PyTorch

Unet++ Implementation in PyTorch

Unet++ Implementation in PyTorch

Training Results

Unet++ Implementation in PyTorch

Original article address

https://blog.csdn.net/yumaomi/article/details/124823392

Unet++ Implementation in PyTorch

Leave a Comment