from pathlib import Path from transformers import AutoTokenizer, AutoModelForTokenClassification import torch import bio entity_types = ["Material", "Component", "Equipment", "Process", "Organization", "Standard"] model_name = "distilbert-base-multilingual-cased" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForTokenClassification.from_pretrained(model_name, num_labels=len(entity_types) + 1) from torch.utils.data import Dataset # 训练数据示例 train_texts = ["本发明涉及一种环氧树脂制备方法。"] train_labels = [["O", "O", "O", "O", "B-MAT", "I-MAT", "B-TECH", "I-TECH", "I-TECH", "O"]] class PatentNERDataset(Dataset): def __init__(self, texts, labels, tokenizer, max_length=512): self.texts = texts self.labels = labels self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.texts) def __getitem__(self, idx): text = self.texts[idx] labels = self.labels[idx] encoding = self.tokenizer(text, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt") input_ids = encoding["input_ids"].squeeze() attention_mask = encoding["attention_mask"].squeeze() # 转换标签 label_map = {"O": 0, "B-MAT": 1, "I-MAT": 2, "B-TECH": 3, "I-TECH": 4} label_ids = [label_map[label] for label in labels] + [0] * (self.max_length - len(labels)) return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": torch.tensor(label_ids)} train_texts, train_labels = bio.get_bio() # 创建数据集 dataset = PatentNERDataset(train_texts, train_labels, tokenizer) import torch from torch.utils.data import DataLoader from torch.optim import AdamW # 训练参数 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) optimizer = AdamW(model.parameters(), lr=5e-5) train_loader = DataLoader(dataset, batch_size=2, shuffle=True) # 训练循环 for epoch in range(3): for batch in train_loader: optimizer.zero_grad() input_ids, attention_mask, labels = batch["input_ids"].to(device), batch["attention_mask"].to(device), batch["labels"].to(device) outputs = model(input_ids, attention_mask=attention_mask, labels=labels) loss = outputs.loss loss.backward() optimizer.step() print(f"Epoch {epoch}, Loss: {loss.item()}") # 保存训练模型 model.save_pretrained("ner_model") tokenizer.save_pretrained("ner_model") # 加载训练模型 model = AutoModelForTokenClassification.from_pretrained("ner_model") tokenizer = AutoTokenizer.from_pretrained("ner_model") # 推理函数 def predict_distilbert(text): inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) outputs = model(**inputs) predictions = torch.argmax(outputs.logits, dim=-1) labels = [model.config.id2label[pred.item()] for pred in predictions[0]] return predictions import jieba def aa(text): jieba.load_userdict("word_dict.txt") words = list(jieba.cut(text)) inputs = tokenizer(words, is_split_into_words=True, return_tensors="pt") # 预测 with torch.no_grad(): outputs = model(**inputs) predictions = outputs.logits.argmax(dim=-1) # 解码 labels = [model.config.id2label[pred.item()] for pred in predictions[0]] tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) # 过滤特殊 token filtered_tokens = [token for token in tokens if token not in ['[CLS]', '[SEP]']] filtered_labels = [label for label, token in zip(labels, tokens) if token not in ['[CLS]', '[SEP]']] # 整合实体 entities = [] current_entity = None for token, label in zip(filtered_tokens, filtered_labels): if label.startswith('B-'): if current_entity: entities.append(current_entity) current_entity = {'type': label[2:], 'tokens': [token]} elif label.startswith('I-') and current_entity and current_entity['type'] == label[2:]: current_entity['tokens'].append(token) else: if current_entity: entities.append(current_entity) current_entity = None if current_entity: entities.append(current_entity) # 输出结果 for entity in entities: print(f"Entity: {' '.join(entity['tokens'])}, Type: {entity['type']}") import polars as pl from data_preparation import demo_file_name, columns df = pl.read_csv(str(Path(demo_file_name).expanduser()), columns=columns, encoding="utf-8") # 测试推理 for desc in df['摘要']: output = aa(desc) print("预测标签:", output)