here is my code
import tensorflow as tf
from tensorflow.keras.layers import Dense
import numpy as np
import gymnasium as gym
from collections import deque
import tqdm
# ----------------- MEMORY -----------------
class Memory:
def __init__(self):
self.states = []
self.actions = []
self.rewards = []
self.dones = []
self.log_probs = []
def clear_memory(self):
self.states = []
self.actions = []
self.rewards = []
self.dones = []
self.log_probs = []
# ------------------ MODEL --------------------
class ActorCritic(tf.keras.Model):
def __init__(self, action_dim):
super().__init__()
self.common1 = Dense(128, activation='relu')
self.common2 = Dense(64, activation='relu')
self.actor = Dense(action_dim)
self.critic = Dense(1)
def call(self, inputs):
x = self.common2(self.common1(inputs))
return self.actor(x), self.critic(x)
# ---------------- AGENT --------------------
class PPO:
def __init__(self, action_dim, lr, gamma, clip, update_every, epochs):
self.policy = ActorCritic(action_dim)
self.old_policy = ActorCritic(action_dim)
self.old_policy.set_weights(self.policy.get_weights())
self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
self.gamma = gamma
self.clip = clip
self.update_every = update_every
self.update_step = 0
self.epochs = epochs
self.criterion = tf.keras.losses.MeanSquaredError()
self.memory = Memory()
def select_action(self, state):
logit, _ = self.old_policy(tf.expand_dims(state, 0))
action = tf.random.categorical(logit, 1)[0, 0]
prob = tf.nn.softmax(logit)
log_prob = tf.math.log(prob[0, action])
self.memory.states.append(state)
self.memory.actions.append(action)
self.memory.log_probs.append(log_prob)
return int(action)
def take_step(self, reward, done):
self.memory.rewards.append(reward)
self.memory.dones.append(done)
self.update_step += 1
if self.update_step % self.update_every == 0:
self.train()
self.memory.clear_memory()
self.old_policy.set_weights(self.policy.get_weights())
def get_expected_returns(self):
returns = []
discounted_reward = 0
for (reward, done) in zip(reversed(self.memory.rewards), reversed(self.memory.dones)):
if done:
discounted_reward = 0
discounted_reward = reward + self.gamma * discounted_reward
returns.insert(0, discounted_reward)
returns = tf.convert_to_tensor(returns, tf.float32)
returns = (returns - tf.math.reduce_mean(returns)) / (tf.math.reduce_std(returns) + 1e-5)
return returns
def compute_loss(self, returns, values, action_probs, old_log_probs):
log_probs = tf.math.log(action_probs)
advantages = returns - tf.stop_gradient(values)
ratio = tf.exp(log_probs - tf.stop_gradient(old_log_probs))
surr1 = ratio * advantages
surr2 = tf.where(advantages > 0, (1 + self.clip) * advantages, (1 - self.clip) * advantages)
actor_loss = -tf.math.reduce_mean(tf.minimum(surr1, surr2))
critic_loss = tf.reduce_mean((returns - values) ** 2)
return actor_loss + critic_loss
@tf.function
def train(self):
states = tf.convert_to_tensor(np.stack(self.memory.states))
actions = tf.convert_to_tensor(self.memory.actions)
old_log_probs = tf.convert_to_tensor(self.memory.log_probs)
old_log_probs = tf.expand_dims(old_log_probs, 1)
#print(states.shape, actions.shape, old_log_probs.shape)
for _ in tf.range(self.epochs):
with tf.GradientTape() as tape:
logits, values = self.policy(states)
probs = tf.nn.softmax(logits)
action_probs = tf.gather(probs, actions, batch_dims=1)
returns = self.get_expected_returns()
returns, action_probs = [tf.expand_dims(x, 1) for x in [returns, action_probs]]
loss = self.compute_loss(tf.stop_gradient(returns), values, action_probs, old_log_probs)
grads = tape.gradient(loss, self.policy.trainable_variables)
self.optimizer.apply_gradients(zip(grads, self.policy.trainable_variables))
# ------------------ TRAIN ------------------------
def train_loop(env_name, max_steps=1000, episodes=10000):
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
print(f"Env : {env_name} | State : {state_dim} | Action : {action_dim}")
lr = 2e-3
gamma = 0.99
clip = 0.2
update_every = 2000
epochs = 4
agent = PPO(action_dim, lr, gamma, clip, update_every, epochs)
iter = tqdm.trange(episodes)
reward_buffer = deque(maxlen=100)
for _ in iter:
state, _ = env.reset()
total_reward = 0
for _ in range(max_steps):
action = agent.select_action(state)
state, reward, done, _, _ = env.step(action)
total_reward += reward
agent.take_step(reward, done)
if done:
break
reward_buffer.append(total_reward)
iter.set_postfix(episode_reward=total_reward, running_reward=np.mean(reward_buffer))
train_loop("CartPole-v1")
i used with and without tf.stop_gradient, no luck, agent doesn’t learn
while the same code in pytorch learns very fast
here is the pytorch implementation:
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
import gymnasium as gym
import numpy as np
import tqdm
from collections import deque
# --------------- MEMORY ----------------------
class Memory:
def __init__(self):
self.states = []
self.actions = []
self.rewards = []
self.dones = []
self.log_probs = []
def clear_memory(self):
self.states = []
self.actions = []
self.rewards = []
self.dones = []
self.log_probs = []
# ---------------- MODEl -----------------
class ActorCritic(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.actor = nn.Sequential(nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, output_dim),
nn.Softmax(dim=1))
self.critic = nn.Sequential(nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 1))
def evaluate(self, states, actions):
values = torch.squeeze(self.critic(states))
dist = Categorical(self.actor(states))
log_probs = dist.log_prob(actions)
return values, log_probs
# ---------------- AGENT ----------------
class PPO:
def __init__(self, state_dim, action_dim, lr, gamma, clip, update_every, epochs):
self.policy = ActorCritic(state_dim, action_dim)
self.old_policy = ActorCritic(state_dim, action_dim)
self.old_policy.load_state_dict(self.policy.state_dict())
self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)
self.gamma = gamma
self.clip = clip
self.update_every = update_every
self.update_step = 0
self.epochs = epochs
self.criterion = nn.MSELoss()
self.memory = Memory()
def select_action(self, state):
state = torch.from_numpy(state).view(1, -1).type(torch.float32)
dist = Categorical(self.old_policy.actor(state))
action = dist.sample()
log_prob = dist.log_prob(action)
self.memory.states.append(state)
self.memory.actions.append(action)
self.memory.log_probs.append(log_prob)
return action.item()
def take_step(self, reward, done):
self.memory.rewards.append(reward)
self.memory.dones.append(done)
self.update_step += 1
if self.update_step % self.update_every == 0:
self.train()
self.memory.clear_memory()
self.old_policy.load_state_dict(self.policy.state_dict())
def train(self):
states = torch.squeeze(torch.stack(self.memory.states)).type(torch.float32)
actions = torch.squeeze(torch.stack(self.memory.actions)).long()
old_log_probs = torch.squeeze(torch.stack(self.memory.log_probs)).detach().type(torch.float32)
returns = []
discounted_reward = 0
for (reward, done) in zip(reversed(self.memory.rewards), reversed(self.memory.dones)):
if done:
discounted_reward = 0
discounted_reward = reward + self.gamma * discounted_reward
returns.insert(0, discounted_reward)
returns = torch.tensor(returns, dtype=torch.float32)
returns = (returns - returns.mean()) / (returns.std() + 1e-5)
for _ in range(self.epochs):
values, log_probs = self.policy.evaluate(states, actions)
advantages = returns - values.detach()
ratio = torch.exp(log_probs - old_log_probs.detach())
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * advantages
actor_loss = -torch.min(surr1, surr2).mean()
critic_loss = self.criterion(returns, values)
loss = actor_loss + critic_loss
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# ------------------ TRAIN -------------------
def train_loop(env_name, max_steps=1000, episodes=10000):
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
print(f"Env : {env_name} | State : {state_dim} | Action : {action_dim}")
lr = 0.01
gamma = 0.99
clip = 0.2
update_every = 2000
epochs = 4
agent = PPO(state_dim, action_dim, lr, gamma, clip, update_every, epochs)
iter = tqdm.trange(episodes)
reward_buffer = deque(maxlen=100)
for _ in iter:
state, _ = env.reset()
total_reward = 0
for _ in range(max_steps):
action = agent.select_action(state)
state, reward, done, _, _ = env.step(action)
total_reward += reward
agent.take_step(reward, done)
if done:
break
reward_buffer.append(total_reward)
iter.set_postfix(episode_reward=total_reward, running_reward=np.mean(reward_buffer))
train_loop("CartPole-v1")