Overview & Motivation
v2 is the central contribution of PokéDreamer: a full Dreamer-style discrete RSSM world model operating at native Game Boy resolution. The key advances over v1:
- 4× resolution increase: 160×144 vs. 40×36 - native Game Boy pixels, preserving text, sprite details, and UI elements
- Discrete latents: 32 categories × 32 classes → 1024-dim one-hot encoding (vs. continuous ℝ³²)
- Full RSSM: Deterministic GRU path + stochastic categorical path trained jointly
- Prior imagination: The prior network enables pure latent rollouts with no emulator access
- Multi-task decoders: Pixel reconstruction + reward prediction + continue probability
Residual CNN Encoder
The encoder (src/models.py · Encoder) maps a raw 160×144×3 Game Boy frame to a 512-dimensional embedding vector. It uses 4 downsampling convolutional layers with residual blocks at each scale.
Architecture
Input: (B, 3, 144, 160)
conv1: Conv2d(3→32, k=4, s=2, p=1) → (B, 32, 72, 80)
res1: ResidualBlock(32)
conv2: Conv2d(32→64, k=4, s=2, p=1) → (B, 64, 36, 40)
res2: ResidualBlock(64)
conv3: Conv2d(64→128,k=4, s=2, p=1) → (B, 128, 18, 20)
res3: ResidualBlock(128)
conv4: Conv2d(128→256,k=4,s=2, p=1)→ (B, 256, 9, 10)
res4: ResidualBlock(256)
flatten → (B, 23040)
fc: Linear(23040, 512)
Output: (B, 512) ← encoder embedding e_t
Each ResidualBlock consists of two 3×3 convolutions with BatchNorm and ReLU, plus a skip connection. This preserves high-frequency spatial details (text, sprite outlines) that a simple CNN stack would lose.
RSSM Architecture
The Recurrent State-Space Model (src/models.py · RSSMCell) maintains two interacting components at each time step:
- Deterministic path h_t: A GRU hidden state (512-dim) that carries temporal context across steps
- Stochastic path s_t: Discrete categorical latents (32 × 32 = 1024-dim one-hot) that capture the stochastic aspects of the environment
RSSM Step (for time t):
# 1. Deterministic update (always runs)
gru_input = concat([prev_s, action_onehot]) # (B, 1024 + 8)
h_t = GRUCell(gru_input, prev_h) # (B, 512)
# 2. Prior distribution (imagination / prior rollout)
prior_logits = prior_net(h_t) # (B, 1024)
# 3. Posterior distribution (training with real frames)
if embed is not None: # encoder embedding available
post_logits = post_net(concat([h_t, e_t])) # (B, 1024)
s_t = gumbel_softmax(post_logits) # (B, 1024)
else: # imagination mode
s_t = gumbel_softmax(prior_logits) # (B, 1024)
# Latent state for decoders
latent = concat([h_t, s_t]) # (B, 512 + 1024 = 1536)
Prior vs. Posterior
The RSSM has two distributions over the stochastic state s_t:
| Distribution | Notation | Input | When Used |
|---|---|---|---|
| Prior | p(s_t | h_t) | GRU hidden state h_t only | Imagination / generation (no real frames) |
| Posterior | q(s_t | h_t, e_t) | h_t + encoder embedding e_t | Training (real observations available) |
During training, we minimize the KL divergence between posterior and prior (ELBO). At inference time, the prior enables pure imagination - generating future frames from the prior without any emulator access.
Discrete Categorical Latents
v2 uses discrete categorical latents - 32 independent categorical variables, each with 32 classes. The total stochastic state is a 32×32 = 1024-dimensional one-hot vector.
Why Discrete?
- Categorical distributions can represent multi-modal or sharp distinctions (e.g., "in battle" vs. "overworld") more naturally than Gaussian
- Gradient flows through discrete variables via Gumbel-Softmax straight-through: forward pass uses hard one-hot samples, backward pass uses soft Gumbel-Softmax gradients
- DreamerV2 established this as the preferred approach for model-based RL
Gumbel-Softmax Sampling
# From src/models.py · RSSMCell.sample_stochastic()
logits = logits.reshape(batch_size, self.category_num, self.class_num)
# → (B, 32, 32)
# Gumbel-Softmax with hard=True (straight-through)
sample = F.gumbel_softmax(logits, tau=1.0, hard=True)
# Forward: one-hot argmax (discrete)
# Backward: Gumbel-Softmax gradients (continuous)
return sample.reshape(batch_size, self.stoch_dim) # (B, 1024)
Decoder & Prediction Heads
The latent state concat(h_t, s_t) (1536-dim) feeds into three decoder heads trained jointly:
Pixel Decoder (Decoder)
Mirror of the encoder - transposed convolutions with residual blocks at each scale, outputting a sigmoid-normalized 3×144×160 reconstruction.
Input: (B, 1536)
fc: Linear(1536, 23040) → reshape (B, 256, 9, 10)
res4: ResidualBlock(256)
deconv4: ConvTranspose2d(256→128, k=4, s=2) → (B, 128, 18, 20)
res3: ResidualBlock(128)
deconv1: ConvTranspose2d(128→64, k=4, s=2) → (B, 64, 36, 40)
res2: ResidualBlock(64)
deconv2: ConvTranspose2d(64→32, k=4, s=2) → (B, 32, 72, 80)
res1: ResidualBlock(32)
deconv3: ConvTranspose2d(32→3, k=4, s=2) → (B, 3, 144, 160)
sigmoid → output in [0, 1]
Reward Predictor (RewardPredictor)
MLP: Linear(1536, 256) → ELU → Linear(256, 1) - scalar reward prediction.
Continue Predictor (ContinuePredictor)
MLP: Linear(1536, 256) → ELU → Linear(256, 1) → Sigmoid - episode continuation probability (not-done mask).
Training & Losses
The RSSM world model is trained end-to-end with a combined multi-task loss:
Total Loss = Reconstruction Loss + KL-Balanced Loss + Reward Loss
1. Reconstruction Loss (BCE / MSE on pixels):
L_recon = MSE(decoder(concat(h_t, s_t)), x_t)
2. KL-Balancing Loss (prior vs. posterior):
KL = α × KL(sg[posterior] || prior) + (1-α) × KL(posterior || sg[prior])
α = 0.8 (80% trains prior, 20% trains posterior)
3. Reward Loss:
L_reward = MSE(reward_pred(latent), r_t)
Training Configuration
python scripts/train_rssm.py \
--data-dir data \
--epochs 12 \
--batch-size 64 \
--seq-len 15 \
--out-dir checkpoints/rssm_v2
Dataset
The training dataset was collected using the PWhiddy PPO checkpoint as a behavioral policy, running for 20 episodes on the PyBoy emulator.
| Property | Value |
|---|---|
| Format | NPZ files (transitions_*.npz) |
| Number of files | 20 |
| Transitions per file | ~800 |
| Total transitions | ~16,000 |
| Observation resolution | 160×144×3 (native Game Boy) |
| RAM fields | map_id, x, y, facing, in_battle, dialog_open, badges, party_hp |
| Total dataset size | ~340MB |
| HuggingFace | xxxTEMPESTxxx/PokeDreamer |
Data Collection
python scripts/collect_data.py --episodes 20 --out-dir data
Training Results
Training was run for 4 epochs (approx. 50 min/epoch on RTX GPU), batch size 64, sequence length 15:
| Epoch | Train Loss | Train Recon | Train KL | Val Loss | Val Recon | Val KL |
|---|---|---|---|---|---|---|
| 1 | 0.1476 | 0.1379 | 0.0078 | 0.1266 | 0.1256 | 0.0010 |
| 2 | 0.1207 | 0.1144 | 0.0063 | 0.1172 | 0.1110 | 0.0062 |
| 3 | 0.1490 | 0.1068 | 0.0422 | 0.1228 | 0.1142 | 0.0086 |
| 4 | 0.1021 | 0.1015 | 0.0005 | 0.1651 | 0.1003 ★ | 0.0648 |
Reconstruction loss steadily decreases across epochs. The best checkpoint (epoch 4, val recon = 0.1003) demonstrates pixel-level world modeling at native Game Boy resolution. KL divergence fluctuates - a known challenge with categorical RSSM training as the discrete latent space balances expressiveness vs. compressibility.
Reconstruction Quality
Reconstruction grids are saved during training to checkpoints/rssm_v2/ - each grid shows a row of real Game Boy frames alongside their RSSM reconstructions. The best epoch produces visually recognizable reconstructions of the overworld, buildings, and UI elements at native 160×144 resolution.
Imagination Demo
The v2 demo video generates frames purely from the RSSM prior - starting from one seed frame, the model imagines the visual consequences of each action autoregressively without any additional emulator access.
# Generate side-by-side: real emulator (left) vs. RSSM imagination (right)
python scripts/generate_demo_video_v2.py \
--checkpoint checkpoints/rssm_v2/best_world_model.pt \
--save-state saves/intro_done.state \
--out-video checkpoints/rssm_v2/side_by_side_demo_v2.mp4
The output MP4 shows left panel (real emulator stepping actions) vs. right panel (RSSM prior imagination of the same action sequence).
Hyperparameter Reference
# RSSM Architecture
rssm:
embed_dim: 512 # encoder output dim
det_dim: 512 # GRU hidden size (h_t)
class_num: 32 # categorical variables
category_num: 32 # classes per variable
stoch_dim: 1024 # class_num × category_num
action_dim: 8 # one-hot action size
latent_dim: 1536 # det_dim + stoch_dim (for decoders)
# Training
training:
batch_size: 64
seq_len: 15 # BPTT sequence length
lr: 3.0e-4
kl_alpha: 0.8 # KL balancing weight (prior fraction)
gumbel_temp: 1.0 # Gumbel-Softmax temperature
epochs: 12 # (4 completed in this run)
# Hardware
hardware:
device: RTX GPU
time_per_epoch: ~50 min
Checkpoints & Files
| File | Description |
|---|---|
checkpoints/rssm_v2/best_world_model.pt | Best validation checkpoint (epoch 4, val recon 0.1003) |
checkpoints/rssm_v2/rssm_epoch_*.pt | Per-epoch snapshots |
checkpoints/rssm_v2/recon_grid_epoch_*.png | Reconstruction visualization grids |
checkpoints/rssm_v2/side_by_side_demo_v2.mp4 | Imagination vs. real emulator demo video |
All checkpoints are also available on Hugging Face.