Skip to content

vllm.model_executor.models.gemma4

Gemma 4 model implementation for vLLM.

Gemma4CrossDecoderLayers

Bases: Module

Cross-decoder layers (YOCO second half, KV-shared).

Source code in vllm/model_executor/models/gemma4.py
@support_torch_compile(
    enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill
)
class Gemma4CrossDecoderLayers(nn.Module):
    """Cross-decoder layers (YOCO second half, KV-shared)."""

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        decoder_layers: list[Gemma4DecoderLayer],
        layer_idx_start: int,
    ):
        super().__init__()
        self.decoder_layers = decoder_layers
        self.layer_idx_start = layer_idx_start

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        per_layer_inputs: torch.Tensor | None = None,
        **kwargs,
    ) -> torch.Tensor:
        return _run_decoder_layers(
            self.decoder_layers,
            self.layer_idx_start,
            positions,
            hidden_states,
            per_layer_inputs,
            **kwargs,
        )

Gemma4MoE

Bases: Module

Mixture of Experts for Gemma4 using vLLM's FusedMoE.

Wraps FusedMoE with custom routing. The router projection is external (Gemma4Router) — this class only handles expert dispatch.

Gemma4 routing: softmax over ALL experts → top-k → renormalize. per_expert_scale is folded into routing weights for mathematical correctness with FusedMoE's fused kernel.

Source code in vllm/model_executor/models/gemma4.py
class Gemma4MoE(nn.Module):
    """Mixture of Experts for Gemma4 using vLLM's FusedMoE.

    Wraps FusedMoE with custom routing. The router projection is
    external (Gemma4Router) — this class only handles expert dispatch.

    Gemma4 routing: softmax over ALL experts → top-k → renormalize.
    per_expert_scale is folded into routing weights for mathematical
    correctness with FusedMoE's fused kernel.
    """

    def __init__(
        self,
        config,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_experts = config.num_experts

        # Per-expert output scale folded into routing weights so that
        # FusedMoE's fused kernel computes: Σ_e (expert_e * w_e * scale_e)
        self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts))

        # Gemma4 routing: softmax over ALL experts → top-k → renormalize.
        # FusedMoE's built-in fused_topk scopes softmax differently, so
        # a custom routing function is needed for numerical correctness.
        per_expert_scale = self.per_expert_scale

        def routing_function(
            hidden_states: torch.Tensor,
            gating_output: torch.Tensor,
            topk: int,
            renormalize: bool,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            _, topk_ids = torch.topk(gating_output, k=topk, dim=-1)
            router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1)
            indicator = torch.nn.functional.one_hot(
                topk_ids, num_classes=gating_output.size(-1)
            ).sum(dim=-2)
            gate_weights = indicator * router_probabilities
            renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True)
            renorm_factor = torch.where(renorm_factor > 0.0, renorm_factor, 1.0)
            dispatch_weights = gate_weights / renorm_factor

            topk_weights = dispatch_weights.gather(1, topk_ids)

            # Fold per_expert_scale into routing weights
            expert_scales = per_expert_scale[topk_ids].to(topk_weights.dtype)
            topk_weights = topk_weights * expert_scales
            return topk_weights.to(torch.float32), topk_ids.to(torch.int32)

        # FusedMoE experts with custom Gemma4 routing
        self.experts = FusedMoE(
            num_experts=config.num_experts,
            top_k=config.top_k_experts,
            hidden_size=config.hidden_size,
            intermediate_size=getattr(
                config,
                "moe_intermediate_size",
                getattr(config, "expert_intermediate_size", None),
            ),
            reduce_results=True,
            renormalize=True,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            custom_routing_function=routing_function,
            activation="gelu",
        )

    def forward(self, x: torch.Tensor, router_logits: torch.Tensor) -> torch.Tensor:
        return self.experts(x, router_logits)

Gemma4Model

Bases: Module

Source code in vllm/model_executor/models/gemma4.py
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
@support_torch_compile(
    enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill
)
class Gemma4Model(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = _get_text_config(vllm_config.model_config.hf_config)
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config

        # PLE config values (default to 0 if not present — disables PLE)
        self.hidden_size_per_layer_input = getattr(
            config, "hidden_size_per_layer_input", 0
        )
        self.vocab_size_per_layer_input = getattr(
            config, "vocab_size_per_layer_input", config.vocab_size
        )

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.embed_tokens",
        )

        # Per-Layer Embedding (PLE) components
        if (
            self.hidden_size_per_layer_input is not None
            and self.hidden_size_per_layer_input > 0
        ):
            total_ple_dim = self.hidden_size_per_layer_input * config.num_hidden_layers
            self.embed_tokens_per_layer = VocabParallelEmbedding(
                self.vocab_size_per_layer_input,
                total_ple_dim,
                quant_config=quant_config,
                prefix=f"{prefix}.embed_tokens_per_layer",
            )
            # Scaled embedding factor (from config, not hardcoded)
            # Register as buffer so it moves to GPU with the model
            # and interacts correctly with torch.compile AOT caching.
            self.register_buffer(
                "embed_scale_per_layer",
                torch.tensor(self.hidden_size_per_layer_input**0.5),
                persistent=False,
            )
            # Projection: hidden_size → total_ple_dim
            # ColumnParallelLinear with gather_output=True
            self.per_layer_model_projection = ColumnParallelLinear(
                config.hidden_size,
                total_ple_dim,
                bias=False,
                gather_output=True,
                return_bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.per_layer_model_projection",
            )
            # PLE projection norm: output = norm(x) * weight
            self.per_layer_projection_norm = RMSNorm(
                self.hidden_size_per_layer_input,
                eps=config.rms_norm_eps,
            )
            # Scale factor for combining projection + per_layer_inputs
            # Register as buffer so it moves to GPU with the model
            # and interacts correctly with torch.compile AOT caching.
            self.register_buffer(
                "per_layer_input_scale",
                torch.rsqrt(torch.tensor(2.0)),
                persistent=False,
            )
            # Scaled projection: multiply output by hidden_size**-0.5.
            # Register as buffer for GPU placement and torch.compile.
            self.register_buffer(
                "per_layer_projection_scale",
                torch.tensor(config.hidden_size**-0.5),
                persistent=False,
            )
        else:
            self.embed_tokens_per_layer = None
            self.embed_scale_per_layer = None
            self.per_layer_model_projection = None
            self.per_layer_projection_norm = None
            self.per_layer_input_scale = None
            self.per_layer_projection_scale = None

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: Gemma4DecoderLayer(
                config,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
            ),
            prefix=f"{prefix}.layers",
        )
        # Final norm: output = norm(x) * weight
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        # Embedding scale = sqrt(hidden_size)
        # Downcast to model dtype (bfloat16 etc.) for numerical parity
        self.register_buffer(
            "normalizer",
            torch.tensor(config.hidden_size**0.5),
            persistent=False,
        )

        # --- You Only Cache Once (YOCO) split for fast prefill ---
        first_kv_shared_layer_idx = config.num_hidden_layers - getattr(
            config, "num_kv_shared_layers", 0
        )

        from vllm.compilation.backends import set_model_tag

        # Layers 0..(K-1) are self-decoder layers in YOCO
        with set_model_tag("self_decoder"):
            self.self_decoder = Gemma4SelfDecoderLayers(
                vllm_config=vllm_config,
                prefix=f"{prefix}.self_decoder",
                decoder_layers=self.layers[:first_kv_shared_layer_idx],
                layer_idx_start=0,
                embed_tokens=self.embed_tokens,
                normalizer=self.normalizer,
                embed_tokens_per_layer=getattr(self, "embed_tokens_per_layer", None),
                embed_scale_per_layer=getattr(self, "embed_scale_per_layer", None),
                per_layer_model_projection=getattr(
                    self, "per_layer_model_projection", None
                ),
                per_layer_projection_norm=getattr(
                    self, "per_layer_projection_norm", None
                ),
                per_layer_input_scale=getattr(self, "per_layer_input_scale", None),
                per_layer_projection_scale=getattr(
                    self, "per_layer_projection_scale", None
                ),
            )
        # Layers K..(N-1) are cross-decoder layers in YOCO
        with set_model_tag("cross_decoder"):
            self.cross_decoder = Gemma4CrossDecoderLayers(
                vllm_config=vllm_config,
                prefix=f"{prefix}.cross_decoder",
                decoder_layers=self.layers[first_kv_shared_layer_idx:],
                layer_idx_start=first_kv_shared_layer_idx,
            )

        self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill

        if self.fast_prefill_enabled:
            # Allocate static buffers for CUDAGraph
            max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
            device = next(self.parameters()).device
            self.positions = torch.zeros(
                max_num_tokens, dtype=torch.int64, device=device
            )
            self.hidden_states = torch.zeros(
                (max_num_tokens, config.hidden_size),
                dtype=self.embed_tokens.weight.dtype,
                device=device,
            )
            if (
                self.hidden_size_per_layer_input
                and self.hidden_size_per_layer_input > 0
            ):
                self.per_layer_inputs = torch.zeros(
                    (
                        max_num_tokens,
                        config.num_hidden_layers,
                        self.hidden_size_per_layer_input,
                    ),
                    dtype=self.embed_tokens.weight.dtype,
                    device=device,
                )
            else:
                self.per_layer_inputs = None

        # Custom factory that includes per_layer_inputs for PLE-enabled PP.
        # per_layer_inputs has shape (batch, num_layers, per_layer_dim),
        # which differs from the standard (batch, hidden_size) shape,
        # so we can't use the default factory.
        ple_dim = self.hidden_size_per_layer_input
        num_layers = config.num_hidden_layers
        hidden_size = config.hidden_size

        def _make_empty_intermediate_tensors(
            batch_size: int,
            dtype: torch.dtype,
            device: torch.device,
        ) -> IntermediateTensors:
            tensors: dict[str, torch.Tensor] = {
                "hidden_states": torch.zeros(
                    (batch_size, hidden_size),
                    dtype=dtype,
                    device=device,
                ),
                "residual": torch.zeros(
                    (batch_size, hidden_size),
                    dtype=dtype,
                    device=device,
                ),
            }
            if ple_dim and ple_dim > 0:
                tensors["per_layer_inputs"] = torch.zeros(
                    (batch_size, num_layers, ple_dim),
                    dtype=dtype,
                    device=device,
                )
            return IntermediateTensors(tensors)

        self.make_empty_intermediate_tensors = _make_empty_intermediate_tensors

    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.self_decoder.embed_input_ids(input_ids)

    def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor | None:
        """Get per-layer embeddings from embed_tokens_per_layer.

        Returns:
            Per-layer embeddings (num_tokens, num_layers,
            hidden_size_per_layer_input)
        """
        return self.self_decoder.get_per_layer_inputs(input_ids)

    def project_per_layer_inputs(
        self,
        inputs_embeds: torch.Tensor,
        per_layer_inputs: torch.Tensor | None,
    ) -> torch.Tensor | None:
        """Project inputs_embeds and combine with per_layer_inputs.

        Steps:
        1. Project inputs_embeds: hidden_size → total_ple_dim
        2. Scale by hidden_size^{-0.5}
        3. Reshape to (num_tokens, num_layers, per_layer_dim)
        4. Normalize with per_layer_projection_norm
        5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
        """
        return self.self_decoder.project_per_layer_inputs(
            inputs_embeds, per_layer_inputs
        )

    def fast_prefill_forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        inputs_embeds: torch.Tensor | None = None,
        per_layer_inputs: torch.Tensor | None = None,
        **kwargs,
    ) -> torch.Tensor:
        logits_indices_padded, num_logits_indices = None, None
        attn_metadata = get_forward_context().attn_metadata

        if attn_metadata is not None:
            assert isinstance(attn_metadata, dict)
            layer_attn_metadata = attn_metadata[
                self.layers[-1].self_attn.attn.layer_name
            ]
            if isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata):
                logits_indices_padded = layer_attn_metadata.logits_indices_padded
                num_logits_indices = layer_attn_metadata.num_logits_indices

        batch_size = positions.size(0)
        self.positions[:batch_size].copy_(positions)
        self_decoder_hidden_states, per_layer_inputs = self.self_decoder(
            input_ids=input_ids,
            positions=self.positions[:batch_size],
            inputs_embeds=inputs_embeds,
            per_layer_inputs=per_layer_inputs,
            **kwargs,
        )

        if logits_indices_padded is None:
            logits_indices_padded = torch.arange(
                batch_size,
                dtype=positions.dtype,
                device=positions.device,
            )

        # NOTE: Keep .clone() until fix in
        # https://github.com/vllm-project/vllm/pull/22282
        hidden_states = self_decoder_hidden_states.clone()

        num_padded = logits_indices_padded.size(0)
        self.positions[:num_padded].copy_(positions[logits_indices_padded])
        self.hidden_states[:num_padded].copy_(
            self_decoder_hidden_states[logits_indices_padded]
        )
        if self.per_layer_inputs is not None and per_layer_inputs is not None:
            self.per_layer_inputs[:num_padded].copy_(
                per_layer_inputs[logits_indices_padded]
            )

        # Update batch_descriptor so the cross-decoder's piecewise
        # CUDAGraphWrapper dispatches to the correct (reduced) batch size.
        forward_context = get_forward_context()
        orig_batch_desc = forward_context.batch_descriptor
        if orig_batch_desc is not None:
            forward_context.batch_descriptor = replace(
                orig_batch_desc, num_tokens=num_padded
            )

        cross_per_layer = (
            self.per_layer_inputs[:num_padded]
            if self.per_layer_inputs is not None
            else None
        )
        cross_hidden_states = self.cross_decoder(
            self.positions[:num_padded],
            self.hidden_states[:num_padded],
            cross_per_layer,
            **kwargs,
        )

        # Restore the original batch_descriptor
        forward_context.batch_descriptor = orig_batch_desc

        if num_logits_indices is not None:
            assert num_logits_indices > 0
            hidden_states[logits_indices_padded[:num_logits_indices]] = (
                cross_hidden_states[:num_logits_indices]
            )
        else:
            hidden_states = cross_hidden_states

        return hidden_states

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
        per_layer_inputs: torch.Tensor | None = None,
        **kwargs,
    ) -> torch.Tensor | IntermediateTensors:
        if self.fast_prefill_enabled:
            hidden_states = self.fast_prefill_forward(
                input_ids,
                positions,
                inputs_embeds,
                per_layer_inputs,
                **kwargs,
            )
            hidden_states = self.norm(hidden_states)
            return hidden_states

        # Normal (non-fast-prefill) path with PP support
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
                # When called from the multimodal wrapper, raw PLE
                # embeddings are pre-computed and passed explicitly.
                # Project them through per_layer_model_projection.
                per_layer_inputs = self.project_per_layer_inputs(
                    hidden_states, per_layer_inputs
                )
            else:
                hidden_states = self.embed_input_ids(input_ids)
                # Compute per-layer inputs for PLE
                per_layer_embeds = self.get_per_layer_inputs(input_ids)
                per_layer_inputs = self.project_per_layer_inputs(
                    hidden_states, per_layer_embeds
                )
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
            per_layer_inputs = intermediate_tensors.get("per_layer_inputs")

        for layer_idx, layer in enumerate(
            islice(self.layers, self.start_layer, self.end_layer)
        ):
            # Extract the per-layer embedding for this specific layer
            if per_layer_inputs is not None:
                actual_layer_idx = self.start_layer + layer_idx
                layer_per_input = per_layer_inputs[
                    :, actual_layer_idx, :
                ]  # (num_tokens, per_layer_dim)
            else:
                layer_per_input = None
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
                per_layer_input=layer_per_input,
                **kwargs,
            )
        if not get_pp_group().is_last_rank:
            return IntermediateTensors(
                {
                    "hidden_states": hidden_states,
                    "residual": residual,
                    "per_layer_inputs": per_layer_inputs,
                }
            )
        # Gemma4 incorporates residual into hidden_states directly
        # Apply norm without residual fusion when possible.
        if residual is None:
            hidden_states = self.norm(hidden_states)
        else:
            hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        # MoE expert weight mapping: checkpoint 3D packed tensors are
        # exploded in _weight_iterator to per-expert 2D weights like:
        #   moe.experts.{id}.gate_proj → FusedMoE w1 (shard of w13)
        #   moe.experts.{id}.up_proj   → FusedMoE w3 (shard of w13)
        #   moe.experts.{id}.down_proj → FusedMoE w2
        # We build the mapping directly since Gemma4 uses bare param
        # names (no .weight suffix) unlike standard MoE checkpoints.
        num_experts = getattr(self.config, "num_experts", None) or 0
        expert_params_mapping = [
            # (param_name, weight_name, expert_id, shard_id)
            (
                "experts.w13_weight"
                if proj_name in ["gate_proj", "up_proj"]
                else "experts.w2_weight",
                f"experts.{expert_id}.{proj_name}",
                expert_id,
                shard_id,
            )
            for expert_id in range(num_experts)
            for shard_id, proj_name in [
                ("w1", "gate_proj"),
                ("w2", "down_proj"),
                ("w3", "up_proj"),
            ]
        ]
        params_dict = dict(self.named_parameters())
        # Include buffers (e.g. layer_scalar) so they can be loaded too
        params_dict.update(dict(self.named_buffers()))
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = loaded_weight[0]
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

            if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
                remapped_name = maybe_remap_kv_scale_name(name, params_dict)
                if remapped_name is not None and remapped_name in params_dict:
                    param = params_dict[remapped_name]
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
                    weight_loader(param, loaded_weight)
                    loaded_params.add(remapped_name)
                    continue

            for param_name, shard_name, shard_id in stacked_params_mapping:
                if shard_name not in name:
                    continue
                stacked_name = name.replace(shard_name, param_name)
                # k_eq_v layers use separate q_proj/k_proj instead of
                # packed qkv_proj. If the stacked param doesn't exist,
                # skip this mapping and fall through to direct load.
                if stacked_name not in params_dict:
                    continue
                if is_pp_missing_parameter(stacked_name, self):
                    continue
                param = params_dict[stacked_name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                loaded_params.add(stacked_name)
                break
            else:
                for (
                    param_name,
                    weight_name,
                    expert_id,
                    shard_id,
                ) in expert_params_mapping:
                    if weight_name not in name:
                        continue
                    moe_name = name.replace(weight_name, param_name)
                    if moe_name not in params_dict:
                        continue
                    if is_pp_missing_parameter(moe_name, self):
                        continue
                    param = params_dict[moe_name]
                    # Expert weights are already in the correct
                    # orientation for FusedMoE after _weight_iterator:
                    #   gate/up: [I, H] → w1/w3 expects [I, H]
                    #   down:    [H, I] → w2 expects [H, I]
                    assert loaded_weight.dim() == 2, (
                        f"Expected 2D expert weight for {weight_name}, "
                        f"got shape {loaded_weight.shape}"
                    )
                    weight_loader = param.weight_loader
                    weight_loader(
                        param,
                        loaded_weight,
                        weight_name + ".weight",
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
                    loaded_params.add(moe_name)
                    break
                else:
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
                    if is_pp_missing_parameter(name, self):
                        continue
                    param = params_dict[name]
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)

        return loaded_params

get_per_layer_inputs

get_per_layer_inputs(input_ids: Tensor) -> Tensor | None

Get per-layer embeddings from embed_tokens_per_layer.

Returns:

Type Description
Tensor | None

Per-layer embeddings (num_tokens, num_layers,

Tensor | None

hidden_size_per_layer_input)

Source code in vllm/model_executor/models/gemma4.py
def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor | None:
    """Get per-layer embeddings from embed_tokens_per_layer.

    Returns:
        Per-layer embeddings (num_tokens, num_layers,
        hidden_size_per_layer_input)
    """
    return self.self_decoder.get_per_layer_inputs(input_ids)

project_per_layer_inputs

project_per_layer_inputs(
    inputs_embeds: Tensor, per_layer_inputs: Tensor | None
) -> Tensor | None

Project inputs_embeds and combine with per_layer_inputs.

Steps: 1. Project inputs_embeds: hidden_size → total_ple_dim 2. Scale by hidden_size^{-0.5} 3. Reshape to (num_tokens, num_layers, per_layer_dim) 4. Normalize with per_layer_projection_norm 5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)

Source code in vllm/model_executor/models/gemma4.py
def project_per_layer_inputs(
    self,
    inputs_embeds: torch.Tensor,
    per_layer_inputs: torch.Tensor | None,
) -> torch.Tensor | None:
    """Project inputs_embeds and combine with per_layer_inputs.

    Steps:
    1. Project inputs_embeds: hidden_size → total_ple_dim
    2. Scale by hidden_size^{-0.5}
    3. Reshape to (num_tokens, num_layers, per_layer_dim)
    4. Normalize with per_layer_projection_norm
    5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
    """
    return self.self_decoder.project_per_layer_inputs(
        inputs_embeds, per_layer_inputs
    )

Gemma4Router

Bases: Module

Router for Gemma4 MoE that preprocesses input before projection.

Applies RMSNorm (no learned weight), root_size scaling (hidden_size^{-0.5}), then a learned per-dimension scale before projecting to expert logits.

This preprocessing is applied ONLY to the router's input, not to the expert MLPs' input.

Source code in vllm/model_executor/models/gemma4.py
class Gemma4Router(nn.Module):
    """Router for Gemma4 MoE that preprocesses input before projection.

    Applies RMSNorm (no learned weight), root_size scaling
    (hidden_size^{-0.5}), then a learned per-dimension scale before
    projecting to expert logits.

    This preprocessing is applied ONLY to the router's input, not to
    the expert MLPs' input.
    """

    def __init__(
        self,
        config,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size

        # RMSNorm without learned weight — pure normalization only
        self.norm = RMSNorm(self.hidden_size, eps=config.rms_norm_eps, has_weight=False)
        # Per-dimension learned scale, applied after norm + root_size
        self.scale = nn.Parameter(torch.ones(self.hidden_size))
        # Constant 1/sqrt(hidden_size) scaling factor
        self.register_buffer(
            "root_size",
            torch.tensor(self.hidden_size**-0.5),
            persistent=False,
        )
        # Project to expert logits; replicated across TP for consistent routing
        # GateLinear supports bf16 W/A → fp32 output, which is important
        # because the topk kernel often needs fp32 for stable routing.
        self.proj = GateLinear(
            self.hidden_size,
            config.num_experts,
            bias=False,
            out_dtype=torch.float32,
            prefix=f"{prefix}.proj",
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Returns raw router logits [T, E]."""
        x = self.norm(x)
        x = x * self.root_size.to(x.dtype)
        x = x * self.scale.to(x.dtype)
        router_logits, _ = self.proj(x)
        return router_logits

forward

forward(x: Tensor) -> Tensor

Returns raw router logits [T, E].

Source code in vllm/model_executor/models/gemma4.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Returns raw router logits [T, E]."""
    x = self.norm(x)
    x = x * self.root_size.to(x.dtype)
    x = x * self.scale.to(x.dtype)
    router_logits, _ = self.proj(x)
    return router_logits

Gemma4SelfDecoderLayers

Bases: Module

Compiled wrapper: embedding + non-KV-shared layers (YOCO first half).

Owns the embedding and PLE modules so they are inside the compiled graph. Gemma4Model delegates embedding methods here.

Source code in vllm/model_executor/models/gemma4.py
@support_torch_compile(
    enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill
)
class Gemma4SelfDecoderLayers(nn.Module):
    """Compiled wrapper: embedding + non-KV-shared layers (YOCO first half).

    Owns the embedding and PLE modules so they are inside the compiled
    graph. Gemma4Model delegates embedding methods here.
    """

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        decoder_layers: list[Gemma4DecoderLayer],
        layer_idx_start: int,
        embed_tokens: VocabParallelEmbedding,
        normalizer: torch.Tensor,
        embed_tokens_per_layer: VocabParallelEmbedding | None,
        embed_scale_per_layer: torch.Tensor | None,
        per_layer_model_projection: ColumnParallelLinear | None,
        per_layer_projection_norm: RMSNorm | None,
        per_layer_input_scale: torch.Tensor | None,
        per_layer_projection_scale: torch.Tensor | None,
    ):
        super().__init__()
        self.decoder_layers = decoder_layers
        self.layer_idx_start = layer_idx_start

        config = _get_text_config(vllm_config.model_config.hf_config)
        self.config = config
        self.hidden_size_per_layer_input = getattr(
            config, "hidden_size_per_layer_input", 0
        )
        self.vocab_size_per_layer_input = getattr(
            config, "vocab_size_per_layer_input", config.vocab_size
        )

        # Shared references to modules owned by Gemma4Model — must be
        # inside this nn.Module so torch.compile captures them.
        self.embed_tokens = embed_tokens
        self.normalizer = normalizer
        self.embed_tokens_per_layer = embed_tokens_per_layer
        self.embed_scale_per_layer = embed_scale_per_layer
        self.per_layer_model_projection = per_layer_model_projection
        self.per_layer_projection_norm = per_layer_projection_norm
        self.per_layer_input_scale = per_layer_input_scale
        self.per_layer_projection_scale = per_layer_projection_scale

    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids) * self.normalizer

    def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor | None:
        """Get per-layer embeddings from embed_tokens_per_layer.

        Returns:
            Per-layer embeddings (num_tokens, num_layers,
            hidden_size_per_layer_input)
        """
        if self.embed_tokens_per_layer is None:
            return None
        per_layer_inputs_mask = torch.logical_and(
            input_ids >= 0,
            input_ids < self.vocab_size_per_layer_input,
        )
        per_layer_inputs_tokens = torch.where(
            per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)
        )
        per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens)
        per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer
        return per_layer_embeds.reshape(
            *input_ids.shape,
            self.config.num_hidden_layers,
            self.hidden_size_per_layer_input,
        )

    def project_per_layer_inputs(
        self,
        inputs_embeds: torch.Tensor,
        per_layer_inputs: torch.Tensor | None,
    ) -> torch.Tensor | None:
        """Project inputs_embeds and combine with per_layer_inputs.

        Steps:
        1. Project inputs_embeds: hidden_size → total_ple_dim
        2. Scale by hidden_size^{-0.5}
        3. Reshape to (num_tokens, num_layers, per_layer_dim)
        4. Normalize with per_layer_projection_norm
        5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
        """
        if self.per_layer_model_projection is None:
            return None
        per_layer_projection = self.per_layer_model_projection(inputs_embeds)
        per_layer_projection = per_layer_projection * self.per_layer_projection_scale
        per_layer_projection = per_layer_projection.reshape(
            *inputs_embeds.shape[:-1],
            self.config.num_hidden_layers,
            self.hidden_size_per_layer_input,
        )
        per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
        if per_layer_inputs is None:
            return per_layer_projection
        return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        inputs_embeds: torch.Tensor | None = None,
        per_layer_inputs: torch.Tensor | None = None,
        **kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
            per_layer_inputs = self.project_per_layer_inputs(
                hidden_states, per_layer_inputs
            )
        else:
            hidden_states = self.embed_input_ids(input_ids)
            per_layer_embeds = self.get_per_layer_inputs(input_ids)
            per_layer_inputs = self.project_per_layer_inputs(
                hidden_states, per_layer_embeds
            )

        hidden_states = _run_decoder_layers(
            self.decoder_layers,
            self.layer_idx_start,
            positions,
            hidden_states,
            per_layer_inputs,
            **kwargs,
        )
        return hidden_states, per_layer_inputs

get_per_layer_inputs

get_per_layer_inputs(input_ids: Tensor) -> Tensor | None

Get per-layer embeddings from embed_tokens_per_layer.

Returns:

Type Description
Tensor | None

Per-layer embeddings (num_tokens, num_layers,

Tensor | None

hidden_size_per_layer_input)

Source code in vllm/model_executor/models/gemma4.py
def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor | None:
    """Get per-layer embeddings from embed_tokens_per_layer.

    Returns:
        Per-layer embeddings (num_tokens, num_layers,
        hidden_size_per_layer_input)
    """
    if self.embed_tokens_per_layer is None:
        return None
    per_layer_inputs_mask = torch.logical_and(
        input_ids >= 0,
        input_ids < self.vocab_size_per_layer_input,
    )
    per_layer_inputs_tokens = torch.where(
        per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)
    )
    per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens)
    per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer
    return per_layer_embeds.reshape(
        *input_ids.shape,
        self.config.num_hidden_layers,
        self.hidden_size_per_layer_input,
    )

project_per_layer_inputs

project_per_layer_inputs(
    inputs_embeds: Tensor, per_layer_inputs: Tensor | None
) -> Tensor | None

Project inputs_embeds and combine with per_layer_inputs.

Steps: 1. Project inputs_embeds: hidden_size → total_ple_dim 2. Scale by hidden_size^{-0.5} 3. Reshape to (num_tokens, num_layers, per_layer_dim) 4. Normalize with per_layer_projection_norm 5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)

Source code in vllm/model_executor/models/gemma4.py
def project_per_layer_inputs(
    self,
    inputs_embeds: torch.Tensor,
    per_layer_inputs: torch.Tensor | None,
) -> torch.Tensor | None:
    """Project inputs_embeds and combine with per_layer_inputs.

    Steps:
    1. Project inputs_embeds: hidden_size → total_ple_dim
    2. Scale by hidden_size^{-0.5}
    3. Reshape to (num_tokens, num_layers, per_layer_dim)
    4. Normalize with per_layer_projection_norm
    5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
    """
    if self.per_layer_model_projection is None:
        return None
    per_layer_projection = self.per_layer_model_projection(inputs_embeds)
    per_layer_projection = per_layer_projection * self.per_layer_projection_scale
    per_layer_projection = per_layer_projection.reshape(
        *inputs_embeds.shape[:-1],
        self.config.num_hidden_layers,
        self.hidden_size_per_layer_input,
    )
    per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
    if per_layer_inputs is None:
        return per_layer_projection
    return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale

_get_text_config

_get_text_config(config)

Dereference text_config if config is a nested Gemma4Config.

Gemma4 checkpoints use architectures=["Gemma4ForConditionalGeneration"] which yields a Gemma4Config with nested text_config. This function transparently returns the text config regardless of nesting.

Source code in vllm/model_executor/models/gemma4.py
def _get_text_config(config):
    """Dereference text_config if config is a nested Gemma4Config.

    Gemma4 checkpoints use architectures=["Gemma4ForConditionalGeneration"]
    which yields a Gemma4Config with nested text_config. This function
    transparently returns the text config regardless of nesting.
    """
    if hasattr(config, "text_config"):
        return config.text_config
    return config

_run_decoder_layers

_run_decoder_layers(
    decoder_layers: list[Gemma4DecoderLayer],
    layer_idx_start: int,
    positions: Tensor,
    hidden_states: Tensor,
    per_layer_inputs: Tensor | None = None,
    **kwargs,
) -> Tensor

Run a slice of decoder layers with PLE extraction.

Source code in vllm/model_executor/models/gemma4.py
def _run_decoder_layers(
    decoder_layers: list[Gemma4DecoderLayer],
    layer_idx_start: int,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
    per_layer_inputs: torch.Tensor | None = None,
    **kwargs,
) -> torch.Tensor:
    """Run a slice of decoder layers with PLE extraction."""
    residual = None
    for idx, layer in enumerate(decoder_layers):
        layer_idx = idx + layer_idx_start
        layer_per_input = (
            per_layer_inputs[:, layer_idx, :] if per_layer_inputs is not None else None
        )
        hidden_states, residual = layer(
            positions,
            hidden_states,
            residual,
            per_layer_input=layer_per_input,
            **kwargs,
        )
    return hidden_states