interpreter.Interpreter#

class interpreter.Interpreter(oracle, learner, env, data_per_iter=5000, **kwargs)[source]#

Bases: AgentWithSimplePolicy

A class to interpret a neural net policy using a decision tree policy. It follows algorithm 1 from https://arxiv.org/abs/2405.14956 By default, the trajectories will be sampled in a DAgger-like way.

Parameters:
oracleobject

The oracle model that generates the data for training. Usually a stable-baselines3 model from the hugging face hub.

learnerobject

The decision tree policy.

envobject

The environment in which the policies are evaluated (gym.Env).

data_per_iterint, optional

The number of data points to generate per iteration (default is 5000).

kwargsoptional
Attributes:
oracleobject

The oracle model that generates the data for training.

learnerobject

The decision tree policy to be trained.

data_per_iterint

The number of data points to generate per iteration.

envobject

The monitored environment in which the policies are evaluated.

tree_policieslist

A list to store the trained tree policies over iterations.

tree_policies_rewardslist

A list to store the rewards of the trained tree policies over iterations.

Methods

eval([eval_horizon, n_simulations, gamma])

Monte-Carlo policy evaluation [1] method to estimate the mean discounted reward using the current policy on the evaluation environment.

fit(nb_timesteps)

Train the decision tree policy using data generated by the oracle.

generate_data(policy, nb_data)

Generate data by running the policy in the environment.

get_params([deep])

Get parameters for this agent.

load(filename, **kwargs)

Load agent object from filepath.

policy(obs)

Abstract method.

reseed([seed_seq])

Get new random number generator for the agent.

sample_parameters(trial)

Sample hyperparameters for hyperparam optimization using Optuna (https://optuna.org/)

save(filename)

Save agent object.

set_writer(writer)

set self._writer.

eval(eval_horizon=100000, n_simulations=10, gamma=1.0)[source]#

Monte-Carlo policy evaluation [1] method to estimate the mean discounted reward using the current policy on the evaluation environment.

Parameters:
eval_horizonint, optional, default: 10**5

Maximum episode length, representing the horizon for each simulation.

n_simulationsint, optional, default: 10

Number of Monte Carlo simulations to perform for the evaluation.

gammafloat, optional, default: 1.0

Discount factor for future rewards.

Returns:
float

The mean value over ‘n_simulations’ of the sum of rewards obtained in each simulation.

References

[1] (1,2)

Sutton, R. S., & Barto, A. G. (2018). Reinforcement Learning: An Introduction. MIT Press.

fit(nb_timesteps)[source]#

Train the decision tree policy using data generated by the oracle.

Parameters:
nb_timestepsint

The number of environment transitions used for learning.

generate_data(policy, nb_data)[source]#

Generate data by running the policy in the environment.

Parameters:
policy:

The policy to generate transitions

nb_dataint

The number of data points to generate.

Returns:
Snp.ndarray

The generated observations.

Anp.ndarray

The generated actions.

get_params(deep=True)#

Get parameters for this agent.

Parameters:
deepbool, default=True

If True, will return the parameters for this agent and contained subobjects.

Returns:
paramsdict

Parameter names mapped to their values.

classmethod load(filename, **kwargs)#

Load agent object from filepath.

If overridden, save() method must also be overriden.

Parameters:
filename: str

Path to the object (pickle) to load.

**kwargs: Keyword Arguments

Arguments required by the __init__ method of the Agent subclass to load.

property output_dir#

Directory that the agent can use to store data.

policy(obs)[source]#

Abstract method. The policy function takes an observation from the environment and returns an action. The specific implementation of the policy function depends on the agent’s learning algorithm or strategy, which can be deterministic or stochastic. Parameters ———- observation (any): An observation from the environment. Returns ——- action (any): The action to be taken based on the provided observation. Notes —– The data type of ‘observation’ and ‘action’ can vary depending on the specific agent and the environment it interacts with.

reseed(seed_seq=None)#

Get new random number generator for the agent.

Parameters:
seed_seqnumpy.random.SeedSequence, rlberry.seeding.seeder.Seeder or int, defaultNone

Seed sequence from which to spawn the random number generator. If None, generate random seed. If int, use as entropy for SeedSequence. If seeder, use seeder.seed_seq

property rng#

Random number generator.

classmethod sample_parameters(trial)#

Sample hyperparameters for hyperparam optimization using Optuna (https://optuna.org/)

Note: only the kwargs sent to __init__ are optimized. Make sure to include in the Agent constructor all “optimizable” parameters.

Parameters:
trial: optuna.trial
save(filename)#

Save agent object. By default, the agent is pickled.

If overridden, the load() method must also be overriden.

Before saving, consider setting writer to None if it can’t be pickled (tensorboard writers keep references to files and cannot be pickled).

Note: dill[R96e87de94430-1]_ is used when pickle fails (see https://stackoverflow.com/a/25353243, for instance). Pickle is tried first, since it is faster.

Parameters:
filename: Path or str

File in which to save the Agent.

Returns:
pathlib.Path

If save() is successful, a Path object corresponding to the filename is returned. Otherwise, None is returned.

Warning

The returned filename might differ from the input filename: For instance, ..

the method can append the correct suffix to the name before saving.

References

set_writer(writer)#

set self._writer. If is not None, add parameters values to writer.

property thread_shared_data#

Data shared by agent instances among different threads.

property unique_id#

Unique identifier for the agent instance. Can be used, for example, to create files/directories for the agent to log data safely.

property writer#

Writer object to log the output (e.g. tensorboard SummaryWriter)..