您当前的位置:首页 > IT编程 > 情感分类
| C语言 | Java | VB | VC | python | Android | TensorFlow | C++ | oracle | 学术与代码 | cnn卷积神经网络 | gnn | 图像修复 | Keras | 数据集 | Neo4j | 自然语言处理 | 深度学习 | 医学CAD | 医学影像 | 超参数 | pointnet | pytorch | 异常检测 | Transformers | 情感分类 |

自学教程:python - 如何在 bert 模型上添加 Bi-LSTM 层?

51自学网 2023-12-16 11:31:27
  情感分类
这篇教程python - 如何在 bert 模型上添加 Bi-LSTM 层?写得很实用,希望能帮到您。

python - 如何在 bert 模型上添加 Bi-LSTM 层?

 

 

我正在使用 pytorch 并且我正在使用基础 pretrained bert 对仇恨言论的句子进行分类。 我想实现一个 Bi-LSTM 层,将最新的所有输出作为输入 来自 bert 模型的变压器编码器作为新模型(实现 nn.Module 的类),我对 nn.LSTM 参数感到困惑。 我使用

标记了数据
bert = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=int(data['class'].nunique()),output_attentions=False,output_hidden_states=False)

我的数据集有 2 列:类(标签)、句子。 有人可以帮我弄这个吗? 提前谢谢你。

编辑: 此外,在处理完 bi-lstm 中的输入后,网络将最终隐藏状态发送到使用 softmax 激活函数执行分类的全连接网络。我该怎么做?

 

最佳答案

 

你可以这样做:

from transformers import BertModel
class CustomBERTModel(nn.Module):
    def __init__(self):
          super(CustomBERTModel, self).__init__()
          self.bert = BertModel.from_pretrained("bert-base-uncased")
          ### New layers:
          self.lstm = nn.LSTM(768, 256, batch_first=True,bidirectional=True)
          self.linear = nn.Linear(256*2, <number_of_classes>)
          

    def forward(self, ids, mask):
          sequence_output, pooled_output = self.bert(
               ids, 
               attention_mask=mask)

          # sequence_output has the following shape: (batch_size, sequence_length, 768)
          lstm_output, (h,c) = self.lstm(sequence_output) ## extract the 1st token's embeddings
          hidden = torch.cat((lstm_output[:,-1, :256],lstm_output[:,0, 256:]),dim=-1)
          linear_output = self.linear(hidden.view(-1,256*2)) ### assuming that you are only using the output of the last LSTM cell to perform classification

          return linear_output

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
model = CustomBERTModel()


返回列表
虚假新闻检测,来自美团NLP团队方案
51自学网自学EXCEL、自学PS、自学CAD、自学C语言、自学css3实例,是一个通过网络自主学习工作技能的自学平台,网友喜欢的软件自学网站。
京ICP备13026421号-1