I am printing the shape of the output tensor of the MLP block during causal inference of Gemma model for a given input. What I observe is that during first token generation, the shape is (batch_size, input_seq_length, hidden_size), but from the subsequent token generations, the shape changes to (batch_size, 1, hidden_size). For example, consider a given input sequence of length 5 and a desired output length of 2:

enter image description here

Why does this happen? My understanding is that during the first token inference, the model processes the entire input sequence through a Gemma_Decoder block, generating a <SOS> (Start of Sentence) token while obtaining token embeddings for each input sequence. However, for subsequent token generations, it only utilizes the last token generated to produce a new token, retrieving information about previous tokens through the kv cache built over time during inference.

I would love to understand it in more depth, so if anyone can provide with links to resources, it would be of great help.