ACL2022 | Dual Tower BERT Model with Label Semantics

Delivering NLP technology insights to you every day!

Source: Alchemy Notes

Author: SinGaln
This is a paper from ACL 2022. The overall idea is to use a dual tower BERT model based on meta-learning to encode text tokens and their corresponding labels separately, and then perform a classification task using the output obtained from their dot product. The paper is not overly complex, and there are few formulas involved, making it relatively easy to understand the author’s thought process. Using a sequence labeling approach for NER is a good idea.

ACL2022 | Dual Tower BERT Model with Label Semantics

Paper Title:

Label Semantics for Few Shot Named Entity Recognition

Paper Link:

https://arxiv.org/pdf/2203.08985.pdf

Model
ACL2022 | Dual Tower BERT Model with Label Semantics

1.1 Architecture

ACL2022 | Dual Tower BERT Model with Label Semantics

Figure 1. Overall Model Architecture

From the above figure, it can be clearly seen that the author uses a dual tower BERT to encode the tokens of the text and the corresponding labels for each token. The reason for using this method is simple; since it is a few-shot task, there is not enough data, so the author believes that the label of each token can provide additional semantic information for the token.
The author’s meta-learning adopts a metric-based method, which can be intuitively understood as first calculating the vector representation of each sample token, and then calculating the similarity with the calculated label representation, which is intuitively reflected in the dot product shown in the figure. Then, the obtained similarity matrix ([batch_size, sequence_length, embed_dim]) is normalized using softmax, and the argmax function is used to take the index with the highest value in the last dimension, corresponding to the label list, to obtain the label corresponding to the current token.

1.2 Detail

In addition, when representing the labels, the author also processed each label accordingly, divided into the following three steps:
1. Convert the abbreviated labels to natural language forms, for example, PER–>person, ORG–>organization, LOC–>local, etc.;
2. Convert the start and middle markers of the annotated labels to natural language forms, for example, those marked in BIO format can be converted to begin, inside, other, etc., and similar for other annotation formats.
3. After converting according to the first two methods, combine them, for example, B-PER–>begin person, I-PER–>inside person.
Since it is a few-shot NER task, the author trained the model on multiple source datasets and then validated the performance of the models with and without fine-tuning on multiple unseen few-shot target datasets.
When encoding tokens, the corresponding vectors can be obtained through the BERT model, as shown below:

ACL2022 | Dual Tower BERT Model with Label Semantics

It should be noted that the output of the BERT model takes the last_hidden_state as the vector for the corresponding token.
When encoding the labels, all labels in the label set are encoded correspondingly, and each complete label obtains a portion of its encoding vector, and all label encodings form a vector set. Finally, the dot product between each and is calculated, as shown below:

ACL2022 | Dual Tower BERT Model with Label Semantics

Since the label encoding representation method is used here, compared to other NER methods, when the model encounters new data and labels, it does not need to initialize a new top-level classifier, thus achieving the goal of few-shot.

1.3 Label Transfer

In the paper, the author also lists the label conversion table for the experimental datasets, part of which is shown below:

ACL2022 | Dual Tower BERT Model with Label Semantics

Figure 2. Label Transfer for Experimental Datasets

1.4 Support Set Sampling Algorithm

The sampling pseudocode is shown below:

ACL2022 | Dual Tower BERT Model with Label Semantics

Figure 3. Sampling Pseudocode

Experimental Results
ACL2022 | Dual Tower BERT Model with Label Semantics

ACL2022 | Dual Tower BERT Model with Label Semantics

Figure 4. Partial Experimental Results

From the experimental results, it can be clearly felt that this method performs well in few-shot scenarios, with the model’s performance exceeding that of other models in the 1-50 shot range, indicating the effectiveness of label semantics; however, under full data, this method shows some discount, indicating that the larger the data volume, the less the model relies on label semantics. Here, the author also has a thought that under full data, the introduction of label semantics may slightly shift the original text semantics, of course, this statement also holds under few-shot, but the shift under few-shot is a positive shift that can enhance the model’s generalization ability, while the shift under full data feels a bit excessive.
Implementation of Dual Tower BERT Code (not using metric-based method):
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @Time    : 2022/5/23 13:49
# @Author  : SinGaln

import torch
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel


class SinusoidalPositionEmbedding(nn.Module):
    """Define Sin-Cos Position Embedding"""

    def __init__(self, output_dim, merge_mode='add'):
        super(SinusoidalPositionEmbedding, self).__init__()
        self.output_dim = output_dim
        self.merge_mode = merge_mode

    def forward(self, inputs):
        input_shape = inputs.shape
        batch_size, seq_len = input_shape[0], input_shape[1]
        position_ids = torch.arange(seq_len, dtype=torch.float)[None]
        indices = torch.arange(self.output_dim // 2, dtype=torch.float)
        indices = torch.pow(10000.0, -2 * indices / self.output_dim)
        embeddings = torch.einsum('bn,d->bnd', position_ids, indices)
        embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
        embeddings = embeddings.repeat((batch_size, *([1] * len(embeddings.shape))))
        embeddings = torch.reshape(embeddings, (batch_size, seq_len, self.output_dim))
        if self.merge_mode == 'add':
            return inputs + embeddings.to(inputs.device)
        elif self.merge_mode == 'mul':
            return inputs * (embeddings + 1.0).to(inputs.device)
        elif self.merge_mode == 'zero':
            return embeddings.to(inputs.device)


class DoubleTownNER(BertPreTrainedModel):
    def __init__(self, config, num_labels, position=False):
        super(DoubleTownNER, self).__init__(config)
        self.position = position
        self.num_labels = num_labels
        self.bert = BertModel(config=config)
        self.fc = nn.Linear(config.hidden_size, self.num_labels)

        if self.position:
            self.sinposembed = SinusoidalPositionEmbedding(config.hidden_size, "add")

    def forward(self, sequence_input_ids, sequence_attention_mask, sequence_token_type_ids, label_input_ids,
                label_attention_mask, label_token_type_ids):
        # Get the encoding of text and labels
        # [batch_size, sequence_length, embed_dim]
        sequence_outputs = self.bert(input_ids=sequence_input_ids, attention_mask=sequence_attention_mask,
                                     token_type_ids=sequence_token_type_ids).last_hidden_state
        # [batch_size, embed_dim]
        label_outputs = self.bert(input_ids=label_input_ids, attention_mask=label_attention_mask,
                                  token_type_ids=label_token_type_ids).pooler_output
        label_outputs = label_outputs.unsqueeze(1)

        # Position vector
        if self.position:
            sequence_outputs = self.sinposembed(sequence_outputs)
        # Dot interaction
        interactive_output = sequence_outputs * label_outputs
        # full-connection
        outputs = self.fc(interactive_output)
        return outputs

if __name__=="__main__":
    pretrain_path = "../bert_model"
    from transformers import BertConfig

    token_input_ids = torch.randint(1, 100, (32, 128))
    token_attention_mask = torch.ones_like(token_input_ids)
    token_token_type_ids = torch.zeros_like(token_input_ids)

    label_input_ids = torch.randint(1, 10, (1, 10))
    label_attention_mask = torch.ones_like(label_input_ids)
    label_token_type_ids = torch.zeros_like(label_input_ids)
    config = BertConfig.from_pretrained(pretrain_path)
    model = DoubleTownNER.from_pretrained(pretrain_path, config=config, num_labels=10, position=True)

    outs = model(sequence_input_ids=token_input_ids, sequence_attention_mask=token_attention_mask, sequence_token_type_ids=token_token_type_ids, label_input_ids=label_input_ids,
                label_attention_mask=label_attention_mask, label_token_type_ids=label_token_type_ids)
    print(outs, outs.size())

📝 Paper submission interpretation, let your article be seen by more people from different backgrounds and directions, not sinking into oblivion, and perhaps increasing citations!

Recent Articles

Which conference to submit to: EMNLP 2022 or COLING 2022?

A new easy-to-use unified model based on Word-Word relationships for NER

Alibaba + Peking University | Amazing effects from simple masking on gradients

ACL’22 | Kuaishou + CAS proposed a data augmentation method: Text Smoothing

For submissions or learning exchanges, please note:Nickname-School (Company)-Field, join the DL&NLP group.

There are many fields:Machine Learning, Deep Learning, Python, Sentiment Analysis, Opinion Mining, Syntax Analysis, Machine Translation, Human-Computer Dialogue, Knowledge Graphs, Speech Recognition, etc..

ACL2022 | Dual Tower BERT Model with Label Semantics

Remember to note!

It is not easy to organize, please give a thumbs up!

Leave a Comment