Code Reference

Complete API documentation for every module in the PokéDreamer codebase - models, dataset, game state extraction, RAM addresses, and training scripts.

src/models.py

Contains all neural network modules for the v2 RSSM world model system. All classes are torch.nn.Module subclasses. Import with:

python
from src.models import (
    ResidualBlock, Encoder, Decoder,
    RSSMCell, RewardPredictor, ContinuePredictor,
    Actor, Critic
)
ResidualBlock(channels)
Standard convolutional residual block. Two 3×3 Conv2d layers with BatchNorm and ReLU, plus skip connection. Used inside both Encoder and Decoder.
__init__(self, channels: int)
Initializes two conv layers, two BN layers.
channelsintNumber of input and output channels (must match for skip connection)
forward(self, x: Tensor) → Tensor
Applies relu(BN(conv1(x))) → BN(conv2) → + x → relu. Preserves spatial dimensions.
Encoder(embed_dim=512)
4-layer Residual CNN that maps native 160×144×3 Game Boy frames to a 512-dim embedding vector. Used by the RSSM to compute the posterior distribution from real observations.
__init__(self, embed_dim: int = 512)
Builds 4 downsampling convolutional layers (k=4, s=2) with residual blocks at each stage, plus a final linear projection.
embed_dimintOutput embedding dimension. Default: 512
forward(self, x: Tensor) → Tensor
Maps pixel frames to embeddings.
xTensorShape (B, 3, 144, 160). Values in [0, 1].
returnsTensorShape (B, embed_dim). Encoder embedding e_t.
Decoder(latent_dim=1536)
Mirror of Encoder - ConvTranspose2d layers with residual blocks reconstructing the pixel frame from the RSSM latent state. Output is sigmoid-normalized to [0, 1].
__init__(self, latent_dim: int = 512 + 1024)
latent_dimintInput latent dimension = det_dim + stoch_dim = 512+1024=1536 by default
forward(self, latent: Tensor) → Tensor
latentTensorShape (B, latent_dim). Concatenated [h_t, s_t].
returnsTensorShape (B, 3, 144, 160). Reconstructed frame, values in [0, 1].
RSSMCell(action_dim=8, det_dim=512, class_num=32, category_num=32)
Single recurrent step of the RSSM. Maintains a deterministic GRU hidden state h_t and a stochastic discrete categorical state s_t. Supports both prior (imagination) and posterior (training with real frames) modes.
__init__(self, action_dim, det_dim, class_num, category_num)
action_dimintSize of one-hot action vector. Default: 8 (8-button Game Boy)
det_dimintGRU hidden dimension. Default: 512
class_numintNumber of categorical variables. Default: 32
category_numintClasses per categorical variable. Default: 32 → stoch_dim = 1024
get_initial_state(self, batch_size, device) → tuple[Tensor, Tensor]
Returns zero-initialized h and uniform-initialized s for sequence starts.
returnstuple(h, s) where h is (B, det_dim) zeros, s is (B, stoch_dim) uniform softmax
sample_stochastic(self, logits, use_gumbel=True, hard=True, temp=1.0) → Tensor
Samples discrete categorical latents from logits. Uses Gumbel-Softmax straight-through by default.
logitsTensorShape (B, stoch_dim)
use_gumbelboolIf True, use Gumbel-Softmax. If False, use argmax straight-through.
hardboolIf True (default), forward pass uses one-hot argmax (discrete), backward uses soft gradients.
tempfloatGumbel-Softmax temperature. Lower → sharper samples.
returnsTensorShape (B, stoch_dim). One-hot categorical sample.
forward(self, prev_h, prev_s, action, embed=None, use_gumbel=True, temp=1.0) → dict
Runs a single RSSM step. If embed is provided, samples from posterior (training); otherwise from prior (imagination).
prev_hTensorPrevious deterministic state. Shape: (B, det_dim)
prev_sTensorPrevious stochastic state. Shape: (B, stoch_dim)
actionTensorOne-hot action. Shape: (B, action_dim)
embedTensor?Encoder embedding. Shape: (B, 512). None for imagination mode.
returnsdict{'h': Tensor, 's': Tensor, 'prior_logits': Tensor, 'post_logits': Tensor}
RewardPredictor(latent_dim=1536)
MLP predicting scalar reward from the concatenated RSSM latent state [h_t, s_t]. Part of the world model multi-task loss.
forward(self, latent: Tensor) → Tensor
latentTensorShape (B, latent_dim)
returnsTensorShape (B,). Predicted scalar reward.
ContinuePredictor(latent_dim=1536)
MLP predicting episode continuation probability (1 - done). Output is sigmoid-normalized to [0, 1]. Used as the not-done discount mask during imagination training.
forward(self, latent: Tensor) → Tensor
latentTensorShape (B, latent_dim)
returnsTensorShape (B,). Continue probability in [0, 1].
Actor(latent_dim=1536, action_dim=8)
Policy network for the v3 imagination-based RL agent. Outputs action logits from the RSSM latent state. Trained via policy gradient on λ-returns inside imagined rollouts.
forward(self, latent: Tensor) → Tensor
latentTensorShape (B, latent_dim). Concatenated [h_t, s_t].
returnsTensorShape (B, action_dim). Raw action logits (apply softmax for probabilities).
Critic(latent_dim=1536)
Value function network. Estimates expected discounted return from the current latent state. Trained to minimize MSE vs. λ-return targets during imagination rollouts.
forward(self, latent: Tensor) → Tensor
latentTensorShape (B, latent_dim)
returnsTensorShape (B,). Scalar value estimate V(h_t, s_t).

src/dataset.py

PokemonDataset(data_dir, seq_len=15, transform=None, max_files=None)
PyTorch Dataset loading gameplay transitions from NPZ files. Supports both single-frame mode (seq_len=1) and trajectory sequence mode (seq_len>1). Automatically validates sequences to avoid episode-boundary crossings.
__init__(self, data_dir, seq_len=15, transform=None, max_files=None)
Loads all transitions_*.npz files from data_dir into memory. Computes valid starting indices for non-boundary-crossing sequences.
data_dirstr|PathDirectory containing transitions_*.npz files
seq_lenintSequence length for RSSM training. Use 1 for VAE frame mode.
max_filesint?Cap number of loaded files (for debugging)
__getitem__(self, idx) → dict
Returns a sample dict. In sequence mode, includes time dimension T:
python
# seq_len > 1 (sequence mode):
{
    'obs':         Tensor[(T, 3, H, W)],  # normalized [0,1]
    'actions':     Tensor[(T,)],           # long int action indices
    'rewards':     Tensor[(T,)],
    # RAM fields (if available in NPZ):
    'map_ids':     Tensor[(T,)],
    'xs':          Tensor[(T,)],
    'ys':          Tensor[(T,)],
    'facings':     Tensor[(T,)],
    'in_battles':  Tensor[(T,)],
    'dialog_opens':Tensor[(T,)],
    'badges':      Tensor[(T,)],
    'party_hps':   Tensor[(T, 6)],
    'party_max_hps':Tensor[(T, 6)],
}

src/game_state.py

Provides the GameState dataclass and extraction utilities for reading synchronized RAM + screen state from a live PyBoy instance.

GameState (dataclass)
Snapshot of all relevant Pokémon Red game state at a single emulator tick. Guaranteed to be synchronized - RAM and screen are read after the same tick.

Fields

map_idintRaw WRAM map ID byte (0–248)
x, yintPlayer tile coordinates
facingintDirection: 0=down, 4=up, 8=left, 12=right
in_battleboolTrue when battle engine is active
dialog_openboolTrue when any text box is on screen
party_hplist[int]Current HP for each party slot
party_max_hplist[int]Max HP for each party slot
badgesintBitmask of 8 gym badges. bin(badges).count('1') = badge count

Properties

badge_countintNumber of earned badges (popcount of bitmask)
map_namestrHuman-readable map name from MAP_NAMES dict
facing_namestr"up" / "down" / "left" / "right"
total_hpintSum of all party current HP
total_max_hpintSum of all party max HP
extract_game_state(pyboy: PyBoy) → GameState
Reads all relevant RAM addresses from a running PyBoy instance and returns a synchronized GameState snapshot. Must be called after pyboy.tick().
extract_game_state(pyboy) → GameState
Reads map_id, x, y, facing, in_battle, dialog_open, badges, and party HP (up to 6 slots) from WRAM.
pyboyPyBoyRunning PyBoy emulator instance
returnsGameStateComplete game state snapshot
screen_capture(pyboy: PyBoy) → np.ndarray
Captures the current Game Boy screen as a (144, 160, 3) uint8 RGB array. Handles both RGBA and RGB output from PyBoy ≥2.0.
screen_capture(pyboy) → ndarray
returnsndarrayShape (144, 160, 3), dtype uint8. Native Game Boy resolution.

src/ram_addresses.py

Defines all WRAM address constants for Pokémon Red (USA/Europe). Addresses are verified against documented memory maps. All addresses are integers in the 0xC000–0xDFFF range.

python
from src.ram_addresses import (
    PLAYER_X,          # 0xD362 - tile X coordinate
    PLAYER_Y,          # 0xD361 - tile Y coordinate
    MAP_ID,            # 0xD35E - current map ID
    PLAYER_FACING,     # 0xC109 - facing direction byte
    IN_BATTLE,         # 0xD057 - battle state flag
    TEXT_BOX_ID,       # 0xD125 - dialog/text box active flag
    BADGE_FLAGS,       # 0xD356 - 8-bit badge bitmask
    PARTY_SIZE,        # 0xD163 - number of Pokémon in party
    HP_ADDRESSES,      # list[int] - current HP hi-byte per slot
    MAX_HP_ADDRESSES,  # list[int] - max HP hi-byte per slot
    MAP_NAMES,         # dict[int, str] - map ID → name
    FACING_NAMES,      # dict[int, str] - facing byte → direction
)

scripts/collect_data.py

Runs the PyBoy emulator with the PWhiddy frozen PPO checkpoint as behavior policy, collecting native-resolution transition data into NPZ files.

bash
python scripts/collect_data.py \
    --episodes 20 \          # number of episodes to collect
    --out-dir data \         # output directory for NPZ files
    --rom path/to/rom.gb \   # path to Pokémon Red ROM
    --save-state saves/intro_done.state  # starting save state

Output: one transitions_{i:04d}.npz per episode, each containing obs, actions, rewards, episode_starts, and all RAM fields.

scripts/train_rssm.py

Trains the full RSSM world model - encoder, RSSMCell, decoder, reward predictor, continue predictor - with the multi-task loss including KL balancing.

bash
python scripts/train_rssm.py \
    --data-dir data \             # NPZ dataset directory
    --epochs 12 \                 # training epochs
    --batch-size 64 \             # batch size
    --seq-len 15 \                # BPTT sequence length
    --det-dim 512 \               # GRU hidden dimension
    --class-num 32 \              # categorical variable count
    --category-num 32 \           # classes per variable
    --kl-alpha 0.8 \              # KL balancing weight
    --lr 3e-4 \                   # learning rate
    --out-dir checkpoints/rssm_v2 # output directory

Saves per-epoch checkpoints, reconstruction grids, and the best model by validation reconstruction loss.

scripts/generate_demo_video_v2.py

Generates the side-by-side demo video: real emulator (left) vs. RSSM imagination (right). The imagination panel runs purely from the RSSM prior - no emulator access.

bash
python scripts/generate_demo_video_v2.py \
    --checkpoint checkpoints/rssm_v2/best_world_model.pt \
    --save-state saves/intro_done.state \
    --steps 200 \                 # number of steps to record
    --out-video output/demo.mp4

scripts/upload_to_hf.py

Utility script for uploading checkpoints and datasets to Hugging Face Hub.

bash
python scripts/upload_to_hf.py \
    --checkpoint checkpoints/rssm_v2/best_world_model.pt \
    --repo xxxTEMPESTxxx/PokeDreamer \
    --token $HF_TOKEN

scripts/train_policy.py

v3 preparation script - trains the Actor-Critic policy inside RSSM imagination rollouts. Currently under development.

bash
python scripts/train_policy.py \
    --world-model checkpoints/rssm_v2/best_world_model.pt \
    --imagination-horizon 15 \
    --epochs 50 \
    --out-dir checkpoints/policy_v3
Return to Home View Source on GitHub