src/models.py
Contains all neural network modules for the v2 RSSM world model system. All classes are torch.nn.Module subclasses. Import with:
from src.models import (
ResidualBlock, Encoder, Decoder,
RSSMCell, RewardPredictor, ContinuePredictor,
Actor, Critic
)
relu(BN(conv1(x))) → BN(conv2) → + x → relu. Preserves spatial dimensions.(B, 3, 144, 160). Values in [0, 1].(B, embed_dim). Encoder embedding e_t.(B, latent_dim). Concatenated [h_t, s_t].(B, 3, 144, 160). Reconstructed frame, values in [0, 1].(h, s) where h is (B, det_dim) zeros, s is (B, stoch_dim) uniform softmax(B, stoch_dim)(B, stoch_dim). One-hot categorical sample.embed is provided, samples from posterior (training); otherwise from prior (imagination).(B, det_dim)(B, stoch_dim)(B, action_dim)(B, 512). None for imagination mode.{'h': Tensor, 's': Tensor, 'prior_logits': Tensor, 'post_logits': Tensor}(B, latent_dim)(B,). Predicted scalar reward.(B, latent_dim)(B,). Continue probability in [0, 1].(B, latent_dim). Concatenated [h_t, s_t].(B, action_dim). Raw action logits (apply softmax for probabilities).(B, latent_dim)(B,). Scalar value estimate V(h_t, s_t).src/dataset.py
transitions_*.npz files from data_dir into memory. Computes valid starting indices for non-boundary-crossing sequences.transitions_*.npz files# seq_len > 1 (sequence mode):
{
'obs': Tensor[(T, 3, H, W)], # normalized [0,1]
'actions': Tensor[(T,)], # long int action indices
'rewards': Tensor[(T,)],
# RAM fields (if available in NPZ):
'map_ids': Tensor[(T,)],
'xs': Tensor[(T,)],
'ys': Tensor[(T,)],
'facings': Tensor[(T,)],
'in_battles': Tensor[(T,)],
'dialog_opens':Tensor[(T,)],
'badges': Tensor[(T,)],
'party_hps': Tensor[(T, 6)],
'party_max_hps':Tensor[(T, 6)],
}
src/game_state.py
Provides the GameState dataclass and extraction utilities for reading synchronized RAM + screen state from a live PyBoy instance.
Fields
bin(badges).count('1') = badge countProperties
(144, 160, 3), dtype uint8. Native Game Boy resolution.src/ram_addresses.py
Defines all WRAM address constants for Pokémon Red (USA/Europe). Addresses are verified against documented memory maps. All addresses are integers in the 0xC000–0xDFFF range.
from src.ram_addresses import (
PLAYER_X, # 0xD362 - tile X coordinate
PLAYER_Y, # 0xD361 - tile Y coordinate
MAP_ID, # 0xD35E - current map ID
PLAYER_FACING, # 0xC109 - facing direction byte
IN_BATTLE, # 0xD057 - battle state flag
TEXT_BOX_ID, # 0xD125 - dialog/text box active flag
BADGE_FLAGS, # 0xD356 - 8-bit badge bitmask
PARTY_SIZE, # 0xD163 - number of Pokémon in party
HP_ADDRESSES, # list[int] - current HP hi-byte per slot
MAX_HP_ADDRESSES, # list[int] - max HP hi-byte per slot
MAP_NAMES, # dict[int, str] - map ID → name
FACING_NAMES, # dict[int, str] - facing byte → direction
)
scripts/collect_data.py
Runs the PyBoy emulator with the PWhiddy frozen PPO checkpoint as behavior policy, collecting native-resolution transition data into NPZ files.
python scripts/collect_data.py \
--episodes 20 \ # number of episodes to collect
--out-dir data \ # output directory for NPZ files
--rom path/to/rom.gb \ # path to Pokémon Red ROM
--save-state saves/intro_done.state # starting save state
Output: one transitions_{i:04d}.npz per episode, each containing obs, actions, rewards, episode_starts, and all RAM fields.
scripts/train_rssm.py
Trains the full RSSM world model - encoder, RSSMCell, decoder, reward predictor, continue predictor - with the multi-task loss including KL balancing.
python scripts/train_rssm.py \
--data-dir data \ # NPZ dataset directory
--epochs 12 \ # training epochs
--batch-size 64 \ # batch size
--seq-len 15 \ # BPTT sequence length
--det-dim 512 \ # GRU hidden dimension
--class-num 32 \ # categorical variable count
--category-num 32 \ # classes per variable
--kl-alpha 0.8 \ # KL balancing weight
--lr 3e-4 \ # learning rate
--out-dir checkpoints/rssm_v2 # output directory
Saves per-epoch checkpoints, reconstruction grids, and the best model by validation reconstruction loss.
scripts/generate_demo_video_v2.py
Generates the side-by-side demo video: real emulator (left) vs. RSSM imagination (right). The imagination panel runs purely from the RSSM prior - no emulator access.
python scripts/generate_demo_video_v2.py \
--checkpoint checkpoints/rssm_v2/best_world_model.pt \
--save-state saves/intro_done.state \
--steps 200 \ # number of steps to record
--out-video output/demo.mp4
scripts/upload_to_hf.py
Utility script for uploading checkpoints and datasets to Hugging Face Hub.
python scripts/upload_to_hf.py \
--checkpoint checkpoints/rssm_v2/best_world_model.pt \
--repo xxxTEMPESTxxx/PokeDreamer \
--token $HF_TOKEN
scripts/train_policy.py
v3 preparation script - trains the Actor-Critic policy inside RSSM imagination rollouts. Currently under development.
python scripts/train_policy.py \
--world-model checkpoints/rssm_v2/best_world_model.pt \
--imagination-horizon 15 \
--epochs 50 \
--out-dir checkpoints/policy_v3