interpreter.policies.ObliqueDTPolicy#

class interpreter.policies.ObliqueDTPolicy(clf, env)[source]#

Bases: Policy

Oblique Decision Tree Policy class.

Parameters:
clfsklearn.base.BaseEstimator

The decision tree classifier or regressor.

envgym.Env

The environment in which the policy operates.

Attributes:
clfsklearn.base.BaseEstimator

The decision tree classifier or regressor.

observation_spacegym.Space

The observation space of the environment.

action_spacegym.Space

The action space of the environment.

Methods

fit(S, A)

Fit the decision tree with the provided oblique observations and actions.

get_oblique_data(S)

Generate oblique data by creating pairwise differences between observations.

predict(obs[, state, deterministic, ...])

Predict the action to take given an observation.

fit(S, A)[source]#

Fit the decision tree with the provided oblique observations and actions.

Parameters:
Snp.ndarray

The observations.

Anp.ndarray

The actions.

get_oblique_data(S)[source]#

Generate oblique data by creating pairwise differences between observations.

Parameters:
Snp.ndarray

The input observations.

Returns:
finalnp.ndarray

The original observations stacked with pairwise differences.

predict(obs, state=None, deterministic=True, episode_start=0)[source]#

Predict the action to take given an observation.

Parameters:
obsnp.ndarray

The observation input.

stateobject, optional

The state of the policy (default is None).

deterministicbool, optional

Whether to use a deterministic policy (default is True).

episode_startint, optional

The episode start index (default is 0).

Returns:
actionnp.ndarray

The action to take.

stateobject

The updated state of the policy.