多模態支援¶
本文件將引導您完成擴充套件基礎模型以接受 多模態輸入 的步驟。
1. 更新基礎 vLLM 模型¶
假設您已按照 這些步驟 在 vLLM 中實現了模型。請按如下方式進一步更新模型:
-
實現 get_placeholder_str 來定義用於在文字 Prompt 中表示多模態項的佔位符字串。這應與模型的聊天模板保持一致。
-
在 forward 中為每個對應於多模態輸入的輸入張量保留一個關鍵字引數,如下例所示:
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
+ pixel_values: torch.Tensor,
) -> SamplerOutput:
更方便的做法是,您可以直接將 **kwargs 傳遞給 forward 方法,並從中檢索多模態輸入的關鍵字引數。
-
實現 embed_multimodal,該方法透過模型的“多模態 Tokenizer”執行多模態輸入並返回嵌入。下面我們提供了一個典型實現模式的樣板程式碼,但您可以根據自己的需求進行調整。
程式碼
class YourModelForImage2Seq(nn.Module): ... def _process_image_input(self, image_input: YourModelImageInputs) -> torch.Tensor: assert self.vision_encoder is not None image_features = self.vision_encoder(image_input) return self.multi_modal_projector(image_features) def embed_multimodal( self, **kwargs: object, ) -> MultiModalEmbeddings | None: # Validate the multimodal input keyword arguments image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return None # Run multimodal inputs through encoder and projector vision_embeddings = self._process_image_input(image_input) return vision_embeddings
重要
返回的 multimodal_embeddings 必須是形狀為 (num_items, feature_size, hidden_size) 的 **3D torch.Tensor**,或者是一系列形狀為 (feature_size, hidden_size) 的 **2D torch.Tensor**,以便 multimodal_embeddings[i] 檢索從請求的第 i 個多模態資料項(例如影像)生成的嵌入。
注意
預設情況下,vLLM 根據 PlaceholderRange 中定義的位置資訊,將多模態嵌入合併到文字嵌入中。此邏輯可在 embed_input_ids 中找到。
如果合併嵌入時需要額外的邏輯,您可以覆蓋此方法。
-
實現 get_language_model getter,以提供對底層語言模型的穩定訪問。
-
完成上述步驟後,使用 SupportsMultiModal 介面更新模型類。
+ from vllm.model_executor.models.interfaces import SupportsMultiModal
- class YourModelForImage2Seq(nn.Module):
+ class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
注意
模型類不必命名為 *ForCausalLM。有關示例,請參閱 HuggingFace Transformers 文件。
2. 指定處理資訊¶
接下來,建立 BaseProcessingInfo 的子類,以提供與 HF 處理相關的基本資訊。
輸入項的最大數量¶
您需要覆蓋抽象方法 get_supported_mm_limits,以返回模型支援的每種模態的最大輸入項數量。
例如,如果模型支援任意數量的影像,但每個 Prompt 只有一個影片
3. 指定虛擬輸入¶
然後,繼承 BaseDummyInputsBuilder 來為 HF 處理和記憶體分析構建虛擬輸入。
用於記憶體分析¶
覆蓋抽象方法 get_dummy_text 和 get_dummy_mm_data 以構建用於記憶體分析的虛擬輸入。這些虛擬輸入應導致模型達到最壞情況的記憶體使用量,以便 vLLM 可以為其預留正確的記憶體量。
假設記憶體使用量隨 Token 數量而增加,則虛擬輸入的構建可以最大化輸出嵌入的數量,這與佔位符特徵 Token 的數量相同。
檢視 HF 的 LlavaForConditionalGeneration 程式碼
程式碼
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L530-L544
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
每個影像的佔位符特徵 Token 數量為 image_features.shape[1]。image_features 在 get_image_features 方法中計算。
程式碼
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L290-L300
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
我們可以推斷 image_features.shape[1] 基於(對於 llava-hf/llava-1.5-7b-hf 模型)視覺塔(CLIPVisionModel)的 image_outputs.hidden_states.shape[1]。此外,我們只需要序列長度(張量的第二個維度)來獲得 image_features.shape[1]。序列長度由 CLIPVisionTransformer 中的初始隱藏狀態確定,因為注意力機制不改變輸出隱藏狀態的序列長度。
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L1094-L1102
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
hidden_states = self.pre_layrnorm(hidden_states)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
為了找到序列長度,我們檢視 CLIPVisionEmbeddings 的程式碼。
程式碼
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L247-L257
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
我們可以推斷 embeddings.shape[1] == self.num_positions,其中
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L195-L196
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
總而言之,影像的佔位符特徵 Token 數量可以計算為:
程式碼
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self.get_hf_config()
hf_processor = self.get_hf_processor()
image_size = hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size
num_image_tokens = (image_size // patch_size) ** 2 + 1
if hf_processor.vision_feature_select_strategy == "default":
num_image_tokens -= 1
return num_image_tokens
請注意,影像 Token 的數量不依賴於影像的寬度和高度。我們可以簡單地使用虛擬 image_size 來計算多模態分析資料。
程式碼
# NOTE: In actuality, this is usually implemented as part of the
# model's subclass of `BaseProcessingInfo`, but we show it as is
# here for simplicity.
def get_image_size_with_most_features(self) -> ImageSize:
hf_config = self.get_hf_config()
width = height = hf_config.image_size
return ImageSize(width=width, height=height)
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides)
}
對於文字,我們只需將模型配置中的多模態影像 Token 擴充套件到所需影像數量。
檢視 HF 的 FuyuForCausalLM 程式碼
程式碼
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/modeling_fuyu.py#L311-L322
if image_patches is not None and past_key_values is None:
patch_embeddings = [
self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype))
.squeeze(0)
.to(inputs_embeds.device)
for patch in image_patches
]
inputs_embeds = self.gather_continuous_embeddings(
word_embeddings=inputs_embeds,
continuous_embeddings=patch_embeddings,
image_patch_input_indices=image_patches_indices,
)
批次中第 i 個項的佔位符特徵 Token 數量為 patch_embeddings[i].shape[0],這與 image_patches[i].shape[0] 相同,即 num_total_patches。
與 LLaVA 不同,Fuyu 沒有在建模檔案中定義 Patch 的數量。我們可以在哪裡獲取更多資訊?考慮到模型輸入來自 FuyuProcessor 的輸出,讓我們**檢視預處理檔案**。
影像輸出是透過在 FuyuProcessor 中呼叫 FuyuImageProcessor.preprocess,然後呼叫 FuyuImageProcessor.preprocess_with_tokenizer_info 來獲得的。
在 FuyuImageProcessor.preprocess 中,影像被調整大小並填充到目標 FuyuImageProcessor.size,將調整大小後的尺寸(但在填充之前)作為元資料返回。
程式碼
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L541-L544
image_encoding = self.image_processor.preprocess(images, **output_kwargs["images_kwargs"])
batch_images = image_encoding["images"]
image_unpadded_heights = image_encoding["image_unpadded_heights"]
image_unpadded_widths = image_encoding["image_unpadded_widths"]
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L480-L
if do_resize:
batch_images = [
[self.resize(image, size=size, input_data_format=input_data_format) for image in images]
for images in batch_images
]
image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images]
image_unpadded_heights = [[image_size[0]] for image_size in image_sizes]
image_unpadded_widths = [[image_size[1]] for image_size in image_sizes]
if do_pad:
batch_images = [
[
self.pad_image(
image,
size=size,
mode=padding_mode,
constant_values=padding_value,
input_data_format=input_data_format,
)
for image in images
]
for images in batch_images
]
在 FuyuImageProcessor.preprocess_with_tokenizer_info 中,影像根據此元資料被分割成 Patch。
程式碼
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L417-L425
model_image_input = self.image_processor.preprocess_with_tokenizer_info(
image_input=tensor_batch_images,
image_present=image_present,
image_unpadded_h=image_unpadded_heights,
image_unpadded_w=image_unpadded_widths,
image_placeholder_id=image_placeholder_id,
image_newline_id=image_newline_id,
variable_sized=True,
)
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L638-L658
image_height, image_width = image.shape[1], image.shape[2]
if variable_sized: # variable_sized=True
new_h = min(
image_height,
math.ceil(image_unpadded_h[batch_index, subseq_index] / patch_height) * patch_height,
)
new_w = min(
image_width,
math.ceil(image_unpadded_w[batch_index, subseq_index] / patch_width) * patch_width,
)
image = image[:, :new_h, :new_w]
image_height, image_width = new_h, new_w
num_patches = self.get_num_patches(image_height=image_height, image_width=image_width)
tensor_of_image_ids = torch.full(
[num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device
)
patches = self.patchify_image(image=image.unsqueeze(0)).squeeze(0)
assert num_patches == patches.shape[0]
Patch 的數量又由 FuyuImageProcessor.get_num_patches 定義。
程式碼
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L552-L562
patch_size = patch_size if patch_size is not None else self.patch_size
patch_height, patch_width = self.patch_size["height"], self.patch_size["width"]
if image_height % patch_height != 0:
raise ValueError(f"{image_height=} must be divisible by {patch_height}")
if image_width % patch_width != 0:
raise ValueError(f"{image_width=} must be divisible by {patch_width}")
num_patches_per_dim_h = image_height // patch_height
num_patches_per_dim_w = image_width // patch_width
num_patches = num_patches_per_dim_h * num_patches_per_dim_w
這些影像 Patch 對應於佔位符 Token(|SPEAKER|)。因此,我們只需要最大化影像 Patch 的數量。由於輸入影像首先被調整大小以適應 image_processor.size,我們可以透過輸入大小等於 image_processor.size 的影像來最大化影像 Patch 的數量。
def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_image_processor()
return ImageSize(
width=image_processor.size["width"],
height=image_processor.size["height"],
)
Fuyu 在 HF 處理器的輸入中不期望影像佔位符,因此虛擬 Prompt 文字為空,無論影像數量如何。
對於多模態影像分析資料,其邏輯與 LLaVA 非常相似。
程式碼
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict:
target_width, target_height = \
self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
image_overrides = mm_options.get("image") if mm_options else None
return {
"image":
self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides,
)
}
4. 指定處理細節¶
之後,建立 BaseMultiModalProcessor 的子類,以填充 HF 處理的缺失細節。
資訊
多模態欄位¶
覆蓋 _get_mm_fields_config 以返回與輸入多模態項相關的 HF 處理器輸出的張量架構。
CLIPImageProcessor 的輸出是一個簡單的張量,形狀為 (num_images, num_channels, image_height, image_width)。
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/image_processing_clip.py#L339-L345
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
for image in all_images
]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
因此,我們如下覆蓋 _get_mm_fields_config:
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
)
注意
我們的 實際程式碼 還支援透過 image_embeds 引數傳遞預計算的影像嵌入。
FuyuImageProcessor.preprocess_with_tokenizer_info 的 image_patches 輸出會連線批次中每個影像的 Patch。
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L673-L679
image_input_ids.append(tensor_of_image_ids)
image_patches.append(patches)
else:
image_input_ids.append(torch.tensor([], dtype=torch.int32, device=image_input.device))
batch_image_input_ids.append(image_input_ids)
batch_image_patches.append(image_patches)
因此,FuyuImageProcessor 輸出的 image_patches 的形狀為 (1, num_images, num_patches, patch_width * patch_height * num_channels)。
為了支援像 LLaVA 那樣使用 MultiModalFieldConfig.batched,我們透過覆蓋 BaseMultiModalProcessor._call_hf_processor 來移除額外的批次維度。
程式碼
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
image_patches = processed_outputs.get("image_patches")
if image_patches is not None:
images = mm_data["images"]
assert isinstance(images, list)
# Original output: (1, num_images, Pn, Px * Py * C)
# New output: (num_images, Pn, Px * Py * C)
assert (isinstance(image_patches, list)
and len(image_patches) == 1)
assert (isinstance(image_patches[0], torch.Tensor)
and len(image_patches[0]) == len(images))
processed_outputs["image_patches"] = image_patches[0]
return processed_outputs
注意
我們的 實際程式碼 對僅文字輸入進行了特殊處理,以防止 HF 處理器發出不必要的警告。
注意
_call_hf_processor 方法為處理指定了 mm_kwargs 和 tok_kwargs。mm_kwargs 用於初始化和呼叫 Huggingface 處理器,而 tok_kwargs 僅用於呼叫 Huggingface 處理器。
這使我們能夠如下覆蓋 _get_mm_fields_config:
Prompt 更新¶
覆蓋 _get_prompt_updates 以返回 PromptUpdate 例項的列表。
每個 PromptUpdate 例項指定了 HF 處理器執行的更新操作(例如:插入、替換)。
檢視 HF 的 LlavaProcessor
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/processing_llava.py#L167-L170
prompt_strings = []
for sample in text:
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
prompt_strings.append(sample)
它只是將每個輸入 image_token 重複 num_image_tokens 次(佔位符特徵 Token 的數量)。基於此,我們如下覆蓋 _get_prompt_updates:
程式碼
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
def get_replacement(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
)
return [image_token_id] * num_image_tokens
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement,
),
]
回憶步驟 2 中的特徵 Token 佈局。
|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE|
|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE|
...
|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE|
我們定義一個輔助函式以直接返回 ncols 和 nrows。
程式碼
def get_image_feature_grid_size(
self,
*,
image_width: int,
image_height: int,
) -> tuple[int, int]:
image_processor = self.get_image_processor()
target_width = image_processor.size["width"]
target_height = image_processor.size["height"]
patch_width = image_processor.patch_size["width"]
patch_height = image_processor.patch_size["height"]
if not (image_width <= target_width and image_height <= target_height):
height_scale_factor = target_height / image_height
width_scale_factor = target_width / image_width
optimal_scale_factor = min(height_scale_factor, width_scale_factor)
image_height = int(image_height * optimal_scale_factor)
image_width = int(image_width * optimal_scale_factor)
ncols = math.ceil(image_width / patch_width)
nrows = math.ceil(image_height / patch_height)
return ncols, nrows
基於此,我們最初可以定義我們的替換 Token 為:
程式碼
def get_replacement(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
ncols, nrows = self.info.get_image_feature_grid_size(
image_width=image_size.width,
image_height=image_size.height,
)
# `_IMAGE_TOKEN_ID` corresponds to `|SPEAKER|`
# `_NEWLINE_TOKEN_ID` corresponds to `|NEWLINE|`
return ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows
但是,這並不完全正確。呼叫 FuyuImageProcessor.preprocess_with_tokenizer_info 後,還會向 Prompt 新增一個 BOS Token(<s>)。
程式碼
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L417-L435
model_image_input = self.image_processor.preprocess_with_tokenizer_info(
image_input=tensor_batch_images,
image_present=image_present,
image_unpadded_h=image_unpadded_heights,
image_unpadded_w=image_unpadded_widths,
image_placeholder_id=image_placeholder_id,
image_newline_id=image_newline_id,
variable_sized=True,
)
prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
tokenizer=self.tokenizer,
prompts=prompts,
scale_factors=scale_factors,
max_tokens_to_generate=self.max_tokens_to_generate,
max_position_embeddings=self.max_position_embeddings,
add_BOS=True,
add_beginning_of_answer_token=True,
)
為了僅將視覺嵌入分配給影像 Token,您可以返回 PromptUpdateDetails 資料類的例項,而不是字串。
程式碼
hf_config = self.info.get_hf_config()
bos_token_id = hf_config.bos_token_id # `<s>`
assert isinstance(bos_token_id, int)
def get_replacement_fuyu(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
ncols, nrows = self.info.get_image_feature_grid_size(
image_width=image_size.width,
image_height=image_size.height,
)
image_tokens = ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows
return PromptUpdateDetails.select_token_id(
image_tokens + [bos_token_id],
embed_token_id=_IMAGE_TOKEN_ID,
)
最後,注意到 HF 處理器會從 Tokenized Prompt 中刪除 |ENDOFTEXT| Token,我們可以搜尋它以在字串開頭進行替換。
程式碼
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
bos_token_id = hf_config.bos_token_id
assert isinstance(bos_token_id, int)
tokenizer = self.info.get_tokenizer()
eot_token_id = tokenizer.bos_token_id
assert isinstance(eot_token_id, int)
def get_replacement_fuyu(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
ncols, nrows = self.info.get_image_feature_grid_size(
image_width=image_size.width,
image_height=image_size.height,
)
image_tokens = ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows
return PromptUpdateDetails.select_token_id(
image_tokens + [bos_token_id],
embed_token_id=_IMAGE_TOKEN_ID,
)
return [
PromptReplacement(
modality="image",
target=[eot_token_id],
replacement=get_replacement_fuyu,
)
]
5. 註冊與處理器相關的類¶
定義了 BaseProcessingInfo(步驟 2)、BaseDummyInputsBuilder(步驟 3)和 BaseMultiModalProcessor(步驟 4)後,使用 MULTIMODAL_REGISTRY.register_processor 裝飾模型類,將其註冊到多模態登錄檔中。
from vllm.model_executor.models.interfaces import SupportsMultiModal
+ from vllm.multimodal import MULTIMODAL_REGISTRY
+ @MULTIMODAL_REGISTRY.register_processor(
+ YourMultiModalProcessor,
+ info=YourProcessingInfo,
+ dummy_inputs=YourDummyInputsBuilder,
+ )
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
註釋¶
插入特徵 Token 而不替換¶
一些 HF 處理器直接插入特徵 Token 而不替換原始 Prompt 中的任何內容。在這種情況下,您可以在 _get_prompt_updates 中使用 PromptInsertion 而不是 PromptReplacement。
示例
- BLIP-2(在 Prompt 開始處插入): vllm/model_executor/models/blip2.py
- Molmo(在
<|endoftext|>Token 之後插入): vllm/model_executor/models/molmo.py
處理與多模態資料無關的 Prompt 更新¶
_get_prompt_updates 假設每次應用 Prompt 更新都對應一個多模態項。如果 HF 處理器執行的額外處理與多模態項的數量無關,您應該覆蓋 _apply_hf_processor_tokens_only,以便處理後的 Token 輸入與對文字輸入應用 HF 處理器後的結果一致。這是因為根據 我們的設計,Token 輸入會繞過 HF 處理器。
示例
- Chameleon(附加
sep_token): vllm/model_executor/models/chameleon.py - Fuyu(附加
boa_token): vllm/model_executor/models/fuyu.py - Molmo(應用未在別處定義的聊天模板): vllm/model_executor/models/molmo.py
自定義 HF 處理器¶
某些模型在 HF Hub 上沒有定義 HF 處理器類。在這種情況下,您可以定義一個與 HF 處理器具有相同呼叫簽名的自定義 HF 處理器,並將其傳遞給 _call_hf_processor。
示例
- DeepSeek-VL2: vllm/model_executor/models/deepseek_vl2.py
- InternVL: vllm/model_executor/models/internvl.py
- Qwen-VL: vllm/model_executor/models/qwen_vl.py