Few-Shot NER with Dual-Tower BERT Model

Delivering NLP technical insights to you every day!

Author | SinGaln

Source | PaperWeekly

This is an article from ACL 2022. The overall idea is to use a dual-tower BERT model to encode text tokens and their corresponding labels separately based on meta-learning, and then perform classification on the output obtained from the dot product of the two. The article is not complex overall, with very few formulas, making it relatively easy to understand the author’s thought process. Using a sequence labeling approach for NER is a good idea.

Few-Shot NER with Dual-Tower BERT Model

Paper Title:

Label Semantics for Few Shot Named Entity Recognition

Paper Link:

https://arxiv.org/pdf/2203.08985.pdf

Few-Shot NER with Dual-Tower BERT Model

Model

1.1 Architecture

Few-Shot NER with Dual-Tower BERT Model

Figure 1. Overall architecture of the model

From the above figure, it is clear that the author uses a dual-tower BERT to encode the tokens of the text and the corresponding labels for each token separately. The reasoning behind this method is also simple: since it is a few-shot task, there is not enough data, so the author believes that the label of each token can provide additional semantic information for the token.
The author’s meta-learning adopts a metric-based method, which can be intuitively understood as first calculating the vector representation of each sample token, and then calculating the similarity with the computed label representation. This is visually reflected in the dot product shown in the figure. Then, softmax normalization is performed on the obtained similarity matrix ([batch_size, sequence_length, embed_dim]), and the argmax function is used to take the index with the maximum value in the last dimension, corresponding to the appropriate label list to obtain the label corresponding to the current token.

1.2 Detail

In addition, when representing labels, the author also processes each label, overall divided into the following three steps:
1. Convert the abbreviated labels of words into natural language forms, such as PER–>person, ORG–>organization, LOC–>local, etc.;
2. Convert the beginning and middle markers of annotated labels into natural language forms, for example, those marked in BIO form can be converted to begin, inside, other, etc. Similar for other annotation forms.
3. After converting according to the first two steps, combine them, for example, B-PER–>begin person, I-PER–>inside person.
Since this is a few-shot NER task, the author trains the model on multiple source datasets, and then they validate the performance of the model with and without fine-tuning on multiple unseen few-shot target datasets.
When encoding tokens, the corresponding vector for each token can be obtained through the BERT model, as shown below:

Few-Shot NER with Dual-Tower BERT Model

It should be noted that the output of the BERT model takes the last_hidden_state as the corresponding vector for the token.
When encoding labels, all labels in the label set are encoded accordingly, and the encoding obtained for each complete label takes a part as its encoding vector, and all label encodings are combined into a vector set. Finally, the dot product between each and the form is calculated as follows:

Few-Shot NER with Dual-Tower BERT Model

Because this method uses label encoding representation, compared to other NER methods, when the model encounters new data and labels, there is no need to initialize a new top-level classifier, thus achieving the few-shot purpose.

1.3 Label Transfer

In the article, the author also lists the label conversion table for the experimental datasets, part of which is shown below:

Few-Shot NER with Dual-Tower BERT Model

Figure 2. Label Transfer for Experimental Datasets

1.4 Support Set Sampling Algorithm

The sampling pseudocode is as follows:

Few-Shot NER with Dual-Tower BERT Model

Figure 3. Sampling Pseudocode

Few-Shot NER with Dual-Tower BERT Model

Experimental Results

Few-Shot NER with Dual-Tower BERT Model

Figure 4. Some Experimental Results

From the experimental results, it can be clearly felt that this method performs well in few-shot scenarios; the model’s performance is better than other models in the 1-50 shot range, indicating the effectiveness of label semantics. However, under full data, this method has some drawbacks, indicating that the larger the data volume, the less the model relies on label semantics. Here, the author has a thought that under full data, the introduction of label semantics may cause a slight shift in the original text semantics. Of course, this statement also holds in few-shot scenarios, but the shift in few-shot is a positive shift that can enhance the model’s generalization ability, while the shift under full data feels a bit like overflow.
Implementation of Dual-Tower BERT Code (not adopting metric-based method):
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @Time    : 2022/5/23 13:49
# @Author  : SinGaln

import torch
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel


class SinusoidalPositionEmbedding(nn.Module):
    """Define Sin-Cos Position Embedding"""

    def __init__(self, output_dim, merge_mode='add'):
        super(SinusoidalPositionEmbedding, self).__init__()
        self.output_dim = output_dim
        self.merge_mode = merge_mode

    def forward(self, inputs):
        input_shape = inputs.shape
        batch_size, seq_len = input_shape[0], input_shape[1]
        position_ids = torch.arange(seq_len, dtype=torch.float)[None]
        indices = torch.arange(self.output_dim // 2, dtype=torch.float)
        indices = torch.pow(10000.0, -2 * indices / self.output_dim)
        embeddings = torch.einsum('bn,d->bnd', position_ids, indices)
        embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
        embeddings = embeddings.repeat((batch_size, *([1] * len(embeddings.shape))))
        embeddings = torch.reshape(embeddings, (batch_size, seq_len, self.output_dim))
        if self.merge_mode == 'add':
            return inputs + embeddings.to(inputs.device)
        elif self.merge_mode == 'mul':
            return inputs * (embeddings + 1.0).to(inputs.device)
        elif self.merge_mode == 'zero':
            return embeddings.to(inputs.device)


class DoubleTownNER(BertPreTrainedModel):
    def __init__(self, config, num_labels, position=False):
        super(DoubleTownNER, self).__init__(config)
        self.position = position
        self.num_labels = num_labels
        self.bert = BertModel(config=config)
        self.fc = nn.Linear(config.hidden_size, self.num_labels)

        if self.position:
            self.sinposembed = SinusoidalPositionEmbedding(config.hidden_size, "add")

    def forward(self, sequence_input_ids, sequence_attention_mask, sequence_token_type_ids, label_input_ids,
                label_attention_mask, label_token_type_ids):
        # Get the encode of text and labels
        # [batch_size, sequence_length, embed_dim]
        sequence_outputs = self.bert(input_ids=sequence_input_ids, attention_mask=sequence_attention_mask,
                                     token_type_ids=sequence_token_type_ids).last_hidden_state
        # [batch_size, embed_dim]
        label_outputs = self.bert(input_ids=label_input_ids, attention_mask=label_attention_mask,
                                  token_type_ids=label_token_type_ids).pooler_output
        label_outputs = label_outputs.unsqueeze(1)

        # Position vector
        if self.position:
            sequence_outputs = self.sinposembed(sequence_outputs)
        # Dot interaction
        interactive_output = sequence_outputs * label_outputs
        # Full connection
        outputs = self.fc(interactive_output)
        return outputs

if __name__=="__main__":
    pretrain_path = "../bert_model"
    from transformers import BertConfig

    token_input_ids = torch.randint(1, 100, (32, 128))
    token_attention_mask = torch.ones_like(token_input_ids)
    token_token_type_ids = torch.zeros_like(token_input_ids)

    label_input_ids = torch.randint(1, 10, (1, 10))
    label_attention_mask = torch.ones_like(label_input_ids)
    label_token_type_ids = torch.zeros_like(label_input_ids)
    config = BertConfig.from_pretrained(pretrain_path)
    model = DoubleTownNER.from_pretrained(pretrain_path, config=config, num_labels=10, position=True)

    outs = model(sequence_input_ids=token_input_ids, sequence_attention_mask=token_attention_mask, sequence_token_type_ids=token_token_type_ids, label_input_ids=label_input_ids,
                label_attention_mask=label_attention_mask, label_token_type_ids=label_token_type_ids)
    print(outs, outs.size())

📝 Submit your article interpretation for it to be seen by more people from different backgrounds and directions, and perhaps increase citations!

Recent Articles

Which conference to submit to, EMNLP 2022 or COLING 2022?

A brand new and easy-to-use unified model for NER based on Word-Word relationships

Alibaba + Peking University | Simple masking on gradients has such miraculous effects

ACL’22 | Kuaishou + Chinese Academy of Sciences propose a data augmentation method: Text Smoothing

Download one: Chinese version! Learn TensorFlow, PyTorch, Machine Learning, Deep Learning, and Data Structures five-piece set! Reply with 【Five-piece set】 in the background.

Download two: Nanjing University Pattern Recognition PPT Reply with 【Nanjing University Pattern Recognition】 in the background.

For submission or learning exchange, please note:Nickname-School (Company)-Direction, to join the DL&NLP group.

There are many directions:Machine Learning, Deep Learning, Python, Sentiment Analysis, Opinion Mining, Syntactic Analysis, Machine Translation, Human-Computer Dialogue, Knowledge Graph, Speech Recognition, etc..

Few-Shot NER with Dual-Tower BERT Model

Remember to note!

It's not easy to organize, so please give a look!

Leave a Comment