Deep Q-Network (DQN) CartPole With PyTorch
Deep Q-Network (DQN) Implementation for CartPole-v1
Introduction
This document provides a detailed explanation of a Deep Q-Network (DQN) implementation for solving the CartPole-v1 environment from OpenAI Gym. The implementation uses PyTorch for the neural network components.
CartPole Problem: The agent must balance a pole on a moving cart by applying forces to the cart. The state space consists of cart position, cart velocity, pole angle, and pole angular velocity. The action space is discrete (left or right force).
1. Environment Setup and Imports
The first section imports necessary libraries and sets up the environment.
import gymnasium as gym import math import random import matplotlib import matplotlib.pyplot as plt from collections import namedtuple, deque from itertools import count import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from IPython import display # Create the CartPole environment with human rendering env = gym.make("CartPole-v1", render_mode="human") # Use GPU if available, otherwise CPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Key Components:
- gymnasium: The OpenAI Gym library for reinforcement learning environments
- torch: PyTorch for neural network implementation
- device: Automatically selects GPU if available for faster computation
2. Experience Replay Memory
Experience replay is a crucial component of DQN that helps break correlations between consecutive samples.
Transition = namedtuple("Transition", ("state", "action", "next_state", "reward")) class ReplayMemory(object): def __init__(self, capacity): self.memory = deque([], maxlen=capacity) def push(self, *args): """Saves a transition.""" self.memory.append(Transition(*args)) def sample(self, batch_size): """Samples a batch of transitions.""" return random.sample(self.memory, batch_size) def __len__(self): return len(self.memory)
Components Explained:
Component | Description |
---|---|
Transition | A named tuple representing (state, action, next_state, reward) tuples |
ReplayMemory | Stores experiences for later sampling |
push() | Adds a new experience to memory |
sample() | Randomly samples a batch of experiences |
Why Experience Replay? It helps to break correlations between consecutive samples, provides more efficient data usage, and helps prevent catastrophic forgetting in neural networks.
3. Deep Q-Network Architecture
The neural network that approximates the Q-function.
class DQN(nn.Module): def __init__(self, input_dim, output_dim): super(DQN, self).__init__() self.fc1 = nn.Linear(input_dim, 128) # input_dim = 4, output_dim = 2 self.fc2 = nn.Linear(128, 128) self.fc3 = nn.Linear(128, output_dim) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x
Network Architecture:
Layer | Type | Activation | Description |
---|---|---|---|
fc1 | Linear | ReLU | Input layer (4 units) → Hidden layer (128 units) |
fc2 | Linear | ReLU | Hidden layer (128 units) → Hidden layer (128 units) |
fc3 | Linear | None | Hidden layer (128 units) → Output layer (2 units) |
Network Design Choices:
- Input size matches CartPole's state space (4 dimensions)
- Output size matches action space (2 actions: left/right)
- ReLU activations provide non-linearity for better function approximation
- No activation on final layer to get raw Q-values
4. Hyperparameters and Helper Functions
Configuration and supporting functions for the DQN algorithm.
# Hyperparameters batch_size = 128 gamma = 0.999 eps_start = 0.9 eps_end = 0.05 eps_decay = 1000 tau = 0.005 learning_rate = 1e-4 # Environment setup n_actions = env.action_space.n state, info = env.reset() n_observations = len(state) # Networks policy_net = DQN(n_observations, n_actions).to(device) target_net = DQN(n_observations, n_actions).to(device) target_net.load_state_dict(policy_net.state_dict()) optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate) memory = ReplayMemory(10000) steps_done = 0
Hyperparameters Explained:
Parameter | Value | Purpose |
---|---|---|
batch_size | 128 | Number of experiences sampled from memory |
gamma | 0.999 | Discount factor for future rewards |
eps_start | 0.9 | Initial exploration rate |
eps_end | 0.05 | Minimum exploration rate |
eps_decay | 1000 | Rate of exploration decay |
tau | 0.005 | Target network update rate |
learning_rate | 1e-4 | Optimizer learning rate |
Action Selection (ε-greedy policy)
def select_action(state): global steps_done sample = random.random() eps_threshold = eps_end + (eps_start - eps_end) * \ math.exp(-1. * steps_done / eps_decay) steps_done += 1 if sample > eps_threshold: with torch.no_grad(): return policy_net(state).max(1).indices.view(1, 1) # greedy action else: return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long) # random action
This function implements ε-greedy action selection:
- Starts with high exploration (ε = 0.9)
- Exponentially decays exploration rate
- Eventually settles at minimum exploration (ε = 0.05)
- With probability ε: random action (exploration)
- With probability 1-ε: best action according to policy_net (exploitation)
Training Visualization
episode_durations = [] def plot_durations(show_result=True): plt.figure(1) durations_t = torch.tensor(episode_durations, dtype=torch.float) if show_result: plt.title('Training...') else: plt.clf() plt.title('Training...') plt.xlabel('Episode') plt.ylabel('Duration') plt.plot(durations_t.numpy()) if len(durations_t) >= 100: means = durations_t.unfold(0, 100, 1).mean(1).view(-1) means = torch.cat((torch.zeros(99), means)) plt.plot(means.numpy()) plt.pause(0.001) # pause a bit so that plots are updated display.display(plt.gcf()) display.clear_output(wait=True)
This function plots episode durations during training, showing:
- Raw episode durations
- Moving average (window=100) for smoother trend visualization
5. Optimization and Training Loop
The core of the DQN algorithm that performs the training.
def optimize_model(): if len(memory) < batch_size: return transitions = memory.sample(batch_size) batch = Transition(*zip(*transitions)) # Compute non-final next states non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool) non_final_next_states = torch.cat([s for s in batch.next_state if s is not None]).to(device) # Prepare batches state_batch = torch.cat(batch.state).to(device) action_batch = torch.cat(batch.action).to(device) reward_batch = torch.cat(batch.reward).to(device) # Compute Q(s_t, a) state_action_values = policy_net(state_batch).gather(1, action_batch) # Compute V(s_{t+1}) for all next states next_state_values = torch.zeros(batch_size, device=device) with torch.no_grad(): next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values # Compute expected Q values expected_state_action_values = (next_state_values * gamma) + reward_batch # Compute Huber loss criterion = nn.SmoothL1Loss() loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1)) # Optimize the model optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 1) optimizer.step()
Training Steps:
- Sample batch from replay memory
- Separate final and non-final states
- Compute current Q-values for taken actions
- Compute target Q-values using target network
- Calculate loss between current and target Q-values
- Backpropagate and update policy network
- Clip gradients to prevent explosion
Key Concepts:
- Target Network: Provides stable Q-value targets by using a separate network that updates slowly
- Huber Loss: Combines benefits of MSE and MAE, less sensitive to outliers
- Gradient Clipping: Prevents exploding gradients in deep networks
Main Training Loop
num_episodes = 250 for i_episode in range(num_episodes): state, info = env.reset() state = torch.tensor(state, device=device).unsqueeze(0) for t in count(): action = select_action(state) observation, reward, terminated, truncated, _ = env.step(action.item()) reward = torch.tensor([reward], device=device) done = terminated or truncated if terminated: next_state = None else: next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0) memory.push(state, action, next_state, reward) state = next_state optimize_model() # Soft update of the target network target_net_state_dict = target_net.state_dict() policy_net_state_dict = policy_net.state_dict() for key in policy_net_state_dict: target_net_state_dict[key] = policy_net_state_dict[key]*tau + target_net_state_dict[key]*(1-tau) target_net.load_state_dict(target_net_state_dict) if done: episode_durations.append(t + 1) plot_durations() break
Training Process:
- Initialize environment
- Select action using ε-greedy policy
- Execute action and observe reward/next state
- Store experience in replay memory
- Optimize model periodically
- Soft update target network
- Repeat until episode termination
6. Network Exploration and Analysis
Understanding how the DQN learns to solve the CartPole problem.
Learning Dynamics:
- Initial Phase: Random exploration dominates (high ε)
- Middle Phase: Network begins to learn patterns, ε decreases
- Final Phase: Network exploits learned policy, minimal exploration
What the Network Learns:
The DQN learns to map states to Q-values that represent:
- Expected long-term reward for moving left in current state
- Expected long-term reward for moving right in current state
Policy Extraction: The optimal policy at any state is simply the action with the highest Q-value: π(s) = argmaxa Q(s,a)
Performance Metrics:
- Episode duration should increase over time
- Moving average should show clear upward trend
- Successful training typically reaches maximum duration (500 steps) consistently
7. Potential Improvements
This basic DQN can be enhanced with several advanced techniques:
Algorithmic Improvements:
- Double DQN: Reduces overestimation of Q-values
- Dueling DQN: Separates value and advantage streams
- Prioritized Experience Replay: Samples important transitions more frequently
- Noisy Nets: Adds parametric noise for exploration
Architecture Improvements:
- Batch normalization for more stable training
- Different network architectures (wider/deeper)
- Different activation functions
Hyperparameter Tuning:
- Learning rate scheduling
- Adaptive exploration rate
- Different discount factors
Conclusion
This implementation demonstrates how Deep Q-Networks can solve the CartPole problem by:
- Using a neural network to approximate the Q-function
- Employing experience replay for stable learning
- Using a target network to provide stable targets
- Balancing exploration and exploitation with ε-greedy policy
The solution showcases fundamental concepts in deep reinforcement learning that scale to more complex problems.
Comments