distilbert.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. from pathlib import Path
  2. from transformers import AutoTokenizer, AutoModelForTokenClassification
  3. import torch
  4. import bio
  5. entity_types = ["Material", "Component", "Equipment", "Process", "Organization", "Standard"]
  6. model_name = "distilbert-base-multilingual-cased"
  7. tokenizer = AutoTokenizer.from_pretrained(model_name)
  8. model = AutoModelForTokenClassification.from_pretrained(model_name, num_labels=len(entity_types) + 1)
  9. from torch.utils.data import Dataset
  10. # 训练数据示例
  11. train_texts = ["本发明涉及一种环氧树脂制备方法。"]
  12. train_labels = [["O", "O", "O", "O", "B-MAT", "I-MAT", "B-TECH", "I-TECH", "I-TECH", "O"]]
  13. class PatentNERDataset(Dataset):
  14. def __init__(self, texts, labels, tokenizer, max_length=512):
  15. self.texts = texts
  16. self.labels = labels
  17. self.tokenizer = tokenizer
  18. self.max_length = max_length
  19. def __len__(self):
  20. return len(self.texts)
  21. def __getitem__(self, idx):
  22. text = self.texts[idx]
  23. labels = self.labels[idx]
  24. encoding = self.tokenizer(text, truncation=True, padding="max_length", max_length=self.max_length,
  25. return_tensors="pt")
  26. input_ids = encoding["input_ids"].squeeze()
  27. attention_mask = encoding["attention_mask"].squeeze()
  28. # 转换标签
  29. label_map = {"O": 0, "B-MAT": 1, "I-MAT": 2, "B-TECH": 3, "I-TECH": 4}
  30. label_ids = [label_map[label] for label in labels] + [0] * (self.max_length - len(labels))
  31. return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": torch.tensor(label_ids)}
  32. train_texts, train_labels = bio.get_bio()
  33. # 创建数据集
  34. dataset = PatentNERDataset(train_texts, train_labels, tokenizer)
  35. import torch
  36. from torch.utils.data import DataLoader
  37. from torch.optim import AdamW
  38. # 训练参数
  39. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  40. model.to(device)
  41. optimizer = AdamW(model.parameters(), lr=5e-5)
  42. train_loader = DataLoader(dataset, batch_size=2, shuffle=True)
  43. # 训练循环
  44. for epoch in range(3):
  45. for batch in train_loader:
  46. optimizer.zero_grad()
  47. input_ids, attention_mask, labels = batch["input_ids"].to(device), batch["attention_mask"].to(device), batch["labels"].to(device)
  48. outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
  49. loss = outputs.loss
  50. loss.backward()
  51. optimizer.step()
  52. print(f"Epoch {epoch}, Loss: {loss.item()}")
  53. # 保存训练模型
  54. model.save_pretrained("ner_model")
  55. tokenizer.save_pretrained("ner_model")
  56. # 加载训练模型
  57. model = AutoModelForTokenClassification.from_pretrained("ner_model")
  58. tokenizer = AutoTokenizer.from_pretrained("ner_model")
  59. # 推理函数
  60. def predict_distilbert(text):
  61. inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
  62. outputs = model(**inputs)
  63. predictions = torch.argmax(outputs.logits, dim=-1)
  64. labels = [model.config.id2label[pred.item()] for pred in predictions[0]]
  65. return predictions
  66. import jieba
  67. def aa(text):
  68. jieba.load_userdict("word_dict.txt")
  69. words = list(jieba.cut(text))
  70. inputs = tokenizer(words, is_split_into_words=True, return_tensors="pt")
  71. # 预测
  72. with torch.no_grad():
  73. outputs = model(**inputs)
  74. predictions = outputs.logits.argmax(dim=-1)
  75. # 解码
  76. labels = [model.config.id2label[pred.item()] for pred in predictions[0]]
  77. tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
  78. # 过滤特殊 token
  79. filtered_tokens = [token for token in tokens if token not in ['[CLS]', '[SEP]']]
  80. filtered_labels = [label for label, token in zip(labels, tokens) if token not in ['[CLS]', '[SEP]']]
  81. # 整合实体
  82. entities = []
  83. current_entity = None
  84. for token, label in zip(filtered_tokens, filtered_labels):
  85. if label.startswith('B-'):
  86. if current_entity:
  87. entities.append(current_entity)
  88. current_entity = {'type': label[2:], 'tokens': [token]}
  89. elif label.startswith('I-') and current_entity and current_entity['type'] == label[2:]:
  90. current_entity['tokens'].append(token)
  91. else:
  92. if current_entity:
  93. entities.append(current_entity)
  94. current_entity = None
  95. if current_entity:
  96. entities.append(current_entity)
  97. # 输出结果
  98. for entity in entities:
  99. print(f"Entity: {' '.join(entity['tokens'])}, Type: {entity['type']}")
  100. import polars as pl
  101. from data_preparation import demo_file_name, columns
  102. df = pl.read_csv(str(Path(demo_file_name).expanduser()), columns=columns, encoding="utf-8")
  103. # 测试推理
  104. for desc in df['摘要']:
  105. output = aa(desc)
  106. print("预测标签:", output)