Skip to content

[quantization][DRAFT] Disk space consumption improvements for full model quantization#495

Closed
stamalakhov wants to merge 1 commit intoSamsung:mainfrom
stamalakhov:quant_full_model_impr_size
Closed

[quantization][DRAFT] Disk space consumption improvements for full model quantization#495
stamalakhov wants to merge 1 commit intoSamsung:mainfrom
stamalakhov:quant_full_model_impr_size

Conversation

@stamalakhov
Copy link
Contributor

@stamalakhov stamalakhov commented Feb 16, 2026

This PR fixes population of static causal_masks`position_embeddings` through the layers to save disk space.

It precomputes static causal_mask/position_embeddings for using in llama/quant_decoder_layer to prevent populating every quantized decoder layer with these statically computed parameters to save disk space.

Using this PR circle model for HuggingFaceTB/SmolLM2-135M-Instruct is just 105MiB (vs 300 Mib of #492)

Draft: #436

TICO-DCO-1.0-Signed-off-by: s.malakhov s.malakhov@partner.samsung.com

@stamalakhov stamalakhov self-assigned this Feb 16, 2026
@stamalakhov stamalakhov changed the title [quantization][DRAFT] Improvements in disk space for full model quantization [quantization][DRAFT] Disk space consumption improvements for full model quantization Feb 16, 2026
@stamalakhov
Copy link
Contributor Author

@mhs4670go
Could you please take a look? Right now increasing seq_len is increasing the size of the final circle model considerably 😢

Comment on lines +158 to +160
causal_mask = self.layers[0].wrapped.get_attention_mask_for(hidden_states)
causal_mask = self._fq(causal_mask, self.obs_causal_mask)
position_embeddings = self.layers[0].wrapped.get_position_embeddings_for(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current wrappers for a decoder layer and an attention creates their own masks because of having self-contained attributes.

How about just creating its own mask and embeddings instead of using first layer's? It needs some duplicate codes but can remove a dependency of the first layer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about just creating its own mask and embeddings instead of using first layer's? It needs some duplicate codes but can remove a dependency of the first layer.

@mhs4670go
Ok. I'll fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go
Fixed.

@stamalakhov stamalakhov force-pushed the quant_full_model_impr_size branch 5 times, most recently from 19aca2b to 5442616 Compare February 20, 2026 12:58
) -> Union[Tuple, CausalLMOutputWithPast]:
# fixed input size, due to position_ids fixed
orig_len = input_ids.shape[-1]
input_ids = fix_inputs(self, self.tokenizer, input_ids)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, fix_inputs can be removed.

# to prevent introduction of attention_mask as a parameter let's use preset attention_mask
L = hidden_states.size(1)
attention_mask = self._slice_causal(L, hidden_states.device)
if attention_mask is None or attention_mask.dtype == torch.bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this condition come again?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go
Ahhh. Sorry.

  • It was recently removed from quant_decoder_layer.py to have fully quantized model( , because causal_mask received from modeling_llama.py of transformers was float, so to have a fully integer model, the line 206 was removed).
  • This draft uses quantized causal_mask from quant_model.py so the check can be restored to have a chance to convert decoder_layer like this tico.convert(layer, (inp,)) without causal_mask in parameters.
  • In case it's left as it is (no check), all decoder layers will be populated with their own attention_masks which will be disk consuming.

Comment on lines +214 to +215
self._fq(cos.unsqueeze(1), self.obs_cos),
self._fq(sin.unsqueeze(1), self.obs_sin),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change for constant folding of unsqueeze?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go
Ahh. Actually these lines are introduced to be the same as in quant_model.py to make the codes consistent. It occured that transforms that are done inside quant_attn produced another constants during tracing. Originally these lines were located at 164, 165 lines of quant_attn.py (you can find them in this draft)

      cos_u = cos.unsqueeze(unsqueeze_dim)
      sin_u = sin.unsqueeze(unsqueeze_dim)

So i summed all of the transforms and applied them once to prevent their populating.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go
The latest version of quant_attn.py doesn't use unsqueeze_dim , so it can be removed from here also.

Comment on lines +164 to +165
if hasattr(cur_layer, "copy_quantizers"):
cur_layer.copy_quantizers(q_m.wrapped.model.wrapped)
Copy link
Contributor

@mhs4670go mhs4670go Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Umm.. how about introducing an api that copies observers when it's really needed? I think wrappers having copy_quantizers method seems not proper. We can just export just a full qmodel only instead of all layers here in this script.

The script has come in because we need static buffers to be shared. So, exporting a single decoder can be done in other scripts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go
Sure. It was just an attempt to have all layers as they are in fully quantized model. I'll remove it. Thank you!

@stamalakhov stamalakhov force-pushed the quant_full_model_impr_size branch 9 times, most recently from 003d8f4 to 5ee8d8c Compare February 26, 2026 10:17
@stamalakhov stamalakhov force-pushed the quant_full_model_impr_size branch 4 times, most recently from 28a32e2 to 6ac97c0 Compare February 27, 2026 08:01
…el quantization

This PR quantizes the full `LLama` model and converts it to circle format.

TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
@stamalakhov
Copy link
Contributor Author

It's now merged, so we can close this one.

@stamalakhov stamalakhov deleted the quant_full_model_impr_size branch February 27, 2026 11:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants