import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
class PolicyNetwork(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(PolicyNetwork, self).__init__()
self.network = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, output_size),
nn.Softmax(dim=-1)
)
def forward(self, x):
return self.network(x)
class REINFORCE:
def __init__(self, input_size, hidden_size, output_size, learning_rate=0.01):
self.policy = PolicyNetwork(input_size, hidden_size, output_size)
self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)
self.episode_rewards = []
self.episode_log_probs = []
def select_action(self, state):
state = torch.FloatTensor(state)
probs = self.policy(state)
m = Categorical(probs)
action = m.sample()
self.episode_log_probs.append(m.log_prob(action))
return action.item()
def update(self, gamma=0.99):
returns = []
R = 0
# Calculate returns
for r in reversed(self.episode_rewards):
R = r + gamma * R
returns.insert(0, R)
returns = torch.FloatTensor(returns)
# Normalize returns
returns = (returns - returns.mean()) / (returns.std() + 1e-8)
# Calculate loss
policy_loss = []
for log_prob, R in zip(self.episode_log_probs, returns):
policy_loss.append(-log_prob * R)
policy_loss = torch.cat(policy_loss).sum()
# Update policy
self.optimizer.zero_grad()
policy_loss.backward()
self.optimizer.step()
# Clear episode data
self.episode_rewards = []
self.episode_log_probs = []
# Example usage:
def train(env, agent, num_episodes=1000):
for episode in range(num_episodes):
state = env.reset()
done = False
while not done:
action = agent.select_action(state)
next_state, reward, done, _ = env.step(action)
agent.episode_rewards.append(reward)
state = next_state
agent.update()
if episode % 100 == 0:
print(f"Episode {episode}")
# To use:
# env = gym.make('CartPole-v1')
# agent = REINFORCE(
# input_size=env.observation_space.shape[0],
# hidden_size=128,
# output_size=env.action_space.n
# )
# train(env, agent)