
pip install git+


from interpreter import Interpreter
from interpreter import ObliqueDTPolicy, SB3Policy, DTPolicy

from stable_baselines3 import SAC
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor

import gymnasium as gym
from sklearn.tree import DecisionTreeRegressor
from huggingface_sb3 import load_from_hub

from pickle import dump, load

# Download a policy from the stable-baselines3 zoo
checkpoint = load_from_hub(
    repo_id="sb3/sac-HalfCheetah-v3", filename=""

# Load the oracle policy
env = gym.make("HalfCheetah-v4")
model = SAC.load(checkpoint)
oracle = SB3Policy(model.policy)

# Get oracle performance
print(evaluate_policy(oracle, Monitor(env))[0])

# Instantiate the decision tree class (here a regression tree with at most 16 leaves)
clf = DecisionTreeRegressor(
)  # Change to DecisionTreeClassifier for discrete Actions.
learner = ObliqueDTPolicy(clf, env)  #
# You can replace by DTPolicy(clf, env) for interpretable axis-parallel DTs.

# Start the imitation learning
interpret = Interpreter(oracle, learner, env)

# Eval and save the best tree
final_tree_reward, _ = evaluate_policy(interpret._policy, env=env, n_eval_episodes=10)
# Here you can replace pickle with joblib or cloudpickle
with open("tree_halfcheetah.pkl", "wb") as f:
    dump(interpret._policy.clf, f)

with open("tree_halfcheetah.pkl", "rb") as f:
    clf = load(f)