Skip to content

vllm.config.compilation

logger module-attribute

logger = init_logger(__name__)

CUDAGraphMode

Bases: Enum

Constants for the cudagraph mode in CompilationConfig. Meanwhile, the subset enum NONE, PIECEWISE and FULL are also treated as concrete runtime mode for cudagraph runtime dispatching.

Source code in vllm/config/compilation.py
class CUDAGraphMode(enum.Enum):
    """Constants for the cudagraph mode in CompilationConfig.
    Meanwhile, the subset enum `NONE`, `PIECEWISE` and `FULL` are also
    treated as concrete runtime mode for cudagraph runtime dispatching.
    """

    NONE = 0
    PIECEWISE = 1
    FULL = 2
    FULL_DECODE_ONLY = (FULL, NONE)
    FULL_AND_PIECEWISE = (FULL, PIECEWISE)

    def decode_mode(self) -> "CUDAGraphMode":
        return CUDAGraphMode(self.value[0]) if self.separate_routine() else self

    def mixed_mode(self) -> "CUDAGraphMode":
        return CUDAGraphMode(self.value[1]) if self.separate_routine() else self

    def has_mode(self, mode: "CUDAGraphMode") -> bool:
        assert not mode.separate_routine()
        if self.separate_routine():
            return mode.value in self.value
        return self == mode

    def requires_piecewise_compilation(self) -> bool:
        return self.has_mode(CUDAGraphMode.PIECEWISE)

    def max_cudagraph_mode(self) -> "CUDAGraphMode":
        return CUDAGraphMode(max(self.value)) if self.separate_routine() else self

    def has_full_cudagraphs(self) -> bool:
        return self.max_cudagraph_mode() == CUDAGraphMode.FULL

    def has_piecewise_cudagraphs(self) -> bool:
        return self.requires_piecewise_compilation()

    def separate_routine(self) -> bool:
        return isinstance(self.value, tuple)

    def valid_runtime_modes(self) -> bool:
        return self in [CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]

    def __str__(self) -> str:
        return self.name

FULL class-attribute instance-attribute

FULL = 2

FULL_AND_PIECEWISE class-attribute instance-attribute

FULL_AND_PIECEWISE = (FULL, PIECEWISE)

FULL_DECODE_ONLY class-attribute instance-attribute

FULL_DECODE_ONLY = (FULL, NONE)

NONE class-attribute instance-attribute

NONE = 0

PIECEWISE class-attribute instance-attribute

PIECEWISE = 1

__str__

__str__() -> str
Source code in vllm/config/compilation.py
def __str__(self) -> str:
    return self.name

decode_mode

decode_mode() -> CUDAGraphMode
Source code in vllm/config/compilation.py
def decode_mode(self) -> "CUDAGraphMode":
    return CUDAGraphMode(self.value[0]) if self.separate_routine() else self

has_full_cudagraphs

has_full_cudagraphs() -> bool
Source code in vllm/config/compilation.py
def has_full_cudagraphs(self) -> bool:
    return self.max_cudagraph_mode() == CUDAGraphMode.FULL

has_mode

has_mode(mode: CUDAGraphMode) -> bool
Source code in vllm/config/compilation.py
def has_mode(self, mode: "CUDAGraphMode") -> bool:
    assert not mode.separate_routine()
    if self.separate_routine():
        return mode.value in self.value
    return self == mode

has_piecewise_cudagraphs

has_piecewise_cudagraphs() -> bool
Source code in vllm/config/compilation.py
def has_piecewise_cudagraphs(self) -> bool:
    return self.requires_piecewise_compilation()

max_cudagraph_mode

max_cudagraph_mode() -> CUDAGraphMode
Source code in vllm/config/compilation.py
def max_cudagraph_mode(self) -> "CUDAGraphMode":
    return CUDAGraphMode(max(self.value)) if self.separate_routine() else self

mixed_mode

mixed_mode() -> CUDAGraphMode
Source code in vllm/config/compilation.py
def mixed_mode(self) -> "CUDAGraphMode":
    return CUDAGraphMode(self.value[1]) if self.separate_routine() else self

requires_piecewise_compilation

requires_piecewise_compilation() -> bool
Source code in vllm/config/compilation.py
def requires_piecewise_compilation(self) -> bool:
    return self.has_mode(CUDAGraphMode.PIECEWISE)

separate_routine

separate_routine() -> bool
Source code in vllm/config/compilation.py
def separate_routine(self) -> bool:
    return isinstance(self.value, tuple)

valid_runtime_modes

valid_runtime_modes() -> bool
Source code in vllm/config/compilation.py
def valid_runtime_modes(self) -> bool:
    return self in [CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]

CompilationConfig

Configuration for compilation. It has three parts:

Why we have different sizes for cudagraph and inductor: - cudagraph: a cudagraph captured for a specific size can only be used for the same size. We need to capture all the sizes we want to use. - inductor: a graph compiled by inductor for a general shape can be used for different sizes. Inductor can also compile for specific sizes, where it can have more information to optimize the graph with fully static shapes. However, we find the general shape compilation is sufficient for most cases. It might be beneficial to compile for certain small batchsizes, where inductor is good at optimizing.

Source code in vllm/config/compilation.py
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
@config
@dataclass
class CompilationConfig:
    """Configuration for compilation. It has three parts:

    - Top-level Compilation control:
        - [`mode`][vllm.config.CompilationConfig.mode]
        - [`debug_dump_path`][vllm.config.CompilationConfig.debug_dump_path]
        - [`cache_dir`][vllm.config.CompilationConfig.cache_dir]
        - [`backend`][vllm.config.CompilationConfig.backend]
        - [`custom_ops`][vllm.config.CompilationConfig.custom_ops]
        - [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
        - [`compile_mm_encoder`][vllm.config.CompilationConfig.compile_mm_encoder]
    - CudaGraph capture:
        - [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode]
        - [`cudagraph_capture_sizes`]
        [vllm.config.CompilationConfig.cudagraph_capture_sizes]
        - [`max_cudagraph_capture_size`]
        [vllm.config.CompilationConfig.max_cudagraph_capture_size]
        - [`cudagraph_num_of_warmups`]
        [vllm.config.CompilationConfig.cudagraph_num_of_warmups]
        - [`cudagraph_copy_inputs`]
        [vllm.config.CompilationConfig.cudagraph_copy_inputs]
    - Inductor compilation:
        - [`use_inductor`][vllm.config.CompilationConfig.use_inductor]
        - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
        - [`inductor_compile_config`]
        [vllm.config.CompilationConfig.inductor_compile_config]
        - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes]
        - custom inductor passes

    Why we have different sizes for cudagraph and inductor:
    - cudagraph: a cudagraph captured for a specific size can only be used
        for the same size. We need to capture all the sizes we want to use.
    - inductor: a graph compiled by inductor for a general shape can be used
        for different sizes. Inductor can also compile for specific sizes,
        where it can have more information to optimize the graph with fully
        static shapes. However, we find the general shape compilation is
        sufficient for most cases. It might be beneficial to compile for
        certain small batchsizes, where inductor is good at optimizing.
    """

    # Top-level Compilation control
    level: int | None = None
    """
    Level is deprecated and will be removed in the next release,
    either 0.12.0 or 0.11.2 whichever is soonest.
    Please use mode. Currently all levels are mapped to mode.
    """
    # Top-level Compilation control
    mode: CompilationMode | None = None
    """The compilation approach used for torch.compile-based compilation of the
    model.

    - None: If None, we will select the default compilation mode.
      For V1 engine this is 3.
    - 0: NONE: No torch.compile compilation is applied, model runs in fully
         eager pytorch mode. The model runs as-is.
    - 1: STOCK_TORCH_COMPILE: The standard `torch.compile` compilation pipeline.
    - 2: DYNAMO_TRACE_ONCE: Single Dynamo trace through the model, avoiding
         recompilation by removing guards.
         Requires no dynamic-shape-dependent control-flow.
    - 3: VLLM_COMPILE: Custom vLLM Inductor-based backend with caching,
         piecewise compilation, shape specialization, and custom passes."""
    debug_dump_path: Path | None = None
    """The path to dump the debug information."""
    cache_dir: str = ""
    """The directory to store the compiled graph, to accelerate Inductor
    compilation. By default, it will use model-related information to generate
    a cache directory."""
    compile_cache_save_format: Literal["binary", "unpacked"] = field(
        default_factory=lambda: envs.VLLM_COMPILE_CACHE_SAVE_FORMAT
    )
    """Format for saving torch compile cache:\n
    - "binary": saves as binary file (multiprocess safe)\n
    - "unpacked": saves as directory structure for inspection/debugging
    (NOT multiprocess safe)\n
    Defaults to `VLLM_COMPILE_CACHE_SAVE_FORMAT` if not specified.
    """
    backend: str = ""
    """The backend for compilation. It needs to be a string:

    - "" (empty string): use the default backend ("inductor" on CUDA-alike
    platforms).
    - "eager"/"openxla"/...: use the specified backend registered in PyTorch.
    - "full.module.name": a qualified name which can be used to import the

    backend function.
    We use string to avoid serialization issues when using compilation in a
    distributed setting. When the compilation mode is 1 or 2, the backend is
    used for the compilation directly (it sees the whole graph). When the
    compilation mode is 3, the backend is used for the piecewise compilation
    (it sees a part of the graph). The backend can not be custom for compilation
    mode 3, i.e. the backend must be either eager or inductor. Furthermore,
    compilation is only piecewise if splitting ops is set accordingly and
    use_inductor_graph_partition is off. Note that the default options for
    splitting ops are sufficient for piecewise compilation.
    """
    custom_ops: list[str] = field(default_factory=list)
    """Fine-grained control over which custom ops to enable/disable. Use 'all'
    to enable all, 'none' to disable all. Also specify a list of custom op
    names to enable (prefixed with a '+'), or disable (prefixed with a '-').
    Examples:

    - 'all,-op1' to enable all except op1
    - 'none,+op1,+op2' to enable only op1 and op2

    By default, all custom ops are enabled when running without Inductor and
    disabled when running with Inductor: mode>=VLLM_COMPILE and use_inductor=True.
    Inductor generates (fused) Triton kernels for disabled custom ops."""
    splitting_ops: list[str] | None = None
    """A list of ops to exclude from cudagraphs, used in piecewise compilation.

    The behavior depends on use_inductor_graph_partition:

    - When use_inductor_graph_partition=False (default):
        These ops are used for Dynamo FX-level graph splitting. The graph is
        split at these ops before Inductor compilation, creating separate
        subgraphs for cudagraph capture.

    - When use_inductor_graph_partition=True:
        These ops are used to register Inductor partition rules. The graph
        partitioning happens at Inductor codegen time after all passes and
        fusions are finished, allowing compilation and custom passes to operate
        on the full graph while still excluding these ops from cudagraphs.

    If None, defaults to attention ops for piecewise cudagraphs.
    If empty list [], no ops are excluded (suitable for full cudagraphs)."""
    compile_mm_encoder: bool = False
    """Whether or not to compile the multimodal encoder.
    Currently, this only works for `Qwen2_5_vl` on selected platforms.
    Disabled by default until more models are supported/tested to work."""

    # Inductor capture
    use_inductor: bool | None = None
    """
    Whether to use inductor compilation.

    This flag is deprecated and will be removed in the next release 0.12.0.
    Please use the 'backend' option instead.

    - False: inductor compilation is not used. graph runs in eager
        (custom_ops enabled by default).
    - True: inductor compilation is used (custom_ops disabled by default).
        One graph for symbolic shape and one graph per size in compile_sizes
        are compiled using configurations in inductor_compile_config.

    This setting is ignored if mode<VLLM_COMPILE.

    For future compatibility:
    If use_inductor is True, backend="inductor" otherwise backend="eager".
    """
    compile_sizes: list[int | str] | None = None
    """Sizes to compile for inductor. In addition
    to integers, it also supports "cudagraph_capture_sizes" to
    specify the sizes for cudagraph capture."""

    inductor_compile_config: dict = field(default_factory=dict)
    """Additional configurations for inductor.
    - None: use default configurations."""

    inductor_passes: dict[str, str] = field(default_factory=dict)
    """Additional passes for inductor. It is a dictionary
    from pass name to pass function qualified name. We use function
    name because the config uses JSON format. If we pass the config
    from Python, functions can also be passed directly via Python object
    constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""

    # CudaGraph compilation
    cudagraph_mode: CUDAGraphMode | None = None
    """
    The mode of the cudagraph:

    - NONE, no cudagraph capture.
    - PIECEWISE.
    - FULL.
    - FULL_DECODE_ONLY.
    - FULL_AND_PIECEWISE. (v1 default)

    PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph
    incompatible ops (i.e. some attention ops) outside the cudagraph
    for general flexibility.

    FULL mode: Capture full cudagraph for all batches. Can be good for small
    models or workloads with small prompts; not supported by many backends.
    Generally for performance FULL_AND_PIECEWISE is better.

    FULL_DECODE_ONLY mode: Capture full cudagraph for decode batches only.
    Mixed prefill-decode batches are run without cudagraphs. Can be good for
    decode instances in a P/D setup where prefill is not as important so we
    can save some memory.

    FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and
    piecewise cudagraph for prefill and mixed prefill-decode batches.
    This is the most performant mode for most models and is the default.

    Currently, the cudagraph mode is only used for the v1 engine.
    Note that the cudagraph logic is generally orthogonal to the
    compilation logic. While piecewise cudagraphs require piecewise
    compilation (mode=VLLM_COMPILE and non-empty splitting_ops), full
    cudagraphs are supported with and without compilation.

    Warning: This flag is new and subject to change in addition
    more modes may be added.
    """
    cudagraph_num_of_warmups: int = 0
    """Number of warmup runs for cudagraph.
    It means the first several runs will be treated as warmup runs.
    Only after that, the execution will be recorded, and the recorded
    cudagraph will be used for subsequent runs."""
    cudagraph_capture_sizes: list[int] | None = None
    """Sizes to capture cudagraph.
    - None (default): capture sizes are inferred from vllm config.
    - list[int]: capture sizes are specified as given."""
    cudagraph_copy_inputs: bool = False
    """Whether to copy input tensors for
    cudagraph. If the caller can guarantee that the same input buffers
    are always used, it can set this to False. Otherwise, it should
    set this to True, and the compiler will copy the input to an
    internally managed buffer. Default is False.
    Note that this flag is only effective when cudagraph_mode is PIECEWISE.
    """
    cudagraph_specialize_lora: bool = True
    """Whether to create separate cuda graphs for cases with and without active
    LoRA adapters. When set to False, the LoRA-enabled cuda graph will be used
    for all cases, incurring the overhead of running LoRA ops even when no
    adapters are active. Setting this to True will remove this overhead at the
    cost of increased startup time and slightly higher memory usage.
    When `enable_lora` is False, this option has no effect.
    """

    use_inductor_graph_partition: bool = False
    """Use inductor graph partition to split the graph at cudagraph_unsafe ops.
    This partition happens at inductor codegen time after all passes and fusions
    are finished. It generates a single `call` function which wraps
    cudagraph-safe ops into partition functions and leave cudagraph-unsafe ops
    outside the partition functions. For a graph with N cudagraph-unsafe ops
    (e.g., Attention), there would be N+1 partitions. To mark an op as
    cudagraph unsafe, we can add `tags=(torch._C.Tag.cudagraph_unsafe)` when
    register the custom op.

    This config supports both full cudagraph and piecewise cudagraph without
    compiling twice. For piecewise cudagraph, it applies vLLM CUDAGraph wrapper
    to each partition. For N+1 partitions, there would be N+1
    CUDAGraph wrapper instances.

    For full CUDAGraph, we always apply a single CUDAGraph wrapper outside the
    inductor `call` function in the model runner. The top-level full cudagraph
    capture ignores all partitioning.
    """

    pass_config: PassConfig = field(default_factory=PassConfig)
    """Custom inductor passes, see PassConfig for more details"""

    max_cudagraph_capture_size: int | None = field(default=None)
    """The maximum cudagraph capture size.

    If cudagraph_capture_sizes is specified, this will be set to the largest
    size in that list (or checked for consistency if specified). If
    cudagraph_capture_sizes is not specified, the list of sizes is generated
    automatically following the pattern:

        [1, 2, 4] + list(range(8, 256, 8)) + list(
        range(256, max_cudagraph_capture_size + 1, 16))

    If not specified, max_cudagraph_capture_size is set to min(max_num_seqs*2,
    512) by default. This voids OOM in tight memory scenarios with small
    max_num_seqs, and prevents capture of many large graphs (>512) that would
    greatly increase startup time with limited performance benefit.
    """

    dynamic_shapes_config: DynamicShapesConfig = field(
        default_factory=DynamicShapesConfig
    )
    """Configuration for dynamic shapes options"""

    local_cache_dir: str = field(default=None, init=False)  # type: ignore
    """local cache dir for each rank"""

    bs_to_padded_graph_size: list[int] = field(
        default=None,  # type: ignore
        init=False,
    )
    """optimization:
    Intuitively, bs_to_padded_graph_size should be dict[int, int].
    since we know all keys are in a range [0, max_cudagraph_capture_size],
    we can optimize it to list[int] for better lookup performance."""

    # keep track of enabled and disabled custom ops
    enabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False)
    """custom ops that are enabled"""
    disabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False)
    """custom ops that are disabled"""
    traced_files: set[str] = field(default_factory=set, init=False)
    """files that are traced for compilation"""
    compilation_time: float = field(default=0.0, init=False)
    """time taken for compilation"""

    static_forward_context: dict[str, Any] = field(default_factory=dict, init=False)
    """Per-model forward context
    Map from layer name to layer objects that need to be accessed outside
    model code, e.g., Attention, FusedMOE when dp_size>1."""

    # Attention ops; used for piecewise cudagraphs
    # Use PyTorch operator format: "namespace::name"
    _attention_ops: ClassVar[list[str]] = [
        "vllm::unified_attention",
        "vllm::unified_attention_with_output",
        "vllm::unified_mla_attention",
        "vllm::unified_mla_attention_with_output",
        "vllm::mamba_mixer2",
        "vllm::mamba_mixer",
        "vllm::short_conv",
        "vllm::linear_attention",
        "vllm::plamo2_mamba_mixer",
        "vllm::gdn_attention_core",
        "vllm::kda_attention",
        "vllm::sparse_attn_indexer",
    ]

    def compute_hash(self) -> str:
        """
        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        # Opt-out: default-include declared fields; keep a tiny exclude set;
        # normalize types; keep SHA-256. For nested opaque configs, include a
        # stable identifier (e.g., pass_config.compute_hash()) instead of object id.

        ignored_factors = {
            # Paths/dirs and runtime/metrics that don’t affect compiled graph
            "debug_dump_path",
            "cache_dir",
            "local_cache_dir",
            "bs_to_padded_graph_size",
            "traced_files",
            "compilation_time",
            "static_forward_context",
            "pass_config",  # handled separately below
        }

        from vllm.config.utils import get_hash_factors, hash_factors

        factors = get_hash_factors(self, ignored_factors)

        factors["pass_config"] = self.pass_config.compute_hash()
        return hash_factors(factors)

    def __repr__(self) -> str:
        exclude = {
            "static_forward_context": True,
            "enabled_custom_ops": True,
            "disabled_custom_ops": True,
            "compilation_time": True,
            "bs_to_padded_graph_size": True,
            "traced_files": True,
            "inductor_compile_config": {
                "post_grad_custom_post_pass": True,
            },
        }

        # exclude default attr in pass_config
        pass_config_exclude = {}
        for attr, default_val in vars(PassConfig()).items():
            if getattr(self.pass_config, attr) == default_val:
                pass_config_exclude[attr] = True
        if pass_config_exclude:
            exclude["pass_config"] = pass_config_exclude

        config = TypeAdapter(CompilationConfig).dump_python(
            self, exclude=exclude, exclude_unset=True
        )

        return str(config)

    __str__ = __repr__

    @field_validator("mode", mode="before")
    @classmethod
    def validate_mode_before(cls, value: Any) -> Any:
        """
        Enable parsing the `mode` field from string mode names.
        Accepts both integers (0-3) and string names, like NONE, STOCK_TORCH_COMPILE,
        DYNAMO_TRACE_ONCE, VLLM_COMPILE.
        """
        if isinstance(value, str):
            # Convert string mode name to integer value
            mode_name = value.upper()

            if mode_name not in CompilationMode.__members__:
                raise ValueError(
                    f"Invalid compilation mode: {value}. "
                    f"Valid modes are: {', '.join(CompilationMode.__members__.keys())}"
                )

            return CompilationMode[mode_name]
        return value

    @field_validator("cudagraph_mode", mode="before")
    @classmethod
    def validate_cudagraph_mode_before(cls, value: Any) -> Any:
        """Enable parsing of the `cudagraph_mode` enum type from string."""
        if isinstance(value, str):
            return CUDAGraphMode[value.upper()]
        return value

    @field_validator("pass_config", mode="before")
    @classmethod
    def validate_pass_config_before(cls, value: Any) -> Any:
        """Enable parsing of the `pass_config` field from a dictionary."""
        if isinstance(value, dict):
            return PassConfig(**value)
        return value

    @field_validator("compile_cache_save_format")
    @classmethod
    def validate_compile_cache_save_format(cls, value: str) -> str:
        if value not in ("binary", "unpacked"):
            raise ValueError(
                f"compile_cache_save_format must be 'binary' or 'unpacked', "
                f"got: {value}"
            )
        return value

    def __post_init__(self) -> None:
        if self.level is not None:
            logger.warning(
                "Level is deprecated and will be removed in the next release,"
                "either 0.12.0 or 0.11.2 whichever is soonest."
                "Use mode instead."
                "If both level and mode are given,"
                "only mode will be used."
            )
            if self.mode is None:
                self.mode = self.level

        count_none = self.custom_ops.count("none")
        count_all = self.custom_ops.count("all")
        assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"

        # TODO(zou3519/luka): There are 2 issues with auto-functionalization V2:
        # 1. A bug in PyTorch, fixed in 2.7:
        #    https://github.com/pytorch/pytorch/issues/147924
        # 2. Custom passes (fusion) rely on auto-functionalization V1 and don't
        #    work with V2. Addressing this will take extra engineering effort
        #    and it is not yet a priority. RFC here:
        #    https://github.com/vllm-project/vllm/issues/14703

        if is_torch_equal_or_newer("2.6"):
            KEY = "enable_auto_functionalized_v2"
            if KEY not in self.inductor_compile_config:
                self.inductor_compile_config[KEY] = False

        for k, v in self.inductor_passes.items():
            if not isinstance(v, str):
                assert callable(v), f"pass {k} should be callable or a qualified name"
                self.inductor_compile_config[k] = (
                    v if isinstance(v, InductorPass) else CallableInductorPass(v)
                )
                continue

            # resolve function from qualified name
            names = v.split(".")
            module = ".".join(names[:-1])
            func_name = names[-1]
            func = __import__(module).__dict__[func_name]
            self.inductor_compile_config[k] = (
                func if isinstance(func, InductorPass) else CallableInductorPass(func)
            )

        if self.pass_config.enable_qk_norm_rope_fusion:
            # TODO(zhuhaoran): support rope native forward match and remove this.
            # Linked issue: https://github.com/vllm-project/vllm/issues/28042
            self.custom_ops.append("+rotary_embedding")

        if (
            is_torch_equal_or_newer("2.9.0.dev")
            and "combo_kernels" not in self.inductor_compile_config
            and "benchmark_combo_kernel" not in self.inductor_compile_config
            # (fixme @boyuan) combo kernel does not support cpu yet.
            and not current_platform.is_cpu()
        ):
            # use horizontal fusion, which is useful for fusing qk-norm and
            # qk-rope when query and key have different shapes.
            self.inductor_compile_config["combo_kernels"] = True
            self.inductor_compile_config["benchmark_combo_kernel"] = True

        if self.use_inductor_graph_partition and not is_torch_equal_or_newer(
            "2.9.0.dev"
        ):
            raise ValueError(
                "use_inductor_graph_partition is only "
                "supported with torch>=2.9.0.dev. Set "
                "use_inductor_graph_partition=False instead."
            )

        for op in self.custom_ops:
            if op[0] not in {"+", "-"} and op not in {"all", "none"}:
                raise ValueError(
                    f"Invalid syntax '{op}' for custom op, "
                    "must be 'all', 'none', '+op' or '-op' "
                    "(where 'op' is the registered op name)"
                )

        # Currently only eager and inductor backend are supported.
        # for piecewise compilation. Custom backends are not suppported for
        # piecewise compilation. Update when more backends are supported.
        if self.mode == CompilationMode.VLLM_COMPILE and self.backend not in [
            "",
            "eager",
            "inductor",
        ]:
            raise ValueError(
                f"Invalid backend for piecewise compilation: {self.backend}"
            )

        if self.use_inductor is not None:
            logger.warning_once(
                "The 'use_inductor' flag is deprecated and will be "
                "removed in the next release (v0.12.0). "
                "Please use the 'backend' option instead.",
            )
            self.backend = "inductor" if self.use_inductor else "eager"

        if self.backend == "":
            self.backend = current_platform.simple_compile_backend

    def init_backend(self, vllm_config: "VllmConfig") -> str | Callable:
        """
        Initialize the backend for the compilation config from a vllm config.
        Arguments:
            vllm_config: The vllm config to initialize the backend from.
        Returns:
            The backend for the compilation config.
        """
        if self.mode is None:
            raise ValueError(
                "No compilation mode is set. This method should only be \
                called via vllm config where the level is set if none is \
                provided."
            )
        if self.mode == CompilationMode.NONE:
            raise ValueError("No compilation mode is set.")

        from torch._dynamo.backends.registry import list_backends

        torch_backends = list_backends(exclude_tags=tuple())
        if self.mode in [
            CompilationMode.STOCK_TORCH_COMPILE,
            CompilationMode.DYNAMO_TRACE_ONCE,
        ]:
            if self.backend in torch_backends:
                return self.backend
            return resolve_obj_by_qualname(self.backend)

        assert self.mode == CompilationMode.VLLM_COMPILE
        if self.backend not in ["eager", "inductor"]:
            raise ValueError(
                f"Invalid backend for piecewise compilation: {self.backend}"
            )

        from vllm.compilation.backends import VllmBackend

        # TODO[@lucaskabela]: See if we can forward prefix
        # https://github.com/vllm-project/vllm/issues/27045
        return VllmBackend(vllm_config)

    def post_init_cudagraph_sizes(self) -> None:
        """To complete the initialization after cudagraph related
        configs are set. This includes:
        - initialize compile_sizes
        - pre-compute the mapping bs_to_padded_graph_size
        """

        computed_compile_sizes = []
        if self.compile_sizes is not None:
            # de-duplicate the sizes provided by the config
            self.compile_sizes = list(set(self.compile_sizes))
            for x in self.compile_sizes:
                if isinstance(x, str):
                    assert x == "cudagraph_capture_sizes", (
                        "Unrecognized size type in compile_sizes, "
                        f"expect 'cudagraph_capture_sizes', got {x}"
                    )
                    computed_compile_sizes.extend(self.cudagraph_capture_sizes)
                else:
                    assert isinstance(x, int)
                    computed_compile_sizes.append(x)
        self.compile_sizes = computed_compile_sizes  # type: ignore

        # make sure the sizes are in ascending order
        self.cudagraph_capture_sizes.sort()
        if self.cudagraph_capture_sizes:
            assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size

        # May get recomputed in the model runner if adjustment is needed for spec-decode
        self.compute_bs_to_padded_graph_size()

    def set_splitting_ops_for_v1(self):
        # NOTE: this function needs to be called only when mode is
        # CompilationMode.VLLM_COMPILE
        assert self.mode == CompilationMode.VLLM_COMPILE, (
            "set_splitting_ops_for_v1 should only be called when "
            "mode is CompilationMode.VLLM_COMPILE"
        )

        if self.use_inductor_graph_partition:
            self.set_splitting_ops_for_inductor_graph_partition()
            return

        if self.pass_config.enable_attn_fusion:
            # here use_inductor_graph_partition is False
            self.set_splitting_ops_for_attn_fusion()
            return

        if self.splitting_ops is None:
            # NOTE: When using full cudagraph, instead of setting an empty
            # list and capture the full cudagraph inside the flattened fx
            # graph, we keep the piecewise fx graph structure but capture
            # the full cudagraph outside the fx graph. This reduces some
            # cpu overhead when the runtime batch_size is not cudagraph
            # captured. see https://github.com/vllm-project/vllm/pull/20059
            # for details. Make a copy to avoid mutating the class-level
            # list via reference.
            self.splitting_ops = list(self._attention_ops)
        elif len(self.splitting_ops) == 0:
            logger.warning_once("Using piecewise compilation with empty splitting_ops")
            if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
                logger.warning_once(
                    "Piecewise compilation with empty splitting_ops do not"
                    "contains piecewise cudagraph. Setting cudagraph_"
                    "mode to NONE. Hint: If you are using attention backends "
                    "that support cudagraph, consider manually setting "
                    "cudagraph_mode to FULL or FULL_DECODE_ONLY to enable "
                    "full cudagraphs."
                )
                self.cudagraph_mode = CUDAGraphMode.NONE
            elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
                logger.warning_once(
                    "Piecewise compilation with empty splitting_ops do not "
                    "contains piecewise cudagraph. Setting cudagraph_mode "
                    "to FULL."
                )
                self.cudagraph_mode = CUDAGraphMode.FULL
            self.splitting_ops = []

    def set_splitting_ops_for_inductor_graph_partition(self):
        assert self.use_inductor_graph_partition
        if self.splitting_ops is None:
            self.splitting_ops = list(self._attention_ops)

    def set_splitting_ops_for_attn_fusion(self):
        assert self.pass_config.enable_attn_fusion
        if self.splitting_ops is None:
            self.splitting_ops = []
            if self.cudagraph_mode.has_piecewise_cudagraphs():
                logger.warning_once(
                    "enable_attn_fusion is incompatible with piecewise "
                    "cudagraph when use_inductor_graph_partition is off. "
                    "In this case, splitting_ops will be set to empty "
                    "list, and cudagraph_mode will be set to FULL. "
                    "Please ensure you are using attention backends that "
                    "support cudagraph or set cudagraph_mode to NONE "
                    "explicitly if encountering any problems."
                )
                self.cudagraph_mode = CUDAGraphMode.FULL

        assert not self.splitting_ops_contain_attention(), (
            "attention ops should not be in splitting_ops "
            "when enable_attn_fusion is True"
        )

    def splitting_ops_contain_attention(self) -> bool:
        return self.splitting_ops is not None and all(
            op in self.splitting_ops for op in self._attention_ops
        )

    def is_attention_compiled_piecewise(self) -> bool:
        if not self.splitting_ops_contain_attention():
            return False

        if not self.use_inductor_graph_partition:
            # Dynamo-level FX split case
            return self.mode == CompilationMode.VLLM_COMPILE

        # Inductor partition case
        return self.backend == "inductor" and self.mode != CompilationMode.NONE

    def custom_op_log_check(self):
        """
        This method logs the enabled/disabled custom ops and checks that the
        passed custom_ops field only contains relevant ops.
        It is called at the end of set_current_vllm_config,
        after the custom ops have been instantiated.
        """

        if len(self.enabled_custom_ops) + len(self.disabled_custom_ops) == 0:
            logger.debug("No custom ops found in model.")
            return

        logger.debug("enabled custom ops: %s", self.enabled_custom_ops)
        logger.debug("disabled custom ops: %s", self.disabled_custom_ops)

        all_ops_in_model = self.enabled_custom_ops | self.disabled_custom_ops
        for op in self.custom_ops:
            if op in {"all", "none"}:
                continue

            assert op[0] in {"+", "-"}, (
                "Invalid custom op syntax (should be checked during init)"
            )

            # check if op name exists in model
            op_name = op[1:]
            if op_name not in all_ops_in_model:
                from vllm.model_executor.custom_op import CustomOp

                # Does op exist at all or is it just not present in this model?
                # Note: Only imported op classes appear in the registry.
                missing_str = (
                    "doesn't exist (or wasn't imported/registered)"
                    if op_name not in CustomOp.op_registry
                    else "not present in model"
                )

                enable_str = "enabling" if op[0] == "+" else "disabling"
                logger.warning_once(
                    "Op '%s' %s, %s with '%s' has no effect",
                    op_name,
                    missing_str,
                    enable_str,
                    op,
                )

    def adjust_cudagraph_sizes_for_spec_decode(
        self, uniform_decode_query_len: int, tensor_parallel_size: int
    ):
        multiple_of = uniform_decode_query_len
        if tensor_parallel_size > 1 and self.pass_config.enable_sequence_parallelism:
            multiple_of = max(uniform_decode_query_len, tensor_parallel_size)
            if (
                multiple_of % uniform_decode_query_len != 0
                or multiple_of % tensor_parallel_size != 0
            ):
                raise ValueError(
                    f"Can't determine cudagraph shapes that are both a "
                    f"multiple of {uniform_decode_query_len} "
                    f"(num_speculative_tokens + 1) required by spec-decode "
                    f"and {tensor_parallel_size} (tensor_parallel_size) "
                    f"required by sequence parallelism please adjust "
                    f"num_speculative_tokens or disable sequence parallelism"
                )

        if not self.cudagraph_capture_sizes or multiple_of <= 1:
            return

        assert self.max_cudagraph_capture_size is not None
        rounded_sizes = sorted(
            set(
                round_up(size, multiple_of)
                for size in self.cudagraph_capture_sizes
                if round_up(size, multiple_of) <= self.max_cudagraph_capture_size
            )
        )

        if len(rounded_sizes) == 0 and multiple_of <= self.max_cudagraph_capture_size:
            # if one valid but would be round_down use that
            rounded_sizes = [multiple_of]

        if len(rounded_sizes) == 0:
            raise ValueError(
                f"No valid cudagraph sizes after rounding to multiple of {multiple_of} "
                f"(num_speculative_tokens + 1 or tp if sequence parallelism is enabled)"
                f" please adjust num_speculative_tokens ({uniform_decode_query_len - 1}"
                f") or max_cudagraph_capture_size ({self.max_cudagraph_capture_size})"
                f" or cudagraph_capture_sizes ({self.cudagraph_capture_sizes})"
            )

        self.max_cudagraph_capture_size = rounded_sizes[-1]
        self.cudagraph_capture_sizes = rounded_sizes

        # Recompute after adjusting the cudagraph sizes
        self.compute_bs_to_padded_graph_size()

    def compute_bs_to_padded_graph_size(self):
        # pre-compute the mapping from batch size to padded graph size
        self.bs_to_padded_graph_size = [
            0 for i in range(self.max_cudagraph_capture_size + 1)
        ]
        for end, start in zip(
            self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1],
            [0] + self.cudagraph_capture_sizes,
        ):
            for bs in range(start, end):
                if bs == start:
                    self.bs_to_padded_graph_size[bs] = start
                else:
                    self.bs_to_padded_graph_size[bs] = end

__str__ class-attribute instance-attribute

__str__ = __repr__

_attention_ops class-attribute

_attention_ops: list[str] = [
    "vllm::unified_attention",
    "vllm::unified_attention_with_output",
    "vllm::unified_mla_attention",
    "vllm::unified_mla_attention_with_output",
    "vllm::mamba_mixer2",
    "vllm::mamba_mixer",
    "vllm::short_conv",
    "vllm::linear_attention",
    "vllm::plamo2_mamba_mixer",
    "vllm::gdn_attention_core",
    "vllm::kda_attention",
    "vllm::sparse_attn_indexer",
]

backend class-attribute instance-attribute

backend: str = ''

The backend for compilation. It needs to be a string:

  • "" (empty string): use the default backend ("inductor" on CUDA-alike platforms).
  • "eager"/"openxla"/...: use the specified backend registered in PyTorch.
  • "full.module.name": a qualified name which can be used to import the

backend function. We use string to avoid serialization issues when using compilation in a distributed setting. When the compilation mode is 1 or 2, the backend is used for the compilation directly (it sees the whole graph). When the compilation mode is 3, the backend is used for the piecewise compilation (it sees a part of the graph). The backend can not be custom for compilation mode 3, i.e. the backend must be either eager or inductor. Furthermore, compilation is only piecewise if splitting ops is set accordingly and use_inductor_graph_partition is off. Note that the default options for splitting ops are sufficient for piecewise compilation.

bs_to_padded_graph_size class-attribute instance-attribute

bs_to_padded_graph_size: list[int] = field(
    default=None, init=False
)

optimization: Intuitively, bs_to_padded_graph_size should be dict[int, int]. since we know all keys are in a range [0, max_cudagraph_capture_size], we can optimize it to list[int] for better lookup performance.

cache_dir class-attribute instance-attribute

cache_dir: str = ''

The directory to store the compiled graph, to accelerate Inductor compilation. By default, it will use model-related information to generate a cache directory.

compilation_time class-attribute instance-attribute

compilation_time: float = field(default=0.0, init=False)

time taken for compilation

compile_cache_save_format class-attribute instance-attribute

compile_cache_save_format: Literal["binary", "unpacked"] = (
    field(
        default_factory=lambda: VLLM_COMPILE_CACHE_SAVE_FORMAT
    )
)

Format for saving torch compile cache:

  • "binary": saves as binary file (multiprocess safe)

  • "unpacked": saves as directory structure for inspection/debugging (NOT multiprocess safe)

Defaults to VLLM_COMPILE_CACHE_SAVE_FORMAT if not specified.

compile_mm_encoder class-attribute instance-attribute

compile_mm_encoder: bool = False

Whether or not to compile the multimodal encoder. Currently, this only works for Qwen2_5_vl on selected platforms. Disabled by default until more models are supported/tested to work.

compile_sizes class-attribute instance-attribute

compile_sizes: list[int | str] | None = None

Sizes to compile for inductor. In addition to integers, it also supports "cudagraph_capture_sizes" to specify the sizes for cudagraph capture.

cudagraph_capture_sizes class-attribute instance-attribute

cudagraph_capture_sizes: list[int] | None = None

Sizes to capture cudagraph. - None (default): capture sizes are inferred from vllm config. - list[int]: capture sizes are specified as given.

cudagraph_copy_inputs class-attribute instance-attribute

cudagraph_copy_inputs: bool = False

Whether to copy input tensors for cudagraph. If the caller can guarantee that the same input buffers are always used, it can set this to False. Otherwise, it should set this to True, and the compiler will copy the input to an internally managed buffer. Default is False. Note that this flag is only effective when cudagraph_mode is PIECEWISE.

cudagraph_mode class-attribute instance-attribute

cudagraph_mode: CUDAGraphMode | None = None

The mode of the cudagraph:

  • NONE, no cudagraph capture.
  • PIECEWISE.
  • FULL.
  • FULL_DECODE_ONLY.
  • FULL_AND_PIECEWISE. (v1 default)

PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph incompatible ops (i.e. some attention ops) outside the cudagraph for general flexibility.

FULL mode: Capture full cudagraph for all batches. Can be good for small models or workloads with small prompts; not supported by many backends. Generally for performance FULL_AND_PIECEWISE is better.

FULL_DECODE_ONLY mode: Capture full cudagraph for decode batches only. Mixed prefill-decode batches are run without cudagraphs. Can be good for decode instances in a P/D setup where prefill is not as important so we can save some memory.

FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and piecewise cudagraph for prefill and mixed prefill-decode batches. This is the most performant mode for most models and is the default.

Currently, the cudagraph mode is only used for the v1 engine. Note that the cudagraph logic is generally orthogonal to the compilation logic. While piecewise cudagraphs require piecewise compilation (mode=VLLM_COMPILE and non-empty splitting_ops), full cudagraphs are supported with and without compilation.

Warning: This flag is new and subject to change in addition more modes may be added.

cudagraph_num_of_warmups class-attribute instance-attribute

cudagraph_num_of_warmups: int = 0

Number of warmup runs for cudagraph. It means the first several runs will be treated as warmup runs. Only after that, the execution will be recorded, and the recorded cudagraph will be used for subsequent runs.

cudagraph_specialize_lora class-attribute instance-attribute

cudagraph_specialize_lora: bool = True

Whether to create separate cuda graphs for cases with and without active LoRA adapters. When set to False, the LoRA-enabled cuda graph will be used for all cases, incurring the overhead of running LoRA ops even when no adapters are active. Setting this to True will remove this overhead at the cost of increased startup time and slightly higher memory usage. When enable_lora is False, this option has no effect.

custom_ops class-attribute instance-attribute

custom_ops: list[str] = field(default_factory=list)

Fine-grained control over which custom ops to enable/disable. Use 'all' to enable all, 'none' to disable all. Also specify a list of custom op names to enable (prefixed with a '+'), or disable (prefixed with a '-'). Examples:

  • 'all,-op1' to enable all except op1
  • 'none,+op1,+op2' to enable only op1 and op2

By default, all custom ops are enabled when running without Inductor and disabled when running with Inductor: mode>=VLLM_COMPILE and use_inductor=True. Inductor generates (fused) Triton kernels for disabled custom ops.

debug_dump_path class-attribute instance-attribute

debug_dump_path: Path | None = None

The path to dump the debug information.

disabled_custom_ops class-attribute instance-attribute

disabled_custom_ops: Counter[str] = field(
    default_factory=Counter, init=False
)

custom ops that are disabled

dynamic_shapes_config class-attribute instance-attribute

dynamic_shapes_config: DynamicShapesConfig = field(
    default_factory=DynamicShapesConfig
)

Configuration for dynamic shapes options

enabled_custom_ops class-attribute instance-attribute

enabled_custom_ops: Counter[str] = field(
    default_factory=Counter, init=False
)

custom ops that are enabled

inductor_compile_config class-attribute instance-attribute

inductor_compile_config: dict = field(default_factory=dict)

Additional configurations for inductor. - None: use default configurations.

inductor_passes class-attribute instance-attribute

inductor_passes: dict[str, str] = field(
    default_factory=dict
)

Additional passes for inductor. It is a dictionary from pass name to pass function qualified name. We use function name because the config uses JSON format. If we pass the config from Python, functions can also be passed directly via Python object constructor, e.g. CompilationConfig(inductor_passes={"a": func}).

level class-attribute instance-attribute

level: int | None = None

Level is deprecated and will be removed in the next release, either 0.12.0 or 0.11.2 whichever is soonest. Please use mode. Currently all levels are mapped to mode.

local_cache_dir class-attribute instance-attribute

local_cache_dir: str = field(default=None, init=False)

local cache dir for each rank

max_cudagraph_capture_size class-attribute instance-attribute

max_cudagraph_capture_size: int | None = field(default=None)

The maximum cudagraph capture size.

If cudagraph_capture_sizes is specified, this will be set to the largest size in that list (or checked for consistency if specified). If cudagraph_capture_sizes is not specified, the list of sizes is generated automatically following the pattern:

[1, 2, 4] + list(range(8, 256, 8)) + list(
range(256, max_cudagraph_capture_size + 1, 16))

If not specified, max_cudagraph_capture_size is set to min(max_num_seqs*2, 512) by default. This voids OOM in tight memory scenarios with small max_num_seqs, and prevents capture of many large graphs (>512) that would greatly increase startup time with limited performance benefit.

mode class-attribute instance-attribute

mode: CompilationMode | None = None

The compilation approach used for torch.compile-based compilation of the model.

  • None: If None, we will select the default compilation mode. For V1 engine this is 3.
  • 0: NONE: No torch.compile compilation is applied, model runs in fully eager pytorch mode. The model runs as-is.
  • 1: STOCK_TORCH_COMPILE: The standard torch.compile compilation pipeline.
  • 2: DYNAMO_TRACE_ONCE: Single Dynamo trace through the model, avoiding recompilation by removing guards. Requires no dynamic-shape-dependent control-flow.
  • 3: VLLM_COMPILE: Custom vLLM Inductor-based backend with caching, piecewise compilation, shape specialization, and custom passes.

pass_config class-attribute instance-attribute

pass_config: PassConfig = field(default_factory=PassConfig)

Custom inductor passes, see PassConfig for more details

splitting_ops class-attribute instance-attribute

splitting_ops: list[str] | None = None

A list of ops to exclude from cudagraphs, used in piecewise compilation.

The behavior depends on use_inductor_graph_partition:

  • When use_inductor_graph_partition=False (default): These ops are used for Dynamo FX-level graph splitting. The graph is split at these ops before Inductor compilation, creating separate subgraphs for cudagraph capture.

  • When use_inductor_graph_partition=True: These ops are used to register Inductor partition rules. The graph partitioning happens at Inductor codegen time after all passes and fusions are finished, allowing compilation and custom passes to operate on the full graph while still excluding these ops from cudagraphs.

If None, defaults to attention ops for piecewise cudagraphs. If empty list [], no ops are excluded (suitable for full cudagraphs).

static_forward_context class-attribute instance-attribute

static_forward_context: dict[str, Any] = field(
    default_factory=dict, init=False
)

Per-model forward context Map from layer name to layer objects that need to be accessed outside model code, e.g., Attention, FusedMOE when dp_size>1.

traced_files class-attribute instance-attribute

traced_files: set[str] = field(
    default_factory=set, init=False
)

files that are traced for compilation

use_inductor class-attribute instance-attribute

use_inductor: bool | None = None

Whether to use inductor compilation.

This flag is deprecated and will be removed in the next release 0.12.0. Please use the 'backend' option instead.

  • False: inductor compilation is not used. graph runs in eager (custom_ops enabled by default).
  • True: inductor compilation is used (custom_ops disabled by default). One graph for symbolic shape and one graph per size in compile_sizes are compiled using configurations in inductor_compile_config.

This setting is ignored if mode<VLLM_COMPILE.

For future compatibility: If use_inductor is True, backend="inductor" otherwise backend="eager".

use_inductor_graph_partition class-attribute instance-attribute

use_inductor_graph_partition: bool = False

Use inductor graph partition to split the graph at cudagraph_unsafe ops. This partition happens at inductor codegen time after all passes and fusions are finished. It generates a single call function which wraps cudagraph-safe ops into partition functions and leave cudagraph-unsafe ops outside the partition functions. For a graph with N cudagraph-unsafe ops (e.g., Attention), there would be N+1 partitions. To mark an op as cudagraph unsafe, we can add tags=(torch._C.Tag.cudagraph_unsafe) when register the custom op.

This config supports both full cudagraph and piecewise cudagraph without compiling twice. For piecewise cudagraph, it applies vLLM CUDAGraph wrapper to each partition. For N+1 partitions, there would be N+1 CUDAGraph wrapper instances.

For full CUDAGraph, we always apply a single CUDAGraph wrapper outside the inductor call function in the model runner. The top-level full cudagraph capture ignores all partitioning.

__post_init__

__post_init__() -> None
Source code in vllm/config/compilation.py
def __post_init__(self) -> None:
    if self.level is not None:
        logger.warning(
            "Level is deprecated and will be removed in the next release,"
            "either 0.12.0 or 0.11.2 whichever is soonest."
            "Use mode instead."
            "If both level and mode are given,"
            "only mode will be used."
        )
        if self.mode is None:
            self.mode = self.level

    count_none = self.custom_ops.count("none")
    count_all = self.custom_ops.count("all")
    assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"

    # TODO(zou3519/luka): There are 2 issues with auto-functionalization V2:
    # 1. A bug in PyTorch, fixed in 2.7:
    #    https://github.com/pytorch/pytorch/issues/147924
    # 2. Custom passes (fusion) rely on auto-functionalization V1 and don't
    #    work with V2. Addressing this will take extra engineering effort
    #    and it is not yet a priority. RFC here:
    #    https://github.com/vllm-project/vllm/issues/14703

    if is_torch_equal_or_newer("2.6"):
        KEY = "enable_auto_functionalized_v2"
        if KEY not in self.inductor_compile_config:
            self.inductor_compile_config[KEY] = False

    for k, v in self.inductor_passes.items():
        if not isinstance(v, str):
            assert callable(v), f"pass {k} should be callable or a qualified name"
            self.inductor_compile_config[k] = (
                v if isinstance(v, InductorPass) else CallableInductorPass(v)
            )
            continue

        # resolve function from qualified name
        names = v.split(".")
        module = ".".join(names[:-1])
        func_name = names[-1]
        func = __import__(module).__dict__[func_name]
        self.inductor_compile_config[k] = (
            func if isinstance(func, InductorPass) else CallableInductorPass(func)
        )

    if self.pass_config.enable_qk_norm_rope_fusion:
        # TODO(zhuhaoran): support rope native forward match and remove this.
        # Linked issue: https://github.com/vllm-project/vllm/issues/28042
        self.custom_ops.append("+rotary_embedding")

    if (
        is_torch_equal_or_newer("2.9.0.dev")
        and "combo_kernels" not in self.inductor_compile_config
        and "benchmark_combo_kernel" not in self.inductor_compile_config
        # (fixme @boyuan) combo kernel does not support cpu yet.
        and not current_platform.is_cpu()
    ):
        # use horizontal fusion, which is useful for fusing qk-norm and
        # qk-rope when query and key have different shapes.
        self.inductor_compile_config["combo_kernels"] = True
        self.inductor_compile_config["benchmark_combo_kernel"] = True

    if self.use_inductor_graph_partition and not is_torch_equal_or_newer(
        "2.9.0.dev"
    ):
        raise ValueError(
            "use_inductor_graph_partition is only "
            "supported with torch>=2.9.0.dev. Set "
            "use_inductor_graph_partition=False instead."
        )

    for op in self.custom_ops:
        if op[0] not in {"+", "-"} and op not in {"all", "none"}:
            raise ValueError(
                f"Invalid syntax '{op}' for custom op, "
                "must be 'all', 'none', '+op' or '-op' "
                "(where 'op' is the registered op name)"
            )

    # Currently only eager and inductor backend are supported.
    # for piecewise compilation. Custom backends are not suppported for
    # piecewise compilation. Update when more backends are supported.
    if self.mode == CompilationMode.VLLM_COMPILE and self.backend not in [
        "",
        "eager",
        "inductor",
    ]:
        raise ValueError(
            f"Invalid backend for piecewise compilation: {self.backend}"
        )

    if self.use_inductor is not None:
        logger.warning_once(
            "The 'use_inductor' flag is deprecated and will be "
            "removed in the next release (v0.12.0). "
            "Please use the 'backend' option instead.",
        )
        self.backend = "inductor" if self.use_inductor else "eager"

    if self.backend == "":
        self.backend = current_platform.simple_compile_backend

__repr__

__repr__() -> str
Source code in vllm/config/compilation.py
def __repr__(self) -> str:
    exclude = {
        "static_forward_context": True,
        "enabled_custom_ops": True,
        "disabled_custom_ops": True,
        "compilation_time": True,
        "bs_to_padded_graph_size": True,
        "traced_files": True,
        "inductor_compile_config": {
            "post_grad_custom_post_pass": True,
        },
    }

    # exclude default attr in pass_config
    pass_config_exclude = {}
    for attr, default_val in vars(PassConfig()).items():
        if getattr(self.pass_config, attr) == default_val:
            pass_config_exclude[attr] = True
    if pass_config_exclude:
        exclude["pass_config"] = pass_config_exclude

    config = TypeAdapter(CompilationConfig).dump_python(
        self, exclude=exclude, exclude_unset=True
    )

    return str(config)

adjust_cudagraph_sizes_for_spec_decode

adjust_cudagraph_sizes_for_spec_decode(
    uniform_decode_query_len: int, tensor_parallel_size: int
)
Source code in vllm/config/compilation.py
def adjust_cudagraph_sizes_for_spec_decode(
    self, uniform_decode_query_len: int, tensor_parallel_size: int
):
    multiple_of = uniform_decode_query_len
    if tensor_parallel_size > 1 and self.pass_config.enable_sequence_parallelism:
        multiple_of = max(uniform_decode_query_len, tensor_parallel_size)
        if (
            multiple_of % uniform_decode_query_len != 0
            or multiple_of % tensor_parallel_size != 0
        ):
            raise ValueError(
                f"Can't determine cudagraph shapes that are both a "
                f"multiple of {uniform_decode_query_len} "
                f"(num_speculative_tokens + 1) required by spec-decode "
                f"and {tensor_parallel_size} (tensor_parallel_size) "
                f"required by sequence parallelism please adjust "
                f"num_speculative_tokens or disable sequence parallelism"
            )

    if not self.cudagraph_capture_sizes or multiple_of <= 1:
        return

    assert self.max_cudagraph_capture_size is not None
    rounded_sizes = sorted(
        set(
            round_up(size, multiple_of)
            for size in self.cudagraph_capture_sizes
            if round_up(size, multiple_of) <= self.max_cudagraph_capture_size
        )
    )

    if len(rounded_sizes) == 0 and multiple_of <= self.max_cudagraph_capture_size:
        # if one valid but would be round_down use that
        rounded_sizes = [multiple_of]

    if len(rounded_sizes) == 0:
        raise ValueError(
            f"No valid cudagraph sizes after rounding to multiple of {multiple_of} "
            f"(num_speculative_tokens + 1 or tp if sequence parallelism is enabled)"
            f" please adjust num_speculative_tokens ({uniform_decode_query_len - 1}"
            f") or max_cudagraph_capture_size ({self.max_cudagraph_capture_size})"
            f" or cudagraph_capture_sizes ({self.cudagraph_capture_sizes})"
        )

    self.max_cudagraph_capture_size = rounded_sizes[-1]
    self.cudagraph_capture_sizes = rounded_sizes

    # Recompute after adjusting the cudagraph sizes
    self.compute_bs_to_padded_graph_size()

compute_bs_to_padded_graph_size

compute_bs_to_padded_graph_size()
Source code in vllm/config/compilation.py
def compute_bs_to_padded_graph_size(self):
    # pre-compute the mapping from batch size to padded graph size
    self.bs_to_padded_graph_size = [
        0 for i in range(self.max_cudagraph_capture_size + 1)
    ]
    for end, start in zip(
        self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1],
        [0] + self.cudagraph_capture_sizes,
    ):
        for bs in range(start, end):
            if bs == start:
                self.bs_to_padded_graph_size[bs] = start
            else:
                self.bs_to_padded_graph_size[bs] = end

compute_hash

compute_hash() -> str

Provide a hash that uniquely identifies all the configs that affect the structure of the computation graph from input ids/embeddings to the final hidden states, excluding anything before input ids/embeddings and after the final hidden states.

Source code in vllm/config/compilation.py
def compute_hash(self) -> str:
    """
    Provide a hash that uniquely identifies all the configs
    that affect the structure of the computation
    graph from input ids/embeddings to the final hidden states,
    excluding anything before input ids/embeddings and after
    the final hidden states.
    """
    # Opt-out: default-include declared fields; keep a tiny exclude set;
    # normalize types; keep SHA-256. For nested opaque configs, include a
    # stable identifier (e.g., pass_config.compute_hash()) instead of object id.

    ignored_factors = {
        # Paths/dirs and runtime/metrics that don’t affect compiled graph
        "debug_dump_path",
        "cache_dir",
        "local_cache_dir",
        "bs_to_padded_graph_size",
        "traced_files",
        "compilation_time",
        "static_forward_context",
        "pass_config",  # handled separately below
    }

    from vllm.config.utils import get_hash_factors, hash_factors

    factors = get_hash_factors(self, ignored_factors)

    factors["pass_config"] = self.pass_config.compute_hash()
    return hash_factors(factors)

custom_op_log_check

custom_op_log_check()

This method logs the enabled/disabled custom ops and checks that the passed custom_ops field only contains relevant ops. It is called at the end of set_current_vllm_config, after the custom ops have been instantiated.

Source code in vllm/config/compilation.py
def custom_op_log_check(self):
    """
    This method logs the enabled/disabled custom ops and checks that the
    passed custom_ops field only contains relevant ops.
    It is called at the end of set_current_vllm_config,
    after the custom ops have been instantiated.
    """

    if len(self.enabled_custom_ops) + len(self.disabled_custom_ops) == 0:
        logger.debug("No custom ops found in model.")
        return

    logger.debug("enabled custom ops: %s", self.enabled_custom_ops)
    logger.debug("disabled custom ops: %s", self.disabled_custom_ops)

    all_ops_in_model = self.enabled_custom_ops | self.disabled_custom_ops
    for op in self.custom_ops:
        if op in {"all", "none"}:
            continue

        assert op[0] in {"+", "-"}, (
            "Invalid custom op syntax (should be checked during init)"
        )

        # check if op name exists in model
        op_name = op[1:]
        if op_name not in all_ops_in_model:
            from vllm.model_executor.custom_op import CustomOp

            # Does op exist at all or is it just not present in this model?
            # Note: Only imported op classes appear in the registry.
            missing_str = (
                "doesn't exist (or wasn't imported/registered)"
                if op_name not in CustomOp.op_registry
                else "not present in model"
            )

            enable_str = "enabling" if op[0] == "+" else "disabling"
            logger.warning_once(
                "Op '%s' %s, %s with '%s' has no effect",
                op_name,
                missing_str,
                enable_str,
                op,
            )

init_backend

init_backend(vllm_config: VllmConfig) -> str | Callable

Initialize the backend for the compilation config from a vllm config. Arguments: vllm_config: The vllm config to initialize the backend from. Returns: The backend for the compilation config.

Source code in vllm/config/compilation.py
def init_backend(self, vllm_config: "VllmConfig") -> str | Callable:
    """
    Initialize the backend for the compilation config from a vllm config.
    Arguments:
        vllm_config: The vllm config to initialize the backend from.
    Returns:
        The backend for the compilation config.
    """
    if self.mode is None:
        raise ValueError(
            "No compilation mode is set. This method should only be \
            called via vllm config where the level is set if none is \
            provided."
        )
    if self.mode == CompilationMode.NONE:
        raise ValueError("No compilation mode is set.")

    from torch._dynamo.backends.registry import list_backends

    torch_backends = list_backends(exclude_tags=tuple())
    if self.mode in [
        CompilationMode.STOCK_TORCH_COMPILE,
        CompilationMode.DYNAMO_TRACE_ONCE,
    ]:
        if self.backend in torch_backends:
            return self.backend
        return resolve_obj_by_qualname(self.backend)

    assert self.mode == CompilationMode.VLLM_COMPILE
    if self.backend not in ["eager", "inductor"]:
        raise ValueError(
            f"Invalid backend for piecewise compilation: {self.backend}"
        )

    from vllm.compilation.backends import VllmBackend

    # TODO[@lucaskabela]: See if we can forward prefix
    # https://github.com/vllm-project/vllm/issues/27045
    return VllmBackend(vllm_config)

is_attention_compiled_piecewise

is_attention_compiled_piecewise() -> bool
Source code in vllm/config/compilation.py
def is_attention_compiled_piecewise(self) -> bool:
    if not self.splitting_ops_contain_attention():
        return False

    if not self.use_inductor_graph_partition:
        # Dynamo-level FX split case
        return self.mode == CompilationMode.VLLM_COMPILE

    # Inductor partition case
    return self.backend == "inductor" and self.mode != CompilationMode.NONE

post_init_cudagraph_sizes

post_init_cudagraph_sizes() -> None

To complete the initialization after cudagraph related configs are set. This includes: - initialize compile_sizes - pre-compute the mapping bs_to_padded_graph_size

Source code in vllm/config/compilation.py
def post_init_cudagraph_sizes(self) -> None:
    """To complete the initialization after cudagraph related
    configs are set. This includes:
    - initialize compile_sizes
    - pre-compute the mapping bs_to_padded_graph_size
    """

    computed_compile_sizes = []
    if self.compile_sizes is not None:
        # de-duplicate the sizes provided by the config
        self.compile_sizes = list(set(self.compile_sizes))
        for x in self.compile_sizes:
            if isinstance(x, str):
                assert x == "cudagraph_capture_sizes", (
                    "Unrecognized size type in compile_sizes, "
                    f"expect 'cudagraph_capture_sizes', got {x}"
                )
                computed_compile_sizes.extend(self.cudagraph_capture_sizes)
            else:
                assert isinstance(x, int)
                computed_compile_sizes.append(x)
    self.compile_sizes = computed_compile_sizes  # type: ignore

    # make sure the sizes are in ascending order
    self.cudagraph_capture_sizes.sort()
    if self.cudagraph_capture_sizes:
        assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size

    # May get recomputed in the model runner if adjustment is needed for spec-decode
    self.compute_bs_to_padded_graph_size()

set_splitting_ops_for_attn_fusion

set_splitting_ops_for_attn_fusion()
Source code in vllm/config/compilation.py
def set_splitting_ops_for_attn_fusion(self):
    assert self.pass_config.enable_attn_fusion
    if self.splitting_ops is None:
        self.splitting_ops = []
        if self.cudagraph_mode.has_piecewise_cudagraphs():
            logger.warning_once(
                "enable_attn_fusion is incompatible with piecewise "
                "cudagraph when use_inductor_graph_partition is off. "
                "In this case, splitting_ops will be set to empty "
                "list, and cudagraph_mode will be set to FULL. "
                "Please ensure you are using attention backends that "
                "support cudagraph or set cudagraph_mode to NONE "
                "explicitly if encountering any problems."
            )
            self.cudagraph_mode = CUDAGraphMode.FULL

    assert not self.splitting_ops_contain_attention(), (
        "attention ops should not be in splitting_ops "
        "when enable_attn_fusion is True"
    )

set_splitting_ops_for_inductor_graph_partition

set_splitting_ops_for_inductor_graph_partition()
Source code in vllm/config/compilation.py
def set_splitting_ops_for_inductor_graph_partition(self):
    assert self.use_inductor_graph_partition
    if self.splitting_ops is None:
        self.splitting_ops = list(self._attention_ops)

set_splitting_ops_for_v1

set_splitting_ops_for_v1()
Source code in vllm/config/compilation.py
def set_splitting_ops_for_v1(self):
    # NOTE: this function needs to be called only when mode is
    # CompilationMode.VLLM_COMPILE
    assert self.mode == CompilationMode.VLLM_COMPILE, (
        "set_splitting_ops_for_v1 should only be called when "
        "mode is CompilationMode.VLLM_COMPILE"
    )

    if self.use_inductor_graph_partition:
        self.set_splitting_ops_for_inductor_graph_partition()
        return

    if self.pass_config.enable_attn_fusion:
        # here use_inductor_graph_partition is False
        self.set_splitting_ops_for_attn_fusion()
        return

    if self.splitting_ops is None:
        # NOTE: When using full cudagraph, instead of setting an empty
        # list and capture the full cudagraph inside the flattened fx
        # graph, we keep the piecewise fx graph structure but capture
        # the full cudagraph outside the fx graph. This reduces some
        # cpu overhead when the runtime batch_size is not cudagraph
        # captured. see https://github.com/vllm-project/vllm/pull/20059
        # for details. Make a copy to avoid mutating the class-level
        # list via reference.
        self.splitting_ops = list(self._attention_ops)
    elif len(self.splitting_ops) == 0:
        logger.warning_once("Using piecewise compilation with empty splitting_ops")
        if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
            logger.warning_once(
                "Piecewise compilation with empty splitting_ops do not"
                "contains piecewise cudagraph. Setting cudagraph_"
                "mode to NONE. Hint: If you are using attention backends "
                "that support cudagraph, consider manually setting "
                "cudagraph_mode to FULL or FULL_DECODE_ONLY to enable "
                "full cudagraphs."
            )
            self.cudagraph_mode = CUDAGraphMode.NONE
        elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
            logger.warning_once(
                "Piecewise compilation with empty splitting_ops do not "
                "contains piecewise cudagraph. Setting cudagraph_mode "
                "to FULL."
            )
            self.cudagraph_mode = CUDAGraphMode.FULL
        self.splitting_ops = []

splitting_ops_contain_attention

splitting_ops_contain_attention() -> bool
Source code in vllm/config/compilation.py
def splitting_ops_contain_attention(self) -> bool:
    return self.splitting_ops is not None and all(
        op in self.splitting_ops for op in self._attention_ops
    )

validate_compile_cache_save_format classmethod

validate_compile_cache_save_format(value: str) -> str
Source code in vllm/config/compilation.py
@field_validator("compile_cache_save_format")
@classmethod
def validate_compile_cache_save_format(cls, value: str) -> str:
    if value not in ("binary", "unpacked"):
        raise ValueError(
            f"compile_cache_save_format must be 'binary' or 'unpacked', "
            f"got: {value}"
        )
    return value

validate_cudagraph_mode_before classmethod

validate_cudagraph_mode_before(value: Any) -> Any

Enable parsing of the cudagraph_mode enum type from string.

Source code in vllm/config/compilation.py
@field_validator("cudagraph_mode", mode="before")
@classmethod
def validate_cudagraph_mode_before(cls, value: Any) -> Any:
    """Enable parsing of the `cudagraph_mode` enum type from string."""
    if isinstance(value, str):
        return CUDAGraphMode[value.upper()]
    return value

validate_mode_before classmethod

validate_mode_before(value: Any) -> Any

Enable parsing the mode field from string mode names. Accepts both integers (0-3) and string names, like NONE, STOCK_TORCH_COMPILE, DYNAMO_TRACE_ONCE, VLLM_COMPILE.

Source code in vllm/config/compilation.py
@field_validator("mode", mode="before")
@classmethod
def validate_mode_before(cls, value: Any) -> Any:
    """
    Enable parsing the `mode` field from string mode names.
    Accepts both integers (0-3) and string names, like NONE, STOCK_TORCH_COMPILE,
    DYNAMO_TRACE_ONCE, VLLM_COMPILE.
    """
    if isinstance(value, str):
        # Convert string mode name to integer value
        mode_name = value.upper()

        if mode_name not in CompilationMode.__members__:
            raise ValueError(
                f"Invalid compilation mode: {value}. "
                f"Valid modes are: {', '.join(CompilationMode.__members__.keys())}"
            )

        return CompilationMode[mode_name]
    return value

validate_pass_config_before classmethod

validate_pass_config_before(value: Any) -> Any

Enable parsing of the pass_config field from a dictionary.

Source code in vllm/config/compilation.py
@field_validator("pass_config", mode="before")
@classmethod
def validate_pass_config_before(cls, value: Any) -> Any:
    """Enable parsing of the `pass_config` field from a dictionary."""
    if isinstance(value, dict):
        return PassConfig(**value)
    return value

CompilationMode

Bases: IntEnum

The compilation approach used for torch.compile-based compilation of the model.

Source code in vllm/config/compilation.py
class CompilationMode(enum.IntEnum):
    """The compilation approach used for torch.compile-based compilation of the
    model."""

    NONE = 0
    """No torch.compile compilation is applied, model runs in fully eager pytorch mode.
    The model runs as-is."""
    STOCK_TORCH_COMPILE = 1
    """The standard `torch.compile` compilation pipeline."""
    DYNAMO_TRACE_ONCE = 2
    """Single Dynamo trace through the model, avoiding recompilation."""
    VLLM_COMPILE = 3
    """Custom vLLM Inductor-based backend with caching, piecewise compilation,
    shape specialization, and custom passes."""

DYNAMO_TRACE_ONCE class-attribute instance-attribute

DYNAMO_TRACE_ONCE = 2

Single Dynamo trace through the model, avoiding recompilation.

NONE class-attribute instance-attribute

NONE = 0

No torch.compile compilation is applied, model runs in fully eager pytorch mode. The model runs as-is.

STOCK_TORCH_COMPILE class-attribute instance-attribute

STOCK_TORCH_COMPILE = 1

The standard torch.compile compilation pipeline.

VLLM_COMPILE class-attribute instance-attribute

VLLM_COMPILE = 3

Custom vLLM Inductor-based backend with caching, piecewise compilation, shape specialization, and custom passes.

DynamicShapesConfig

Configuration to control/debug torch compile dynamic shapes.

Source code in vllm/config/compilation.py
@config
@dataclass
class DynamicShapesConfig:
    """Configuration to control/debug torch compile dynamic shapes."""

    type: DynamicShapesType = DynamicShapesType.BACKED
    """Controls the type of dynamic shapes handling to use with torch.compile().

    - BACKED: Default PyTorch behavior with potential guards ignored.
    - UNBACKED: No guards guaranteed (most sound) but may throw
      data dependent errors.
    - BACKED_SIZE_OBLIVIOUS: Experimental safer alternative to
      backed/unbacked.
    """

    # TODO add a debug mode to fail

    def compute_hash(self) -> str:
        """
        Provide a hash for DynamicShapesConfig
        """

        from vllm.config.utils import get_hash_factors, hash_factors

        factors = get_hash_factors(self, {})
        return hash_factors(factors)

type class-attribute instance-attribute

Controls the type of dynamic shapes handling to use with torch.compile().

  • BACKED: Default PyTorch behavior with potential guards ignored.
  • UNBACKED: No guards guaranteed (most sound) but may throw data dependent errors.
  • BACKED_SIZE_OBLIVIOUS: Experimental safer alternative to backed/unbacked.

compute_hash

compute_hash() -> str

Provide a hash for DynamicShapesConfig

Source code in vllm/config/compilation.py
def compute_hash(self) -> str:
    """
    Provide a hash for DynamicShapesConfig
    """

    from vllm.config.utils import get_hash_factors, hash_factors

    factors = get_hash_factors(self, {})
    return hash_factors(factors)

DynamicShapesType

Bases: str, Enum

Types of dynamic shapes handling in torch.compile(). see Dynamic shapes and vllm guard dropping in torch_compile.md for more details.

Source code in vllm/config/compilation.py
class DynamicShapesType(str, enum.Enum):
    """Types of dynamic shapes handling in torch.compile().
    see  Dynamic shapes and vllm guard dropping in torch_compile.md
    for more details."""

    BACKED = "backed"
    """Use backed dynamic shapes. torch.compile() guards on backed dynamic
    shapes and may add guards. Symbols are specialized to 0, 1, or >=2 even
    without encountering branching on those ranges."""

    UNBACKED = "unbacked"
    """Use unbacked dynamic shapes. Guaranteed not to be guarded on and not
    0/1 specialized, but may throw data dependent errors when branches require
    their value without explicit unbacked handling."""

    BACKED_SIZE_OBLIVIOUS = "backed_size_oblivious"
    """Experimental flag that treats backed symbols as unbacked when explicit
    unbacked handling is defined."""

BACKED class-attribute instance-attribute

BACKED = 'backed'

Use backed dynamic shapes. torch.compile() guards on backed dynamic shapes and may add guards. Symbols are specialized to 0, 1, or >=2 even without encountering branching on those ranges.

BACKED_SIZE_OBLIVIOUS class-attribute instance-attribute

BACKED_SIZE_OBLIVIOUS = 'backed_size_oblivious'

Experimental flag that treats backed symbols as unbacked when explicit unbacked handling is defined.

UNBACKED class-attribute instance-attribute

UNBACKED = 'unbacked'

Use unbacked dynamic shapes. Guaranteed not to be guarded on and not 0/1 specialized, but may throw data dependent errors when branches require their value without explicit unbacked handling.

PassConfig

Configuration for custom Inductor passes.

This is separate from general CompilationConfig so that inductor passes don't all have access to full configuration - that would create a cycle as the PassManager is set as a property of config.

Source code in vllm/config/compilation.py
@config
@dataclass
class PassConfig:
    """Configuration for custom Inductor passes.

    This is separate from general `CompilationConfig` so that inductor passes
    don't all have access to full configuration - that would create a cycle as
    the `PassManager` is set as a property of config."""

    enable_fusion: bool = False
    """Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
    enable_attn_fusion: bool = False
    """Whether to enable the custom attention+quant fusion pass."""
    enable_noop: bool = False
    """Whether to enable the custom no-op elimination pass."""
    enable_sequence_parallelism: bool = False
    """Whether to enable sequence parallelism."""
    enable_async_tp: bool = False
    """Whether to enable async TP."""
    enable_fi_allreduce_fusion: bool = False
    """Whether to enable flashinfer allreduce fusion."""
    fi_allreduce_fusion_max_size_mb: float | None = None
    """The threshold of the communicated tensor sizes under which
    vllm should use flashinfer fused allreduce. Specified as a
    float in MB.
    Unspecified will fallback to default values
    which are compute capability and world size dependent.
        FI_ALLREDUCE_FUSION_MAX_SIZE_MB = {
            90: {
                2: 64,  # 64MB
                4: 2,  # 2MB
                8: 1,  # 1MB
            },
            100: {
                2: 64,  # 64MB
                4: 32,  # 32MB
                8: 1,  # 1MB
            },
        }, where key is the device capability"""
    enable_qk_norm_rope_fusion: bool = False
    """Whether to enable the fused Q/K RMSNorm + RoPE pass."""

    # TODO(luka) better pass enabling system.

    def flashinfer_max_size(self, world_size: int) -> int | None:
        """
        Returns the max communication size in bytes for flashinfer
        allreduce fusion for the given world size. Returns None if world size
        is not supported by configs as it's not supported by flashinfer.
        """

        MiB = 1024 * 1024
        max_size_mb = self.fi_allreduce_fusion_max_size_mb
        if max_size_mb is None:
            max_size_mb = self.default_fi_allreduce_fusion_max_size_mb().get(world_size)

        return int(max_size_mb * MiB) if max_size_mb is not None else None

    @staticmethod
    def default_fi_allreduce_fusion_max_size_mb() -> dict[int, float]:
        from vllm.compilation.collective_fusion import FI_ALLREDUCE_FUSION_MAX_SIZE_MB
        from vllm.platforms import current_platform

        if not current_platform.is_cuda():
            return {}
        return FI_ALLREDUCE_FUSION_MAX_SIZE_MB.get(
            current_platform.get_device_capability().to_int(), {}
        )

    def compute_hash(self) -> str:
        """
        Produces a hash unique to the pass configuration.
        Any new fields that affect compilation should be added to the hash.
        Any future fields that don't affect compilation should be excluded.
        """
        return InductorPass.hash_dict(asdict(self))

    def __post_init__(self) -> None:
        if not self.enable_noop:
            if self.enable_fusion:
                logger.warning_once(
                    "Fusion enabled but reshape elimination disabled. "
                    "RMSNorm/SiluMul + quant (fp8) fusion might not work"
                )
            if self.enable_attn_fusion:
                logger.warning_once(
                    "Fusion enabled but reshape elimination disabled. "
                    "Attention + quant (fp8) fusion might not work"
                )
            if self.enable_fi_allreduce_fusion:
                logger.warning_once(
                    "Fusion enabled but reshape elimination disabled. "
                    "Allreduce + rms norm + quant (fp8) fusion might not work"
                )
        if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda_alike():
            logger.warning_once(
                "QK Norm + RoPE fusion enabled but the current platform is not "
                "CUDA or ROCm. The fusion will be disabled."
            )
            self.enable_qk_norm_rope_fusion = False

enable_async_tp class-attribute instance-attribute

enable_async_tp: bool = False

Whether to enable async TP.

enable_attn_fusion class-attribute instance-attribute

enable_attn_fusion: bool = False

Whether to enable the custom attention+quant fusion pass.

enable_fi_allreduce_fusion class-attribute instance-attribute

enable_fi_allreduce_fusion: bool = False

Whether to enable flashinfer allreduce fusion.

enable_fusion class-attribute instance-attribute

enable_fusion: bool = False

Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass.

enable_noop class-attribute instance-attribute

enable_noop: bool = False

Whether to enable the custom no-op elimination pass.

enable_qk_norm_rope_fusion class-attribute instance-attribute

enable_qk_norm_rope_fusion: bool = False

Whether to enable the fused Q/K RMSNorm + RoPE pass.

enable_sequence_parallelism class-attribute instance-attribute

enable_sequence_parallelism: bool = False

Whether to enable sequence parallelism.

fi_allreduce_fusion_max_size_mb class-attribute instance-attribute

fi_allreduce_fusion_max_size_mb: float | None = None

The threshold of the communicated tensor sizes under which vllm should use flashinfer fused allreduce. Specified as a float in MB. Unspecified will fallback to default values which are compute capability and world size dependent. FI_ALLREDUCE_FUSION_MAX_SIZE_MB = { 90: { 2: 64, # 64MB 4: 2, # 2MB 8: 1, # 1MB }, 100: { 2: 64, # 64MB 4: 32, # 32MB 8: 1, # 1MB }, }, where key is the device capability

__post_init__

__post_init__() -> None
Source code in vllm/config/compilation.py
def __post_init__(self) -> None:
    if not self.enable_noop:
        if self.enable_fusion:
            logger.warning_once(
                "Fusion enabled but reshape elimination disabled. "
                "RMSNorm/SiluMul + quant (fp8) fusion might not work"
            )
        if self.enable_attn_fusion:
            logger.warning_once(
                "Fusion enabled but reshape elimination disabled. "
                "Attention + quant (fp8) fusion might not work"
            )
        if self.enable_fi_allreduce_fusion:
            logger.warning_once(
                "Fusion enabled but reshape elimination disabled. "
                "Allreduce + rms norm + quant (fp8) fusion might not work"
            )
    if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda_alike():
        logger.warning_once(
            "QK Norm + RoPE fusion enabled but the current platform is not "
            "CUDA or ROCm. The fusion will be disabled."
        )
        self.enable_qk_norm_rope_fusion = False

compute_hash

compute_hash() -> str

Produces a hash unique to the pass configuration. Any new fields that affect compilation should be added to the hash. Any future fields that don't affect compilation should be excluded.

Source code in vllm/config/compilation.py
def compute_hash(self) -> str:
    """
    Produces a hash unique to the pass configuration.
    Any new fields that affect compilation should be added to the hash.
    Any future fields that don't affect compilation should be excluded.
    """
    return InductorPass.hash_dict(asdict(self))

default_fi_allreduce_fusion_max_size_mb staticmethod

default_fi_allreduce_fusion_max_size_mb() -> dict[
    int, float
]
Source code in vllm/config/compilation.py
@staticmethod
def default_fi_allreduce_fusion_max_size_mb() -> dict[int, float]:
    from vllm.compilation.collective_fusion import FI_ALLREDUCE_FUSION_MAX_SIZE_MB
    from vllm.platforms import current_platform

    if not current_platform.is_cuda():
        return {}
    return FI_ALLREDUCE_FUSION_MAX_SIZE_MB.get(
        current_platform.get_device_capability().to_int(), {}
    )

flashinfer_max_size

flashinfer_max_size(world_size: int) -> int | None

Returns the max communication size in bytes for flashinfer allreduce fusion for the given world size. Returns None if world size is not supported by configs as it's not supported by flashinfer.

Source code in vllm/config/compilation.py
def flashinfer_max_size(self, world_size: int) -> int | None:
    """
    Returns the max communication size in bytes for flashinfer
    allreduce fusion for the given world size. Returns None if world size
    is not supported by configs as it's not supported by flashinfer.
    """

    MiB = 1024 * 1024
    max_size_mb = self.fi_allreduce_fusion_max_size_mb
    if max_size_mb is None:
        max_size_mb = self.default_fi_allreduce_fusion_max_size_mb().get(world_size)

    return int(max_size_mb * MiB) if max_size_mb is not None else None