MLX 的CoT 训练 LORA SFT 微调

GIT上有训练例子:https://github.com/jbarnes850/deepseek-r1-finetune

训练文件:https://hf-mirror.com/datasets/FreedomIntelligence/medical-o1-reasoning-SFT

原始文件格式:
{ "Question": "患者的具体医疗问题描述", "Complex_CoT": "详细的逐步推理过程", "Response": "最终答案" }

处理后的格式:

Please reason step by step:

Question: {样本的Question字段}

Let's solve this step by step:
{样本的Complex_CoT字段}

Final Answer: {样本的Response字段}

下面例子:

{
  "Question": "A 45-year-old patient presents with sudden onset chest pain, shortness of breath, and anxiety. The pain is described as sharp and worsens with deep breathing. What is the most likely diagnosis and what immediate tests should be ordered?",
  "Complex_CoT": "The patient's symptoms suggest possible acute coronary syndrome, pulmonary embolism, or pneumothorax. Given the sharp chest pain worsened by deep breathing, pulmonary embolism is a strong consideration. Immediate tests should include ECG, troponin, D-dimer, and chest X-ray.",
  "Response": "The most likely diagnosis is pulmonary embolism. Immediate tests should include ECG, troponin, D-dimer, and chest X-ray."
}

#处理后
Please reason step by step:

Question: A 45-year-old patient presents with sudden onset chest pain, shortness of breath, and anxiety. The pain is described as sharp and worsens with deep breathing. What is the most likely diagnosis and what immediate tests should be ordered?

Let's solve this step by step:
The patient's symptoms suggest possible acute coronary syndrome, pulmonary embolism, or pneumothorax. Given the sharp chest pain worsened by deep breathing, pulmonary embolism is a strong consideration. Immediate tests should include ECG, troponin, D-dimer, and chest X-ray.

Final Answer: The most likely diagnosis is pulmonary embolism. Immediate tests should include ECG, troponin, D-dimer, and chest X-ray.

处理后的数据是一个 Hugging Face Dataset 对象,其内部结构如下

如果要导出 则是TEXT的LORA 的JSONL

例如

{
"text": "Please reason step by step:\n\nQuestion: {Question}\n\nLet's solve this step by step:\n{Complex_CoT}\n\nFinal Answer: {Response}"
}

一行一行的TEXT文本

相关信息 https://el.psy.congroo.com/wp-admin/post.php?post=983 MLX数据格式

关于将上面的SFT信息转为JSONL的代码 ,未测试。

def prepare_dataset(tokenizer):
    """Prepare the medical reasoning dataset and export to JSONL"""
    # Load raw dataset
    dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", "en")
    
    # Split dataset (5% for training, 1% for testing)
    dataset = dataset["train"].train_test_split(
        train_size=0.05, 
        test_size=0.01, 
        seed=42
    )

    # Define formatting function
    def format_instruction(sample):
        return f"""Please reason step by step:

Question: {sample['Question']}

Let's solve this step by step:
{sample['Complex_CoT']}

Final Answer: {sample['Response']}"""

    # Create formatted text datasets
    text_train = dataset["train"].map(
        lambda x: {"text": format_instruction(x)},
        remove_columns=dataset["train"].column_names,
        num_proc=os.cpu_count()
    )
    
    text_test = dataset["test"].map(
        lambda x: {"text": format_instruction(x)},
        remove_columns=dataset["test"].column_names,
        num_proc=os.cpu_count()
    )

    # Export to JSONL (关键新增代码)
    text_train.to_json(
        "medical_train.jsonl",
        orient="records",
        lines=True,
        force_ascii=False  # 保留非ASCII字符(如中文)
    )
    
    text_test.to_json(
        "medical_test.jsonl",
        orient="records",
        lines=True,
        force_ascii=False
    )

    # Tokenization (保留原有流程)
    train_dataset = text_train.map(
        lambda x: tokenizer(
            x["text"],
            truncation=True,
            padding="max_length",
            max_length=1024,
            return_tensors=None,
        ),
        remove_columns=["text"],
        num_proc=os.cpu_count()
    )

    print("\nJSONL 文件已生成:")
    print(f"- medical_train.jsonl ({len(text_train)} 个样本)")
    print(f"- medical_test.jsonl ({len(text_test)} 个样本)")
    
    return train_dataset

MLX CLI训练命令 使用SFT 加入监督函数

mlx-cli train \
    --stage sft \                  # 指定微调阶段为SFT(监督微调)
    --do_train \                   # 表示进行训练
    --model_name_or_path /path/to/pretrained/model \  # 预训练模型的路径
    --dataset your_dataset_name \  # SFT数据集的名称或路径
    --finetuning_type lora \       # 使用LoRA微调方法
    --output_dir ./output \        # 输出目录
    --learning_rate 5e-5 \         # 学习率
    --num_train_epochs 3 \         # 训练轮数
    --per_device_train_batch_size 8 \  # 每个设备的训练批次大小
    --loss_function cross_entropy  # 使用交叉熵损失函数

在这个命令中,--loss_function 参数用于指定监督函数,确保训练过程是有监督的

之前的MLX的LORA快速微调

直接使用SFT数据也可以实现LORA微调 但是没有监督函数。

mlx_lm.lora --model ../../qwen2.5-0.5B --train --data ./data

发表回复