

【Learning VLA】Understand π0 codebase framework【I】
Understand π0 implementation on github
Model architecture overview#
-
Vision-Language Backbone: Based on PaliGemma, a pre-trained VLM that processes visual inputs and language instructions. This backbone leverages 400M images from SigLIP pre-training and 2.6B tokens from the Gemma language model.
-
Action Expert: A specialized module that translates the visual and language representations into robot actions. This component is trained on approximately 300M parameters and uses flow matching to generate continuous action distributions appropriate for robotic control.
VLM initialization#
llm = nnx_bridge.ToNNX(
_gemma.Module(
configs=[paligemma_config, action_expert_config],
embed_dtype=config.dtype,
)
)
img = nnx_bridge.ToNNX(
_siglip.Module(
num_classes=paligemma_config.width,
variant="So400m/14",
pool_type="none",
scan=True,
dtype_mm=config.dtype,
)
)
self.PaliGemma = nnx.Dict(llm=llm, img=img)
pythonToken calculation#
prefix embedding#
- visual & language inputs
def embed_prefix(self, obs: _model.Observation):
# Embed images
for name in obs.images:
image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False)
tokens.append(image_tokens)
# ... handle masks and attention
# Add language inputs
if obs.tokenized_prompt is not None:
tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method="embed")
tokens.append(tokenized_inputs)
# ... handle masks and attention
pythonsuffix embedding#
- action & state information
- projection network initialization
self.state_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
self.action_in_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
self.action_time_mlp_in = nnx.Linear(2 * action_expert_config.width, action_expert_config.width, rngs=rngs)
self.action_time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
self.action_out_proj = nnx.Linear(action_expert_config.width, config.action_dim, rngs=rngs)
python- calculate embedding
@at.typecheck
def embed_suffix(
self, obs: _model.Observation, noisy_actions: _model.Actions, timestep: at.Float[at.Array, " b"]
) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Bool[at.Array, " s"]]:
input_mask = []
ar_mask = []
tokens = []
# 1. Project state to token space
state_token = self.state_proj(obs.state)[:, None, :]
tokens.append(state_token)
input_mask.append(jnp.ones((obs.state.shape[0], 1), dtype=jnp.bool_))
ar_mask += [True] # State token can't attend to previous tokens
# 2. Project actions to token space
action_tokens = self.action_in_proj(noisy_actions)
# 3. Create time embeddings using sine-cosine positional encoding
time_emb = posemb_sincos(timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0)
time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=self.action_horizon)
# 4. Mix action and time information
action_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1)
action_time_tokens = self.action_time_mlp_in(action_time_tokens)
action_time_tokens = nnx.swish(action_time_tokens)
action_time_tokens = self.action_time_mlp_out(action_time_tokens)
tokens.append(action_time_tokens)
input_mask.append(jnp.ones(action_time_tokens.shape[:2], dtype=jnp.bool_))
ar_mask += [True] + ([False] * (self.action_horizon - 1))
pythonFlow Matching#
Inference#
sample actions#
- flow matching
given t=1, apply velocity to noise, until t=0
- code reading
def step(carry)
: given the current (, time), determine the next (, time)
use KV caching to process the prefix only once
@override
def sample_actions(
self,
rng,
observation,
*,
num_steps
) -> _model.Actions:
observation = _model.preprocess_observation(None, observation, train=False)
# note that we use the convention more common in diffusion literature, where t=1 is noise and t=0 is the target
# distribution. yes, this is the opposite of the pi0 paper, and I'm sorry.
dt = -1.0 / num_steps
batch_size = observation.state.shape[0]
noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim))
# first fill KV cache with a forward pass of the prefix
prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
positions = jnp.cumsum(prefix_mask, axis=1) - 1
_, kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions)
def step(carry):
x_t, time = carry
suffix_tokens, suffix_mask, suffix_ar_mask = self.embed_suffix(
observation, x_t, jnp.broadcast_to(time, batch_size)
)
# `suffix_attn_mask` is shape (b, suffix_len, suffix_len) indicating how the suffix tokens can attend to each other
suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)
# `prefix_attn_mask` is shape (b, suffix_len, prefix_len) indicating how the suffix tokens can attend to the prefix tokens
prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1])
# `combined_mask` is shape (b, suffix_len, prefix_len + suffix_len) indicating how the suffix tokens (which generate the queries) can attend to the full prefix + suffix sequence (which generates the keys and values)
full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1)
assert full_attn_mask.shape == (
batch_size,
suffix_tokens.shape[1],
prefix_tokens.shape[1] + suffix_tokens.shape[1],
)
# `positions` is shape (b, suffix_len) indicating the positions of the suffix tokens
positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1
(prefix_out, suffix_out), _ = self.PaliGemma.llm(
[None, suffix_tokens], mask=full_attn_mask, positions=positions, kv_cache=kv_cache
)
assert prefix_out is None
v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
return x_t + dt * v_t, time + dt
def cond(carry):
x_t, time = carry
# robust to floating-point error
return time >= -dt / 2
x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0))
return x_0
pythonTraining#
- flow matching loss
-
generate random noise
-
sample time points from beta distribution
-
interpolate between noise and gt actions ->
: action of the current timestep; the model takes as input
-
calculate gt velocity
-
get model predictions
-
project model output to velocity space
-
compute MSE loss between predicted and gt velocity
def compute_loss(self, rng, observation, actions, *, train=False):
# Generate noise and time samples
noise = jax.random.normal(noise_rng, actions.shape)
time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001
# Interpolate between noise and actions
time_expanded = time[..., None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
# Process through model
prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
suffix_tokens, suffix_mask, suffix_ar_mask = self.embed_suffix(observation, x_t, time)
# Compute attention and get model output
attn_mask = make_attn_mask(input_mask, ar_mask)
positions = jnp.cumsum(input_mask, axis=1) - 1
(prefix_out, suffix_out), _ = self.PaliGemma.llm(
[prefix_tokens, suffix_tokens], mask=attn_mask, positions=positions
)
v_t = self.action_out_proj(suffix_out[:, -self.action_horizon:])
# Compute MSE loss between predicted and target velocity
return jnp.mean(jnp.square(v_t - u_t), axis=-1)
python