Completed & Archived ✓ v2/ · src/models.py

v2: Discrete RSSM
World Model

The major leap - native 160×144 Game Boy resolution, a 4-layer Residual CNN, and a full Recurrent State-Space Model with 32×32 discrete categorical latents. The RSSM's prior enables pure imagination rollouts at pixel fidelity.

Overview & Motivation

v2 is the central contribution of PokéDreamer: a full Dreamer-style discrete RSSM world model operating at native Game Boy resolution. The key advances over v1:

Dreamer Lineage
v2 is architecturally based on DreamerV2 (Hafner et al., 2021). The discrete categorical latents are trained via Gumbel-Softmax straight-through estimation - a key technique for backpropagating through discrete distributions.

Residual CNN Encoder

The encoder (src/models.py · Encoder) maps a raw 160×144×3 Game Boy frame to a 512-dimensional embedding vector. It uses 4 downsampling convolutional layers with residual blocks at each scale.

Architecture

python
Input:  (B, 3, 144, 160)

conv1: Conv2d(3→32,  k=4, s=2, p=1) → (B, 32,  72, 80)
res1:  ResidualBlock(32)
conv2: Conv2d(32→64, k=4, s=2, p=1) → (B, 64,  36, 40)
res2:  ResidualBlock(64)
conv3: Conv2d(64→128,k=4, s=2, p=1) → (B, 128, 18, 20)
res3:  ResidualBlock(128)
conv4: Conv2d(128→256,k=4,s=2, p=1)→ (B, 256, 9,  10)
res4:  ResidualBlock(256)

flatten → (B, 23040)
fc:     Linear(23040, 512)

Output: (B, 512)  ← encoder embedding e_t

Each ResidualBlock consists of two 3×3 convolutions with BatchNorm and ReLU, plus a skip connection. This preserves high-frequency spatial details (text, sprite outlines) that a simple CNN stack would lose.

RSSM Architecture

The Recurrent State-Space Model (src/models.py · RSSMCell) maintains two interacting components at each time step:

python
RSSM Step (for time t):

# 1. Deterministic update (always runs)
gru_input = concat([prev_s, action_onehot])  # (B, 1024 + 8)
h_t = GRUCell(gru_input, prev_h)             # (B, 512)

# 2. Prior distribution (imagination / prior rollout)
prior_logits = prior_net(h_t)                # (B, 1024)

# 3. Posterior distribution (training with real frames)
if embed is not None:  # encoder embedding available
    post_logits = post_net(concat([h_t, e_t]))  # (B, 1024)
    s_t = gumbel_softmax(post_logits)           # (B, 1024)
else:                  # imagination mode
    s_t = gumbel_softmax(prior_logits)          # (B, 1024)

# Latent state for decoders
latent = concat([h_t, s_t])  # (B, 512 + 1024 = 1536)

Prior vs. Posterior

The RSSM has two distributions over the stochastic state s_t:

DistributionNotationInputWhen Used
Priorp(s_t | h_t)GRU hidden state h_t onlyImagination / generation (no real frames)
Posteriorq(s_t | h_t, e_t)h_t + encoder embedding e_tTraining (real observations available)

During training, we minimize the KL divergence between posterior and prior (ELBO). At inference time, the prior enables pure imagination - generating future frames from the prior without any emulator access.

Discrete Categorical Latents

v2 uses discrete categorical latents - 32 independent categorical variables, each with 32 classes. The total stochastic state is a 32×32 = 1024-dimensional one-hot vector.

Why Discrete?

Gumbel-Softmax Sampling

python
# From src/models.py · RSSMCell.sample_stochastic()
logits = logits.reshape(batch_size, self.category_num, self.class_num)
# → (B, 32, 32)

# Gumbel-Softmax with hard=True (straight-through)
sample = F.gumbel_softmax(logits, tau=1.0, hard=True)
# Forward: one-hot argmax (discrete)
# Backward: Gumbel-Softmax gradients (continuous)

return sample.reshape(batch_size, self.stoch_dim)  # (B, 1024)

Decoder & Prediction Heads

The latent state concat(h_t, s_t) (1536-dim) feeds into three decoder heads trained jointly:

Pixel Decoder (Decoder)

Mirror of the encoder - transposed convolutions with residual blocks at each scale, outputting a sigmoid-normalized 3×144×160 reconstruction.

text
Input:  (B, 1536)
fc:     Linear(1536, 23040) → reshape (B, 256, 9, 10)
res4:   ResidualBlock(256)
deconv4: ConvTranspose2d(256→128, k=4, s=2)  → (B, 128, 18, 20)
res3:   ResidualBlock(128)
deconv1: ConvTranspose2d(128→64, k=4, s=2)   → (B, 64,  36, 40)
res2:   ResidualBlock(64)
deconv2: ConvTranspose2d(64→32, k=4, s=2)    → (B, 32,  72, 80)
res1:   ResidualBlock(32)
deconv3: ConvTranspose2d(32→3, k=4, s=2)     → (B, 3,  144, 160)
sigmoid → output in [0, 1]

Reward Predictor (RewardPredictor)

MLP: Linear(1536, 256) → ELU → Linear(256, 1) - scalar reward prediction.

Continue Predictor (ContinuePredictor)

MLP: Linear(1536, 256) → ELU → Linear(256, 1) → Sigmoid - episode continuation probability (not-done mask).

Training & Losses

The RSSM world model is trained end-to-end with a combined multi-task loss:

text
Total Loss = Reconstruction Loss + KL-Balanced Loss + Reward Loss

1. Reconstruction Loss (BCE / MSE on pixels):
   L_recon = MSE(decoder(concat(h_t, s_t)), x_t)

2. KL-Balancing Loss (prior vs. posterior):
   KL = α × KL(sg[posterior] || prior) + (1-α) × KL(posterior || sg[prior])
   α = 0.8  (80% trains prior, 20% trains posterior)
   
3. Reward Loss:
   L_reward = MSE(reward_pred(latent), r_t)
KL Balancing Rationale
Standard ELBO pushes the posterior toward the prior. KL balancing (α=0.8) directs most gradient signal toward improving the prior - which is what we care about for imagination. The posterior gets lighter updates (α=0.2).

Training Configuration

bash
python scripts/train_rssm.py \
    --data-dir data \
    --epochs 12 \
    --batch-size 64 \
    --seq-len 15 \
    --out-dir checkpoints/rssm_v2

Dataset

The training dataset was collected using the PWhiddy PPO checkpoint as a behavioral policy, running for 20 episodes on the PyBoy emulator.

PropertyValue
FormatNPZ files (transitions_*.npz)
Number of files20
Transitions per file~800
Total transitions~16,000
Observation resolution160×144×3 (native Game Boy)
RAM fieldsmap_id, x, y, facing, in_battle, dialog_open, badges, party_hp
Total dataset size~340MB
HuggingFacexxxTEMPESTxxx/PokeDreamer

Data Collection

bash
python scripts/collect_data.py --episodes 20 --out-dir data

Training Results

Training was run for 4 epochs (approx. 50 min/epoch on RTX GPU), batch size 64, sequence length 15:

EpochTrain LossTrain ReconTrain KLVal LossVal ReconVal KL
10.14760.13790.00780.12660.12560.0010
20.12070.11440.00630.11720.11100.0062
30.14900.10680.04220.12280.11420.0086
40.10210.10150.00050.16510.1003 ★0.0648

Reconstruction loss steadily decreases across epochs. The best checkpoint (epoch 4, val recon = 0.1003) demonstrates pixel-level world modeling at native Game Boy resolution. KL divergence fluctuates - a known challenge with categorical RSSM training as the discrete latent space balances expressiveness vs. compressibility.

Reconstruction Quality

Reconstruction grids are saved during training to checkpoints/rssm_v2/ - each grid shows a row of real Game Boy frames alongside their RSSM reconstructions. The best epoch produces visually recognizable reconstructions of the overworld, buildings, and UI elements at native 160×144 resolution.

Imagination Demo

The v2 demo video generates frames purely from the RSSM prior - starting from one seed frame, the model imagines the visual consequences of each action autoregressively without any additional emulator access.

bash
# Generate side-by-side: real emulator (left) vs. RSSM imagination (right)
python scripts/generate_demo_video_v2.py \
    --checkpoint checkpoints/rssm_v2/best_world_model.pt \
    --save-state saves/intro_done.state \
    --out-video checkpoints/rssm_v2/side_by_side_demo_v2.mp4

The output MP4 shows left panel (real emulator stepping actions) vs. right panel (RSSM prior imagination of the same action sequence).

Hyperparameter Reference

yaml
# RSSM Architecture
rssm:
  embed_dim:     512      # encoder output dim
  det_dim:       512      # GRU hidden size (h_t)
  class_num:     32       # categorical variables
  category_num:  32       # classes per variable
  stoch_dim:     1024     # class_num × category_num
  action_dim:    8        # one-hot action size
  latent_dim:    1536     # det_dim + stoch_dim (for decoders)

# Training
training:
  batch_size:    64
  seq_len:       15       # BPTT sequence length
  lr:            3.0e-4
  kl_alpha:      0.8      # KL balancing weight (prior fraction)
  gumbel_temp:   1.0      # Gumbel-Softmax temperature
  epochs:        12       # (4 completed in this run)
  
# Hardware
hardware:
  device:        RTX GPU
  time_per_epoch: ~50 min

Checkpoints & Files

FileDescription
checkpoints/rssm_v2/best_world_model.ptBest validation checkpoint (epoch 4, val recon 0.1003)
checkpoints/rssm_v2/rssm_epoch_*.ptPer-epoch snapshots
checkpoints/rssm_v2/recon_grid_epoch_*.pngReconstruction visualization grids
checkpoints/rssm_v2/side_by_side_demo_v2.mp4Imagination vs. real emulator demo video

All checkpoints are also available on Hugging Face.

Next: v3 Roadmap →