Dequantize the fp4 tensor back to high precision.
Source code in vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py
| def dequantize_to_dtype(
tensor_fp4: torch.Tensor,
tensor_sf: torch.Tensor,
global_scale: torch.Tensor | float,
dtype: torch.dtype,
block_size: int = 16,
swizzle: bool | None = True,
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert tensor_fp4.dtype == torch.uint8
m, packed_k = tensor_fp4.shape
k = packed_k * 2
tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32)
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
if swizzle:
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
tensor_sf_dtype = tensor_sf.to(torch.float32) * global_scale
# scale the tensor
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
return out.to(dtype)
|