

【Generative Model】Look-into Video-Prediction-Policy codebase
Understand VPO implementation on github
views
| comments
Credits#
- link: VPO repo ↗
@article{hu2024video,
title={Video Prediction Policy: A Generalist Robot Policy with Predictive Visual Representations},
author={Hu, Yucheng and Guo, Yanjiang and Wang, Pengchao and Chen, Xiaoyu and Wang, Yen-Jen and Zhang, Jianke and Sreenath, Koushil and Lu, Chaochao and Chen, Jianyu},
journal={arXiv preprint arXiv:2412.14803},
year={2024}
}
plaintextTraining#
text conditioning in SVP with CLIP#
-
CLIP:
- tokenize
- call the encoder
- add positional encodings
-
concat with the image embeddings
details in
def encode_text(...)
Feature Extractions in SVP#
- The UNet processes video frames through multiple layers
Overall#
- in
policy_models.VPP_policy
class VPP_Policy(...):
def __init__(...):
# img encoder
self.TVP_encoder = Diffusion_feature_extractor(pipeline=pipeline,
tokenizer=tokenizer,
text_encoder=text_encoder,
position_encoding = self.use_position_encoding)
# goal encoder
self.language_goal = LangClip(model_name='ViT-B/32').to(self.device)
...
def extract_predictive_featue(self, dataset_batch):
...
latent_goal = self.language_goal(dataset_batch["lang"])
with torch.no_grad():
input_rgb = torch.cat([rgb_static, rgb_gripper], dim=0)
language = language + language
perceptual_features = self.TVP_encoder(input_rgb, language, self.timestep, self.extract_layer_idx, all_layer=self.use_all_layer, step_time=1, max_length=self.max_length)
perceptual_features = self.Video_Former(perceptual_features)
...
return predictive_feature, latent_goal
def training_step(self, dataset_batch):
predictive_feature, latent_goal = self.extract_predictive_feature(dataset_batch)
act_loss, _, _ = self.diffusion_loss(
predictive_feature.
latent_goal,
dataset_batch["actions"],
)
pythondetails in SVP#
- features are extracted at specific layers (controlled by
extract_layer_idx
)
class VPP_Policy(pl.LightningModule):
def __init__(
self,
latent_dim: int = 512,
use_Former: str = '3d',
extract_layer_idx: int = 1, # Controls which layer to extract features from
use_all_layer: bool = False, # Option to use all layers
action_dim: int = 7,
action_seq_len: int = 10,
):
python- in
policy_models.module.diffusion_extract
class Diffusion_feature_extractor(nn.Module):
def forward(
self,...
extract_layer_idx,...
)
...
for i, t in enumerate(timesteps):
...
feature_pred = self.step_unet(
...
use_layer_idx=extract_layer_idx
)[0]
...
return feature_pred
pythoncross-attention and self-attention on features in VideoFormer#
- These features capture both spatial and temporal information
if use 3d,
self.Video_Former = Video_Former_3D(...)
class Video_Former_3D(nn.Module):
def __init__(
self,
dim: int,
depth: int,
condition_dim: int = 1280,
dim_head: int = 64,
heads: int = 8,
num_latents: int = 64,
num_frame: int = 16,
num_time_embeds: int = 4,
use_temporal: bool = False,
):
# ...
self.layers = nn.ModuleList([])
if self.use_temporal:
for _ in range(depth):
self.layers.append(
nn.ModuleList([
PerceiverAttentionLayer(dim=dim, dim_head=dim_head, heads=heads),
Attention(dim, num_heads=heads, qkv_bias=True, use_cross_attn=False,
y_dim=512, attn_mask=attn_mask),
feed_forward_layer(dim=dim, mult=ff_mult, activation=activation),
])
)
def forward(x_f, mask): # x_f: input visual embedding
# 1. Mask the position embeddings for the padded frames
...
# 2. Apply the position embeddings
...
x_f = x_f + time_pos_emb
# 3. Apply attention and feed-forward layer
for attn, temp_attn, ffw in self.layers:
x = x + attn(x_f, x)
x = rearrange(x, '(b T) q d -> (b q) T d', b = batch_size)
x = x + Temp_attn(x)
x = rearrange(x, '(b q) T d -> (b T) q d', b = batch_size)
x = x + ffw(x)
x = x.reshape(batch_size, -1 ,x.shape[1],x.shape[2])
x = rearrange(x, 'b T q d -> b (T q) d')
norm = self.norm(x)
return norm
pythonclass PerceiverAttentionLayer(nn.Module):
def forward(self, features, latents):
# Layer normalization
x = self.norm_media(features)
latents = self.norm_latents(latents)
# Cross-attention
q = self.to_q(latents)
kv_input = toch.cat((x, latents), dim=-2)
k = self.to_k(kv_input)
v = self.to_v(kv_input)
# attention scores
sim = eimsum('b h q d, b h f d -> b h q f', q, k)
alphas = sim.softmax(dim=-1)
out = einsum('b h q f, b h f v -> b h q v', alphas, v)
out = rearrange(out, 'b h q v -> b q (h v)')
return self.to_out(out)
pythonclass Attention(...):
def forward(self, x):
# Perform self-attention
pythondiffusion loss in DiT#
diffusion loss#
class VPP_Policy(...):
def diffusion_loss(
self,
perceptual_emb,
latent_goal,
actions
):
sigmas = self.male_sample_density()(shape=(len(actions),))
noise = torch.rand_like(actions)
loss = self.model.loss(perceptual_emb, actions, latent_goal, noise, sigmas)
return loss, sigmas, noise
python
self.model = GCDenoiser()
score matching loss#
The loss effectively implements score matching, where the model learns to predict the score (gradient of log probability) of the data distribution.
class GCDenoiser(nn.Module):
def loss(self, state, action, goal, noise, sigma, **kwargs):
c_skip, c_out, c_in = [append_dims(x, action.ndim) for x in self.get_scalings(sigma)]
noised_input = action + noise * append_dims(sigma, action.ndim)
model_output = self.inner_model(state, noised_input * c_in, goal, sigma, **kwargs)
target = (action - c_skip * noised_input) / c_out # NOTE: target score
return (model_output - target).pow(2).flatten(1).mean(), model_output
pythoninner_model
architecture#
Inputs:
- States (visual observations)
- Goals (language instructions)
- Actions (current actions)
- Sigma (noise level)
Process:
-
Encode context: States + Goals → Encoder → Context
-
Decode actions: Actions + Context + Sigma → Decoder → Predicted Actions
class GCDenoiser(nn.Module):
def _init__(...):
self.inner_model = DiffusionTransformer(
action_dim = action_dim,
obs_dim = obs_dim,
goal_dim = goal_dim,
proprio_dim= proprio_dim,
goal_conditioned = True,
...
)
python- in
policy_models.module.diffusion_decoder
class DiffusionTransformer(nn.Module):
def __init__(...):
self.encoder = TransformerEncoder(...)
self.decoder = TransformerFiLMDecoder(...)
self.proprio_emb = nn.Sequential(
nn.Linear(...)
nn.Mish(),
nn.Linear(...)
)
self.sigma_emb = ...
self.action_emb = nn.Linear(...)
def forward(self, states, actions, goals, signa, uncond: Optional[bool]):
# actions: actually noises (or noised_input for training)
context = self.forward_enc_only(states, actions, goals, sigma, uncond)
pred_actoins = self.forward_dec_only(context, actions, sigma)
return pred_actions
pythonInference#
class VPP_Policy:
def eval_forward(self, obs, goals)
act_seq = self.denoise_actions(
torch.zeros_like(latent_goal).to(latent_goal.device),
perceptual_emb,
latent_goal,
inference=True,
)
python-
denoise_actions
depend on the specific samplers -
TODO: look into the different samplers