8000 refine inference step 2 by wenhuach21 · Pull Request #498 · intel/auto-round · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

refine inference step 2 #498

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 9, 2025
11 changes: 3 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pip install auto-round-lib

## Model Quantization

### Basic Usage (Gaudi2/CPU/GPU)
### Basic Usage (Gaudi2/CPU/XPU/GPU)

A user guide detailing the full list of supported arguments is provided by calling ```auto-round -h``` on the terminal.
Set the format you want in `format` and
Expand Down 8000 Expand Up @@ -268,7 +268,7 @@ autoround.save_quantized(output_dir, format='auto_round', inplace=True)

### Export Formats
**AutoRound Format**: This format is well-suited for CPU, HPU devices, 2 bits, as well as mixed-precision
inference. **[2,4] bits are supported**. However, it has not yet gained widespread community adoption.
inference. **[2,3,4,8] bits are supported**. However, it has not yet gained widespread community adoption.

**AutoGPTQ Format**: This format is well-suited for symmetric quantization on CUDA devices and is widely adopted by the
community, **[2,3,4,8] bits are supported**. However, **the
Expand Down Expand Up @@ -324,14 +324,9 @@ in [Gaudi Guide](https://docs.habana.ai/en/latest/).
from transformers import AutoModelForCausalLM, AutoTokenizer
from auto_round import AutoRoundConfig

backend = "auto" ##cpu, hpu, cuda
quantization_config = AutoRoundConfig(
backend=backend
)
quantized_model_path = "./tmp_autoround"
model = AutoModelForCausalLM.from_pretrained(quantized_model_path,
device_map=backend.split(':')[0],
quantization_config=quantization_config)
device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(quantized_model_path)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
Expand Down
11 changes: 10 additions & 1 deletion auto_round/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,14 @@
from .mllm import AutoRoundMLLM
from auto_round.utils import LazyImport

from auto_round.inference.auto_quantizer import AutoHfQuantizer,AutoRoundConfig
def __getattr__(name):
if name == 'AutoHfQuantizer':
from auto_round.inference.auto_quantizer import AutoHfQuantizer
return AutoHfQuantizer
if name == 'AutoRoundConfig':
from auto_round.inference.auto_quantizer import AutoRoundConfig
return AutoRoundConfig

raise AttributeError(f"auto-round has no attribute '{name}'")

from .version import __version__
15 changes: 11 additions & 4 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,8 @@ def quantize_and_save(self, output_dir: str = "tmp_autoround", format: str = "au
for index in range(len(formats)):
format = formats[index]
if "auto_round" in format:
if self.sym and ("gptq" not in format and "awq" not in format):
format = format.replace('auto_round', 'auto_round:gptq')
if (self.sym and ("gptq" not in format and "awq" not in format)) or self.bits==3:
format = format.replace('auto_round', 'auto_round:auto_gptq')
formats[index] = format

# Remove duplicates from formats list
Expand All @@ -496,6 +496,13 @@ def remove_duplicates(lst):
# Save the quantized model in the specified formats
folders = []
for format in formats:
if "gptq" in format and not self.sym:
logger.warning(
"The asymmetrical kernel of the GPTQ format may result in a noticeable accuracy drop,"
" particularly for 2-bit quantization and smaller models."
" We recommend exporting to either the AutoAWQ format ( only 4 bits) or "
"the AutoRound format(2/4/8 bits)."
)
save_format_ = format.replace(":", "-").replace("_", "-")
save_folder = os.path.join(output_dir, save_format_) if len(formats) > 1 else output_dir
self.save_quantized(save_folder, format=format, inplace=inplace, **kwargs)
Expand Down Expand Up @@ -1598,8 +1605,8 @@ def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **k
logger.warning(
"The asymmetrical kernel of the GPTQ format may result in a noticeable accuracy drop,"
" particularly for 2-bit quantization and smaller models."
" We recommend exporting to either the AutoAWQ format (4 bits) or "
"the AutoRound format (2 bits) to enhance performance."
" We recommend exporting to either the AutoAWQ format ( only 4 bits) or "
"the AutoRound format(2/4/8 bits)."
)
if "awq" in format and not self.bits == 4:
raise ValueError("The AWQ format only supports W4 quantization ")
Expand Down
19 changes: 12 additions & 7 deletions auto_round/export/export_to_autoround/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import auto_round.export.export_to_autoround.qlinear_triton_act
import auto_round_extension.cuda.qlinear_tritonv2
from auto_round.utils import get_layer_names_in_block, get_module, logger, set_module, supported_layer_types
from auto_round.utils import get_module, logger, set_module, supported_layer_types
import threadpoolctl as tctl
import inspect
from tqdm import tqdm
Expand Down Expand Up @@ -75,15 +75,15 @@ def dynamic_import_quant_linear_for_packing(backend, bits, group_size, sym, act_

from auto_round_extension.cuda.qlinear_tritonv2 import QuantLinear
return QuantLinear
elif "auto_round" in backend and "gptq" in backend:
from auto_round.export.export_to_autoround.qlinear_triton import QuantLinear ##no g_idx
return QuantLinear
elif "auto_round" in backend and "gptq" in backend and bits in (2, 4, 8):
from auto_round.export.export_to_autoround.qlinear_triton import QuantLinear ##no g_idx
return QuantLinear
elif "awq" in backend:
from ..export_to_awq.utils import WQLinear_GEMM
return WQLinear_GEMM
elif "gptqmodel" in backend:
return auto_round_extension.cuda.qlinear_tritonv2.QuantLinear
elif "gptq" in backend and not "gptqmodel" in backend: ## have g_idx
elif "gptq" in backend and not "gptqmodel" in backend: ## have g_idx
return get_autogptq_packing_qlinear(backend, bits, group_size, sym)
else:
assert False, f"only support auto_gptq, auto_awq and auto_round backend"
Expand Down Expand Up @@ -190,7 +190,9 @@ def pack_layer(layer_name, model, backend):
new_layer.device = device
set_module(model, layer_name, new_layer)
qlayer = new_layer
if sym:
import auto_round.export.export_to_autoround.qlinear_triton
if sym and isinstance(QuantLinear, (auto_round.export.export_to_autoround.qlinear_triton.QuantLinear,
auto_round_extension.cuda.qlinear_tritonv2.QuantLinear)):
zp = int(zp.flatten()[0])

qlayer.to("cpu")
Expand Down Expand Up @@ -248,7 +250,7 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex

##if using sym, we change to gptq sym kernel to avoid compiling from auto_round source
if (kwargs.get("sym") is None or kwargs.get("sym") == True) and ("gptq" not in backend and "awq" not in backend):
backend = backend.replace('auto_round', 'auto_round:gptq')
backend = backend.replace('auto_round', 'auto_round:auto_gptq')

model = kwargs["model"]
safe_serialization = True if 'safe_serialization' not in kwargs.keys() else kwargs["safe_serialization"]
Expand All @@ -260,6 +262,9 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex
quantization_config["quant_method"] = "intel/auto-round"

quantization_config["backend"] = backend
if quantization_config["bits"]==3:
backend = "auto_round:auto_gptq"

tokenizer = kwargs.get("tokenizer", None)
processor = kwargs.get("processor", None)
extra_config = {}
Expand Down
1 change: 1 addition & 0 deletions auto_round/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_round.inference.convert_model import convert_hf_model, infer_target_device, post_init

49 changes: 26 additions & 23 deletions auto_round/inference/auto_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,26 +155,11 @@ def merge_quantization_configs(
loading_attr_dict = quantization_config_from_args.get_loading_attributes() \
if quantization_config_from_args is not None else None
if isinstance(quantization_config, dict):
if "auto-round" in quantization_config["quant_method"]:
if "auto-round" in quantization_config[
"quant_method"] or quantization_config_from_args.__class__.__name__ == "AutoRoundConfig":
quantization_config = AutoRoundConfig.from_dict(quantization_config)
else:
if isinstance(quantization_config_from_args, (AutoRoundConfig)):
logger.info(f"Loading quantized model in auto_round format.")
tmp_backend = quantization_config["quant_method"]
if "auto-round" not in tmp_backend and "gptq" not in tmp_backend and "awq" not in tmp_backend:
logger.error("could not convert to auto_round format, currently only supports `gptq`,`awq` or "
"`auto-round` format")
exit(-1)
target_backend = quantization_config["backend"] if "backend" in quantization_config else "auto"
if loading_attr_dict is not None and "backend" in loading_attr_dict:
target_backend = loading_attr_dict["backend"]
loading_attr_dict.pop("backend")
if "auto_round" not in target_backend:
target_backend = f"auto_round:{tmp_backend}" #
quantization_config = AutoRoundConfig.from_dict(quantization_config)
setattr(quantization_config, "backend", target_backend)
else:
quantization_config = AutoQuantizationConfig.from_dict(quantization_config) # pylint: disable=E1101
quantization_config = AutoQuantizationConfig.from_dict(quantization_config) # pylint: disable=E1101

if isinstance(quantization_config,
(GPTQConfig, AwqConfig, AutoRoundConfig)) and quantization_config_from_args is not None:
Expand Down Expand Up @@ -265,8 +250,8 @@ def __init__(

def post_init(self):
r"""Safety checker that arguments are correct."""
if self.bits not in [2, 4, 8]:
raise ValueError(f"Only support quantization to [2,4,8] bits but found {self.bits}")
if self.bits not in [2, 3, 4, 8]:
raise ValueError(f"Only support quantization to [2,3,4,8] bits but found {self.bits}")
if self.group_size != -1 and self.group_size <= 0:
raise ValueError("group_size must be greater than 0 or equal to -1")

Expand All @@ -278,6 +263,26 @@ def to_dict(self):
config_dict = super().to_dict()
return config_dict

@classmethod
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
quant_method = config_dict["quant_method"]
if "auto-round" not in quant_method and "gptq" not in quant_method and "awq" not in quant_method:
raise NotImplementedError(
"Failed to convert to auto_round format. Only `gptqv1`, `awq`, and `auto-round` formats are supported."
)

if "gptq" in quant_method and "meta" in config_dict:
raise NotImplementedError(
"Failed to convert gptq format to auto_round format. Only supports `gptqv1`")

if "awq" in quant_method and config_dict.get("version", "gemm") != "gemm":
raise NotImplementedError(
"Failed to convert awq format to auto_round format. Only supports awq format with gemm version")

if "auto-round" not in quant_method:
config_dict["backend"] = f"auto_round:{quant_method}"
return super().from_dict(config_dict, return_unused_kwargs=return_unused_kwargs, **kwargs)


class AutoRoundQuantizer(HfQuantizer):
"""Quantizer of the AutoRound method, currently only triton and exllamav2 backend has been supported."""
Expand Down Expand Up @@ -306,7 +311,6 @@ def validate_environment(self, *args, **kwargs):
"auto-round` or install from source")

def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
self.target_device = infer_target_device(self.device_map)
if torch_dtype is None:
torch_dtype = torch.float16
elif torch_dtype != torch.float16:
Expand All @@ -330,8 +334,7 @@ class StoreAttr(object):

def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
if self.pre_quantized:
target_device = self.target_device if hasattr(self, self.target_device) else infer_target_device(
self.device_map)
target_device = infer_target_device(self.device_map)
model, used_backends = convert_hf_model(model, target_device)
self.used_backends = used_backends

Expand Down
73 changes: 68 additions & 5 deletions auto_round/inference/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def feature_num_greater_checker(in_feature, out_feature, num):
bits=[2, 4, 8],
priority=1, feature_checks=[feature_multiply_checker_32],
alias=["auto_round", "tritonv2"],
requirements=["auto-round>=0.2"]
requirements=["auto-round>=0.5.0"]
)

BackendInfos['auto_round:tritonv2_zp'] = BackendInfo(device=["cuda"], sym=[True], ## asym has accuracy issue
Expand All @@ -135,7 +135,7 @@ def feature_num_greater_checker(in_feature, out_feature, num):
bits=[2, 4, 8],
priority=1, feature_checks=[feature_multiply_checker_32],
alias=["tritonv2", "tritonv2_zp"],
requirements=["auto-round>=0.5"]
requirements=["auto-round>=0.5.0"]
)

BackendInfos['gptqmodel:marlin'] = BackendInfo(device=["cuda"], sym=[True],
Expand All @@ -145,7 +145,7 @@ def feature_num_greater_checker(in_feature, out_feature, num):
dtype=["float16", "bfloat16"],
priority=6, feature_checks=[in_output_feature_multiply_checker_32],
alias=["marlin", "gptqmodel"],
requirements=["gptqmodel>=2.0"]
requirements=["gptqmodel>=2.0"],
)

BackendInfos['gptqmodel:marlin_zp'] = BackendInfo(device=["cuda"], sym=[True],
Expand Down Expand Up @@ -504,7 +504,7 @@ def get_autogptq_infer_linear(backend, bits=4, group_size=128, sym=False):
return QuantLinear


def find_backend(target_backend: str, orig_backend: str = None) -> str | None:
def find_backend(target_backend: str, orig_backend: str = None):
"""
Finds the matching backend key based on the target backend name or its aliases.

Expand Down Expand Up @@ -620,7 +620,10 @@ def get_layer_backend(device, backend, orig_backend, bits, group_size, sym, in_f
try:
require_version(requirement)
except ImportError:
logger.error(f"pip install '{requirement}' ")
if "gptqmodel" in requirement:
logger.error(f"pip install -v '{requirement}' --no-build-isolation")
else:
logger.error(f"pip install '{requirement}' ")
else:
str_info = requirement()[1]
logger.error(str_info)
Expand All @@ -633,3 +636,63 @@ def get_layer_backend(device, backend, orig_backend, bits, group_size, sym, in_f
reverse=True)

return supported_backends[0]


def get_highest_priority_backend(bits, sym, group_size, device, packing_format):
supported_backends = []
for key in BackendInfos.keys():
backend = BackendInfos[key]
# Check if device is supported by the backend
if device not in backend.device:
continue

# Check if bit-width is supported
if bits not in backend.bits:
continue

# Check if group_size is valid (if required by backend)
if backen A401 d.group_size is not None and group_size not in backend.group_size:
continue

# Check if symmetric/asymmetric quantization is supported
if sym not in backend.sym:
continue

# Check if the format is convertible when packing formats differ
if packing_format == backend.packing_format or packing_format in backend.convertable_format:
pass
else:
continue
supported_backends.append(key)

if len(supported_backends) > 0:

supported_backends = sorted(supported_backends,
key=lambda support_backend: BackendInfos[support_backend].priority,
reverse=True)
return supported_backends[0]
else:
return None


def process_requirement(requirements: list):
gptqmodel_requirements = None
other_requirements = []
for requirement in requirements:
if "gptqmodel" in requirement:
gptqmodel_requirements = requirement
else:
other_requirements.append(requirement)

infos = []

if gptqmodel_requirements is not None:
infos.append(f"pip install -v '{gptqmodel_requirements}' --no-build-isolation")
infos.append(f"pip install 'numpy<2.0'")

other_info = f"pip install"
if len(other_requirements) > 0:
for requirement in other_requirements:
other_info += f" {requirement}"
infos.append(other_info)
return infos
Loading
0