南宁建设信息网站,龙华区网站建设,北京电力建设公司官网,微信支付 公司网站前言
随着多模态大模型的发展#xff0c;其不仅限于文字处理#xff0c;更能够在图像、视频、音频方面进行识别与理解。医疗领域中#xff0c;医生们往往需要对各种医学图像进行处理#xff0c;以辅助诊断和治疗。如果将多模态大模型与图像诊断相结合#xff0c;那么这会…前言
随着多模态大模型的发展其不仅限于文字处理更能够在图像、视频、音频方面进行识别与理解。医疗领域中医生们往往需要对各种医学图像进行处理以辅助诊断和治疗。如果将多模态大模型与图像诊断相结合那么这会极大地提升诊断效率。
项目目标
训练一个医疗多模态大模型用于图像诊断。 刚好家里老爷子近期略感头疼去医院做了脑部CT诊断患有垂体瘤我将尝试使用多模态大模型进行进一步诊断。 实现过程
1. 数据集准备
为了训练模型需要准备大量的医学图像数据。通过搜索我们找到以下训练数据
数据名称MedTrinity-25M 数据地址https://github.com/UCSC-VLAA/MedTrinity-25M 数据简介MedTrinity-25M数据集是一个用于医学图像分析和计算机视觉研究的大型数据集。 数据来源该数据集由加州大学圣克鲁兹分校UCSC提供旨在促进医学图像处理和分析的研究。 数据量MedTrinity-25M包含约2500万条医学图像数据涵盖多种医学成像技术如CT、MRI和超声等。 数据内容 该数据集有两份分别是 25Mdemo 和 25Mfull 。
25Mdemo (约162,000条)数据集内容如下
25Mfull (约24,800,000条)数据集内容如下
2. 数据下载
2.1 安装Hugging Face的Datasets库
pip install datasets2.2 下载数据集
from datasets import load_dataset# 加载数据集
ds load_dataset(UCSC-VLAA/MedTrinity-25M, 25M_demo, cache_dircache)执行结果
说明
以上方法是使用HuggingFace的Datasets库下载数据集下载的路径为当前脚本所在路径下的cache文件夹。使用HuggingFace下载需要能够访问https://huggingface.co/ 并且在网站上申请数据集读取权限才可以。如果没有权限访问HuggingFace可以关注一起AI技术公众号后回复 “MedTrinity”获取百度网盘下载地址。
2.3 预览数据集
# 查看训练集的前1个样本
print(ds[train][:1]) 运行结果
{image: [PIL.JpegImagePlugin.JpegImageFile image modeRGB size512x512 at 0x15DD6D06530], id: [8031efe0-1b5c-11ef-8929-000066532cad], caption: [The image is a non-contrasted computed tomography (CT) scan of the brain, showing the cerebral structures without any medical devices present. The region of interest, located centrally and in the middle of the image, exhibits an area of altered density, which is indicative of a brain hemorrhage. This area is distinct from the surrounding brain tissue, suggesting a possible hematoma or bleeding within the brain parenchyma. The location and characteristics of this abnormality may suggest a relationship with the surrounding brain tissue, potentially causing a mass effect or contributing to increased intracranial pressure.]
}使用如下命令对数据集的图片进行可视化查看
# 可视化image内容
from PIL import Image
import matplotlib.pyplot as pltimage ds[train][0][image] # 获取第一张图像plt.imshow(image)
plt.axis(off) # 不显示坐标轴
plt.show()运行结果
3. 数据预处理
由于后续我们要通过LLama Factory进行多模态大模型微调所以我们需要对上述的数据集进行预处理以符合LLama Factory的要求。
3.1 LLama Factory数据格式
查看LLama Factory的多模态数据格式要求如下
[{messages: [{content: image他们是谁,role: user},{content: 他们是拜仁慕尼黑的凯恩和格雷茨卡。,role: assistant},{content: 他们在做什么,role: user},{content: 他们在足球场上庆祝。,role: assistant}],images: [mllm_demo_data/1.jpg]}
]3.2 实现数据格式转换脚本
from datasets import load_dataset
import os
import json
from PIL import Imagedef save_images_and_json(ds, output_dirmllm_data):将数据集中的图像和对应的 JSON 信息保存到指定目录。参数:ds: 数据集对象包含图像和标题。output_dir: 输出目录默认为 mllm_data。# 创建输出目录if not os.path.exists(output_dir):os.makedirs(output_dir)# 创建一个列表来存储所有的消息和图像信息all_data []# 遍历数据集中的每个项目for item in ds:img_path f{output_dir}/{item[id]}.jpg # 图像保存路径image item[image] # 假设这里是一个 PIL 图像对象# 将图像对象保存为文件image.save(img_path) # 使用 PIL 的 save 方法# 添加消息和图像信息到列表中all_data.append({messages: [{content: image图片中的诊断结果是怎样?,role: user,},{content: item[caption], # 从数据集中获取的标题role: assistant,},],images: [img_path], # 图像文件路径})# 创建 JSON 文件json_file_path f{output_dir}/mllm_data.jsonwith open(json_file_path, w, encodingutf-8) as f:json.dump(all_data, f, ensure_asciiFalse) # 确保中文字符正常显示if __name__ __main__:# 加载数据集ds load_dataset(UCSC-VLAA/MedTrinity-25M, 25M_demo, cache_dircache)# 保存数据集中的图像和 JSON 信息save_images_and_json(ds[train])运行结果
4. 模型下载
本次微调我们使用阿里最新发布的多模态大模型Qwen2-VL-2B-Instruct 作为底座模型。 模型说明地址https://modelscope.cn/models/Qwen/Qwen2-VL-2B-Instruct
使用如下命令下载模型
git lfs install
# 下载模型
git clone https://www.modelscope.cn/Qwen/Qwen2-VL-2B-Instruct.git5. 环境准备
5.1 机器环境
硬件
显卡4080 Super显存16GB
软件
系统Ubuntu 20.04 LTSpython3.10pytorch2.1.2 cuda12.1
5.2 准备虚拟环境
# 创建python3.10版本虚拟环境
conda create --name train_env python3.10# 激活环境
conda activate train_env# 安装依赖包
pip install streamlit torch torchvision# 安装Qwen2建议的transformers版本
pip install githttps://github.com/huggingface/transformers6. 准备训练框架
下载并安装LLamaFactory框架的具体步骤请见【课程总结】day24上大模型三阶段训练方法(LLaMa Factory)中 准备训练框架 部分内容本章不再赘述。
6.1 修改LLaMaFactory源码以适配transformer
由于Qwen2-VL使用的transformer的版本为4.47.0.dev0,LLamaFactory还不支持所以需要修改LLaMaFactory的代码具体方法如下
第一步在 llamafactory 源码中找到 check_dependencies() 函数这个函数位于 src/llamafactory/extras/misc.py 文件的第 82 行。
第二步修改 check_dependencies() 函数并保存
# 原始代码
require_version(transformers4.41.2,4.45.2, To fix: pip install transformers4.41.2,4.45.2)
# 修改后代码
require_version(transformers4.41.2,4.47.0, To fix: pip install transformers4.41.2,4.47.0)第三步重新启动LLaMaFactory服务
llamafactory-cli webui这个过程可能会提示 ImportError: accelerate0.34.0 is required for a normal functioning of this module, but found accelerate0.32.0. 如遇到上述问题可以重新安装accelerate如下 # 卸载旧的 accelerate
pip uninstall accelerate# 安装新的 accelerate
pip install accelerate0.34.07. 测试当前模型
第一步启动LLaMa Factory后访问http://0.0.0.0:7860
第二步在web页面配置模型路径为 4.步骤 下载的模型路径并点击加载模型
第三步上传一张CT图片并输入问题:“请使用中文描述下这个图像并给出你的诊断结果”
由上图可以看到模型能够识别到这是一个CT图像显示了大概的位置以及相应的器官但是并不能给出是否存在诊断结果。
8. 模型训练
8.1 数据准备
第一步将 3.2步骤 生成的mllm_data文件拷贝到LLaMaFactory的data目录下
第二步将 4.步骤 下载的底座模型Qwen2-VL 拷贝到LLaMaFactory的model目录下
第三步修改 LLaMaFactory data目录下的dataset_info.json增加自定义数据集 mllm_med: {file_name: mllm_data/mllm_data.json,formatting: sharegpt,columns: {messages: messages,images: images},tags: {role_tag: role,content_tag: content,user_tag: user,assistant_tag: assistant}},8.2 配置训练参数
访问LLaMaFactory的web页面配置微调的训练参数
Model name: Qwen2-VL-2B-InstructModel path: models/Qwen2-VL-2B-InstructFinetuning method: loraStage : Supervised Fine-TuningDataset: mllm_medOutput dir: saves/Qwen2-VL/lora/Qwen2-VL-sft-demo1 配置参数中最好将 save_steps 设置大一点否则训练过程会生成非常多的训练日志导致硬盘空间不足而训练终止。 点击Preview Command预览命令行无误后点击Run按钮开始训练。 训练参数
llamafactory-cli train \--do_train True \--model_name_or_path models/Qwen2-VL-2B-Instruct \--preprocessing_num_workers 16 \--finetuning_type lora \--template qwen2_vl \--flash_attn auto \--dataset_dir data \--dataset mllm_med \--cutoff_len 1024 \--learning_rate 5e-05 \--num_train_epochs 3.0 \--max_samples 100000 \--per_device_train_batch_size 2 \--gradient_accumulation_steps 8 \--lr_scheduler_type cosine \--max_grad_norm 1.0 \--logging_steps 5 \--save_steps 3000 \--warmup_steps 0 \--optim adamw_torch \--packing False \--report_to none \--output_dir saves/Qwen2-VL-2B/full/Qwen2-VL-sft-demo1 \--bf16 True \--plot_loss True \--ddp_timeout 180000000 \--include_num_input_tokens_seen True \--lora_rank 8 \--lora_alpha 16 \--lora_dropout 0 \--lora_target all训练过程 训练的过程中可以通过 watch -n 1 nvidia-smi 实时查看GPU显存的消耗情况。 经过35小时的训练模型训练完成损失函数如下 损失函数一般降低至1.2左右太低会导致模型过拟合。 8.3 合并导出模型
接下来我们将 Lora补丁 与 原始模型 合并导出
切换到 Expert 标签下Model path: 选择Qwen2-VL的基座模型即models/Qwen2-VL-2B-InstructCheckpoint path: 选择lora微调的输出路径即 saves/Qwen2-VL/lora/Qwen2-VL-sft-demo1Export path设置一个新的路径例如Qwen2-VL-sft-final点击 开始导出 按钮 导出完毕后会在LLaMaFactory的根目录下生成一个 Qwen2-VL-sft-final 的文件夹。
9. 模型验证
9.1 模型效果对比
第一步在LLaMa Factory中卸载之前的模型
第二步在LLaMa Factory中加载导出的模型并配置模型路径为 Qwen2-VL-sft-final
第三步加载模型并上传之前的CT图片提问同样的问题 可以看到经过微调后的模型可以给出具体区域存在的可能异常问题。
9.2 实际诊断
接下来我将使用微调后的模型为家里老爷子的CT片做诊断看看模型给出的诊断与大夫的异同点。 我总计测试了CT片上的52张局部结果其中具有代表性的为上述三张可以看到模型还是比较准确地诊断出脑部有垂体瘤可能会影响到眼部。这与大夫给出的诊断和后续检查方案一致。
不足之处
训练集
多模态本次训练只是采用了MedTrinity-25Mdemo数据集如果使用MedTrinity-25Mfull数据集效果应该会更好。中英文本次训练集中使用的MedTrinity-25Mdemo数据集只包含了英文数据如果将英文标注翻译为中文提供中英文双文数据集相信效果会更好。对话数据集本次训练只是使用了多模态数据集如果增加中文对话(如中文医疗对话数据-Chinese-medical-dialogue)相信效果会更好。
前端页面
前端页面本次实践曾使用streamlit构建前端页面以便图片上传和问题提出但是在加载微调后的模型时会出现ValueError: No chat template is set for this processor 问题所以转而使用LLaMaFactory的web页面进行展示。多个图片推理在Qwen2-VL的官方指导文档中提供了 Multi image inference 方法本次未进行尝试相信将多个图片交给大模型进行推理效果会更好。
内容小结
Qwen2-VL-2B作为多模态大模型具备有非常强的多模态处理能力除了能够识别图片内容还可以进行相关的推理。我们可以通过 LLaMaFactory 对模型进行微调使得其具备医疗方面的处理能力。微调数据集采用开源的MedTrinity-25M数据集该数据集有两个版本25Mdemo和25Mfull。训练前需要对数据集进行预处理使得其适配LLaMaFactory的微调格式。经过微调后的多模态大模型不但可以详细地描述图片中的内容还可以给出可能的诊断结果。