Lingyun's website

Back

【Learning VLA】Understand π0 codebase framework【I】Blur image

Model architecture overview#

  1. Vision-Language Backbone: Based on PaliGemma, a pre-trained VLM that processes visual inputs and language instructions. This backbone leverages 400M images from SigLIP pre-training and 2.6B tokens from the Gemma language model.

  2. Action Expert: A specialized module that translates the visual and language representations into robot actions. This component is trained on approximately 300M parameters and uses flow matching to generate continuous action distributions appropriate for robotic control.

VLM initialization#

llm = nnx_bridge.ToNNX(
    _gemma.Module(
        configs=[paligemma_config, action_expert_config],
        embed_dtype=config.dtype,
    )
)
img = nnx_bridge.ToNNX(
    _siglip.Module(
        num_classes=paligemma_config.width,
        variant="So400m/14",
        pool_type="none",
        scan=True,
        dtype_mm=config.dtype,
    )
)
self.PaliGemma = nnx.Dict(llm=llm, img=img)
python

Token calculation#

prefix embedding#

  • visual & language inputs
def embed_prefix(self, obs: _model.Observation):
    # Embed images
    for name in obs.images:
        image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False)
        tokens.append(image_tokens)
        # ... handle masks and attention
    
    # Add language inputs
    if obs.tokenized_prompt is not None:
        tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method="embed")
        tokens.append(tokenized_inputs)
        # ... handle masks and attention
python

suffix embedding#

  • action & state information
  1. projection network initialization
self.state_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)

self.action_in_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)

self.action_time_mlp_in = nnx.Linear(2 * action_expert_config.width, action_expert_config.width, rngs=rngs)

self.action_time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)

self.action_out_proj = nnx.Linear(action_expert_config.width, config.action_dim, rngs=rngs)
python
  1. calculate embedding
@at.typecheck
def embed_suffix(
    self, obs: _model.Observation, noisy_actions: _model.Actions, timestep: at.Float[at.Array, " b"]
) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Bool[at.Array, " s"]]:
    input_mask = []
    ar_mask = []
    tokens = []
    
    # 1. Project state to token space
    state_token = self.state_proj(obs.state)[:, None, :]
    tokens.append(state_token)
    input_mask.append(jnp.ones((obs.state.shape[0], 1), dtype=jnp.bool_))
    ar_mask += [True]  # State token can't attend to previous tokens

    # 2. Project actions to token space
    action_tokens = self.action_in_proj(noisy_actions)
    
    # 3. Create time embeddings using sine-cosine positional encoding
    time_emb = posemb_sincos(timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0)
    time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=self.action_horizon)
    
    # 4. Mix action and time information
    action_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1)
    action_time_tokens = self.action_time_mlp_in(action_time_tokens)
    action_time_tokens = nnx.swish(action_time_tokens)
    action_time_tokens = self.action_time_mlp_out(action_time_tokens)
    
    tokens.append(action_time_tokens)
    input_mask.append(jnp.ones(action_time_tokens.shape[:2], dtype=jnp.bool_))
    ar_mask += [True] + ([False] * (self.action_horizon - 1))
python

Flow Matching#

Inference#

sample actions#

  • flow matching

given t=1, apply velocity to noise, until t=0

  • code reading

def step(carry): given the current (xtx_t, time), determine the next (xtx_t, time)

use KV caching to process the prefix only once

@override
def sample_actions(
	self, 
	rng, 
	observation,
	*,
	num_steps
) -> _model.Actions:
    observation = _model.preprocess_observation(None, observation, train=False)
    # note that we use the convention more common in diffusion literature, where t=1 is noise and t=0 is the target
    # distribution. yes, this is the opposite of the pi0 paper, and I'm sorry.
    dt = -1.0 / num_steps
    batch_size = observation.state.shape[0]
    noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim))

    # first fill KV cache with a forward pass of the prefix
    prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
    prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
    positions = jnp.cumsum(prefix_mask, axis=1) - 1
    _, kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions)

    def step(carry):
        x_t, time = carry
        suffix_tokens, suffix_mask, suffix_ar_mask = self.embed_suffix(
            observation, x_t, jnp.broadcast_to(time, batch_size)
        )
        # `suffix_attn_mask` is shape (b, suffix_len, suffix_len) indicating how the suffix tokens can attend to each other
        suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)
        # `prefix_attn_mask` is shape (b, suffix_len, prefix_len) indicating how the suffix tokens can attend to the prefix tokens
        prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1])
        # `combined_mask` is shape (b, suffix_len, prefix_len + suffix_len) indicating how the suffix tokens (which generate the queries) can attend to the full prefix + suffix sequence (which generates the keys and values)
        full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1)
        assert full_attn_mask.shape == (
            batch_size,
            suffix_tokens.shape[1],
            prefix_tokens.shape[1] + suffix_tokens.shape[1],
        )
        # `positions` is shape (b, suffix_len) indicating the positions of the suffix tokens
        positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1

        (prefix_out, suffix_out), _ = self.PaliGemma.llm(
            [None, suffix_tokens], mask=full_attn_mask, positions=positions, kv_cache=kv_cache
        )
        assert prefix_out is None
        v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])

        return x_t + dt * v_t, time + dt

    def cond(carry):
        x_t, time = carry
        # robust to floating-point error
        return time >= -dt / 2

    x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0))
    return x_0
python

Training#

  • flow matching loss
  1. generate random noise

  2. sample time points from beta distribution

  3. interpolate between noise and gt actions -> xtx_t

xtx_t: action of the current timestep; the model takes xtx_t as input

  1. calculate gt velocity

  2. get model predictions

  3. project model output to velocity space

  4. compute MSE loss between predicted and gt velocity

def compute_loss(self, rng, observation, actions, *, train=False):
    # Generate noise and time samples
    noise = jax.random.normal(noise_rng, actions.shape)
    time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001
    
    # Interpolate between noise and actions
    time_expanded = time[..., None, None]
    x_t = time_expanded * noise + (1 - time_expanded) * actions
    u_t = noise - actions
    
    # Process through model
    prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
    suffix_tokens, suffix_mask, suffix_ar_mask = self.embed_suffix(observation, x_t, time)
    
    # Compute attention and get model output
    attn_mask = make_attn_mask(input_mask, ar_mask)
    positions = jnp.cumsum(input_mask, axis=1) - 1
    (prefix_out, suffix_out), _ = self.PaliGemma.llm(
        [prefix_tokens, suffix_tokens], mask=attn_mask, positions=positions
    )
    v_t = self.action_out_proj(suffix_out[:, -self.action_horizon:])
    
    # Compute MSE loss between predicted and target velocity
    return jnp.mean(jnp.square(v_t - u_t), axis=-1)
python
【Learning VLA】Understand π0 codebase framework【I】
https://avidjoycexu.github.io/blog/pi-0
Author Lingyun Xu
Published at May 5, 2025
Comment seems to stuck. Try to refresh?✨