8000 support rtn via iters==0 by wenhuach21 · Pull Request #510 · intel/auto-round · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

support rtn via iters==0 #510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 43 additions & 9 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
find_matching_blocks, is_debug_mode,
TORCH_VERSION_AT_LEAST_2_6,
supported_layer_types,
get_layer_features,
get_layer_features, set_module,
)
from .low_cpu_mem.utils import get_layers_before_block

Expand Down Expand Up @@ -192,10 +192,13 @@ def __init__(
self.nblocks = nblocks
self.dataset = dataset
self.iters = iters
if self.iters <= 0:
logger.warning("iters must be positive, reset it to 200")
if self.iters < 0:
logger.warning("`iters` must be non-negative, reset it to 200")
self.iters = 200
self.lr = lr or (1.0 / self.iters) ##must after iter setting
if self.iters == 0:
self.lr = 5e-3
else:
self.lr = lr or (1.0 / self.iters) ##must after iter setting
self.minmax_lr = minmax_lr or self.lr

self.data_type = data_type
Expand Down Expand Up @@ -360,7 +363,7 @@ def _dq_check(self):
self.super_bits = gguf_config["super_bits"] if self.super_bits is None else self.super_bits
self.super_group_size = gguf_config["super_group_size"] \
if self.super_group_size is None else self.super_group_size

def check_configs(self):

"""Checks if the configurations are valid.
Expand All @@ -375,7 +378,7 @@ def check_configs(self):
assert self.act_group_size == -1 or self.act_group_size >= 1, \
"only supports positive group_size or -1(per channel)"
assert self.batch_size > 0, "batch size must be positive"
assert self.iters > 0, "iters must be positive"
assert self.iters >= 0, "iters must be non-negative"
assert self.seqlen > 0, "seqlen must be positive"
assert self.nblocks > 0, "nblocks must be positive"
assert self.gradient_accumulate_steps > 0, "gradient accumulate step must be positive"
Expand Down Expand Up @@ -475,7 +478,7 @@ def quantize_and_save(self, output_dir: str = "tmp_autoround", format: str = "au
for index in range(len(formats)):
format = formats[index]
if "auto_round" in format:
if (self.sym and ("gptq" not in format and "awq" not in format)) or self.bits==3:
if (self.sym and ("gptq" not in format and "awq" not in format)) or self.bits == 3:
format = format.replace('auto_round', 'auto_round:auto_gptq')
formats[index] = format

Expand Down Expand Up @@ -506,18 +509,49 @@ def remove_duplicates(lst):
save_format_ = format.replace(":", "-").replace("_", "-")
save_folder = os.path.join(output_dir, save_format_) if len(formats) > 1 else output_dir
self.save_quantized(save_folder, format=format, inplace=inplace, **kwargs)

folders.append(save_folder)

return model, folders

@torch.inference_mode
def quantize_rtn(self):
if self.amp:
self.model.to(self.amp_dtype)
self.model.to("cpu")
all_to_quantized_module_names = []
for n, m in self.model.named_modules():
if check_to_quantized(m):
all_to_quantized_module_names.append(n)
pbar = tqdm(all_to_quantized_module_names)

for name in pbar:
pbar.set_description(f"Quantizing {name}")
m = get_module(self.model, name)

m.to(self.device)
m = WrapperLinear(m, enable_minmax_tuning=False, enable_norm_bias_tuning=False, enable_round_tuning=False)
m = m.unwrapper({})
m.to("cpu")
if self.is_packing_immediate:
from auto_round.export import PACKING_LAYER_WITH_FORMAT
if check_to_quantized(m):
target_backend = self.formats[0].split(":")[0] if ":" in self.formats[0] else self.formats[0]
PACKING_LAYER_WITH_FORMAT[target_backend](n, self.model, self.formats[0])
else:
set_module(self.model, name, m)
self.quantized = True
return self.model, self.layer_config

def quantize(self):
"""Quantize the model and return the quantized model along with layer configurations.
the entry of AutoRound.

Returns:
The quantized model and layer configurations.
"""
if self.iters == 0:
return self.quantize_rtn()

if bool(self.quant_block_list):
all_blocks = self.quant_block_list
Expand Down Expand Up @@ -585,7 +619,7 @@ def quantize(self):
quantized_layers.append(n)
else:
unquantized_layers.append(n)
elif hasattr(m, "scales") or hasattr(m, "scale"): ##packing_immediately
elif hasattr(m, "scales") or hasattr(m, "scale"): ##packing_immediately
quantized_layers.append(n)
summary_info = (
f"Summary: quantized {len(quantized_layers)}/{len(quantized_layers) + len(unquantized_layers)} in the model"
Expand Down
2 changes: 1 addition & 1 deletion auto_round/inference/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def feature_num_greater_checker(in_feature, out_feature, num):
BackendInfos['auto_gptq:cuda'] = BackendInfo(device=["cuda"], sym=[True, False],
packing_format="triton_zp",
bits=[2, 3, 4, 8], group_size=None,
priority=1, feature_checks=[feature_multiply_checker_32],
priority=0, feature_checks=[feature_multiply_checker_32],
alias=["auto_gptq:cuda"],
dtype=["float16"],
convertable_format=["triton_zp"],
Expand Down
2 changes: 1 addition & 1 deletion auto_round/inference/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def post_init(model, used_backends):
for l in data_types[1:]:
common &= set(l)
common = list(common)
if str(model.dtype).split('.')[-1] not in common:
if len(common)>0 and str(model.dtype).split('.')[-1] not in common:
if common[0] == "float16":
model = model.to(torch.float16)
logger.warning("force model to float16")
Expand Down
4 changes: 2 additions & 2 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,10 +522,10 @@ def check_to_quantized(config):
False otherwise.
"""
if isinstance(config, dict):
bits = int(config.get("bits", 4))
bits = int(config.get("bits", 16))
act_bits = int(config.get("act_bits", 16))
else:
bits = int(config.bits) if hasattr(config, "bits") else 4
bits = int(config.bits) if hasattr(config, "bits") else 16
act_bits = int(config.act_bits) if hasattr(config, "act_bits") else 16

return bits <= 8 or act_bits <= 8
Expand Down
14 changes: 8 additions & 6 deletions auto_round/wrapper.py
1E0A
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ class WrapperLinear(torch.nn.Module):
device (str): Device on which to run computations (e.g., 'cpu' or 'cuda').
"""

def __init__(self, orig_layer, enable_minmax_tuning=True, enable_norm_bias_tuning=False, device='cpu', **kwargs):
def __init__(self, orig_layer, enable_minmax_tuning=True, enable_norm_bias_tuning=False, device='cpu',
enable_round_tuning=True, **kwargs):
"""Initializes the WrapperLinear module.

Args:
Expand All @@ -72,6 +73,7 @@ def __init__(self, orig_layer, enable_minmax_tuning=True, enable_norm_bias_tunin
self.output_device = device
self.device = self.orig_layer.tuning_device if hasattr(self.orig_layer, "tuning_device") else device
self.enable_minmax_tuning = enable_minmax_tuning
self.enable_round_tuning = enable_round_tuning
self.enable_norm_bias_tuning = enable_norm_bias_tuning and (orig_layer.bias is not None)
self.enable_act_quant = self.orig_layer.act_bits <= 8 or self._check_act_quantization(
self.orig_layer.act_data_type)
Expand Down Expand Up @@ -108,7 +110,7 @@ def _init_tuning_params_and_quant_func(self):
weight_reshape = reshape_and_pad_tensor(orig_weight.data, orig_layer.group_size)
self.weight_min = torch.clamp(weight_reshape.min(1)[0], max=0)
self.weight_max = torch.clamp(weight_reshape.max(1)[0], min=0)
self._init_params("value", p_dtype, weight_reshape.shape, 0, True)
self._init_params("value", p_dtype, weight_reshape.shape, 0, self.enable_round_tuning)
# Min-max scale initialization
shape = get_scale_shape(orig_weight, orig_layer.group_size)
self._init_params("min_scale", p_dtype, shape, 1.0, self.enable_minmax_tuning)
Expand Down Expand Up @@ -166,7 +168,7 @@ def _qdq_weight(self, value, min_scale, max_scale):
weight = self.orig_layer.get_weight().to(self.device)
if isinstance(self.orig_layer, transformers.modeling_utils.Conv1D):
weight = weight.t()

quant_kwargs = {}
if hasattr(self.orig_layer, "super_bits"):
quant_kwargs["super_bits"] = self.orig_layer.super_bits
Expand All @@ -185,7 +187,7 @@ def _qdq_weight(self, value, min_scale, max_scale):
data_type=self.data_type,
q_scale_thresh=self.q_scale_thresh,
**quant_kwargs
)
)
weight_q = weight_q.to(weight.dtype)
if isinstance(self.orig_layer, transformers.modeling_utils.Conv1D):
weight_q = weight_q.t()
Expand Down Expand Up @@ -257,12 +259,12 @@ def _set_dict_attr(attr_dict, attr_name):
else:
name = "w_" + key
setattr(self.orig_layer, name, attr_dict[key].to("cpu"))

if isinstance(scale, dict):
_set_dict_attr(scale, "scale")
else:
self.orig_layer.scale = scale.reshape(shape[0], -1).to("cpu")

if zp is not None:
if isinstance(zp, dict):
_set_dict_attr(zp, "zp")
Expand Down
42 changes: 42 additions & 0 deletions test/_test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,45 @@ def is_pytest_mode_compile():

def is_pytest_mode_lazy():
return pytest.mode == "lazy"


def model_infer(model, tokenizer, apply_chat_template=False):
prompts = [
"Hello,my name is",
# "The president of the United States is",
# "The capital of France is",
# "The future of AI is",
]
if apply_chat_template:
texts = []
for prompt in prompts:
messages = [
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
texts.append(text)
prompts = texts

inputs = tokenizer(prompts, return_tensors="pt", padding=False, truncation=True)

outputs = model.generate(
input_ids=inputs["input_ids"].to(model.device),
attention_mask=inputs["attention_mask"].to(model.device),
do_sample=False, ## change this to follow official usage
max_new_tokens=5
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs["input_ids"], outputs)
]

decoded_outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

for i, prompt in enumerate(prompts):
print(f"Prompt: {prompt}")
print(f"Generated: {decoded_outputs[i]}")
print("-" * 50)
return decoded_outputs[0]
Loading
Loading
0