ner.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import torch
  2. from sympy.physics.units import current
  3. import bio
  4. import polars as pl
  5. from transformers import AutoTokenizer, AutoModelForTokenClassification
  6. from torch.utils.data import Dataset
  7. from torch.nn import CrossEntropyLoss # 需要导入损失函数
  8. from data_preparation import columns, sample_file_name, clean_raw_text
  9. from bio import entity_types
  10. from env import env
  11. label_map = {"O": 0}
  12. index = 1
  13. for tag in set(entity_types.values()):
  14. label_map[f'B-{tag}'] = index
  15. index += 1
  16. label_map[f'I-{tag}'] = index
  17. index += 1
  18. label_index = {v: k for k, v in label_map.items()}
  19. # model_name = "distilbert-base-multilingual-cased"
  20. # model_name = "hfl/chinese-roberta-wwm-ext"
  21. # tokenizer = AutoTokenizer.from_pretrained(model_name)
  22. # model = AutoModelForTokenClassification.from_pretrained(model_name, num_labels=len(set(label_map)) + 1) # include:“O”
  23. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  24. class PatentNERDataset(Dataset):
  25. def __init__(self, index, texts, labels, tokenizer, max_length=512):
  26. self.index = index
  27. self.texts = texts
  28. self.labels = labels
  29. self.tokenizer = tokenizer
  30. self.max_length = max_length
  31. def __len__(self):
  32. return len(self.texts)
  33. def __getitem__(self, idx):
  34. text = self.texts[idx]
  35. labels = self.labels[idx]
  36. encoding = self.tokenizer(text, truncation=True, padding="max_length", max_length=self.max_length,
  37. return_tensors="pt")
  38. input_ids = encoding["input_ids"].squeeze()
  39. attention_mask = encoding["attention_mask"].squeeze()
  40. # 转换标签
  41. label_ids = [label_map[label] for label in labels] + [0] * (self.max_length - len(labels))
  42. return {"input_ids": input_ids,
  43. "attention_mask": attention_mask,
  44. "labels": torch.tensor(label_ids),
  45. "text": text,
  46. "index": self.index[idx]
  47. }
  48. def parse_entity(text, predict):
  49. entities = []
  50. bios = []
  51. entity = []
  52. current_categ = None
  53. for word, tensor in zip(text, predict):
  54. label_id = tensor.item()
  55. label = label_index[label_id]
  56. if label == 'O' or (current_categ and not label[2:] == current_categ):
  57. if entity:
  58. entities.append(''.join(entity))
  59. bios.append(current_categ)
  60. entity = []
  61. current_categ = None
  62. continue
  63. current_categ = label[2:]
  64. entity.append(word)
  65. return entities, bios
  66. def train(model_name):
  67. tokenizer = AutoTokenizer.from_pretrained(model_name)
  68. model = AutoModelForTokenClassification.from_pretrained(model_name,
  69. num_labels=len(set(label_map)) + 1) # include:“O”
  70. # tokenizer = AutoTokenizer.from_pretrained("building_ner_model_bert_wwm")
  71. # model = AutoModelForTokenClassification.from_pretrained("building_ner_model_bert_wwm",
  72. # num_labels=len(set(label_map)) + 1) # include:“O”
  73. # 训练数据
  74. train_index, train_texts, train_labels = bio.get_bio()
  75. # 创建数据集
  76. dataset = PatentNERDataset(train_index, train_texts, train_labels, tokenizer)
  77. from torch.utils.data import DataLoader
  78. from torch.optim import AdamW
  79. # 训练参数
  80. model.to(device)
  81. optimizer = AdamW(model.parameters(), lr=5e-5)
  82. train_loader = DataLoader(dataset, batch_size=2, shuffle=True)
  83. # 创建一个损失函数实例,用于单独计算样本损失
  84. # reduction='none' 表示不进行聚合,返回每个元素的损失
  85. # reduction='mean' 在这里用于计算单个样本所有 token 的平均损失
  86. loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='mean') # 'mean' 计算单个样本的平均loss
  87. # 训练循环
  88. for epoch in range(4):
  89. for batch in train_loader:
  90. optimizer.zero_grad()
  91. input_ids, attention_mask, labels = batch["input_ids"].to(device), batch["attention_mask"].to(device), batch["labels"].to(device)
  92. outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
  93. loss = outputs.loss
  94. loss.backward()
  95. optimizer.step()
  96. print(f"Epoch {epoch}, Loss: {loss.item()}")
  97. original_texts = batch["text"]
  98. original_bios = batch["labels"]
  99. # 获取模型的 logits (预测分数)
  100. logits = outputs.logits # Shape: (batch_size, sequence_length, num_labels)
  101. # --- 单独计算每个样本的损失 ---
  102. sample_losses = []
  103. for i in range(logits.size(0)): # 遍历批次中的每个样本
  104. # 提取单个样本的 logits 和 labels
  105. sample_logits = logits[i] # Shape: (sequence_length, num_labels)
  106. sample_labels = labels[i] # Shape: (sequence_length)
  107. # 检查是否存在有效标签,避免除以零或 NaN
  108. if (sample_labels != -100).sum() > 0:
  109. # 使用 loss_fct 计算该样本的平均损失
  110. # CrossEntropyLoss 需要 (N, C) 和 (N) 格式,这里 N=sequence_length, C=num_labels
  111. individual_loss = loss_fct(sample_logits, sample_labels)
  112. sample_losses.append((individual_loss.item(), i))
  113. else:
  114. # 如果样本没有有效标签(可能全是padding或特殊标记),损失设为0
  115. sample_losses.append((0.0, i))
  116. # --- 按损失值降序排序 ---
  117. sample_losses.sort(key=lambda x: x[0], reverse=True)
  118. # --- 打印损失最高的样本信息 ---
  119. for individual_loss_value, index in sample_losses:
  120. # 只打印损失大于某个值的样本,或者打印前 N 个
  121. # if individual_loss_value > 0.1: # 可以加一个过滤条件
  122. print(f" - Sample Index in Batch: {batch['index'][index]}")
  123. print(f" Individual Avg Loss: {individual_loss_value:.4f}")
  124. print(f" Original Text: {original_texts[index]}") # 打印部分文本
  125. # print(f" Original BIO : {' '.join(original_bios[index][:20])}...") # 打印部分BIO
  126. # 可选:进行预测并打印对比,帮助分析错误
  127. with torch.no_grad():
  128. pred_labels_ids = torch.argmax(logits[index], dim=-1)
  129. pred_labels, _ = parse_entity(original_texts[index], pred_labels_ids)
  130. true_labels, _ = parse_entity(original_texts[index], labels[index])
  131. print(f" Predicted BIO: {' '.join(pred_labels)}")
  132. print(f" True BIO : {' '.join(true_labels)}")
  133. # 保存训练模型
  134. # model.save_pretrained("building_ner_model")
  135. # tokenizer.save_pretrained("building_ner_model")
  136. # model.save_pretrained("building_ner_model_bert_wwm")
  137. # tokenizer.save_pretrained("building_ner_model_bert_wwm")
  138. model.save_pretrained(env.resolve_output(f"{model_name}-building"))
  139. tokenizer.save_pretrained(f"{model_name}-building")
  140. def predict_test(model_name):
  141. # 加载训练模型
  142. # model = AutoModelForTokenClassification.from_pretrained("building_ner_model")
  143. # tokenizer = AutoTokenizer.from_pretrained("building_ner_model")
  144. # model = AutoModelForTokenClassification.from_pretrained("building_ner_model_bert_wwm")
  145. # tokenizer = AutoTokenizer.from_pretrained("building_ner_model_bert_wwm")
  146. model = AutoModelForTokenClassification.from_pretrained(model_name)
  147. tokenizer = AutoTokenizer.from_pretrained(model_name)
  148. # id2label = model.config.id2label
  149. # 推理函数
  150. def predict_distilbert(text):
  151. inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
  152. outputs = model(**inputs)
  153. predictions = torch.argmax(outputs.logits, dim=-1).squeeze().tolist()
  154. # labels = [model.config.id2label[pred.item()] for pred in predictions[0]]
  155. # 解析识别结果
  156. # tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze().tolist())
  157. tokens = list(text)
  158. results = []
  159. for token, pred in zip(tokens, predictions):
  160. # label = id2label[pred]
  161. label = label_index[pred]
  162. # if label != "LABEL_0": # 过滤掉非实体的 "O" 标签
  163. if label != "O": # 过滤掉非实体的 "O" 标签
  164. results.append((label, token))
  165. # 输出结果
  166. print(text)
  167. print("识别出的实体类型及词语:")
  168. for entity_type, entity in results:
  169. if entity_type.startswith("B-"):
  170. print(f" {entity}", end='')
  171. else:
  172. print(f"{entity}", end='')
  173. print("\n")
  174. return predictions
  175. df = pl.read_csv(str(env.resolve_data(sample_file_name)), columns=columns, encoding="utf-8")
  176. for description in df['摘要']:
  177. description = clean_raw_text(description)
  178. predict_distilbert(description)
  179. if __name__ == '__main__':
  180. # train('hfl/chinese-roberta-wwm-ext-large-building')
  181. # train('hfl/chinese-roberta-wwm-ext-large')
  182. predict_test('hfl/chinese-roberta-wwm-ext-large-building-building')
  183. # test()