interpreter.policies.ObliqueDTPolicy#
- class interpreter.policies.ObliqueDTPolicy(clf, env)[source]#
Bases:
PolicyOblique Decision Tree Policy class.
- Parameters:
- clfsklearn.base.BaseEstimator
The decision tree classifier or regressor.
- envgym.Env
The environment in which the policy operates.
- Attributes:
- clfsklearn.base.BaseEstimator
The decision tree classifier or regressor.
- observation_spacegym.Space
The observation space of the environment.
- action_spacegym.Space
The action space of the environment.
Methods
fit(S, A)Fit the decision tree with the provided oblique observations and actions.
Generate oblique data by creating pairwise differences between observations.
predict(obs[, state, deterministic, ...])Predict the action to take given an observation.
- fit(S, A)[source]#
Fit the decision tree with the provided oblique observations and actions.
- Parameters:
- Snp.ndarray
The observations.
- Anp.ndarray
The actions.
- get_oblique_data(S)[source]#
Generate oblique data by creating pairwise differences between observations.
- Parameters:
- Snp.ndarray
The input observations.
- Returns:
- finalnp.ndarray
The original observations stacked with pairwise differences.
- predict(obs, state=None, deterministic=True, episode_start=0)[source]#
Predict the action to take given an observation.
- Parameters:
- obsnp.ndarray
The observation input.
- stateobject, optional
The state of the policy (default is None).
- deterministicbool, optional
Whether to use a deterministic policy (default is True).
- episode_startint, optional
The episode start index (default is 0).
- Returns:
- actionnp.ndarray
The action to take.
- stateobject
The updated state of the policy.