Gemma 3 support for packed sequence training with FlashAttention 2?

Hi everyone,

I have a question about finetuning Gemma 3 models on Hugging Face with sequence packing and FlashAttention 2.

From the recent Hugging Face “packing + FA2” work, it looks like some architectures support training with:

  • Packed sequences (multiple samples concatenated into a single long sequence)

  • Variable‑length attention via cumulative sequence lengths (e.g. cu_seqlens) instead of a classic 0/1 padding mask

  • Proper attention isolation between individual sequences inside the same pack, so that tokens from one sample cannot attend to tokens from another

I would like to clarify how this applies specifically to Gemma 3 models on Hugging Face:

  1. Do the current Gemma 3 implementations (e.g. gemma-3-... on HF) officially support training with packed sequences when using FlashAttention 2?

  2. If yes, what is the expected interface?

    • Should the trainer/collator provide a standard binary attention_mask of shape [batch_size, seq_len], and the FA2 integration internally derives cumulative sequence lengths?

    • Or is there a supported variant where the model is driven by cumulative sequence lengths / cu_seqlens (e.g. via position_ids or another field) and no binary mask is passed?

  3. Finally, is there any official example or recommended configuration (Trainer/TRL/SFTTrainer + collator) that demonstrates:

    • Gemma 3 + FlashAttention2

    • Sequence packing

    • Correct per‑sequence isolation in a packed batch

I am fine implementing a custom collator (e.g. flattening multiple examples into a single sequence and computing the right metadata), but I would like to align with the intended / supported behavior for Gemma 3 rather than relying on assumptions from other architectures.

Thanks in advance!

1 Like

Hi @SyrineM

  1. Gemma-family models are officially supported for sequence packing with FlashAttention 2 in Hugging Face Transformers, as long as the model exposes position_ids, which Gemma does.

  2. In a standard Hugging Face Trainer flow, your collator should provide flattened input_ids with position_ids that reset to 0 at the start of each packed sequence. The model’s FlashAttention-2 integration uses these position resets to internally derive cu_seqlens, so you do not need to manually compute or pass cumulative sequence lengths. and standard binary attention_mask is typically unnecessary.

  3. Please check out these references
    Improving Hugging Face Training Efficiency Through Packing with Flash Attention 2
    https://ai.google.dev/gemma/docs/core/huggingface_text_full_finetune

    Thanks