Skip to content

vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils

dequantize_to_dtype

dequantize_to_dtype(
    tensor_fp4: Tensor,
    tensor_sf: Tensor,
    global_scale: Tensor | float,
    dtype: dtype,
    block_size: int = 16,
    swizzle: bool | None = True,
)

Dequantize the fp4 tensor back to high precision.

Source code in vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py
def dequantize_to_dtype(
    tensor_fp4: torch.Tensor,
    tensor_sf: torch.Tensor,
    global_scale: torch.Tensor | float,
    dtype: torch.dtype,
    block_size: int = 16,
    swizzle: bool | None = True,
):
    """Dequantize the fp4 tensor back to high precision."""
    # Two fp4 values are packed into one uint8.
    assert tensor_fp4.dtype == torch.uint8
    m, packed_k = tensor_fp4.shape
    k = packed_k * 2
    tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32)
    tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
    tensor_sf = tensor_sf.view(torch.float8_e4m3fn)

    if swizzle:
        tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
    tensor_sf_dtype = tensor_sf.to(torch.float32) * global_scale

    # scale the tensor
    out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
    return out.to(dtype)