Introduction to Image Segmentation Loss Functions in TensorFlow

Commonly used image segmentation loss functions include binary cross-entropy, dice coefficient, Tversky, and Focal Loss. Today, I will reproduce the above loss functions in TensorFlow and compare the results.

1. Cross Entropy

The cross-entropy loss function compares the predicted class values with the target values on a pixel-by-pixel basis, and then averages the values over all pixels. The formula is as follows, where p is the true class value, and p’ is the predicted probability of belonging to class 1.

Introduction to Image Segmentation Loss Functions in TensorFlow

This function has equal weights for each class, making it susceptible to class imbalance.

The reproduction code is as follows:

def binary_crossentropy(Y_pred, Y_gt):    epsilon = 1.e-5    Y_pred = tf.clip_by_value(Y_pred, epsilon, 1. - epsilon)    logits = tf.log(Y_pred / (1 - Y_pred))    loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=Y_gt, logits=logits)    loss = tf.reduce_mean(loss)    return loss

2. Dice Loss

Dice loss is used in the V-net model, where the regions of interest typically occupy a relatively small area, thus increasing the weight of the foreground region can reduce the impact of class imbalance. The formula is as follows, where TP, FP, and FN are the counts of true positives, false positives, and false negatives, respectively.

Introduction to Image Segmentation Loss Functions in TensorFlow

In some papers, the Dice loss calculation formula can also be expressed as follows, where p is the true class value (0 or 1), and p’ is the predicted probability value (0~1).

Introduction to Image Segmentation Loss Functions in TensorFlow

The reproduction code is as follows:

def binary_dice(Y_pred, Y_gt):    smooth = 1.e-5    smooth_tf = tf.constant(smooth, tf.float32)    pred_flat = tf.cast(Y_pred, tf.float32)    true_flat = tf.cast(Y_gt, tf.float32)    Z, H, W, C = Y_gt.get_shape().as_list()[1:]    pred_flat = tf.reshape(pred_flat, [-1, H * W * C * Z])    true_flat = tf.reshape(true_flat, [-1, H * W * C * Z])    intersection = 2 * tf.reduce_sum(pred_flat * true_flat, axis=-1) + smooth_tf    denominator = tf.reduce_sum(pred_flat, axis=-1) + tf.reduce_sum(true_flat, axis=-1) + smooth_tf    loss = -tf.reduce_mean(intersection / denominator)    return loss

3. Tversky Loss

Tversky loss is a general expression of Dice loss, which adds weight factors to the false positive and false negative regions. The formula is as follows, where p is the true class value (0 or 1), and p’ is the predicted probability value (0~1). It can be observed that when the beta value is 0.5, Tversky loss becomes Dice loss.

Introduction to Image Segmentation Loss Functions in TensorFlow

The reproduction code is as follows:

def binary_tversky(Y_pred, Y_gt, beta):    smooth = 1.e-5    smooth_tf = tf.constant(smooth, tf.float32)    pred_flat = tf.cast(Y_pred, tf.float32)    true_flat = tf.cast(Y_gt, tf.float32)    Z, H, W, C = Y_gt.get_shape().as_list()[1:]    pred_flat = tf.reshape(pred_flat, [-1, H * W * C * Z])    true_flat = tf.reshape(true_flat, [-1, H * W * C * Z])    intersection = tf.reduce_sum(pred_flat * true_flat, axis=-1)    denominator = intersection + tf.reduce_sum(beta * pred_flat * (1 - true_flat), axis=-1) + tf.reduce_sum(        (1 - beta) * true_flat * (1 - pred_flat), axis=-1)    loss = -tf.reduce_mean((intersection + smooth_tf) / (denominator + smooth_tf))    return loss

4. Focal Loss

Focal loss is an improvement of the Cross Entropy function, which reduces the loss weight of simple samples, allowing the network to focus more on the loss of difficult samples. The formula is as follows, where p is the true class value, and p’ is the predicted probability of belonging to class 1.

Introduction to Image Segmentation Loss Functions in TensorFlow

The reproduction code is as follows:

def binary_focalloss(Y_pred, Y_gt, alpha=0.25, gamma=2.):    epsilon = 1.e-5    pt_1 = tf.where(tf.equal(Y_gt, 1), Y_pred, tf.ones_like(Y_pred))    pt_0 = tf.where(tf.equal(Y_gt, 0), Y_pred, tf.zeros_like(Y_pred))    # clip to prevent NaN's and Inf's    pt_1 = tf.clip_by_value(pt_1, epsilon, 1. - epsilon)    pt_0 = tf.clip_by_value(pt_0, epsilon, 1. - epsilon)    loss_1 = alpha * tf.pow(1. - pt_1, gamma) * tf.log(pt_1)    loss_0 = (1 - alpha) * tf.pow(pt_0, gamma) * tf.log(1. - pt_0)    loss = -tf.reduce_sum(loss_1 + loss_0)    loss = tf.reduce_mean(loss)    return loss

5. Cross Entropy + Dice Loss

Some articles combine different loss functions to train networks. The paper published by Tencent Medical AI Lab titled “AnatomyNet: Deep Learning for Fast and Fully Automated Whole-volume Segmentation of Head and Neck Anatomy” proposed using Dice loss + Focal loss to address the segmentation problem of small organs. Here, I reproduce the implementation of the Cross Entropy + Dice loss function, as follows:

def binary_dicePcrossentroy(Y_pred, Y_gt):    # step 1, calculate binary cross-entropy    epsilon = 1.e-5    Y_pred = tf.clip_by_value(Y_pred, epsilon, 1. - epsilon)    logits = tf.log(Y_pred / (1 - Y_pred))    loss1 = tf.nn.sigmoid_cross_entropy_with_logits(labels=Y_gt, logits=logits)    loss1 = tf.reduce_mean(loss1)    # step 2, calculate binary dice    smooth_tf = tf.constant(epsilon, tf.float32)    pred_flat = tf.cast(Y_pred, tf.float32)    true_flat = tf.cast(Y_gt, tf.float32)    Z, H, W, C = Y_gt.get_shape().as_list()[1:]    pred_flat = tf.reshape(pred_flat, [-1, H * W * C * Z])    true_flat = tf.reshape(true_flat, [-1, H * W * C * Z])    intersection = 2 * tf.reduce_sum(pred_flat * true_flat, axis=-1) + smooth_tf    denominator = tf.reduce_sum(pred_flat, axis=-1) + tf.reduce_sum(true_flat, axis=-1) + smooth_tf    loss2 = tf.reduce_mean(1 - intersection / denominator)    # step 3, calculate all loss    loss = loss1 + tf.log1p(loss2)    return loss

Train the above five loss functions, and predict and calculate the dice value on 10 test samples. The results are as follows.

Introduction to Image Segmentation Loss Functions in TensorFlow
To help everyone learn better, I have shared the entire project code on GitHub:
https://github.com/junqiangchen/Image-Segmentation-Loss-Functions
If you think this project is good, I hope you give it a Star and Fork, so that more people can learn.If you encounter any problems, feel free to leave a message, and I will try to answer.

Leave a Comment