Deep Q-Network (DQN) CartPole With PyTorch

DQN Implementation for CartPole

Deep Q-Network (DQN) Implementation for CartPole-v1

Introduction

This document provides a detailed explanation of a Deep Q-Network (DQN) implementation for solving the CartPole-v1 environment from OpenAI Gym. The implementation uses PyTorch for the neural network components.

CartPole Problem: The agent must balance a pole on a moving cart by applying forces to the cart. The state space consists of cart position, cart velocity, pole angle, and pole angular velocity. The action space is discrete (left or right force).

1. Environment Setup and Imports

The first section imports necessary libraries and sets up the environment.

import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from IPython import display

# Create the CartPole environment with human rendering
env = gym.make("CartPole-v1", render_mode="human")

# Use GPU if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Key Components:

  • gymnasium: The OpenAI Gym library for reinforcement learning environments
  • torch: PyTorch for neural network implementation
  • device: Automatically selects GPU if available for faster computation

2. Experience Replay Memory

Experience replay is a crucial component of DQN that helps break correlations between consecutive samples.

Transition = namedtuple("Transition", ("state", "action", "next_state", "reward"))  

class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Saves a transition."""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        """Samples a batch of transitions."""
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

Components Explained:

Component Description
Transition A named tuple representing (state, action, next_state, reward) tuples
ReplayMemory Stores experiences for later sampling
push() Adds a new experience to memory
sample() Randomly samples a batch of experiences

Why Experience Replay? It helps to break correlations between consecutive samples, provides more efficient data usage, and helps prevent catastrophic forgetting in neural networks.

3. Deep Q-Network Architecture

The neural network that approximates the Q-function.

class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)  # input_dim = 4, output_dim = 2 
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Network Architecture:

Layer Type Activation Description
fc1 Linear ReLU Input layer (4 units) → Hidden layer (128 units)
fc2 Linear ReLU Hidden layer (128 units) → Hidden layer (128 units)
fc3 Linear None Hidden layer (128 units) → Output layer (2 units)

Network Design Choices:

  • Input size matches CartPole's state space (4 dimensions)
  • Output size matches action space (2 actions: left/right)
  • ReLU activations provide non-linearity for better function approximation
  • No activation on final layer to get raw Q-values

4. Hyperparameters and Helper Functions

Configuration and supporting functions for the DQN algorithm.

# Hyperparameters
batch_size = 128
gamma = 0.999
eps_start = 0.9
eps_end = 0.05
eps_decay = 1000
tau = 0.005
learning_rate = 1e-4

# Environment setup
n_actions = env.action_space.n
state, info = env.reset()
n_observations = len(state)

# Networks
policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate)
memory = ReplayMemory(10000)

steps_done = 0

Hyperparameters Explained:

Parameter Value Purpose
batch_size 128 Number of experiences sampled from memory
gamma 0.999 Discount factor for future rewards
eps_start 0.9 Initial exploration rate
eps_end 0.05 Minimum exploration rate
eps_decay 1000 Rate of exploration decay
tau 0.005 Target network update rate
learning_rate 1e-4 Optimizer learning rate

Action Selection (ε-greedy policy)

def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = eps_end + (eps_start - eps_end) * \
        math.exp(-1. * steps_done / eps_decay)  

    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1).indices.view(1, 1)  # greedy action
    else:
        return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)  # random action

This function implements ε-greedy action selection:

  • Starts with high exploration (ε = 0.9)
  • Exponentially decays exploration rate
  • Eventually settles at minimum exploration (ε = 0.05)
  • With probability ε: random action (exploration)
  • With probability 1-ε: best action according to policy_net (exploitation)

Training Visualization

episode_durations = []
def plot_durations(show_result=True):
    plt.figure(1)
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    if show_result:
        plt.title('Training...')
    else:
        plt.clf()
        plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())   
    plt.pause(0.001)  # pause a bit so that plots are updated
    
    display.display(plt.gcf())
    display.clear_output(wait=True)

This function plots episode durations during training, showing:

  • Raw episode durations
  • Moving average (window=100) for smoother trend visualization

5. Optimization and Training Loop

The core of the DQN algorithm that performs the training.

def optimize_model():
    if len(memory) < batch_size:
        return
    transitions = memory.sample(batch_size)
    batch = Transition(*zip(*transitions))

    # Compute non-final next states
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None]).to(device)
    
    # Prepare batches
    state_batch = torch.cat(batch.state).to(device)
    action_batch = torch.cat(batch.action).to(device)
    reward_batch = torch.cat(batch.reward).to(device)

    # Compute Q(s_t, a)
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states
    next_state_values = torch.zeros(batch_size, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
    
    # Compute expected Q values
    expected_state_action_values = (next_state_values * gamma) + reward_batch
    
    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
    
    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 1)
    optimizer.step()

Training Steps:

  1. Sample batch from replay memory
  2. Separate final and non-final states
  3. Compute current Q-values for taken actions
  4. Compute target Q-values using target network
  5. Calculate loss between current and target Q-values
  6. Backpropagate and update policy network
  7. Clip gradients to prevent explosion

Key Concepts:

  • Target Network: Provides stable Q-value targets by using a separate network that updates slowly
  • Huber Loss: Combines benefits of MSE and MAE, less sensitive to outliers
  • Gradient Clipping: Prevents exploding gradients in deep networks

Main Training Loop

num_episodes = 250
for i_episode in range(num_episodes):
    state, info = env.reset()
    state = torch.tensor(state, device=device).unsqueeze(0)
    for t in count():       
        action = select_action(state)
        observation, reward, terminated, truncated, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)
        done = terminated or truncated
        
        if terminated:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
        
        memory.push(state, action, next_state, reward)
        state = next_state
        
        optimize_model()
        
        # Soft update of the target network
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*tau + target_net_state_dict[key]*(1-tau)
        target_net.load_state_dict(target_net_state_dict)
        
        if done:
            episode_durations.append(t + 1)
            plot_durations()
            break

Training Process:

  1. Initialize environment
  2. Select action using ε-greedy policy
  3. Execute action and observe reward/next state
  4. Store experience in replay memory
  5. Optimize model periodically
  6. Soft update target network
  7. Repeat until episode termination

6. Network Exploration and Analysis

Understanding how the DQN learns to solve the CartPole problem.

Learning Dynamics:

  • Initial Phase: Random exploration dominates (high ε)
  • Middle Phase: Network begins to learn patterns, ε decreases
  • Final Phase: Network exploits learned policy, minimal exploration

What the Network Learns:

The DQN learns to map states to Q-values that represent:

  • Expected long-term reward for moving left in current state
  • Expected long-term reward for moving right in current state

Policy Extraction: The optimal policy at any state is simply the action with the highest Q-value: π(s) = argmaxa Q(s,a)

Performance Metrics:

  • Episode duration should increase over time
  • Moving average should show clear upward trend
  • Successful training typically reaches maximum duration (500 steps) consistently

7. Potential Improvements

This basic DQN can be enhanced with several advanced techniques:

Algorithmic Improvements:

  • Double DQN: Reduces overestimation of Q-values
  • Dueling DQN: Separates value and advantage streams
  • Prioritized Experience Replay: Samples important transitions more frequently
  • Noisy Nets: Adds parametric noise for exploration

Architecture Improvements:

  • Batch normalization for more stable training
  • Different network architectures (wider/deeper)
  • Different activation functions

Hyperparameter Tuning:

  • Learning rate scheduling
  • Adaptive exploration rate
  • Different discount factors

Conclusion

This implementation demonstrates how Deep Q-Networks can solve the CartPole problem by:

  • Using a neural network to approximate the Q-function
  • Employing experience replay for stable learning
  • Using a target network to provide stable targets
  • Balancing exploration and exploitation with ε-greedy policy

The solution showcases fundamental concepts in deep reinforcement learning that scale to more complex problems.

Comments

Popular posts from this blog

Tech Duos For Web Development

CIFAR-10 Dataset Classification Using Convolutional Neural Networks (CNNs) With PyTorch

Long-short-term-memory (LSTM) Word Prediction With PyTorch