【AI】LLaVA Multiple Images SFT

 

LLaVA Mistral Multiple Images SFT

LLaVA是2023年4月提出的针对多模态场景的,可多轮图文问答ChatBot模型。LLaVA通过简单地把1024维输出的CLIP特征用projector和语言模型的embedding拼接起来,就能实现该效果。

image-20240331212648086

但是,在原文章中,作者是针对单图问答场景进行的训练,如果想实现一个多图输入场景的任务,应该如何改造结构以及构造训练数据呢?下面我们一起来看一下。

代码结构

启动命令

bash llava/scripts/v1_5/finetune.sh

训练入口

llava/train/train.py

训练框架

  1. 训练框架使用了Huggingface下的Trainer,Trainer是专门为了Transformer架构优化的训练器。进去之后可以看到作者了使用deepspeed训练框架,这里不再赘述。

Fine-tune的整体流程

llava_train.drawio

关键代码

多轮对话预处理

图像标记将在分词后的提示文本块之间插入。以下是该功能工作原理的分解:

  1. 提示通过 <image> 标记分割,创建一个块的列表。
  2. 使用提供的分词器对每个块进行分词,得到一个令牌 ID 的列表。
  3. 使用 insert_separator 函数将分词后的块列表与 image_token_index(代表图像的令牌)交错插入。
  4. input_ids 列表如下构建:
    • 如果第一个分词后的块以序列开始(BOS)令牌开头,则 input_ids 的第一个元素设置为该 BOS 令牌。
    • 通过迭代交错的分词块列表和 image_token_index 填充 input_ids 的剩余元素。

因此,结果的 input_ids 列表将具有以下结构:

[BOS_token(如果存在),tokens_from_chunk1, image_token, tokens_from_chunk2, image_token, ..., tokens_from_last_chunk]

image_token_index 将插入原始提示中每对连续块之间。

例如,如果提示是 "This is <image> a sample <image> prompt",且 image_token_index 是 1234,结果的 input_ids 列表可能看起来像:

[101, 1010, 2003, 1015, 1234, 2034, 3076, 1234, 2001, 1028, 102]

这里,令牌 ID 代表分词的单词,而值 1234 是插入块之间的 image_token_index

大致改动

要适应多图训练,首先要判断自己的任务是要图片和文字interleaved的形式还是separate的形式。

  1. 数据预处理:确保Input conversation中的image_token被正确替换了;
  2. Model Forward:确保训练input_embedding是否按照期望顺序被cat在一起了。

注意,因为LLaVA本身SFT时候,是把所有image的embedding都放到了最前面(通过对话预处理实现的),因此如果你训练改成interleaved的形式,可能导致其本身SFT Align的分布变化。

预训练数据组织

原SFT训练数据格式,为了展示用,复制了两条数据

[
    {
        "conversations": [
            {"from": "human", "value": "Please tell me what's unusual about this image: <image>"},
            {"from": "gpt", "value": "A man is ironing his clothes on a vehicle. "},
            {"from": "human", "value": "What's funny about this?"},
            {"from": "gpt", "value": "Because people don't usually do this at home."}
        ],
        "image": "llava/image_folder(local image path)"
    },
    {
        "conversations": [
            {"from": "human", "value": "Please tell me what's unusual about this image: <image>"},
            {"from": "gpt", "value": "A man is ironing his clothes on a vehicle. "},
            {"from": "human", "value": "What's funny about this?"},
            {"from": "gpt", "value": "Because people don't usually do this at home."}
        ],
        "image": "llava/image_folder(local image path)"
    }
]

改动后SFT训练数据格式:

[
    {
        "conversations": [
            {"from": "human", "value": "Please tell me what's unusual about this image: <image>"},
            {"from": "gpt", "value": "A man is ironing his clothes on a vehicle. "},
            {"from": "human", "value": "What's funny about this?"},
            {"from": "gpt", "value": "Because people don't usually do this at home."}
        ],
        "images": ["llava/image_folder(local image path)", "llava/image_folder(local image path)"]
    },
    {
        "conversations": [
            {"from": "human", "value": "Please tell me what's unusual about this image: <image>"},
            {"from": "gpt", "value": "A man is ironing his clothes on a vehicle. "},
            {"from": "human", "value": "What's funny about this?"},
            {"from": "gpt", "value": "Because people don't usually do this at home."}
        ],
        "images": ["llava/image_folder(local image path)", "llava/image_folder(local image path)"]
    }
]

代码改动

  1. (optional)修改image token的位置,我们把stack在前面的1个换成多个
# llava/train/train.py
def preprocess_multimodal(
    sources: Sequence[str],
    data_args: DataArguments
) -> Dict:
    is_multimodal = data_args.is_multimodal
    if not is_multimodal:
        return sources

    for source in sources:
        for sentence in source:
            if DEFAULT_IMAGE_TOKEN in sentence['value']:
                replace_token = DEFAULT_IMAGE_TOKEN + '\n'
                sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, replace_token).strip()
                # sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
                sentence['value'] = sentence['value'].strip()
                if "mmtag" in conversation_lib.default_conversation.version:
                    sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')
            replace_token = DEFAULT_IMAGE_TOKEN
            if data_args.mm_use_im_start_end:
                replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
            sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)

    return sources
  1. 修改多图Input
# llava/train/train.py LazySupervisedDataset
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        sources = self.list_data_dict[i]
        if isinstance(i, int):
            sources = [sources]
        assert len(sources) == 1, "Don't know why it is wrapped to a list"  # FIXME
        if 'image' in sources[0]:
            image_files = self.list_data_dict[i]['images']
            image_folder = self.data_args.image_folder
            processor = self.data_args.image_processor
            images = []
            for image in image_files:
                image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
                if self.data_args.image_aspect_ratio == 'pad':
                    def expand2square(pil_img, background_color):
                        width, height = pil_img.size
                        if width == height:
                            return pil_img
                        elif width > height:
                            result = Image.new(pil_img.mode, (width, width), background_color)
                            result.paste(pil_img, (0, (width - height) // 2))
                            return result
                        else:
                            result = Image.new(pil_img.mode, (height, height), background_color)
                            result.paste(pil_img, ((height - width) // 2, 0))
                            return result
                    image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
                    image = processor.preprocess(image, return_tensors='pt')['pixel_values']
                else:
                    image = processor.preprocess(image, return_tensors='pt')['pixel_values']
                images.append(image)
                sources = preprocess_multimodal(
                    copy.deepcopy([e["conversations"] for e in sources]),
                    self.data_args)
        else:
            sources = copy.deepcopy([e["conversations"] for e in sources])
        data_dict = preprocess(
            sources,
            self.tokenizer,
            has_image=('image' in self.list_data_dict[i]))
        if isinstance(i, int):
            data_dict = dict(input_ids=data_dict["input_ids"][0],
                             labels=data_dict["labels"][0])

        # image exist in the data
        if 'images' in self.list_data_dict[i]:
            data_dict['images'] = images
        elif self.data_args.is_multimodal:
            # image does not exist in the data, but the model is multimodal
            crop_size = self.data_args.image_processor.crop_size
            data_dict['images'] = [torch.zeros(3, crop_size['height'], crop_size['width'])]
        return data_dict
  1. 修改batch Input
# llava/train/train.py
@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances]
                                  for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id)
        labels = torch.nn.utils.rnn.pad_sequence(labels,
                                                 batch_first=True,
                                                 padding_value=IGNORE_INDEX)
        input_ids = input_ids[:, :self.tokenizer.model_max_length]
        labels = labels[:, :self.tokenizer.model_max_length]
        batch = dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )

        if 'image' in instances[0]:
            images = [instance['images'] for instance in instances]
            if all(x is not None and x.shape == images[0].shape for x in images):
                batch['images'] = torch.stack(images)
            else:
                batch['images'] = images

        return batch

Happy coding! 感兴趣的朋友可以在Github关注

Chengru-Song/awesome-MultiModel-LLM-SFT

后续会增加更多多模态Fine-tune相关代码。Stay tuned!