ner.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. from pathlib import Path
  2. from transformers import AutoTokenizer, AutoModelForTokenClassification
  3. import torch
  4. import bio
  5. import polars as pl
  6. from data_preparation import demo_file_name, columns, desc_file_name
  7. # entity_types = ["Material", "Component", "Equipment", "Process", "Organization", "Standard"]
  8. from bio import entity_types
  9. label_map = {"O": 0}
  10. index = 1
  11. for tag in set(entity_types.values()):
  12. label_map[f'B-{tag}'] = index
  13. index += 1
  14. label_map[f'I-{tag}'] = index
  15. index += 1
  16. label_index = {v: k for k, v in label_map.items()}
  17. # model_name = "distilbert-base-multilingual-cased"
  18. # model_name = "hfl/chinese-roberta-wwm-ext"
  19. # tokenizer = AutoTokenizer.from_pretrained(model_name)
  20. # model = AutoModelForTokenClassification.from_pretrained(model_name, num_labels=len(set(label_map)) + 1) # include:“O”
  21. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  22. from torch.utils.data import Dataset
  23. class PatentNERDataset(Dataset):
  24. def __init__(self, texts, labels, tokenizer, max_length=512):
  25. self.texts = texts
  26. self.labels = labels
  27. self.tokenizer = tokenizer
  28. self.max_length = max_length
  29. def __len__(self):
  30. return len(self.texts)
  31. def __getitem__(self, idx):
  32. text = self.texts[idx]
  33. labels = self.labels[idx]
  34. encoding = self.tokenizer(text, truncation=True, padding="max_length", max_length=self.max_length,
  35. return_tensors="pt")
  36. input_ids = encoding["input_ids"].squeeze()
  37. attention_mask = encoding["attention_mask"].squeeze()
  38. # 转换标签
  39. label_ids = [label_map[label] for label in labels] + [0] * (self.max_length - len(labels))
  40. return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": torch.tensor(label_ids)}
  41. def train(model_name):
  42. tokenizer = AutoTokenizer.from_pretrained(model_name)
  43. model = AutoModelForTokenClassification.from_pretrained(model_name,
  44. num_labels=len(set(label_map)) + 1) # include:“O”
  45. # tokenizer = AutoTokenizer.from_pretrained("building_ner_model_bert_wwm")
  46. # model = AutoModelForTokenClassification.from_pretrained("building_ner_model_bert_wwm",
  47. # num_labels=len(set(label_map)) + 1) # include:“O”
  48. # 训练数据
  49. train_texts, train_labels = bio.get_bio()
  50. # 创建数据集
  51. dataset = PatentNERDataset(train_texts, train_labels, tokenizer)
  52. from torch.utils.data import DataLoader
  53. from torch.optim import AdamW
  54. # 训练参数
  55. model.to(device)
  56. optimizer = AdamW(model.parameters(), lr=5e-5)
  57. train_loader = DataLoader(dataset, batch_size=2, shuffle=True)
  58. # 训练循环
  59. for epoch in range(4):
  60. for batch in train_loader:
  61. optimizer.zero_grad()
  62. input_ids, attention_mask, labels = batch["input_ids"].to(device), batch["attention_mask"].to(device), batch["labels"].to(device)
  63. outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
  64. loss = outputs.loss
  65. loss.backward()
  66. optimizer.step()
  67. print(f"Epoch {epoch}, Loss: {loss.item()}")
  68. # 保存训练模型
  69. # model.save_pretrained("building_ner_model")
  70. # tokenizer.save_pretrained("building_ner_model")
  71. # model.save_pretrained("building_ner_model_bert_wwm")
  72. # tokenizer.save_pretrained("building_ner_model_bert_wwm")
  73. model.save_pretrained(f"{model_name}-building")
  74. tokenizer.save_pretrained(f"{model_name}-building")
  75. def test(model_name):
  76. # 加载训练模型
  77. # model = AutoModelForTokenClassification.from_pretrained("building_ner_model")
  78. # tokenizer = AutoTokenizer.from_pretrained("building_ner_model")
  79. # model = AutoModelForTokenClassification.from_pretrained("building_ner_model_bert_wwm")
  80. # tokenizer = AutoTokenizer.from_pretrained("building_ner_model_bert_wwm")
  81. model = AutoModelForTokenClassification.from_pretrained(model_name)
  82. tokenizer = AutoTokenizer.from_pretrained(model_name)
  83. # id2label = model.config.id2label
  84. # 推理函数
  85. def predict_distilbert(text):
  86. inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
  87. outputs = model(**inputs)
  88. predictions = torch.argmax(outputs.logits, dim=-1).squeeze().tolist()
  89. # labels = [model.config.id2label[pred.item()] for pred in predictions[0]]
  90. # 解析识别结果
  91. # tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze().tolist())
  92. tokens = list(text)
  93. results = []
  94. for token, pred in zip(tokens, predictions):
  95. # label = id2label[pred]
  96. label = label_index[pred]
  97. # if label != "LABEL_0": # 过滤掉非实体的 "O" 标签
  98. if label != "O": # 过滤掉非实体的 "O" 标签
  99. results.append((label, token))
  100. # 输出结果
  101. print(text)
  102. print("识别出的实体类型及词语:")
  103. for entity_type, entity in results:
  104. if entity_type.startswith("B-"):
  105. print(f" {entity}", end='')
  106. else:
  107. print(f"{entity}", end='')
  108. print("\n")
  109. return predictions
  110. df = pl.read_csv(str(Path(desc_file_name).expanduser()), columns=columns, encoding="utf-8")
  111. for description in df['摘要']:
  112. description = description.replace(r'\r', ' ').replace(r'\n', ' ').replace(r'\t', '').replace(' ', '')
  113. predict_distilbert(description)
  114. if __name__ == '__main__':
  115. train('hfl/chinese-roberta-wwm-ext-large-building')
  116. # train('hfl/chinese-roberta-wwm-ext-large')
  117. # test('hfl/chinese-roberta-wwm-ext-large-building')
  118. # test()