PPO Problem with Tensorflow

here is my code

import tensorflow as tf
from tensorflow.keras.layers import Dense
import numpy as np
import gymnasium as gym
from collections import deque
import tqdm

# ----------------- MEMORY -----------------
class Memory:
    def __init__(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.log_probs = []

    def clear_memory(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.log_probs = []

# ------------------ MODEL --------------------
class ActorCritic(tf.keras.Model):
    def __init__(self, action_dim):
        super().__init__()

        self.common1 = Dense(128, activation='relu')
        self.common2 = Dense(64, activation='relu')
        self.actor = Dense(action_dim)
        self.critic = Dense(1)

    def call(self, inputs):
        x = self.common2(self.common1(inputs))
        return self.actor(x), self.critic(x)


# ---------------- AGENT --------------------
class PPO:
    def __init__(self, action_dim, lr, gamma, clip, update_every, epochs):
        self.policy = ActorCritic(action_dim)
        self.old_policy = ActorCritic(action_dim)
        self.old_policy.set_weights(self.policy.get_weights())

        self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
        self.gamma = gamma
        self.clip = clip
        self.update_every = update_every
        self.update_step = 0
        self.epochs = epochs
        self.criterion = tf.keras.losses.MeanSquaredError()
        self.memory = Memory()

    def select_action(self, state):
        logit, _ = self.old_policy(tf.expand_dims(state, 0))
        action = tf.random.categorical(logit, 1)[0, 0]
        prob = tf.nn.softmax(logit)
        log_prob = tf.math.log(prob[0, action])

        self.memory.states.append(state)
        self.memory.actions.append(action)
        self.memory.log_probs.append(log_prob)
        return int(action)

    def take_step(self, reward, done):
        self.memory.rewards.append(reward)
        self.memory.dones.append(done)
        self.update_step += 1
        if self.update_step % self.update_every == 0:
            self.train()
            self.memory.clear_memory()
            self.old_policy.set_weights(self.policy.get_weights())

    def get_expected_returns(self):

        returns = []
        discounted_reward = 0
        for (reward, done) in zip(reversed(self.memory.rewards), reversed(self.memory.dones)):
            if done:
                discounted_reward = 0
            discounted_reward = reward + self.gamma * discounted_reward
            returns.insert(0, discounted_reward)

        returns = tf.convert_to_tensor(returns, tf.float32)
        returns = (returns - tf.math.reduce_mean(returns)) / (tf.math.reduce_std(returns) + 1e-5)
        return returns

    def compute_loss(self, returns, values, action_probs, old_log_probs):
        log_probs = tf.math.log(action_probs)
        advantages = returns - tf.stop_gradient(values)
        ratio = tf.exp(log_probs - tf.stop_gradient(old_log_probs))
        surr1 = ratio * advantages
        surr2 = tf.where(advantages > 0, (1 + self.clip) * advantages, (1 - self.clip) * advantages)
        actor_loss = -tf.math.reduce_mean(tf.minimum(surr1, surr2))
        critic_loss = tf.reduce_mean((returns - values) ** 2)
        return actor_loss + critic_loss

    @tf.function
    def train(self):
        states = tf.convert_to_tensor(np.stack(self.memory.states))
        actions = tf.convert_to_tensor(self.memory.actions)
        old_log_probs = tf.convert_to_tensor(self.memory.log_probs)
        old_log_probs = tf.expand_dims(old_log_probs, 1)
        #print(states.shape, actions.shape, old_log_probs.shape)
        for _ in tf.range(self.epochs):
            with tf.GradientTape() as tape:
                logits, values = self.policy(states)
                probs = tf.nn.softmax(logits)
                action_probs = tf.gather(probs, actions, batch_dims=1)
                returns = self.get_expected_returns()
                returns, action_probs = [tf.expand_dims(x, 1) for x in [returns, action_probs]]
                loss = self.compute_loss(tf.stop_gradient(returns), values, action_probs, old_log_probs)

            grads = tape.gradient(loss, self.policy.trainable_variables)
            self.optimizer.apply_gradients(zip(grads, self.policy.trainable_variables))



# ------------------ TRAIN ------------------------
def train_loop(env_name, max_steps=1000, episodes=10000):
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    print(f"Env : {env_name} | State : {state_dim} | Action : {action_dim}")
    lr = 2e-3
    gamma = 0.99
    clip = 0.2
    update_every = 2000
    epochs = 4
    agent = PPO(action_dim, lr, gamma, clip, update_every, epochs)
    iter = tqdm.trange(episodes)
    reward_buffer = deque(maxlen=100)
    for _ in iter:
        state, _ = env.reset()
        total_reward = 0
        for _ in range(max_steps):
            action = agent.select_action(state)
            state, reward, done, _, _ = env.step(action)
            total_reward += reward
            agent.take_step(reward, done)
            if done:
                break

        reward_buffer.append(total_reward)
        iter.set_postfix(episode_reward=total_reward, running_reward=np.mean(reward_buffer))


train_loop("CartPole-v1")



i used with and without tf.stop_gradient, no luck, agent doesn’t learn
while the same code in pytorch learns very fast

here is the pytorch implementation:

import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
import gymnasium as gym
import numpy as np
import tqdm
from collections import deque
# --------------- MEMORY ----------------------
class Memory:
    def __init__(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.log_probs = []

    def clear_memory(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.log_probs = []


# ---------------- MODEl -----------------
class ActorCritic(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.actor = nn.Sequential(nn.Linear(input_dim, 128),
                                   nn.ReLU(),
                                   nn.Linear(128, 64),
                                   nn.ReLU(),
                                   nn.Linear(64, output_dim),
                                   nn.Softmax(dim=1))

        self.critic = nn.Sequential(nn.Linear(input_dim, 128),
                                    nn.ReLU(),
                                    nn.Linear(128, 64),
                                    nn.ReLU(),
                                    nn.Linear(64, 1))

    def evaluate(self, states, actions):
        values = torch.squeeze(self.critic(states))
        dist = Categorical(self.actor(states))
        log_probs = dist.log_prob(actions)
        return values, log_probs


# ---------------- AGENT ----------------
class PPO:
    def __init__(self, state_dim, action_dim, lr, gamma, clip, update_every, epochs):
        self.policy = ActorCritic(state_dim, action_dim)
        self.old_policy = ActorCritic(state_dim, action_dim)
        self.old_policy.load_state_dict(self.policy.state_dict())

        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)
        self.gamma = gamma
        self.clip = clip
        self.update_every = update_every
        self.update_step = 0
        self.epochs = epochs

        self.criterion = nn.MSELoss()
        self.memory = Memory()

    def select_action(self, state):
        state = torch.from_numpy(state).view(1, -1).type(torch.float32)
        dist = Categorical(self.old_policy.actor(state))
        action = dist.sample()
        log_prob = dist.log_prob(action)

        self.memory.states.append(state)
        self.memory.actions.append(action)
        self.memory.log_probs.append(log_prob)
        return action.item()

    def take_step(self, reward, done):
        self.memory.rewards.append(reward)
        self.memory.dones.append(done)
        self.update_step += 1
        if self.update_step % self.update_every == 0:
            self.train()
            self.memory.clear_memory()
            self.old_policy.load_state_dict(self.policy.state_dict())

    def train(self):
        states = torch.squeeze(torch.stack(self.memory.states)).type(torch.float32)
        actions = torch.squeeze(torch.stack(self.memory.actions)).long()
        old_log_probs = torch.squeeze(torch.stack(self.memory.log_probs)).detach().type(torch.float32)

        returns = []
        discounted_reward = 0
        for (reward, done) in zip(reversed(self.memory.rewards), reversed(self.memory.dones)):
            if done:
                discounted_reward = 0
            discounted_reward = reward + self.gamma * discounted_reward
            returns.insert(0, discounted_reward)

        returns = torch.tensor(returns, dtype=torch.float32)
        returns = (returns - returns.mean()) / (returns.std() + 1e-5)
        for _ in range(self.epochs):
            values, log_probs = self.policy.evaluate(states, actions)
            advantages = returns - values.detach()
            ratio = torch.exp(log_probs - old_log_probs.detach())
            surr1 = ratio * advantages
            surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * advantages
            actor_loss = -torch.min(surr1, surr2).mean()
            critic_loss = self.criterion(returns, values)
            loss = actor_loss + critic_loss
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()


# ------------------ TRAIN -------------------
def train_loop(env_name, max_steps=1000, episodes=10000):
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    print(f"Env : {env_name} | State : {state_dim} | Action : {action_dim}")
    lr = 0.01
    gamma = 0.99
    clip = 0.2
    update_every = 2000
    epochs = 4
    agent = PPO(state_dim, action_dim, lr, gamma, clip, update_every, epochs)
    iter = tqdm.trange(episodes)
    reward_buffer = deque(maxlen=100)
    for _ in iter:
        state, _ = env.reset()
        total_reward = 0
        for _ in range(max_steps):
            action = agent.select_action(state)
            state, reward, done, _, _ = env.step(action)
            total_reward += reward
            agent.take_step(reward, done)
            if done:
                break

        reward_buffer.append(total_reward)
        iter.set_postfix(episode_reward=total_reward, running_reward=np.mean(reward_buffer))


train_loop("CartPole-v1")