Deep Q-Networks (DQN) Basics (Intro to Deep RL)
Deep Q-Networks (DQN) combine reinforcement learning with deep neural networks to solve complex decision-making tasks. By using neural networks to approximate the Q-value function, DQN handles high-dimensional state spaces that are intractable for classical tabular methods.
Reinforcement Learning and Q-Learning Foundations
Reinforcement learning uses Markov Decision Processes to find optimal policies, approximating state-action values via the Bellman equation.
MDP and the Bellman Equation
Reinforcement learning is framed using Markov Decision Processes (MDPs), defined by a tuple \\( (S, A, P, R, \\gamma) \\), where \\( S \\) is the state space, \\( A \\) is the action space, \\( P \\) is the state transition probability, \\( R \\) is the reward function, and \\( \\gamma \\in [0, 1] \\) is the discount factor. The goal of the agent is to find a policy \\( \\pi \\) that maximizes the expected cumulative discounted reward: \\( G_t = \\sum_{k=0}^{\\infty} \\gamma^k R_{t+k+1} \\). The action-value function \\( Q^\\pi(s, a) \\) measures the expected return of taking action \\( a \\) in state \\( s \\) and following policy \\( \\pi \\) thereafter.
The optimal Q-value function \\( Q^*(s, a) \\) satisfies the Bellman optimality equation:
\\( Q^*(s, a) = R(s, a) + \\gamma \\max_{a'} Q^*(s', a') \\)
Where \\( s' \\) is the next state. In tabular Q-learning, this equation is used to update a lookup table of Q-values. However, when the state space is continuous or high-dimensional (such as raw screen pixels in Atari games), a lookup table becomes intractable, necessitating value function approximation.
Value Function Approximation
Value function approximation replaces the tabular Q-lookup table with a parameterized function, typically a deep neural network: \\( Q(s, a; \\theta) \\approx Q^*(s, a) \\), where \\( \\theta \\) represents the network weights. The neural network takes the state representation as input and outputs a scalar Q-value for each possible action.
This design allows the model to generalize across similar states, enabling the agent to make decisions in states it has never seen during training. However, combining deep neural networks with reinforcement learning can be unstable because the data sequence is highly correlated and the target values are non-stationary.
DQN Core Innovations
DQN stabilizes deep reinforcement learning by incorporating experience replay buffers and separate target networks.
Experience Replay
In standard reinforcement learning, updating the network weights using sequential trajectories \\( (s_t, a_t, r_t, s_{t+1}) \\) leads to optimization failures. The data sequence is highly correlated over time, which violates the independent and identically distributed (i.i.d.) assumption of gradient descent. Furthermore, the agent's current policy determines the incoming data, creating feedback loops that can cause the network to diverge.
DQN (Mnih et al., 2015) solves this using an experience replay buffer. During training, the agent stores its transitions \\( (s, a, r, s', d) \\) (where \\( d \\) is a boolean indicating whether \\( s' \\) is a terminal state) in a large buffer \\( \\mathcal{D} \\). Instead of training on the latest transition, we sample a mini-batch of transitions uniformly at random from \\( \\mathcal{D} \\) to update the weights. This random sampling breaks the temporal correlation of the training data, stabilizing optimization.
Target Network
The second innovation in DQN is the use of a separate target network. In standard Q-learning, the network weights \\( \\theta \\) are used to compute both the current Q-value \\( Q(s, a; \\theta) \\) and the target Q-value: \\( Y_t = r + \\gamma \\max_{a'} Q(s', a'; \\theta) \\). This means that every weight update changes both the prediction and the target, which is equivalent to chasing a moving target. This feedback loop can cause the values to oscillate or grow uncontrollably.
DQN addresses this by introducing a target network with weights \\( \\theta^- \\). The target weights are used only to compute the target value: \\( Y_t = r + \\gamma \\max_{a'} Q(s', a'; \\theta^-) \\). The target network weights are held constant for \\( C \\) steps, and then updated to match the online network weights: \\( \\theta^- \\leftarrow \\theta \\). Alternatively, they can be updated slowly using Polyak averaging: \\( \\theta^- \\leftarrow \\tau \\theta + (1 - \\tau) \\theta^- \\) (where \\( \\tau \\ll 1 \\)). This separation of concerns stabilizes the target distribution, ensuring smooth convergence.
PyTorch DQN Implementation
A PyTorch implementation defines the DQN network structure, transition storage buffer, and optimization step.
DQN Model and Replay Buffer
The following PyTorch code implements the DQN neural network architecture and the experience replay buffer:
<pre><code class="language-python">import random import torch import torch.nn as nn from collections import deque class QNetwork(nn.Module): def __init__(self, state_dim, num_actions): super().__init__() self.network = nn.Sequential( nn.Linear(state_dim, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, num_actions) ) def forward(self, state): # state shape: [batch_size, state_dim] return self.network(state) # [batch_size, num_actions] class ReplayBuffer: def __init__(self, capacity): self.buffer = deque(maxlen=capacity) def push(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): # Sample a random mini-batch of transitions batch = random.sample(self.buffer, batch_size) state, action, reward, next_state, done = zip(*batch) return (torch.FloatTensor(state), torch.LongTensor(action), torch.FloatTensor(reward), torch.FloatTensor(next_state), torch.FloatTensor(done)) def __len__(self): return len(self.buffer)</pre>Training Loop and Optimization
This PyTorch code demonstrates the core Q-value update step, calculating the loss against the target network and performing backpropagation:
<pre><code class="language-python">def optimize_dqn_step(policy_net, target_net, replay_buffer, optimizer, batch_size, gamma): if len(replay_buffer) < batch_size: return None # Sample transition batch states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size) # Get current Q-values: Q(s, a; theta) # gather selects the Q-values matching the taken actions q_values = policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1) # Compute target Q-values using the target network: target = r + gamma * max Q(s', a'; theta^-) with torch.no_grad(): max_next_q_values = target_net(next_states).max(1)[0] targets = rewards + gamma * max_next_q_values * (1 - dones) # Compute Huber loss (more robust to outliers than MSE) loss_fn = nn.SmoothL1Loss() loss = loss_fn(q_values, targets) # Optimize optimizer.zero_grad() loss.backward() # Clip gradients to prevent exploding updates nn.utils.clip_grad_norm_(policy_net.parameters(), max_norm=1.0) optimizer.step() return loss.item() # Instantiate nets policy_net = QNetwork(state_dim=4, num_actions=2) target_net = QNetwork(state_dim=4, num_actions=2) target_net.load_state_dict(policy_net.state_dict()) # Align weights initially opt = torch.optim.Adam(policy_net.parameters(), lr=1e-3) buffer = ReplayBuffer(capacity=1000) # Simulate pushing a transition buffer.push([0.1, -0.2, 0.3, 0.1], 1, 1.0, [0.1, -0.1, 0.4, 0.0], 0) # (In real loops, we accumulate steps before running optimize_dqn_step)</pre>