Skip to content

vllm.attention.ops.vit_attn_wrappers

This file contains ops for ViT attention to be compatible with torch.compile as there are operations here not supported by torch.compile (for instance, .item() in flash attention)

Using these ops and wrapping vision blocks with torch.compile can speed up throughput in vision models by ~5% relative on H100, and improve token latencies by ~7% (see qwen2_5_vl for example usage)

To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0)

flash_attn_maxseqlen_wrapper

flash_attn_maxseqlen_wrapper(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    cu_seqlens: Tensor,
    max_seqlen: Tensor,
    batch_size: int,
    is_rocm_aiter: bool,
    use_upstream_fa: bool,
) -> Tensor
Source code in vllm/attention/ops/vit_attn_wrappers.py
def flash_attn_maxseqlen_wrapper(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    cu_seqlens: torch.Tensor,
    max_seqlen: torch.Tensor,
    batch_size: int,
    is_rocm_aiter: bool,
    use_upstream_fa: bool,
) -> torch.Tensor:
    if is_rocm_aiter:
        from aiter import flash_attn_varlen_func
    else:
        if use_upstream_fa:
            from flash_attn import flash_attn_varlen_func
        else:
            from vllm.attention.utils.fa_utils import flash_attn_varlen_func
    q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
    output = flash_attn_varlen_func(
        q,
        k,
        v,
        cu_seqlens_q=cu_seqlens,
        cu_seqlens_k=cu_seqlens,
        max_seqlen_q=max_seqlen.item(),
        max_seqlen_k=max_seqlen.item(),
        dropout_p=0.0,
        causal=False,
    )
    context_layer = einops.rearrange(
        output, "(b s) h d -> s b (h d)", b=batch_size
    ).contiguous()
    return context_layer

flash_attn_maxseqlen_wrapper_fake

flash_attn_maxseqlen_wrapper_fake(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    cu_seqlens: Tensor,
    max_seqlen: Tensor,
    batch_size: int,
    is_rocm_aiter: bool,
    use_upstream_fa: bool,
) -> Tensor
Source code in vllm/attention/ops/vit_attn_wrappers.py
def flash_attn_maxseqlen_wrapper_fake(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    cu_seqlens: torch.Tensor,
    max_seqlen: torch.Tensor,
    batch_size: int,
    is_rocm_aiter: bool,
    use_upstream_fa: bool,
) -> torch.Tensor:
    b, s, h, d = q.shape
    return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)

torch_sdpa_wrapper

torch_sdpa_wrapper(
    q: Tensor, k: Tensor, v: Tensor, cu_seqlens: Tensor
) -> Tensor
Source code in vllm/attention/ops/vit_attn_wrappers.py
def torch_sdpa_wrapper(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    cu_seqlens: torch.Tensor,
) -> torch.Tensor:
    outputs = []
    for i in range(1, len(cu_seqlens)):
        start_idx = cu_seqlens[i - 1]
        end_idx = cu_seqlens[i]
        q_i = q[:, start_idx:end_idx]
        k_i = k[:, start_idx:end_idx]
        v_i = v[:, start_idx:end_idx]
        q_i, k_i, v_i = (
            einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
        )
        output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
        output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
        outputs.append(output_i)
    context_layer = torch.cat(outputs, dim=1)
    context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
    return context_layer

torch_sdpa_wrapper_fake

torch_sdpa_wrapper_fake(
    q: Tensor, k: Tensor, v: Tensor, cu_seqlens: Tensor
) -> Tensor
Source code in vllm/attention/ops/vit_attn_wrappers.py
def torch_sdpa_wrapper_fake(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    cu_seqlens: torch.Tensor,
) -> torch.Tensor:
    b, s, h, d = q.shape
    return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)

vit_flash_attn_wrapper

vit_flash_attn_wrapper(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    cu_seqlens: Tensor,
    max_seqlen: Tensor,
    batch_size: int,
    is_rocm_aiter: bool,
    use_upstream_fa: bool,
) -> Tensor
Source code in vllm/attention/ops/vit_attn_wrappers.py
def vit_flash_attn_wrapper(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    cu_seqlens: torch.Tensor,
    max_seqlen: torch.Tensor,
    batch_size: int,
    is_rocm_aiter: bool,
    use_upstream_fa: bool,
) -> torch.Tensor:
    return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
        q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
    )

vit_torch_sdpa_wrapper

vit_torch_sdpa_wrapper(
    q: Tensor, k: Tensor, v: Tensor, cu_seqlens: Tensor
) -> Tensor
Source code in vllm/attention/ops/vit_attn_wrappers.py
def vit_torch_sdpa_wrapper(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    cu_seqlens: torch.Tensor,
) -> torch.Tensor:
    return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens)