Source code for interpreter.interpreter

from .policies import DTPolicy, SB3Policy, ObliqueDTPolicy

from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.utils import check_for_correct_spaces
from stable_baselines3.common.monitor import Monitor

from rlberry.agents import AgentWithSimplePolicy

from gymnasium.spaces import Discrete, Box
from gymnasium.wrappers.time_limit import TimeLimit

import numpy as np
from copy import deepcopy
from tqdm import tqdm


[docs] class Interpreter(AgentWithSimplePolicy): """ A class to interpret a neural net policy using a decision tree policy. It follows algorithm 1 from https://arxiv.org/abs/2405.14956 By default, the trajectories will be sampled in a DAgger-like way. Parameters ---------- oracle : object The oracle model that generates the data for training. Usually a stable-baselines3 model from the hugging face hub. learner : object The decision tree policy. env : object The environment in which the policies are evaluated (gym.Env). data_per_iter : int, optional The number of data points to generate per iteration (default is 5000). kwargs : optional Attributes ---------- oracle : object The oracle model that generates the data for training. learner : object The decision tree policy to be trained. data_per_iter : int The number of data points to generate per iteration. env : object The monitored environment in which the policies are evaluated. tree_policies : list A list to store the trained tree policies over iterations. tree_policies_rewards : list A list to store the rewards of the trained tree policies over iterations. """ def __init__(self, oracle, learner, env, data_per_iter=5000, **kwargs): assert isinstance(oracle, SB3Policy) and ( isinstance(learner, DTPolicy) or isinstance(learner, ObliqueDTPolicy) ) AgentWithSimplePolicy.__init__(self, env, **kwargs) if not isinstance(self.eval_env, Monitor): self.eval_env = Monitor(self.eval_env) self._oracle = oracle self._learner = learner self._policy = deepcopy(learner) self._data_per_iter = data_per_iter check_for_correct_spaces( self.env, self._learner.observation_space, self._learner.action_space, ) check_for_correct_spaces( self.env, self._oracle.observation_space, self._oracle.action_space ) check_for_correct_spaces( self.eval_env, self._learner.observation_space, self._learner.action_space, ) check_for_correct_spaces( self.eval_env, self._oracle.observation_space, self._oracle.action_space )
[docs] def fit(self, nb_timesteps): """ Train the decision tree policy using data generated by the oracle. Parameters ---------- nb_timesteps : int The number of environment transitions used for learning. """ print("Fitting tree nb {} ...".format(0)) nb_iter = int(max(1, nb_timesteps // self._data_per_iter)) S, A = self.generate_data(self._oracle, self._data_per_iter) self._learner.fit(S, A) self._policy = deepcopy(self._learner) tree_reward, _ = evaluate_policy(self._learner, self.eval_env) current_max_reward = tree_reward # self.tree_policies = [deepcopy(self._learner)] # self.tree_policies_rewards = [tree_reward] for t in range(1, nb_iter + 1): print("Fitting tree nb {} ...".format(t + 1)) S_tree, _ = self.generate_data(self._learner, self._data_per_iter) # S_tree, _ = self.generate_data( # self._learner, self._data_per_iter # ) # S = np.concatenate((S, S_tree)) S = np.concatenate((S, S_tree)) A = np.concatenate((A, self._oracle.predict(S_tree)[0])) # A = np.concatenate((A, self._oracle.predict(S_tree)[0])) self._learner.fit(S, A) tree_reward, _ = evaluate_policy(self._learner, self.eval_env) if tree_reward > current_max_reward: current_max_reward = tree_reward self._policy = deepcopy(self._learner) print("New best tree reward: {}".format(tree_reward))
# self.tree_policies += [deepcopy(self._learner)] # self.tree_policies_rewards += [tree_reward]
[docs] def policy(self, obs): return self._policy.predict(obs)
[docs] def eval(self, eval_horizon=10**5, n_simulations=10, gamma=1.0): return evaluate_policy( self._policy, TimeLimit(self.eval_env, eval_horizon), n_eval_episodes=n_simulations, )[0]
[docs] def generate_data(self, policy, nb_data): """ Generate data by running the policy in the environment. Parameters ---------- policy: The policy to generate transitions nb_data : int The number of data points to generate. Returns ------- S : np.ndarray The generated observations. A : np.ndarray The generated actions. """ assert nb_data >= 0 if isinstance(self.env.action_space, Discrete): A = np.zeros((nb_data)) elif isinstance(self.env.action_space, Box): A = np.zeros((nb_data, self.env.action_space.shape[0])) S = np.zeros((nb_data, self.env.observation_space.shape[0])) s, _ = self.env.reset() for i in tqdm(range(nb_data)): action, _ = policy.predict(s) S[i] = s A[i] = action s, _, term, trunc, _ = self.env.step(action) if term or trunc: s, _ = self.env.reset() return S, A