diff --git a/backends/openvino/quantizer/__init__.py b/backends/openvino/quantizer/__init__.py index e819aaf5159..2c76bbc5b7a 100644 --- a/backends/openvino/quantizer/__init__.py +++ b/backends/openvino/quantizer/__init__.py @@ -1,9 +1,10 @@ -from .llm_compression import apply_nncf_data_aware_compression +from .llm_compression import apply_nncf_data_aware_compression, apply_nncf_data_aware_compression_from_builder from .quantizer import OpenVINOQuantizer, QuantizationMode, quantize_model __all__ = [ "OpenVINOQuantizer", "quantize_model", "QuantizationMode", + "apply_nncf_data_aware_compression_from_builder", "apply_nncf_data_aware_compression", ] diff --git a/backends/openvino/quantizer/llm_compression.py b/backends/openvino/quantizer/llm_compression.py index 1737f638bf9..a6f98c58e07 100644 --- a/backends/openvino/quantizer/llm_compression.py +++ b/backends/openvino/quantizer/llm_compression.py @@ -6,9 +6,12 @@ # mypy: disable-error-code=import-not-found -from typing import Tuple +import logging +from typing import Optional, Tuple +import random import torch +from datasets import load_dataset # type: ignore[import-untyped] from executorch.extension.llm.export.builder import LLMEdgeManager from torchao.quantization.pt2e.quantizer import Quantizer @@ -18,10 +21,20 @@ except ImportError: raise ImportError("Please install nncf via backends/openvino/requirements.txt") +TASK_TO_HF_DATASET = { + "wikitext": { + "path": "Salesforce/wikitext", + "name": "wikitext-2-raw-v1", + "split": "train", + }, +} + -# This code is adapted from https://github.com/pytorch/executorch/blob/0c54fd0483314da173f8e14d63d2ed9591c7133a/extension/llm/export/builder.py#L278 def get_calibration_data( - module: torch.fx.GraphModule, tokenizer, prompts: str, max_len: int + tokenizer, + data: str, + nsamples: int, + seqlen: int, ): """ This method is used to obtain calibration data from a prompt so that the algorithm @@ -29,27 +42,18 @@ def get_calibration_data( the model. Currently, this method is only tested with Llama models. """ - # TODO: change criteria & support batch inputs if necessary - pos = 0 - token_list = tokenizer.encode(prompts, bos=True, eos=False) - - with torch.no_grad(): - while token_list[-1] != tokenizer.eos_id and pos < max_len: - logits = module( - torch.full((1, 1), token_list[pos]), - {"input_pos": torch.tensor((pos,))}, - ) - pos += 1 - if pos >= len(token_list): - token_list.append(torch.argmax(logits[:], dim=-1).item()) - token_list = [ - ( - torch.tensor(pos, dtype=torch.int64), - token, - ) - for pos, token in enumerate(token_list) - ] - return token_list + # Copied from optimum.gptq.data.get_wikitext2 with added computation of `limit` variable: + limit = nsamples * seqlen // 4 # ~1k for 128 samples with seqlen=32 to be aligned with optimum + text = "".join([" \n" if s == "" else s for s in data["text"][:limit]]) + + enc = tokenizer.encode(text, bos=True, eos=False) + dataset = [] + for _ in range(nsamples): + i = random.randint(0, len(enc) - seqlen - 1) + j = i + seqlen + inp = enc[i:j] + dataset.extend([(token, pos) for pos, token in enumerate(inp)]) + return dataset def transform_fn(token_pos_map: Tuple[int, int]): @@ -60,74 +64,147 @@ def transform_fn(token_pos_map: Tuple[int, int]): :param token_pos_map: This input contains the position and its token ID """ inputs = ( - torch.tensor([[token_pos_map[1]]]), - {"input_pos": torch.tensor([token_pos_map[0]])}, + torch.tensor([[token_pos_map[0]]]), + {"input_pos": torch.tensor([token_pos_map[1]])}, ) return inputs -def apply_nncf_data_aware_compression( - builder_exported: LLMEdgeManager, +def _build_nncf_calibration_dataset( + calibration_task: Optional[str], + tokenizer, + seq_len: Optional[int], + subset_size: Optional[int], + awq: bool, + scale_estimation: bool, +): + if not (awq or scale_estimation): + return None + + if subset_size is None or subset_size <= 0: + raise ValueError("subset_size must be a positive integer when calibration is enabled.") + + has_calibration_inputs = ( + calibration_task is not None and tokenizer is not None and seq_len is not None + ) + + # Scale estimation requires full calibration setup. + if scale_estimation and not has_calibration_inputs: + missing_params = [] + if calibration_task is None: + missing_params.append("calibration_task") + if tokenizer is None: + missing_params.append("tokenizer") + if seq_len is None: + missing_params.append("seq_len") + raise ValueError( + "Missing required calibration parameter(s): " + + ", ".join(missing_params) + + ". Please provide calibration_task, tokenizer, and seq_len." + ) + + if not has_calibration_inputs: + return None + + if calibration_task not in TASK_TO_HF_DATASET: + raise ValueError( + f"Unsupported calibration task: {calibration_task}. Supported tasks are: {list(TASK_TO_HF_DATASET.keys())}" + ) + + dataset = load_dataset(**TASK_TO_HF_DATASET[calibration_task]) + calibration_data = get_calibration_data( + tokenizer, + dataset, + subset_size, + seq_len, + ) + + return nncf.Dataset( + calibration_data, + transform_func=transform_fn, + ) + + +def apply_nncf_data_aware_compression_from_builder( + builder: LLMEdgeManager, quantizer: Quantizer, awq: bool, scale_estimation: bool, ) -> LLMEdgeManager: + """ + Applies NNCF data-aware weight compression to the exported LLM graph using the builder's configuration. + :param builder: LLMEdgeManager containing the pre-autograd graph module and calibration configuration. + :param quantizer: TorchAO quantizer to use for compression. + :param awq: If True, enables Activation-aware Weights Quantization (AWQ). + :param scale_estimation: If True, enables NNCF's scale estimation algorithm. + :param calibration_task: Optional task key for calibration dataset (e.g. "wikitext", "c4", "gsm8k"). + :param subset_size: Optional max number of samples from the calibration dataset to use for calibration. + :return: LLMEdgeManager with compressed pre-autograd graph module. + """ + tokenizer_path = builder.tokenizer_path + tokenizer = get_tokenizer(tokenizer_path) if tokenizer_path is not None else None + compressed_model = apply_nncf_data_aware_compression( + model=builder.pre_autograd_graph_module, + quantizer=quantizer, + awq=awq, + scale_estimation=scale_estimation, + tokenizer=tokenizer, + ) + builder.pre_autograd_graph_module = compressed_model + return builder + + +def apply_nncf_data_aware_compression( + model: torch.fx.GraphModule, + quantizer: Quantizer, + awq: bool, + scale_estimation: bool, + calibration_task: Optional[str] = "wikitext", + tokenizer: Optional[str] = None, + seq_len: Optional[int] = 32, + subset_size: Optional[int] = 128, +) -> torch.fx.GraphModule: """ Applies NNCF data-aware weight compression to the exported LLM graph. Uses the builder's tokenizer and calibration prompt to generate token-level calibration data, then runs `nncf.experimental.torch.fx.compress_pt2e` with the given quantizer and optional AWQ / scale estimation enabled. - :param builder_exported: LLMEdgeManager containing the FX graph, tokenizer path, - calibration prompt, and max sequence length. + :param model: torch.fx.GraphModule to be compressed. :param quantizer: TorchAO quantizer to use for compression. :param awq: If True, enables Activation-aware Weights Quantization (AWQ). :param scale_estimation: If True, enables NNCF's scale estimation algorithm. - :return: The updated LLMEdgeManager with compressed torch FX model + :param calibration_task: Optional task key for calibration dataset when passing + GraphModule directly (e.g. "wikitext", "c4", "gsm8k"). + :param tokenizer: Optional tokenizer when passing GraphModule directly. + :param seq_len: Optional max sequence length of each calibration prompt when passing GraphModule directly. + :param subset_size: Optional max number of samples from the calibration dataset to use for calibration. + Default is 128. This is high because it is token-level data, not sample-level. The number of tokens is much higher than the number of samples. + :return: Compressed torch FX model. """ - nncf_calibration_data = None - if ( - builder_exported.calibration_seq_length is not None - and builder_exported.calibration_data is not None - and builder_exported.tokenizer_path is not None - and (awq or scale_estimation) - ): - tokenizer = get_tokenizer(builder_exported.tokenizer_path) - nncf_calibration_data = nncf.Dataset( - get_calibration_data( - builder_exported.pre_autograd_graph_module, # type: ignore[arg-type] - tokenizer, - builder_exported.calibration_data, - builder_exported.calibration_seq_length, - ), - transform_func=transform_fn, - ) + if not quantizer: + logging.info("No quantizer provided, skipping NNCF compression.") + return model + + nncf_calibration_data = _build_nncf_calibration_dataset( + calibration_task=calibration_task, + tokenizer=tokenizer, + seq_len=seq_len, + subset_size=subset_size, + awq=awq, + scale_estimation=scale_estimation, + ) - # AWQ can work without a dataset as well. - if scale_estimation and not nncf_calibration_data: - missing_params = [] - if builder_exported.calibration_data is None: - missing_params.append("calibration_data") - if builder_exported.calibration_seq_length is None: - missing_params.append("calibration_seq_length") - if builder_exported.tokenizer_path is None: - missing_params.append("tokenizer_path") - if missing_params: - msg = ( - "Missing required calibration parameter(s): " - + ", ".join(missing_params) - + ". Please provide calibration_data, calibration_seq_length, and tokenizer_path." - ) - raise ValueError(msg) - - builder_exported.pre_autograd_graph_module = ( - nncf.experimental.torch.fx.compress_pt2e( - builder_exported.pre_autograd_graph_module, - quantizer=quantizer, - dataset=nncf_calibration_data, - awq=awq, - scale_estimation=scale_estimation, - ) + # Since it is a static model, each input is a single token. + total_calibration_dataset_size = subset_size * seq_len + compressed_model = nncf.experimental.torch.fx.compress_pt2e( + model, + quantizer=quantizer, + dataset=nncf_calibration_data, + awq=awq, + scale_estimation=scale_estimation, + subset_size=total_calibration_dataset_size, ) - return builder_exported + + return compressed_model diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 0394bf7f320..8cc7315c652 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -997,11 +997,14 @@ def _to_edge_and_lower_llama_openvino( for partitioner in partitioners: logging.info(f"--> {partitioner.__class__.__name__}") - from executorch.backends.openvino.quantizer import apply_nncf_data_aware_compression + from executorch.backends.openvino.quantizer import apply_nncf_data_aware_compression_from_builder logging.info(f"Applying AWQ = {awq}, Scale Estimation = {scale_estimation}") - builder = apply_nncf_data_aware_compression( - builder_exported, quantizers[0], awq, scale_estimation + quantizer = None + if(quantizers): + quantizer = quantizers[0] + builder = apply_nncf_data_aware_compression_from_builder( + builder_exported, quantizer, awq, scale_estimation ) builder = builder.to_edge_transform_and_lower(partitioners)