import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from collections import deque
import random
class DQN(nn.Module):
def __init__(self, input_dim, output_dim, hidden_dim=128):
super(DQN, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
return (torch.FloatTensor(state),
torch.LongTensor(action),
torch.FloatTensor(reward),
torch.FloatTensor(next_state),
torch.FloatTensor(done))
def __len__(self):
return len(self.buffer)
class DQNAgent:
def __init__(self, state_dim, action_dim, hidden_dim=128, lr=1e-3, gamma=0.99,
epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995,
buffer_size=10000, batch_size=64, target_update=10):
self.action_dim = action_dim
self.gamma = gamma
self.epsilon = epsilon_start
self.epsilon_end = epsilon_end
self.epsilon_decay = epsilon_decay
self.batch_size = batch_size
self.target_update = target_update
# Networks
self.policy_net = DQN(state_dim, action_dim, hidden_dim)
self.target_net = DQN(state_dim, action_dim, hidden_dim)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
self.memory = ReplayBuffer(buffer_size)
self.steps = 0
def select_action(self, state):
if random.random() > self.epsilon:
with torch.no_grad():
state = torch.FloatTensor(state).unsqueeze(0)
q_values = self.policy_net(state)
return q_values.max(1)[1].item()
return random.randrange(self.action_dim)
def update(self):
if len(self.memory) < self.batch_size:
return
# Sample batch and compute loss
state, action, reward, next_state, done = self.memory.sample(self.batch_size)
# Compute current Q values
current_q = self.policy_net(state).gather(1, action.unsqueeze(1))
# Compute target Q values
with torch.no_grad():
next_q = self.target_net(next_state).max(1)[0]
target_q = reward + (1 - done) * self.gamma * next_q
# Compute loss and update
loss = F.smooth_l1_loss(current_q.squeeze(), target_q)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Update target network
self.steps += 1
if self.steps % self.target_update == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
# Decay epsilon
self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)
return loss.item()
# Example usage:
def train_dqn(env, episodes=1000):
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQNAgent(state_dim, action_dim)
for episode in range(episodes):
state = env.reset()
total_reward = 0
done = False
while not done:
# Select and perform action
action = agent.select_action(state)
next_state, reward, done, _ = env.step(action)
# Store transition and update
agent.memory.push(state, action, reward, next_state, done)
loss = agent.update()
total_reward += reward
state = next_state
if episode % 10 == 0:
print(f"Episode {episode}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.2f}")
# To use:
# import gym
# env = gym.make('CartPole-v1')
# train_dqn(env)