跳到內容

多模態輸入

本頁將教你如何將多模態輸入傳遞給 vLLM 中的多模態模型

注意

我們正在積極迭代多模態支援。請參閱 RFC 以瞭解即將進行的更改,如果您有任何反饋或功能請求,請在 GitHub 上提出 issue

提示

在部署多模態模型時,請考慮設定 --allowed-media-domains 以限制 vLLM 可以訪問的域,防止其訪問可能容易受到伺服器端請求偽造 (SSRF) 攻擊的任意端點。您可以使用域列表為該引數賦值。例如:--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com

另外,請考慮設定 VLLM_MEDIA_URL_ALLOW_REDIRECTS=0 以防止跟隨 HTTP 重定向繞過域限制。

如果 vLLM 在容器化環境中執行,vLLM Pod 可能擁有對內部網路的無限制訪問許可權,因此此限制尤為重要。

離線推理

要輸入多模態資料,請遵循 vllm.inputs.PromptType 中的該模式

快取的穩定 UUID (multi_modal_uuids)

在使用多模態輸入時,vLLM 通常會按內容對每個媒體項進行雜湊處理,以便在請求之間進行快取。您可以選擇性地傳遞 multi_modal_uuids 來為每個項提供自己的穩定 ID,這樣快取就可以在請求之間重用工作,而無需重新雜湊原始內容。

程式碼
from vllm import LLM
from PIL import Image

# Qwen2.5-VL example with two images
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct")

prompt = "USER: <image><image>\nDescribe the differences.\nASSISTANT:"
img_a = Image.open("/path/to/a.jpg")
img_b = Image.open("/path/to/b.jpg")

outputs = llm.generate({
    "prompt": prompt,
    "multi_modal_data": {"image": [img_a, img_b]},
    # Provide stable IDs for caching.
    # Requirements (matched by this example):
    #  - Include every modality present in multi_modal_data.
    #  - For lists, provide the same number of entries.
    #  - Use None to fall back to content hashing for that item.
    "multi_modal_uuids": {"image": ["sku-1234-a", None]},
})

for o in outputs:
    print(o.outputs[0].text)

使用 UUID,如果您期望快取命中,甚至可以完全跳過傳送媒體資料。請注意,如果跳過的媒體沒有相應的 UUID,或者 UUID 快取命中失敗,則請求將失敗。

程式碼
from vllm import LLM
from PIL import Image

# Qwen2.5-VL example with two images
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct")

prompt = "USER: <image><image>\nDescribe the differences.\nASSISTANT:"
img_b = Image.open("/path/to/b.jpg")

outputs = llm.generate({
    "prompt": prompt,
    "multi_modal_data": {"image": [None, img_b]},
    # Since img_a is expected to be cached, we can skip sending the actual
    # image entirely.
    "multi_modal_uuids": {"image": ["sku-1234-a", None]},
})

for o in outputs:
    print(o.outputs[0].text)

警告

如果停用了多模態處理器快取和字首快取,則使用者提供的 multi_modal_uuids 將被忽略。

影像輸入

您可以在多模態字典的 'image' 欄位中傳遞單個影像,如下例所示

程式碼
from vllm import LLM

llm = LLM(model="llava-hf/llava-1.5-7b-hf")

# Refer to the HuggingFace repo for the correct format to use
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"

# Load the image using PIL.Image
image = PIL.Image.open(...)

# Single prompt inference
outputs = llm.generate({
    "prompt": prompt,
    "multi_modal_data": {"image": image},
})

for o in outputs:
    generated_text = o.outputs[0].text
    print(generated_text)

# Batch inference
image_1 = PIL.Image.open(...)
image_2 = PIL.Image.open(...)
outputs = llm.generate(
    [
        {
            "prompt": "USER: <image>\nWhat is the content of this image?\nASSISTANT:",
            "multi_modal_data": {"image": image_1},
        },
        {
            "prompt": "USER: <image>\nWhat's the color of this image?\nASSISTANT:",
            "multi_modal_data": {"image": image_2},
        }
    ]
)

for o in outputs:
    generated_text = o.outputs[0].text
    print(generated_text)

完整示例: examples/offline_inference/vision_language.py

要在同一個文字提示中替換多個影像,您可以傳遞影像列表

程式碼
from vllm import LLM

llm = LLM(
    model="microsoft/Phi-3.5-vision-instruct",
    trust_remote_code=True,  # Required to load Phi-3.5-vision
    max_model_len=4096,  # Otherwise, it may not fit in smaller GPUs
    limit_mm_per_prompt={"image": 2},  # The maximum number to accept
)

# Refer to the HuggingFace repo for the correct format to use
prompt = "<|user|>\n<|image_1|>\n<|image_2|>\nWhat is the content of each image?<|end|>\n<|assistant|>\n"

# Load the images using PIL.Image
image1 = PIL.Image.open(...)
image2 = PIL.Image.open(...)

outputs = llm.generate({
    "prompt": prompt,
    "multi_modal_data": {"image": [image1, image2]},
})

for o in outputs:
    generated_text = o.outputs[0].text
    print(generated_text)

完整示例: examples/offline_inference/vision_language_multi_image.py

如果使用 LLM.chat 方法,您可以直接在訊息內容中使用各種格式的影像:影像 URL、PIL Image 物件或預計算的嵌入。

from vllm import LLM
from vllm.assets.image import ImageAsset

llm = LLM(model="llava-hf/llava-1.5-7b-hf")
image_url = "https://picsum.photos/id/32/512/512"
image_pil = ImageAsset('cherry_blossom').pil_image
image_embeds = torch.load(...)

conversation = [
    {"role": "system", "content": "You are a helpful assistant"},
    {"role": "user", "content": "Hello"},
    {"role": "assistant", "content": "Hello! How can I assist you today?"},
    {
        "role": "user",
        "content": [
            {
                "type": "image_url",
                "image_url": {"url": image_url},
            },
            {
                "type": "image_pil",
                "image_pil": image_pil,
            },
            {
                "type": "image_embeds",
                "image_embeds": image_embeds,
            },
            {
                "type": "text",
                "text": "What's in these images?",
            },
        ],
    },
]

# Perform inference and log output.
outputs = llm.chat(conversation)

for o in outputs:
    generated_text = o.outputs[0].text
    print(generated_text)

多影像輸入可以擴充套件以執行影片字幕。我們使用 Qwen2-VL 來演示這一點,因為它支援影片。

程式碼
from vllm import LLM

# Specify the maximum number of frames per video to be 4. This can be changed.
llm = LLM("Qwen/Qwen2-VL-2B-Instruct", limit_mm_per_prompt={"image": 4})

# Create the request payload.
video_frames = ... # load your video making sure it only has the number of frames specified earlier.
message = {
    "role": "user",
    "content": [
        {
            "type": "text",
            "text": "Describe this set of frames. Consider the frames to be a part of the same video.",
        },
    ],
}
for i in range(len(video_frames)):
    base64_image = encode_image(video_frames[i]) # base64 encoding.
    new_image = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
    message["content"].append(new_image)

# Perform inference and log output.
outputs = llm.chat([message])

for o in outputs:
    generated_text = o.outputs[0].text
    print(generated_text)

自定義 RGBA 背景顏色

載入 RGBA 影像(帶透明度的影像)時,vLLM 會將其轉換為 RGB 格式。預設情況下,透明畫素會被白色背景替換。您可以透過 media_io_kwargs 中的 rgba_background_color 引數自定義此背景顏色。

程式碼
from vllm import LLM

# Default white background (no configuration needed)
llm = LLM(model="llava-hf/llava-1.5-7b-hf")

# Custom black background for dark theme
llm = LLM(
    model="llava-hf/llava-1.5-7b-hf",
    media_io_kwargs={"image": {"rgba_background_color": [0, 0, 0]}},
)

# Custom brand color background (e.g., blue)
llm = LLM(
    model="llava-hf/llava-1.5-7b-hf",
    media_io_kwargs={"image": {"rgba_background_color": [0, 0, 255]}},
)

注意

  • rgba_background_color 接受 RGB 值,格式為列表 [R, G, B] 或元組 (R, G, B),其中每個值在 0-255 之間。
  • 此設定僅影響 RGBA 影像的透明度;RGB 影像不受影響。
  • 如果未指定,為相容性起見,將使用預設的白色背景 (255, 255, 255)

影片輸入

您可以將 NumPy 陣列列表直接傳遞給多模態字典的 'video' 欄位,而不是使用多影像輸入。

除了 NumPy 陣列,您還可以傳遞 'torch.Tensor' 例項,如使用 Qwen2.5-VL 的此示例所示。

程式碼
from transformers import AutoProcessor
from vllm import LLM, SamplingParams
from qwen_vl_utils import process_vision_info

model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
video_path = "https://content.pexels.com/videos/free-videos.mp4"

llm = LLM(
    model=model_path,
    gpu_memory_utilization=0.8,
    enforce_eager=True,
    limit_mm_per_prompt={"video": 1},
)

sampling_params = SamplingParams(max_tokens=1024)

video_messages = [
    {
        "role": "system",
        "content": "You are a helpful assistant.",
    },
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "describe this video."},
            {
                "type": "video",
                "video": video_path,
                "total_pixels": 20480 * 28 * 28,
                "min_pixels": 16 * 28 * 28,
            },
        ]
    },
]

messages = video_messages
processor = AutoProcessor.from_pretrained(model_path)
prompt = processor.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)

image_inputs, video_inputs = process_vision_info(messages)
mm_data = {}
if video_inputs is not None:
    mm_data["video"] = video_inputs

llm_inputs = {
    "prompt": prompt,
    "multi_modal_data": mm_data,
}

outputs = llm.generate([llm_inputs], sampling_params=sampling_params)
for o in outputs:
    generated_text = o.outputs[0].text
    print(generated_text)

注意

'process_vision_info' 僅適用於 Qwen2.5-VL 及類似模型。

完整示例: examples/offline_inference/vision_language.py

音訊輸入

您可以將元組 (array, sampling_rate) 傳遞給多模態字典的 'audio' 欄位。

完整示例: examples/offline_inference/audio_language.py

嵌入輸入

要將屬於某個資料型別(例如影像、影片或音訊)的預計算嵌入直接輸入到語言模型中,請將形狀為 (num_items, feature_size, LM 的 hidden_size) 的張量傳遞給多模態字典的相應欄位。

您必須透過 enable_mm_embeds=True 來啟用此功能。

警告

如果傳遞的嵌入形狀不正確,vLLM 引擎可能會崩潰。僅對受信任的使用者啟用此標誌!

影像嵌入

程式碼
from vllm import LLM

# Inference with image embeddings as input
llm = LLM(model="llava-hf/llava-1.5-7b-hf", enable_mm_embeds=True)

# Refer to the HuggingFace repo for the correct format to use
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"

# Embeddings for single image
# torch.Tensor of shape (1, image_feature_size, hidden_size of LM)
image_embeds = torch.load(...)

outputs = llm.generate({
    "prompt": prompt,
    "multi_modal_data": {"image": image_embeds},
})

for o in outputs:
    generated_text = o.outputs[0].text
    print(generated_text)

對於 Qwen2-VL 和 MiniCPM-V,我們接受嵌入旁的額外引數。

程式碼
# Construct the prompt based on your model
prompt = ...

# Embeddings for multiple images
# torch.Tensor of shape (num_images, image_feature_size, hidden_size of LM)
image_embeds = torch.load(...)

# Qwen2-VL
llm = LLM(
    "Qwen/Qwen2-VL-2B-Instruct",
    limit_mm_per_prompt={"image": 4},
    enable_mm_embeds=True,
)
mm_data = {
    "image": {
        "image_embeds": image_embeds,
        # image_grid_thw is needed to calculate positional encoding.
        "image_grid_thw": torch.load(...),  # torch.Tensor of shape (1, 3),
    }
}

# MiniCPM-V
llm = LLM(
    "openbmb/MiniCPM-V-2_6",
    trust_remote_code=True,
    limit_mm_per_prompt={"image": 4},
    enable_mm_embeds=True,
)
mm_data = {
    "image": {
        "image_embeds": image_embeds,
        # image_sizes is needed to calculate details of the sliced image.
        "image_sizes": [image.size for image in images],  # list of image sizes
    }
}

outputs = llm.generate({
    "prompt": prompt,
    "multi_modal_data": mm_data,
})

for o in outputs:
    generated_text = o.outputs[0].text
    print(generated_text)

對於 Qwen3-VL,image_embeds 應同時包含基本影像嵌入和深度堆疊特徵。

音訊嵌入輸入

您可以像影像嵌入一樣傳遞預計算的音訊嵌入。

程式碼
from vllm import LLM
import torch

# Enable audio embeddings support
llm = LLM(model="fixie-ai/ultravox-v0_5-llama-3_2-1b", enable_mm_embeds=True)

# Refer to the HuggingFace repo for the correct format to use
prompt = "USER: <audio>\nWhat is in this audio?\nASSISTANT:"

# Load pre-computed audio embeddings
# torch.Tensor of shape (1, audio_feature_size, hidden_size of LM)
audio_embeds = torch.load(...)

outputs = llm.generate({
    "prompt": prompt,
    "multi_modal_data": {"audio": audio_embeds},
})

for o in outputs:
    generated_text = o.outputs[0].text
    print(generated_text)

線上服務

我們的 OpenAI 相容伺服器透過 Chat Completions API 接受多模態資料。媒體輸入還支援可選的 UUID,使用者可以提供這些 UUID 來唯一標識每個媒體,用於跨請求快取媒體結果。

重要

使用 Chat Completions API 需要聊天模板。對於 HF 格式的模型,預設聊天模板定義在 chat_template.jsontokenizer_config.json 中。

如果沒有預設聊天模板,我們將首先查詢 內建的後備模板。如果沒有後備模板,則會引發錯誤,您必須透過 --chat-template 引數手動提供聊天模板。

對於某些模型,我們在 示例 中提供了替代聊天模板。例如,VLM2Vec 使用 示例/template_vlm2vec_phi3v.jinja,這與 Phi-3-Vision 的預設模板不同。

影像輸入

影像輸入支援按照 OpenAI Vision API 進行。下面是一個使用 Phi-3.5-Vision 的簡單示例。

首先,啟動 OpenAI 相容伺服器

vllm serve microsoft/Phi-3.5-vision-instruct --runner generate \
  --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt '{"image":2}'

然後,您可以使用 OpenAI 客戶端,如下所示

程式碼
from openai import OpenAI

openai_api_key = "EMPTY"
openai_api_base = "https://:8000/v1"

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

# Single-image input inference
image_url = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"

chat_response = client.chat.completions.create(
    model="microsoft/Phi-3.5-vision-instruct",
    messages=[
        {
            "role": "user",
            "content": [
                # NOTE: The prompt formatting with the image token `<image>` is not needed
                # since the prompt will be processed automatically by the API server.
                {
                    "type": "text",
                    "text": "What’s in this image?",
                },
                {
                    "type": "image_url",
                    "image_url": {"url": image_url},
                    "uuid": image_url,  # Optional
                },
            ],
        }
    ],
)
print("Chat completion output:", chat_response.choices[0].message.content)

# Multi-image input inference
image_url_duck = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/duck.jpg"
image_url_lion = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/lion.jpg"

chat_response = client.chat.completions.create(
    model="microsoft/Phi-3.5-vision-instruct",
    messages=[
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "What are the animals in these images?",
                },
                {
                    "type": "image_url",
                    "image_url": {"url": image_url_duck},
                    "uuid": image_url_duck,  # Optional
                },
                {
                    "type": "image_url",
                    "image_url": {"url": image_url_lion},
                    "uuid": image_url_lion,  # Optional
                },
            ],
        }
    ],
)
print("Chat completion output:", chat_response.choices[0].message.content)

完整示例: examples/online_serving/openai_chat_completion_client_for_multimodal.py

提示

vLLM 還支援從本地檔案路徑載入:啟動 API 伺服器/引擎時,您可以透過 --allowed-local-media-path 指定允許的本地媒體路徑,並在 API 請求中將檔案路徑作為 url 傳遞。

提示

API 請求的文字內容中不需要放置影像佔位符 - 它們已經由影像內容表示。事實上,您可以透過交替文字和影像內容,將影像佔位符放在文字中間。

注意

預設情況下,透過 HTTP URL 獲取影像的超時時間為 5 秒。您可以透過設定環境變數來覆蓋此設定

export VLLM_IMAGE_FETCH_TIMEOUT=<timeout>

影片輸入

您可以將影片檔案透過 video_url 傳遞,而不是 image_url。下面是一個使用 LLaVA-OneVision 的簡單示例。

首先,啟動 OpenAI 相容伺服器

vllm serve llava-hf/llava-onevision-qwen2-0.5b-ov-hf --runner generate --max-model-len 8192

然後,您可以使用 OpenAI 客戶端,如下所示

程式碼
from openai import OpenAI

openai_api_key = "EMPTY"
openai_api_base = "https://:8000/v1"

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

video_url = "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ForBiggerFun.mp4"

## Use video url in the payload
chat_completion_from_url = client.chat.completions.create(
    messages=[
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "What's in this video?",
                },
                {
                    "type": "video_url",
                    "video_url": {"url": video_url},
                    "uuid": video_url,  # Optional
                },
            ],
        }
    ],
    model=model,
    max_completion_tokens=64,
)

result = chat_completion_from_url.choices[0].message.content
print("Chat completion output from image url:", result)

完整示例: examples/online_serving/openai_chat_completion_client_for_multimodal.py

注意

預設情況下,透過 HTTP URL 獲取影片的超時時間為 30 秒。您可以透過設定環境變數來覆蓋此設定

export VLLM_VIDEO_FETCH_TIMEOUT=<timeout>

自定義 RGBA 背景顏色

要為 RGBA 影像使用自定義背景顏色,請透過 --media-io-kwargs 傳遞 rgba_background_color 引數。

# Example: Black background for dark theme
vllm serve llava-hf/llava-1.5-7b-hf \
  --media-io-kwargs '{"image": {"rgba_background_color": [0, 0, 0]}}'

# Example: Custom gray background
vllm serve llava-hf/llava-1.5-7b-hf \
  --media-io-kwargs '{"image": {"rgba_background_color": [128, 128, 128]}}'

音訊輸入

音訊輸入支援按照 OpenAI Audio API 進行。下面是一個使用 Ultravox-v0.5-1B 的簡單示例。

首先,啟動 OpenAI 相容伺服器

vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b

然後,您可以使用 OpenAI 客戶端,如下所示

程式碼
import base64
import requests
from openai import OpenAI
from vllm.assets.audio import AudioAsset

def encode_base64_content_from_url(content_url: str) -> str:
    """Encode a content retrieved from a remote url to base64 format."""

    with requests.get(content_url) as response:
        response.raise_for_status()
        result = base64.b64encode(response.content).decode('utf-8')

    return result

openai_api_key = "EMPTY"
openai_api_base = "https://:8000/v1"

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

# Any format supported by librosa is supported
audio_url = AudioAsset("winning_call").url
audio_base64 = encode_base64_content_from_url(audio_url)

chat_completion_from_base64 = client.chat.completions.create(
    messages=[
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "What's in this audio?",
                },
                {
                    "type": "input_audio",
                    "input_audio": {
                        "data": audio_base64,
                        "format": "wav",
                    },
                    "uuid": audio_url,  # Optional
                },
            ],
        },
    ],
    model=model,
    max_completion_tokens=64,
)

result = chat_completion_from_base64.choices[0].message.content
print("Chat completion output from input audio:", result)

或者,您也可以傳遞 audio_url,這是音訊輸入 image_url 的對應項。

程式碼
chat_completion_from_url = client.chat.completions.create(
    messages=[
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "What's in this audio?",
                },
                {
                    "type": "audio_url",
                    "audio_url": {"url": audio_url},
                    "uuid": audio_url,  # Optional
                },
            ],
        }
    ],
    model=model,
    max_completion_tokens=64,
)

result = chat_completion_from_url.choices[0].message.content
print("Chat completion output from audio url:", result)

完整示例: examples/online_serving/openai_chat_completion_client_for_multimodal.py

注意

預設情況下,透過 HTTP URL 獲取音訊的超時時間為 10 秒。您可以透過設定環境變數來覆蓋此設定

export VLLM_AUDIO_FETCH_TIMEOUT=<timeout>

嵌入輸入

要將屬於某個資料型別(例如影像、影片或音訊)的預計算嵌入直接輸入到語言模型中,請將形狀為 (num_items, feature_size, LM 的 hidden_size) 的張量傳遞給多模態字典的相應欄位。

您必須透過 vllm serve 中的 --enable-mm-embeds 標誌來啟用此功能。

警告

如果傳遞的嵌入形狀不正確,vLLM 引擎可能會崩潰。僅對受信任的使用者啟用此標誌!

影像嵌入輸入

對於影像嵌入,您可以將 base64 編碼的張量傳遞給 image_embeds 欄位。以下示例演示瞭如何將影像嵌入傳遞給 OpenAI 伺服器。

程式碼
from vllm.utils.serial_utils import tensor2base64

image_embedding = torch.load(...)
grid_thw = torch.load(...) # Required by Qwen/Qwen2-VL-2B-Instruct

base64_image_embedding = tensor2base64(image_embedding)

client = OpenAI(
    # defaults to os.environ.get("OPENAI_API_KEY")
    api_key=openai_api_key,
    base_url=openai_api_base,
)

# Basic usage - this is equivalent to the LLaVA example for offline inference
model = "llava-hf/llava-1.5-7b-hf"
embeds = {
    "type": "image_embeds",
    "image_embeds": f"{base64_image_embedding}",
    "uuid": image_url,  # Optional
}

# Pass additional parameters (available to Qwen2-VL and MiniCPM-V)
model = "Qwen/Qwen2-VL-2B-Instruct"
embeds = {
    "type": "image_embeds",
    "image_embeds": {
        "image_embeds": f"{base64_image_embedding}",  # Required
        "image_grid_thw": f"{base64_image_grid_thw}",  # Required by Qwen/Qwen2-VL-2B-Instruct
    },
    "uuid": image_url,  # Optional
}
model = "openbmb/MiniCPM-V-2_6"
embeds = {
    "type": "image_embeds",
    "image_embeds": {
        "image_embeds": f"{base64_image_embedding}",  # Required
        "image_sizes": f"{base64_image_sizes}",  # Required by openbmb/MiniCPM-V-2_6
    },
    "uuid": image_url,  # Optional
}
chat_completion = client.chat.completions.create(
    messages=[
        {
            "role": "system",
            "content": "You are a helpful assistant.",
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "What's in this image?",
                },
                embeds,
            ],
        },
    ],
    model=model,
)

對於線上服務,您也可以在期望透過提供的 UUID 快取命中時跳過傳送媒體。您可以這樣做:

```python
    # Image/video/audio URL:
    {
        "type": "image_url",
        "image_url": None,
        "uuid": image_uuid,
    },

    # image_embeds
    {
        "type": "image_embeds",
        "image_embeds": None,
        "uuid": image_uuid,
    },

    # input_audio:
    {
        "type": "input_audio",
        "input_audio": None,
        "uuid": audio_uuid,
    },

    # PIL Image:
    {
        "type": "image_pil",
        "image_pil": None,
        "uuid": image_uuid,
    },

```

注意

現在,多個訊息可以包含 {"type": "image_embeds"},使您可以在單個請求中傳遞多個影像嵌入(類似於普通影像)。嵌入數量受 --limit-mm-per-prompt 的限制。

重要提示:嵌入的形狀格式因嵌入數量而異

  • 單個嵌入:形狀為 (1, feature_size, hidden_size) 的 3D 張量
  • 多個嵌入:2D 張量列表,每個形狀為 (feature_size, hidden_size)

如果與需要額外引數的模型一起使用,您還必須為每個引數提供一個張量,例如 image_grid_thwimage_sizes 等。