from pathlib import Path from transformers import AutoTokenizer, AutoModelForTokenClassification import torch import bio import polars as pl from data_preparation import demo_file_name, columns, desc_file_name # entity_types = ["Material", "Component", "Equipment", "Process", "Organization", "Standard"] from bio import entity_types label_map = {"O": 0} index = 1 for tag in set(entity_types.values()): label_map[f'B-{tag}'] = index index += 1 label_map[f'I-{tag}'] = index index += 1 label_index = {v: k for k, v in label_map.items()} # model_name = "distilbert-base-multilingual-cased" # model_name = "hfl/chinese-roberta-wwm-ext" # tokenizer = AutoTokenizer.from_pretrained(model_name) # model = AutoModelForTokenClassification.from_pretrained(model_name, num_labels=len(set(label_map)) + 1) # include:“O” device = torch.device("cuda" if torch.cuda.is_available() else "cpu") from torch.utils.data import Dataset class PatentNERDataset(Dataset): def __init__(self, texts, labels, tokenizer, max_length=512): self.texts = texts self.labels = labels self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.texts) def __getitem__(self, idx): text = self.texts[idx] labels = self.labels[idx] encoding = self.tokenizer(text, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt") input_ids = encoding["input_ids"].squeeze() attention_mask = encoding["attention_mask"].squeeze() # 转换标签 label_ids = [label_map[label] for label in labels] + [0] * (self.max_length - len(labels)) return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": torch.tensor(label_ids)} def train(model_name): tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForTokenClassification.from_pretrained(model_name, num_labels=len(set(label_map)) + 1) # include:“O” # tokenizer = AutoTokenizer.from_pretrained("building_ner_model_bert_wwm") # model = AutoModelForTokenClassification.from_pretrained("building_ner_model_bert_wwm", # num_labels=len(set(label_map)) + 1) # include:“O” # 训练数据 train_texts, train_labels = bio.get_bio() # 创建数据集 dataset = PatentNERDataset(train_texts, train_labels, tokenizer) from torch.utils.data import DataLoader from torch.optim import AdamW # 训练参数 model.to(device) optimizer = AdamW(model.parameters(), lr=5e-5) train_loader = DataLoader(dataset, batch_size=2, shuffle=True) # 训练循环 for epoch in range(4): for batch in train_loader: optimizer.zero_grad() input_ids, attention_mask, labels = batch["input_ids"].to(device), batch["attention_mask"].to(device), batch["labels"].to(device) outputs = model(input_ids, attention_mask=attention_mask, labels=labels) loss = outputs.loss loss.backward() optimizer.step() print(f"Epoch {epoch}, Loss: {loss.item()}") # 保存训练模型 # model.save_pretrained("building_ner_model") # tokenizer.save_pretrained("building_ner_model") # model.save_pretrained("building_ner_model_bert_wwm") # tokenizer.save_pretrained("building_ner_model_bert_wwm") model.save_pretrained(f"{model_name}-building") tokenizer.save_pretrained(f"{model_name}-building") def test(model_name): # 加载训练模型 # model = AutoModelForTokenClassification.from_pretrained("building_ner_model") # tokenizer = AutoTokenizer.from_pretrained("building_ner_model") # model = AutoModelForTokenClassification.from_pretrained("building_ner_model_bert_wwm") # tokenizer = AutoTokenizer.from_pretrained("building_ner_model_bert_wwm") model = AutoModelForTokenClassification.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) # id2label = model.config.id2label # 推理函数 def predict_distilbert(text): inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device) outputs = model(**inputs) predictions = torch.argmax(outputs.logits, dim=-1).squeeze().tolist() # labels = [model.config.id2label[pred.item()] for pred in predictions[0]] # 解析识别结果 # tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze().tolist()) tokens = list(text) results = [] for token, pred in zip(tokens, predictions): # label = id2label[pred] label = label_index[pred] # if label != "LABEL_0": # 过滤掉非实体的 "O" 标签 if label != "O": # 过滤掉非实体的 "O" 标签 results.append((label, token)) # 输出结果 print(text) print("识别出的实体类型及词语:") for entity_type, entity in results: if entity_type.startswith("B-"): print(f" {entity}", end='') else: print(f"{entity}", end='') print("\n") return predictions df = pl.read_csv(str(Path(desc_file_name).expanduser()), columns=columns, encoding="utf-8") for description in df['摘要']: description = description.replace(r'\r', ' ').replace(r'\n', ' ').replace(r'\t', '').replace(' ', '') predict_distilbert(description) if __name__ == '__main__': train('hfl/chinese-roberta-wwm-ext-large-building') # train('hfl/chinese-roberta-wwm-ext-large') # test('hfl/chinese-roberta-wwm-ext-large-building') # test()