Note
Click here to download the full example code
Illustration of the forest effectΒΆ
In this example we show that the decision function of a forest is the average of independent trees, and that averaging allows to produce smooth decision functions.
Out:
2020-04-30 21:08:19 Building the graph...
2020-04-30 21:08:28 Saved the forest effect plot in forest_effect.pdf
import sys
import warnings
warnings.filterwarnings("ignore")
import logging
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
sys.path.extend([".", ".."])
from onelearn import AMFClassifier
from experiments.plot import (
plot_contour_binary_classif,
plot_scatter_binary_classif,
get_mesh,
)
logging.basicConfig(
level=logging.INFO, format="%(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
norm = plt.Normalize(vmin=0.0, vmax=1.0)
levels = 30
def plot_forest_effect(forest, dataset):
n_estimators = forest.n_estimators
_ = plt.figure(figsize=(2 * (n_estimators / 2 + 1), 4))
X, y = dataset
xx, yy, X_mesh = get_mesh(X)
# Plot the training points
ax = plt.subplot(2, n_estimators / 2 + 1, 1)
plot_scatter_binary_classif(ax, xx, yy, X, y, title="Input data")
forest.partial_fit(X, y)
for idx_tree in range(n_estimators):
ax = plt.subplot(2, n_estimators / 2 + 1, idx_tree + 2)
Z = forest.predict_proba_tree(X_mesh, idx_tree)[:, 1].reshape(xx.shape)
plot_contour_binary_classif(
ax, xx, yy, Z, title="Tree #%d" % (idx_tree + 1), norm=norm, levels=levels
)
ax = plt.subplot(2, n_estimators / 2 + 1, n_estimators + 2)
Z = forest.predict_proba(X_mesh)[:, 1].reshape(xx.shape)
plot_contour_binary_classif(ax, xx, yy, Z, title="Forest", norm=norm, levels=levels)
plt.tight_layout()
n_samples = 100
n_features = 2
n_classes = 2
random_state = 42
dataset = make_moons(n_samples=n_samples, noise=0.15, random_state=random_state)
n_estimators = 10
amf = AMFClassifier(
n_classes=n_classes,
n_estimators=n_estimators,
random_state=random_state,
use_aggregation=True,
split_pure=True,
)
logging.info("Building the graph...")
plot_forest_effect(amf, dataset)
plt.savefig("forest_effect.pdf")
logging.info("Saved the forest effect plot in forest_effect.pdf")
Total running time of the script: ( 0 minutes 9.546 seconds)