[docs]classInterpreter(AgentWithSimplePolicy):""" A class to interpret a neural net policy using a decision tree policy. It follows algorithm 1 from https://arxiv.org/abs/2405.14956 By default, the trajectories will be sampled in a DAgger-like way. Parameters ---------- oracle : object The oracle model that generates the data for training. Usually a stable-baselines3 model from the hugging face hub. learner : object The decision tree policy. env : object The environment in which the policies are evaluated (gym.Env). data_per_iter : int, optional The number of data points to generate per iteration (default is 5000). kwargs : optional Attributes ---------- oracle : object The oracle model that generates the data for training. learner : object The decision tree policy to be trained. data_per_iter : int The number of data points to generate per iteration. env : object The monitored environment in which the policies are evaluated. tree_policies : list A list to store the trained tree policies over iterations. tree_policies_rewards : list A list to store the rewards of the trained tree policies over iterations. """def__init__(self,oracle,learner,env,data_per_iter=5000,**kwargs):assertisinstance(oracle,SB3Policy)and(isinstance(learner,DTPolicy)orisinstance(learner,ObliqueDTPolicy))AgentWithSimplePolicy.__init__(self,env,**kwargs)ifnotisinstance(self.eval_env,Monitor):self.eval_env=Monitor(self.eval_env)self._oracle=oracleself._learner=learnerself._policy=deepcopy(learner)self._data_per_iter=data_per_itercheck_for_correct_spaces(self.env,self._learner.observation_space,self._learner.action_space,)check_for_correct_spaces(self.env,self._oracle.observation_space,self._oracle.action_space)check_for_correct_spaces(self.eval_env,self._learner.observation_space,self._learner.action_space,)check_for_correct_spaces(self.eval_env,self._oracle.observation_space,self._oracle.action_space)
[docs]deffit(self,nb_timesteps):""" Train the decision tree policy using data generated by the oracle. Parameters ---------- nb_timesteps : int The number of environment transitions used for learning. """print("Fitting tree nb {} ...".format(0))nb_iter=int(max(1,nb_timesteps//self._data_per_iter))S,A=self.generate_data(self._oracle,self._data_per_iter)self._learner.fit(S,A)self._policy=deepcopy(self._learner)tree_reward,_=evaluate_policy(self._learner,self.eval_env)current_max_reward=tree_reward# self.tree_policies = [deepcopy(self._learner)]# self.tree_policies_rewards = [tree_reward]fortinrange(1,nb_iter+1):print("Fitting tree nb {} ...".format(t+1))S_tree,_=self.generate_data(self._learner,self._data_per_iter)# S_tree, _ = self.generate_data(# self._learner, self._data_per_iter# )# S = np.concatenate((S, S_tree))S=np.concatenate((S,S_tree))A=np.concatenate((A,self._oracle.predict(S_tree)[0]))# A = np.concatenate((A, self._oracle.predict(S_tree)[0]))self._learner.fit(S,A)tree_reward,_=evaluate_policy(self._learner,self.eval_env)iftree_reward>current_max_reward:current_max_reward=tree_rewardself._policy=deepcopy(self._learner)print("New best tree reward: {}".format(tree_reward))
[docs]defgenerate_data(self,policy,nb_data):""" Generate data by running the policy in the environment. Parameters ---------- policy: The policy to generate transitions nb_data : int The number of data points to generate. Returns ------- S : np.ndarray The generated observations. A : np.ndarray The generated actions. """assertnb_data>=0ifisinstance(self.env.action_space,Discrete):A=np.zeros((nb_data))elifisinstance(self.env.action_space,Box):A=np.zeros((nb_data,self.env.action_space.shape[0]))S=np.zeros((nb_data,self.env.observation_space.shape[0]))s,_=self.env.reset()foriintqdm(range(nb_data)):action,_=policy.predict(s)S[i]=sA[i]=actions,_,term,trunc,_=self.env.step(action)iftermortrunc:s,_=self.env.reset()returnS,A