哪些网站有任务做,上海网站开发平台,网站首页大图的尺寸,网站引导页模板如何使用uer做多分类任务
语料集下载 找到这里点击即可 里面是这有json文件的 因此我们对此要做一些处理#xff0c;将其转为tsv格式
# -*- coding: utf-8 -*-
import json
import csv
import chardet# 检测文件编码
def detect_encoding(file_path):with open(file_path,…如何使用uer做多分类任务
语料集下载 找到这里点击即可 里面是这有json文件的 因此我们对此要做一些处理将其转为tsv格式
# -*- coding: utf-8 -*-
import json
import csv
import chardet# 检测文件编码
def detect_encoding(file_path):with open(file_path, rb) as f:raw_data f.read()return chardet.detect(raw_data)[encoding]# 输入文件名
input_file ./datasets/iflytek/train.json
# 输出文件名
output_file ./datasets/iflytek/train.tsv# 检测输入文件的编码格式
file_encoding detect_encoding(input_file)# 打开输入的 JSON 文件和输出的 TSV 文件
with open(input_file, r, encodingfile_encoding) as json_file, open(output_file, w, newline, encodingutf-8) as tsv_file:# 准备 TSV 写入器tsv_writer csv.writer(tsv_file, delimiter\t)# 写入表头列表[label, label_des, sentence]中要注意根据json文件中的键值做更换tsv_writer.writerow([label, label_des, sentence])# 逐行读取 JSON 文件for line in json_file:try:# 解析每一行的 JSON 数据json_data json.loads(line.strip())# 写入到 TSV 文件中列表[label, label_des, sentence]中要注意根据json文件中的键值做更换tsv_writer.writerow([json_data[label], json_data[label_des], json_data[sentence]])except json.JSONDecodeError as e:print(f无法解析的行: {line.strip()})print(f错误信息: {e})print(fJSON 文件已成功转换为 TSV 文件输入文件编码: {file_encoding})
接着呢要把所有tsv文件的sentence表头名改成text_a不然运行uer框架会报错原因请看源代码逻辑
def read_dataset(args, path):dataset, columns [], {}with open(path, moder, encodingutf-8) as f:for line_id, line in enumerate(f):if line_id 0:for i, column_name in enumerate(line.rstrip(\r\n).split(\t)):columns[column_name] icontinueline line.rstrip(\r\n).split(\t)tgt int(line[columns[label]])if args.soft_targets and logits in columns.keys():soft_tgt [float(value) for value in line[columns[logits]].split( )]if text_b not in columns: # Sentence classification.text_a line[columns[text_a]]src args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] args.tokenizer.tokenize(text_a) [SEP_TOKEN])seg [1] * len(src)else: # Sentence-pair classification.text_a, text_b line[columns[text_a]], line[columns[text_b]]src_a args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] args.tokenizer.tokenize(text_a) [SEP_TOKEN])src_b args.tokenizer.convert_tokens_to_ids(args.tokenizer.tokenize(text_b) [SEP_TOKEN])src src_a src_bseg [1] * len(src_a) [2] * len(src_b)if len(src) args.seq_length:src src[: args.seq_length]seg seg[: args.seq_length]if len(src) args.seq_length:PAD_ID args.tokenizer.convert_tokens_to_ids([PAD_TOKEN])[0]src [PAD_ID] * (args.seq_length - len(src))seg [0] * (args.seq_length - len(seg))if args.soft_targets and logits in columns.keys():dataset.append((src, tgt, seg, soft_tgt))else:dataset.append((src, tgt, seg))return dataset这里规定好了表头名只有labeltext_a,text_b 搞完之后进入训练代码我的显存只有16G因此
python finetune/run_classifier.py --pretrained_model_path models/cluecorpussmall_roberta_wwm_large_seq512_model.bin --vocab_path models/google_zh_vocab.txt --config_path models/bert/large_config.json --train_path datasets/iflytek/train.tsv --dev_path datasets/iflytek/dev.tsv --output_model_path models/iflytek_classifier_model.bin --epochs_num 3 --batch_size 16 --seq_length 128这里可以看到只有61.49的正确率其实是因为显存还不够训练不了那么大的标准的参数应该设置为batch_size32 seq_length256 有能力的可以更改参数进行训练 接着来预测
python inference/run_classifier_infer.py --load_model_path models/iflytek_classifier_model.bin --vocab_path models/google_zh_vocab.txt --config_path models/bert/large_config.json --test_path datasets/iflytek/test.tsv --prediction_path datasets/iflytek/prediction.tsv --seq_length 256 --labels_num 119最后自行查看预测效果