from .policies import DTPolicy, SB3Policy, ObliqueDTPolicy
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.utils import check_for_correct_spaces
from stable_baselines3.common.monitor import Monitor
from rlberry.agents import AgentWithSimplePolicy
from gymnasium.spaces import Discrete, Box
from gymnasium.wrappers.time_limit import TimeLimit
import numpy as np
from copy import deepcopy
from tqdm import tqdm
[docs]
class Interpreter(AgentWithSimplePolicy):
"""
A class to interpret a neural net policy using a decision tree policy.
It follows algorithm 1 from https://arxiv.org/abs/2405.14956
By default, the trajectories will be sampled in a DAgger-like way.
Parameters
----------
oracle : object
The oracle model that generates the data for training.
Usually a stable-baselines3 model from the hugging face hub.
learner : object
The decision tree policy.
env : object
The environment in which the policies are evaluated (gym.Env).
data_per_iter : int, optional
The number of data points to generate per iteration (default is 5000).
kwargs : optional
Attributes
----------
oracle : object
The oracle model that generates the data for training.
learner : object
The decision tree policy to be trained.
data_per_iter : int
The number of data points to generate per iteration.
env : object
The monitored environment in which the policies are evaluated.
tree_policies : list
A list to store the trained tree policies over iterations.
tree_policies_rewards : list
A list to store the rewards of the trained tree policies over iterations.
"""
def __init__(self, oracle, learner, env, data_per_iter=5000, **kwargs):
assert isinstance(oracle, SB3Policy) and (
isinstance(learner, DTPolicy) or isinstance(learner, ObliqueDTPolicy)
)
AgentWithSimplePolicy.__init__(self, env, **kwargs)
if not isinstance(self.eval_env, Monitor):
self.eval_env = Monitor(self.eval_env)
self._oracle = oracle
self._learner = learner
self._policy = deepcopy(learner)
self._data_per_iter = data_per_iter
check_for_correct_spaces(
self.env,
self._learner.observation_space,
self._learner.action_space,
)
check_for_correct_spaces(
self.env, self._oracle.observation_space, self._oracle.action_space
)
check_for_correct_spaces(
self.eval_env,
self._learner.observation_space,
self._learner.action_space,
)
check_for_correct_spaces(
self.eval_env, self._oracle.observation_space, self._oracle.action_space
)
[docs]
def fit(self, nb_timesteps):
"""
Train the decision tree policy using data generated by the oracle.
Parameters
----------
nb_timesteps : int
The number of environment transitions used for learning.
"""
print("Fitting tree nb {} ...".format(0))
nb_iter = int(max(1, nb_timesteps // self._data_per_iter))
S, A = self.generate_data(self._oracle, self._data_per_iter)
self._learner.fit(S, A)
self._policy = deepcopy(self._learner)
tree_reward, _ = evaluate_policy(self._learner, self.eval_env)
current_max_reward = tree_reward
# self.tree_policies = [deepcopy(self._learner)]
# self.tree_policies_rewards = [tree_reward]
for t in range(1, nb_iter + 1):
print("Fitting tree nb {} ...".format(t + 1))
S_tree, _ = self.generate_data(self._learner, self._data_per_iter)
# S_tree, _ = self.generate_data(
# self._learner, self._data_per_iter
# )
# S = np.concatenate((S, S_tree))
S = np.concatenate((S, S_tree))
A = np.concatenate((A, self._oracle.predict(S_tree)[0]))
# A = np.concatenate((A, self._oracle.predict(S_tree)[0]))
self._learner.fit(S, A)
tree_reward, _ = evaluate_policy(self._learner, self.eval_env)
if tree_reward > current_max_reward:
current_max_reward = tree_reward
self._policy = deepcopy(self._learner)
print("New best tree reward: {}".format(tree_reward))
# self.tree_policies += [deepcopy(self._learner)]
# self.tree_policies_rewards += [tree_reward]
[docs]
def policy(self, obs):
return self._policy.predict(obs)
[docs]
def eval(self, eval_horizon=10**5, n_simulations=10, gamma=1.0):
return evaluate_policy(
self._policy,
TimeLimit(self.eval_env, eval_horizon),
n_eval_episodes=n_simulations,
)[0]
[docs]
def generate_data(self, policy, nb_data):
"""
Generate data by running the policy in the environment.
Parameters
----------
policy:
The policy to generate transitions
nb_data : int
The number of data points to generate.
Returns
-------
S : np.ndarray
The generated observations.
A : np.ndarray
The generated actions.
"""
assert nb_data >= 0
if isinstance(self.env.action_space, Discrete):
A = np.zeros((nb_data))
elif isinstance(self.env.action_space, Box):
A = np.zeros((nb_data, self.env.action_space.shape[0]))
S = np.zeros((nb_data, self.env.observation_space.shape[0]))
s, _ = self.env.reset()
for i in tqdm(range(nb_data)):
action, _ = policy.predict(s)
S[i] = s
A[i] = action
s, _, term, trunc, _ = self.env.step(action)
if term or trunc:
s, _ = self.env.reset()
return S, A