Completed & Archived ✓ v1/

v1: VAE + GRU
MPC Planner

The first generation world model - a Variational Autoencoder compressing Game Boy frames to 32-dimensional continuous latents, paired with a GRU dynamics model and a lookahead MPC planner that reasons about imagined futures.

Motivation & Framing

v1 establishes the core thesis of PokéDreamer: that an agent can reason about Pokémon Red using a learned dynamics model rather than requiring access to the real emulator at every decision step. This is the fundamental distinction between model-based and model-free RL.

The Key Question
Can you point at a component and say "this learned (s_t, a_t) → s_{t+1} from data, and the planner used its rollouts - not the real emulator - to decide"? If yes at every decision point, you have a world model project.

The v1 approach uses a symbolic intermediate representation: the VAE compresses pixel frames to a compact latent z ∈ ℝ³², the GRU dynamics model propagates those latents autoregressively, and linear/MLP probes read out player position from frozen latents for the planner's objective function.

Architecture Overview

The v1 system is composed of four distinct learned components working in sequence:

ComponentInputOutputDescription
VAE Encoder40×36×3 pixelsz ∈ ℝ³²Compresses screen frames to compact latents
GRU Dynamics(z_t, a_t)z_{t+1}Predicts next latent given current state+action
RAM Probesz ∈ ℝ³²(x, y, map_id)Decode player position from latents
MPC Plannercurrent zaction sequenceImagines K-step futures, scores by probe output

The observation space uses the PWhiddy PPO's downsampled 40×36×3 pixel format - this allows reusing the frozen PPO checkpoint as data collection policy and executor, while the world model operates in the same latent space.

Variational Autoencoder (VAE)

The VAE is the first piece of the world model pipeline. It learns to compress Game Boy screen frames into a low-dimensional latent space z ∈ ℝ³², from which the decoder can reconstruct the original image.

Architecture Details

Training Configuration

text
Checkpoint:    checkpoints/vae/best_vae.pt
Epochs:        15
latent_dim:    32
beta (KL):     1.0
lr:            1e-3
batch_size:    128

GRU Dynamics Model

The dynamics model learns the transition function z_{t+1} ≈ f(z_t, a_t) - predicting the next latent from the current latent plus the chosen action. This is what enables the agent to imagine future states without running the emulator.

Scheduled Sampling

A critical training detail: scheduled sampling gradually shifts the model from using ground-truth previous latents (teacher forcing) to using its own predicted latents. This prevents the exposure bias problem where the model sees perfect inputs at train time but its own (imperfect) predictions at inference time.

Why Scheduled Sampling Matters
Without scheduled sampling (pure teacher forcing), compounding errors cause the dynamics model to drift rapidly - exceeding 10 tiles of position error after 29 imagined steps. With scheduled sampling, this drift stays flat under 3.5 tiles.

Architecture Details

RAM State Probes

To use the world model for planning, the agent needs a way to extract meaningful information from the latent space z. We train lightweight probes on frozen latents - they never modify the VAE or dynamics model.

Probe Architecture

What This Proves
A 98.7% map ID classification accuracy and 1.23-tile coordinate error from a 32-dim continuous latent demonstrates the VAE is capturing rich semantic information about game state - not just pixel statistics.

MPC Planner

The Model Predictive Control (MPC) planner is the brain of the v1 agent. Rather than reacting to observations, it looks ahead by simulating multiple action sequences through the dynamics model and selecting the best-scoring imagined future.

Planning Loop

pseudocode
For each candidate action sequence:
    imagined_state = dynamics.rollout(z_current, sequence, k steps)
    (x_pred, y_pred) = probe.decode(imagined_state)
    score = evaluate(x_pred, y_pred, target_map_id)

Best sequence = argmax(score)
Execute first action of best sequence via frozen PPO

The evaluation function is deliberately simple: minimize Manhattan distance to a target tile/map. The contribution is the world model + imagination capability, not the sophistication of the planner.

Results: VAE Training

MetricValueNotes
Train Loss (Total)1258.41Recon: 1227.76, KL: 30.66
Val Loss (Total)1255.95Recon: 1225.83, KL: 30.13
Epochs15Converged before 15
Latent Dim32Continuous Gaussian

Note: Loss values are in raw MSE pixel space (not normalized). The train/val parity indicates no significant overfitting.

Results: RAM State Probes

TaskMetricValue
Map ID ClassificationValidation Accuracy98.7%
Player Coordinate DecodingManhattan Distance (MAE)1.23 tiles
Val Loss (Combined)Probe loss8.3260

Results: Dynamics Model

ModelVal AR LossVal TF LossNotes
Scheduled Sampling (Primary)0.103140.0367520% min TF ratio
Pure Teacher Forcing (Ablation)0.725470.01913TF=1.0 always

Autoregressive (AR) loss is the meaningful metric - it measures performance when the model uses its own predictions as input, as it must do during imagination. The TF model has low teacher-forced loss but fails catastrophically in autoregressive rollout.

Results: Rollout Drift Analysis

This is the key ablation. Over 14,564 validation trajectories of length 29, we compare the scheduled sampling model vs. the pure teacher forcing ablation as both are rolled out autoregressively.

Step kSS Latent MSETF Latent MSESS Tile ErrorTF Tile ErrorRatio
10.092680.122403.72 tiles4.06 tiles1.09×
50.087510.230613.32 tiles5.08 tiles1.53×
100.088420.401893.30 tiles6.47 tiles1.96×
150.092380.598763.33 tiles7.81 tiles2.35×
200.098300.784553.36 tiles9.13 tiles2.72×
250.106060.923763.42 tiles9.72 tiles2.84×
290.111971.040633.47 tiles10.44 tiles3.01×

The SS model's tile error remains flat (3.30–3.47 tiles) across all 29 steps. The TF ablation degrades from 4.06 tiles at step 1 to 10.44 tiles at step 29 - a 3× degradation in imagination quality.

Full Hyperparameter Registry

yaml
# VAE
vae:
  latent_dim: 32
  beta: 1.0
  lr: 1.0e-3
  batch_size: 128
  epochs: 15
  input_resolution: [40, 36]   # HxW

# Dynamics (Scheduled Sampling)
dynamics:
  hidden_dim: 256               # GRU hidden size
  action_embed_dim: 16
  seq_len: 30                   # BPTT rollout steps
  decay_epochs: 15              # linear SS decay
  min_teacher_forcing: 0.2
  lr: 1.0e-3
  batch_size: 128
  epochs: 20

# Probes
probes:
  input_dim: 32                 # frozen latent dim
  hidden_dim: 64                # per-task head
  lr: 1.0e-3
  batch_size: 128
  epochs: 10

State & Action Schema

The symbolic state representation extracted from WRAM at each emulator tick:

python
state = {
    'map_id':      int,          # raw map ID byte (248 locations)
    'x':           int,          # player tile X coordinate
    'y':           int,          # player tile Y coordinate
    'facing':      int,          # 0=down, 4=up, 8=left, 12=right
    'in_battle':   bool,         # True when battle is active
    'dialog_open': bool,         # True when text box is showing
    'badges':      int,          # bitmask of 8 gym badges
    'party_hp':    list[int],    # current HP per party slot
    'party_max_hp':list[int],    # max HP per party slot
}

# Action Space
action ∈ {UP=0, DOWN=1, LEFT=2, RIGHT=3, A=4, B=5, START=6, SELECT=7}
Next: v2 Discrete RSSM →