基于sd1.5的交通事故图lora模型
目录
软硬件依赖
硬件环境
- 配备RTX 3090 GPU,显存24GB
软件环境
- PyTorch 2.0.1, TensorBoard
- CUDA版本11.8
训练参数调节
底模选择
- 选择通用性强的底模stable diffusion v1.5。
参数设置
- train batch size=1
- epoch=30,repeat=10,steps=300
- text encoder learning rate=1.00E-4,Unet learning rate=1.00E-5
- network rank(dim)=32, network Alpha=32
- Lr scheduler:linear
- Lr optimizer:AdamW8bit
- saved precision:fp16
数据集
数据来源
- 交通行业发展报告、事故统计等文档资料。
- 来自世界卫生组织(WHO)、中华人民共和国公安部、中国国家统计局、美国国家公路交通安全管理局(NHSTA)等国际与国内政府部门、研究机构、高校及企业的相关数据
数据整合与对齐
- 分类汇总不同类型路段和碰撞起因。
- 利用Visio绘制事故路段碰撞示意图,构建图像-文本样本训练集。
模型训练
- 使用模型训练器对模型参数进行调节
- 训练过程使用tensorboard实时记录模型loss函数值
- 生成模型后使用检验训练集对模型泛华和拟合指标进行检验和对比
- 参考对比结果和loss函数下降值对模型参数进行调整
模型评估
在原有训练集基础上对输入prompt进行调整
- 路段类型(intersection/T-intersection/straight road)
- 事故车辆类型、数目
- 事故原因
- 车辆行走轨迹
训练模型生图结果对比
当参数为上述值时,可减少大量色块、多余元素出现的过拟合情况以及生成特征过于简单、特征和元素难以对应的不拟合情况
评论