interpreter.policies
.ObliqueDTPolicy#
- class interpreter.policies.ObliqueDTPolicy(clf, env)[source]#
Bases:
Policy
Oblique Decision Tree Policy class.
- Parameters:
- clfsklearn.base.BaseEstimator
The decision tree classifier or regressor.
- envgym.Env
The environment in which the policy operates.
- Attributes:
- clfsklearn.base.BaseEstimator
The decision tree classifier or regressor.
- observation_spacegym.Space
The observation space of the environment.
- action_spacegym.Space
The action space of the environment.
Methods
fit
(S, A)Fit the decision tree with the provided oblique observations and actions.
Generate oblique data by creating pairwise differences between observations.
predict
(obs[, state, deterministic, ...])Predict the action to take given an observation.
- fit(S, A)[source]#
Fit the decision tree with the provided oblique observations and actions.
- Parameters:
- Snp.ndarray
The observations.
- Anp.ndarray
The actions.
- get_oblique_data(S)[source]#
Generate oblique data by creating pairwise differences between observations.
- Parameters:
- Snp.ndarray
The input observations.
- Returns:
- finalnp.ndarray
The original observations stacked with pairwise differences.
- predict(obs, state=None, deterministic=True, episode_start=0)[source]#
Predict the action to take given an observation.
- Parameters:
- obsnp.ndarray
The observation input.
- stateobject, optional
The state of the policy (default is None).
- deterministicbool, optional
Whether to use a deterministic policy (default is True).
- episode_startint, optional
The episode start index (default is 0).
- Returns:
- actionnp.ndarray
The action to take.
- stateobject
The updated state of the policy.