Commonly used image segmentation loss functions include binary cross-entropy, dice coefficient, Tversky, and Focal Loss. Today, I will reproduce the above loss functions in TensorFlow and compare the results.
1. Cross Entropy
The cross-entropy loss function compares the predicted class values with the target values on a pixel-by-pixel basis, and then averages the values over all pixels. The formula is as follows, where p is the true class value, and p’ is the predicted probability of belonging to class 1.
This function has equal weights for each class, making it susceptible to class imbalance.
The reproduction code is as follows:
def binary_crossentropy(Y_pred, Y_gt): epsilon = 1.e-5 Y_pred = tf.clip_by_value(Y_pred, epsilon, 1. - epsilon) logits = tf.log(Y_pred / (1 - Y_pred)) loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=Y_gt, logits=logits) loss = tf.reduce_mean(loss) return loss
2. Dice Loss
Dice loss is used in the V-net model, where the regions of interest typically occupy a relatively small area, thus increasing the weight of the foreground region can reduce the impact of class imbalance. The formula is as follows, where TP, FP, and FN are the counts of true positives, false positives, and false negatives, respectively.
In some papers, the Dice loss calculation formula can also be expressed as follows, where p is the true class value (0 or 1), and p’ is the predicted probability value (0~1).
The reproduction code is as follows:
def binary_dice(Y_pred, Y_gt): smooth = 1.e-5 smooth_tf = tf.constant(smooth, tf.float32) pred_flat = tf.cast(Y_pred, tf.float32) true_flat = tf.cast(Y_gt, tf.float32) Z, H, W, C = Y_gt.get_shape().as_list()[1:] pred_flat = tf.reshape(pred_flat, [-1, H * W * C * Z]) true_flat = tf.reshape(true_flat, [-1, H * W * C * Z]) intersection = 2 * tf.reduce_sum(pred_flat * true_flat, axis=-1) + smooth_tf denominator = tf.reduce_sum(pred_flat, axis=-1) + tf.reduce_sum(true_flat, axis=-1) + smooth_tf loss = -tf.reduce_mean(intersection / denominator) return loss
3. Tversky Loss
Tversky loss is a general expression of Dice loss, which adds weight factors to the false positive and false negative regions. The formula is as follows, where p is the true class value (0 or 1), and p’ is the predicted probability value (0~1). It can be observed that when the beta value is 0.5, Tversky loss becomes Dice loss.
The reproduction code is as follows:
def binary_tversky(Y_pred, Y_gt, beta): smooth = 1.e-5 smooth_tf = tf.constant(smooth, tf.float32) pred_flat = tf.cast(Y_pred, tf.float32) true_flat = tf.cast(Y_gt, tf.float32) Z, H, W, C = Y_gt.get_shape().as_list()[1:] pred_flat = tf.reshape(pred_flat, [-1, H * W * C * Z]) true_flat = tf.reshape(true_flat, [-1, H * W * C * Z]) intersection = tf.reduce_sum(pred_flat * true_flat, axis=-1) denominator = intersection + tf.reduce_sum(beta * pred_flat * (1 - true_flat), axis=-1) + tf.reduce_sum( (1 - beta) * true_flat * (1 - pred_flat), axis=-1) loss = -tf.reduce_mean((intersection + smooth_tf) / (denominator + smooth_tf)) return loss
4. Focal Loss
Focal loss is an improvement of the Cross Entropy function, which reduces the loss weight of simple samples, allowing the network to focus more on the loss of difficult samples. The formula is as follows, where p is the true class value, and p’ is the predicted probability of belonging to class 1.
The reproduction code is as follows:
def binary_focalloss(Y_pred, Y_gt, alpha=0.25, gamma=2.): epsilon = 1.e-5 pt_1 = tf.where(tf.equal(Y_gt, 1), Y_pred, tf.ones_like(Y_pred)) pt_0 = tf.where(tf.equal(Y_gt, 0), Y_pred, tf.zeros_like(Y_pred)) # clip to prevent NaN's and Inf's pt_1 = tf.clip_by_value(pt_1, epsilon, 1. - epsilon) pt_0 = tf.clip_by_value(pt_0, epsilon, 1. - epsilon) loss_1 = alpha * tf.pow(1. - pt_1, gamma) * tf.log(pt_1) loss_0 = (1 - alpha) * tf.pow(pt_0, gamma) * tf.log(1. - pt_0) loss = -tf.reduce_sum(loss_1 + loss_0) loss = tf.reduce_mean(loss) return loss
5. Cross Entropy + Dice Loss
Some articles combine different loss functions to train networks. The paper published by Tencent Medical AI Lab titled “AnatomyNet: Deep Learning for Fast and Fully Automated Whole-volume Segmentation of Head and Neck Anatomy” proposed using Dice loss + Focal loss to address the segmentation problem of small organs. Here, I reproduce the implementation of the Cross Entropy + Dice loss function, as follows:
def binary_dicePcrossentroy(Y_pred, Y_gt): # step 1, calculate binary cross-entropy epsilon = 1.e-5 Y_pred = tf.clip_by_value(Y_pred, epsilon, 1. - epsilon) logits = tf.log(Y_pred / (1 - Y_pred)) loss1 = tf.nn.sigmoid_cross_entropy_with_logits(labels=Y_gt, logits=logits) loss1 = tf.reduce_mean(loss1) # step 2, calculate binary dice smooth_tf = tf.constant(epsilon, tf.float32) pred_flat = tf.cast(Y_pred, tf.float32) true_flat = tf.cast(Y_gt, tf.float32) Z, H, W, C = Y_gt.get_shape().as_list()[1:] pred_flat = tf.reshape(pred_flat, [-1, H * W * C * Z]) true_flat = tf.reshape(true_flat, [-1, H * W * C * Z]) intersection = 2 * tf.reduce_sum(pred_flat * true_flat, axis=-1) + smooth_tf denominator = tf.reduce_sum(pred_flat, axis=-1) + tf.reduce_sum(true_flat, axis=-1) + smooth_tf loss2 = tf.reduce_mean(1 - intersection / denominator) # step 3, calculate all loss loss = loss1 + tf.log1p(loss2) return loss
Train the above five loss functions, and predict and calculate the dice value on 10 test samples. The results are as follows.
