{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\nWeighted depths of `AMFRegressor` on several 1D signals.\n========================================================\n\nThe example below illustrates the weighted depth learned internally by the\nAMF algorithm to estimate 1D regression functions. We observe that AMF automatically\nadapts to the local regularity of signals, by putting more emphasis on deeper trees\nwhere the regression function is not unsmooth.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import sys\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom matplotlib.cm import get_cmap\nimport logging\n\nsys.path.extend([\".\", \"..\"])\n\nfrom onelearn import AMFRegressor\nfrom onelearn.datasets import get_signal, make_regression\n\nlogging.basicConfig(\n    level=logging.INFO, format=\"%(asctime)s %(message)s\", datefmt=\"%Y-%m-%d %H:%M:%S\"\n)\n\ncolormap = get_cmap(\"tab20\")\n\nn_samples_train = 5000\nn_samples_test = 1000\nrandom_state = 42\n\n\nnoise = 0.03\nuse_aggregation = True\nsplit_pure = True\nn_estimators = 100\nstep = 10.0\n\n\nsignals = [\"heavisine\", \"bumps\", \"blocks\", \"doppler\"]\n\n\ndef plot_weighted_depth(signal):\n    X_train, y_train = make_regression(\n        n_samples=n_samples_train, signal=signal, noise=noise, random_state=random_state\n    )\n    X_test = np.linspace(0, 1, num=n_samples_test)\n\n    amf = AMFRegressor(\n        random_state=random_state,\n        use_aggregation=use_aggregation,\n        n_estimators=n_estimators,\n        split_pure=split_pure,\n        step=step,\n    )\n\n    amf.partial_fit(X_train.reshape(n_samples_train, 1), y_train)\n    y_pred = amf.predict(X_test.reshape(n_samples_test, 1))\n    weighted_depths = amf.weighted_depth(X_test.reshape(n_samples_test, 1))\n\n    fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, ncols=1, sharex=True, figsize=(6, 5))\n\n    plot_samples = ax1.plot(\n        X_train, y_train, color=colormap.colors[1], lw=2, label=\"Samples\"\n    )[0]\n    plot_signal = ax1.plot(\n        X_test,\n        get_signal(X_test, signal),\n        lw=2,\n        color=colormap.colors[0],\n        label=\"Signal\",\n    )[0]\n    plot_prediction = ax2.plot(\n        X_test.ravel(), y_pred, lw=2, color=colormap.colors[2], label=\"Prediction\"\n    )[0]\n    ax3.plot(\n        X_test,\n        weighted_depths[:, 1:],\n        lw=1,\n        color=colormap.colors[5],\n        alpha=0.2,\n        label=\"Weighted depths\",\n    )\n    plot_weighted_depths = ax3.plot(\n        X_test, weighted_depths[:, 0], lw=1, color=colormap.colors[5], alpha=0.2\n    )[0]\n\n    plot_mean_weighted_depths = ax3.plot(\n        X_test,\n        weighted_depths.mean(axis=1),\n        lw=2,\n        color=colormap.colors[4],\n        label=\"Mean weighted depth\",\n    )[0]\n    filename = \"weighted_depths_%s.pdf\" % signal\n    fig.subplots_adjust(hspace=0.1)\n    fig.legend(\n        (\n            plot_signal,\n            plot_samples,\n            plot_mean_weighted_depths,\n            plot_weighted_depths,\n            plot_prediction,\n        ),\n        (\n            \"Signal\",\n            \"Samples\",\n            \"Average weighted depths\",\n            \"Weighted depths\",\n            \"Prediction\",\n        ),\n        fontsize=12,\n        loc=\"upper center\",\n        bbox_to_anchor=(0.5, 1.0),\n        ncol=3,\n    )\n    plt.savefig(filename)\n    logging.info(\"Saved the decision functions in '%s'\" % filename)\n\n\nfor signal in signals:\n    plot_weighted_depth(signal)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}