Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generalize deepspeed linear and implement it for non cuda systems #6932

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
9 changes: 7 additions & 2 deletions deepspeed/linear/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
from dataclasses import dataclass, field
from typing import List

import torch


@dataclass
class LoRAConfig:
"""
Configuration settings for LoRAOptimizedLinear.

Attributes:
lora_r (int): LoRA attention dimension, also know as the rank. Defaults is 64.
lora_r (int): LoRA attention dimension, also known as the rank. Defaults is 64.
lora_alpha (float): LoRA scaling factor, default is 16.
base_weight_sharding (int): The degree to which the base weights are sharded,
should typically be set to the data-parallel world size to maximize the memory
Expand Down Expand Up @@ -42,8 +44,11 @@ class QuantizationConfig:
Attributes:
q_bits (int): The number of bits used for quantization. Default is 8.
mantissa_bits (int): The number of bits reserved for the mantissa in fixed-point quantization. Default is 3.
group_size (int): The size of the group used for quantization. Default is 512.
group_size (int): The number of elements used for quantization. Default is 512.
q_dtype (torch.dtype): The data type to quantize to. Default is uint8. (in CUDA, buffers are allocated as
uint8, but inside the kernels the quantization is done to fp8)
"""
q_bits: int = 8
mantissa_bits: int = 3
group_size: int = 512
q_dtype: torch.dtype = torch.uint8
8 changes: 4 additions & 4 deletions deepspeed/linear/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,24 @@ def __new__(
self.quantizer = quantizer
else:
# if FPQuantizerBuilder is not compatible in this env this init will fail
self.quantizer = FP_Quantize(group_size=self.quantization_config.group_size)
self.quantizer = FP_Quantize(quantization_config=self.quantization_config)
self._ensure_quantized(self)
return self

def _ensure_quantized(self, tensor: torch.Tensor):
# If the tensor is on the accelerator and is not quantized, then quantize it in-place.
if get_accelerator().on_accelerator(tensor) and tensor.dtype != torch.uint8:
if get_accelerator().on_accelerator(tensor) and tensor.dtype != self.quantization_config.q_dtype:
with get_accelerator().stream(get_accelerator().current_stream(tensor.device)):
tensor.data = self.quantizer.quantize(tensor.data,
q_bits=self.quantization_config.q_bits,
q_mantisa_bits=self.quantization_config.mantissa_bits)
assert tensor.dtype == torch.uint8
assert tensor.dtype == self.quantization_config.q_dtype

def dequantized(self) -> torch.Tensor:
"""
Return a tensor containing the dequantized weights of this parameter.
"""
if get_accelerator().on_accelerator(self.data) and self.data.dtype == torch.uint8:
if get_accelerator().on_accelerator(self.data) and self.data.dtype == self.quantization_config.q_dtype:
with get_accelerator().stream(get_accelerator().current_stream(self.data.device)):
return self.quantizer.dequantize(self.data,
q_bits=self.quantization_config.q_bits,
Expand Down
61 changes: 42 additions & 19 deletions deepspeed/ops/fp_quantizer/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class Quantizer(ABC):
"""
Abstract Quantizer class that implmenents quantize/dequantize methods.
Abstract Quantizer class that implements quantize/dequantize methods.

Arguments:
group_size (int, optional): number of values or elements that are grouped
Expand All @@ -42,12 +42,18 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non

class FP_Quantize(Quantizer):

def __init__(self, group_size=512) -> None:
def __init__(self, quantization_config) -> None:
global fp_quant_module
super().__init__(group_size=group_size)
super().__init__(group_size=quantization_config.group_size)
if fp_quant_module is None:
fp_quant_module = FPQuantizerBuilder().load()
self.is_python_impl = getattr(fp_quant_module, "PYTHON_IMPL", False)
self.q_config = quantization_config

self.orig_dtype = None
self.num_groups = None
self.input_q = None
self.scale = None

def quantize(self,
input,
Expand All @@ -73,15 +79,27 @@ def quantize(self,
else:
assert (0), \
f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!"
self.num_groups = input.numel() // self.group_size
self.input_q = torch.ones(self.num_groups,
int(self.group_size * q_bits) // 8 + 4,
dtype=torch.uint8,
device=input.device)
out = fp_quant_module.quantize(self.input_q, input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits)

# Adding (group_size - 1) is for padding
self.num_groups = (input.numel() + self.q_config.group_size - 1) // self.q_config.group_size
# group_size should be the minimal number between the defined group size and number of elements in tensor.
group_size = int(min(self.q_config.group_size, input.numel()) * q_bits) // 8
# CUDA quantization kernel saves the scale as (fp32) inside the quantized tensor for each group
if not self.is_python_impl:
group_size += 4
# CUDA quantization kernel allocates tensors as uint8, but handles them as fp8 inside the kernel.
self.input_q = torch.ones(self.num_groups, group_size, dtype=self.q_config.q_dtype, device=input.device)
# CUDA quantization kernel attaches scales to quantized result, in python implementation it can't be done
# because they are of different types.
self.scale = torch.ones(self.num_groups, 1, device=input.device)
out = fp_quant_module.quantize(self.input_q, input, self.scale, group_size, stochastic_mode, q_bits,
q_mantisa_bits)
if return_meta_tensor:
data, self.scale = out.split(self.group_size, dim=-1)
data = data.contiguous().reshape(input.shape)
if not self.is_python_impl:
data, self.scale = out.split(group_size, dim=-1)
data = data.contiguous().reshape(input.shape)
else:
data = out.contiguous().reshape(input.shape)
self.scale = self.scale.contiguous()
del self.input_q
del out
Expand All @@ -93,9 +111,9 @@ def quantize(self,

def to(self, *args, **kwargs):
# Intermediate tensors may need to be moved to different devices
if hasattr(self, 'input_q'):
if hasattr(self, 'input_q') and self.input_q is not None:
self.input_q = self.input_q.to(*args, **kwargs)
if hasattr(self, 'scale'):
if hasattr(self, 'scale') and self.scale is not None:
self.scale = self.scale.to(*args, **kwargs)

def get_scales(self):
Expand All @@ -118,11 +136,16 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non
assert (0), \
f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"

if scale is not None:
if scale is not None and not self.is_python_impl:
assert input_q.numel() == fp_out.numel(), \
f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1)
input_q = torch.cat([input_q.reshape(-1, self.q_config.group_size), scale], dim=-1).contiguous()
elif scale is not None and self.is_python_impl:
group_size = int(min(self.q_config.group_size, input_q.numel()) * q_bits) // 8
input_q = input_q.reshape(-1, group_size)

fp_quant_module.dequantize(fp_out, input_q, self.scale, self.q_config.group_size, q_mantisa_bits,
q_bits - q_mantisa_bits - 1)
return fp_out

def selective_dequantize(self,
Expand Down Expand Up @@ -151,11 +174,11 @@ def selective_dequantize(self,
assert (0), \
f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"

if scale is not None:
if scale is not None and not self.is_python_impl:
assert input_q.numel() == fp_out.numel(), \
f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
input_q = torch.cat([input_q.reshape(-1, self.q_config.group_size), scale], dim=-1).contiguous()

fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits,
fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.q_config.group_size, q_mantisa_bits,
q_bits - q_mantisa_bits - 1)
return fp_out
17 changes: 17 additions & 0 deletions op_builder/fp_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,20 @@ def extra_ldflags(self):

def include_paths(self):
return ['csrc/fp_quantizer/includes', 'csrc/includes']

@staticmethod
def get_default_quant_dtype():
import torch
return torch.uint8

@staticmethod
def get_quant_range(q_bits=None):
if q_bits == 8:
return 480
elif q_bits == 6:
return 28.
elif q_bits == 12:
return 510.
else:
assert (0), \
"Please specify the right quantization range for the selected precision!"
86 changes: 86 additions & 0 deletions op_builder/hpu/fp_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2024 Habana Labs, Ltd. an Intel Company
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
try:
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
# if successful this also means we're doing a local install and not JIT compile path
from op_builder import __deepspeed__ # noqa: F401 # type: ignore
from op_builder.builder import OpBuilder
except ImportError:
from deepspeed.ops.op_builder.builder import OpBuilder


class FPQuantizerBuilder(OpBuilder):
BUILD_VAR = "DS_BUILD_FP_QUANTIZER"
NAME = "fp_quantizer"

def __init__(self, name=None):
name = self.NAME if name is None else name
super().__init__(name=name)

def absolute_name(self):
return f'deepspeed.ops.fp_quantizer.{self.NAME}_op'

def sources(self):
return []

def load(self, verbose=True):
return FPQuantizer

@staticmethod
def get_default_quant_dtype():
return torch.float8_e4m3fn

@staticmethod
def get_quant_range(q_bits=None):
import habana_frameworks.torch.utils.experimental as htexp
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
dtype = torch.float8_e4m3fnuz
else:
dtype = torch.float8_e4m3fn
return torch.finfo(dtype).max


class FPQuantizer:
PYTHON_IMPL = True

@classmethod
def selective_dequantize(cls, val_q, scales, indexes, group_size, q_mantisa_bits, q_exponent_bits):
assert False, "Selective dequantize isn't implemented for HPU!"

@classmethod
def dequantize(cls, fp_out, input_q, scale, group_size, q_mantisa_bits, q_exponent_bits):
orig_shape = fp_out.shape
orig_dtype = fp_out.dtype
dequant_out = torch.ops.hpu.cast_from_fp8(input_q, (1.0 / scale), orig_dtype).view(orig_shape)
fp_out.copy_(dequant_out)
return fp_out

@classmethod
def quantize(cls, out, val, scale, group_size, stochastic_rounding, q_bits, q_mantisa_bits):
assert q_bits == 8, "Quantize on HPU only supports quantization to FP8"
assert q_mantisa_bits == 3, "Quantize on HPU only supports q_mantissa_bits = 3"
assert out.dtype.is_floating_point, "Quantization on HPU is only to float dtypes"

num_groups, group_size = out.shape

# Reshape the tensor
val_reshaped = val.view(num_groups, group_size).float()
# Calculate the scale
max_vals = val_reshaped.abs().max(dim=1, keepdim=True)[0]
q_range = torch.finfo(out.dtype).max
tmp_scale = q_range / max_vals
scale.copy_(tmp_scale)
# Copy quantized
quant, _ = torch.ops.hpu.cast_to_fp8_v2(val_reshaped, scale, stochastic_rounding, dtype=out.dtype)
out.copy_(quant)

return out

@classmethod
def get_scales(cls, out, num_groups):
return out
9 changes: 2 additions & 7 deletions tests/unit/linear/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class TestLoRALinear(DistributedTest):

def test(self, base_weight_sharding):
rank = dist.get_rank()
lora_config = None
quantization_config = None

input_features = 64 # Number of input features
Expand Down Expand Up @@ -77,15 +76,13 @@ class TestQuantLinear(DistributedTest):
world_size = 2

def test(self, q_bits):
rank = dist.get_rank()
lora_config = None

input_features = 64 # Number of input features
output_features = 64 # Number of output features
batch_size = 5 # Number of samples in a batch

lora_config = None
quantization_config = QuantizationConfig(q_bits=q_bits)
quantization_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype()

linear_layer = OptimizedLinear(input_dim=input_features,
output_dim=output_features,
Expand All @@ -106,15 +103,13 @@ class TestOptimizedLinear(DistributedTest):
world_size = 2

def test(self, base_weight_sharding, q_bits):
rank = dist.get_rank()
lora_config = None

input_features = 64 # Number of input features
output_features = 64 # Number of output features
batch_size = 5 # Number of samples in a batch

lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=base_weight_sharding)
quantization_config = QuantizationConfig(q_bits=q_bits)
quantization_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype()

linear_layer = OptimizedLinear(input_dim=input_features,
output_dim=output_features,
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/linear/test_quant_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ def test_requires_grad(self):
def test_move_to_accelerator(self):
device = get_accelerator().current_device()
data = torch.rand(5, 5, device='cpu', dtype=torch.bfloat16)
qp = QuantizedParameter(data)
quantization_config = QuantizationConfig()
quantization_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype()
qp = QuantizedParameter(data, quantization_config=quantization_config)
assert qp.device == torch.device('cpu')
qp = qp.to(get_accelerator().current_device_name())
assert qp.device == torch.device(device)
assert qp.dtype == torch.uint8
assert qp.dtype == quantization_config.q_dtype

def test_hf_clone(self):
device = get_accelerator().current_device_name()
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/ops/fp_quantizer/test_fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from deepspeed.ops.fp_quantizer import FP_Quantize, matmul_fp8

from deepspeed import get_accelerator
from deepspeed.linear import QuantizationConfig


@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
Expand All @@ -25,7 +26,11 @@
def test_fp_quant(dtype, q_bits, M):
device_name = get_accelerator().device_name()
quantization_group_size = 128
fpq = FP_Quantize(group_size=quantization_group_size)

quant_config = QuantizationConfig()
quant_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype()
quant_config.group_size = quantization_group_size
fpq = FP_Quantize(quantization_config=quant_config)

N = 8192
H = 4096
Expand Down
Loading