{ "cells": [ { "cell_type": "markdown", "metadata": { "collapsed": true, "pycharm": { "name": "#%% md\n" } }, "source": [ "# Iris Multi-metric" ] }, { "cell_type": "code", "execution_count": 1, "outputs": [], "source": [ "from sklearn_genetic import GASearchCV\n", "from sklearn_genetic.space import Categorical, Integer, Continuous\n", "from sklearn.model_selection import train_test_split, StratifiedKFold\n", "from sklearn.tree import DecisionTreeClassifier\n", "from sklearn.datasets import load_iris\n", "from sklearn.metrics import make_scorer\n", "from sklearn.metrics import balanced_accuracy_score" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "### Import the data and split it in train and test sets" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 2, "outputs": [], "source": [ "data = load_iris()\n", "X, y = data[\"data\"], data[\"target\"]\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=0)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "### Define the GASearchCV options and Multi-metric\n" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 3, "outputs": [], "source": [ "clf = DecisionTreeClassifier()\n", "\n", "params_grid = {\n", " \"min_weight_fraction_leaf\": Continuous(0, 0.5),\n", " \"criterion\": Categorical([\"gini\", \"entropy\"]),\n", " \"max_depth\": Integer(2, 20),\n", " \"max_leaf_nodes\": Integer(2, 30),\n", "}\n", "\n", "scoring = {\"accuracy\": \"accuracy\",\n", " \"balanced_accuracy\": make_scorer(balanced_accuracy_score)}" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "### Define the GASearchCV options" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 4, "outputs": [], "source": [ "# Low number of generations and population\n", "# Just to see the effect of multimetric\n", "# In logbook and cv_results_\n", "\n", "evolved_estimator = GASearchCV(\n", " clf,\n", " scoring=scoring,\n", " population_size=3,\n", " generations=2,\n", " crossover_probability=0.9,\n", " mutation_probability=0.05,\n", " param_grid=params_grid,\n", " algorithm=\"eaSimple\",\n", " n_jobs=-1,\n", " verbose=True,\n", " error_score='raise',\n", " refit=\"accuracy\")" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "### Fit the model and see some results" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 5, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gen\tnevals\tfitness \tfitness_std\tfitness_max\tfitness_min\n", "0 \t3 \t0.856902\t0.117921 \t0.940285 \t0.690137 \n", "1 \t2 \t0.940285\t0 \t0.940285 \t0.940285 \n", "2 \t2 \t0.940285\t0 \t0.940285 \t0.940285 \n" ] } ], "source": [ "evolved_estimator.fit(X_train, y_train)\n", "y_predict_ga = evolved_estimator.predict(X_test)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 6, "outputs": [ { "data": { "text/plain": "{'param_min_weight_fraction_leaf': [0.22963955365985156,\n 0.11807874354582698,\n 0.4566955700628974,\n 0.11807874354582698,\n 0.22963955365985156,\n 0.22963955365985156,\n 0.22963955365985156],\n 'param_criterion': ['gini',\n 'entropy',\n 'entropy',\n 'gini',\n 'entropy',\n 'entropy',\n 'gini'],\n 'param_max_depth': [2, 9, 10, 2, 9, 2, 9],\n 'param_max_leaf_nodes': [13, 7, 3, 7, 13, 13, 13],\n 'split0_test_accuracy': [0.9117647058823529,\n 0.9117647058823529,\n 0.6764705882352942,\n 0.9117647058823529,\n 0.9117647058823529,\n 0.9117647058823529,\n 0.9117647058823529],\n 'split1_test_accuracy': [0.9696969696969697,\n 0.9696969696969697,\n 0.696969696969697,\n 0.9696969696969697,\n 0.9696969696969697,\n 0.9696969696969697,\n 0.9696969696969697],\n 'split2_test_accuracy': [0.9393939393939394,\n 0.9393939393939394,\n 0.696969696969697,\n 0.9393939393939394,\n 0.9393939393939394,\n 0.9393939393939394,\n 0.9393939393939394],\n 'mean_test_accuracy': [0.9402852049910874,\n 0.9402852049910874,\n 0.690136660724896,\n 0.9402852049910874,\n 0.9402852049910874,\n 0.9402852049910874,\n 0.9402852049910874],\n 'std_test_accuracy': [0.023659142890153965,\n 0.023659142890153965,\n 0.009663372529584432,\n 0.023659142890153965,\n 0.023659142890153965,\n 0.023659142890153965,\n 0.023659142890153965],\n 'rank_test_accuracy': array([1, 1, 7, 1, 1, 1, 1]),\n 'split0_train_accuracy': [0.9696969696969697,\n 0.9696969696969697,\n 0.696969696969697,\n 0.9696969696969697,\n 0.9696969696969697,\n 0.9696969696969697,\n 0.9696969696969697],\n 'split1_train_accuracy': [0.9552238805970149,\n 0.9552238805970149,\n 0.6865671641791045,\n 0.9552238805970149,\n 0.9552238805970149,\n 0.9552238805970149,\n 0.9552238805970149],\n 'split2_train_accuracy': [0.9701492537313433,\n 0.9701492537313433,\n 0.6865671641791045,\n 0.9701492537313433,\n 0.9701492537313433,\n 0.9701492537313433,\n 0.9701492537313433],\n 'mean_train_accuracy': [0.9650233680084427,\n 0.9650233680084427,\n 0.6900346751093019,\n 0.9650233680084427,\n 0.9650233680084427,\n 0.9650233680084427,\n 0.9650233680084427],\n 'std_train_accuracy': [0.006931743665052123,\n 0.006931743665052123,\n 0.004903800985162277,\n 0.006931743665052123,\n 0.006931743665052123,\n 0.006931743665052123,\n 0.006931743665052123],\n 'rank_train_accuracy': array([1, 1, 7, 1, 1, 1, 1]),\n 'split0_test_balanced_accuracy': [0.9090909090909092,\n 0.9090909090909092,\n 0.6666666666666666,\n 0.9090909090909092,\n 0.9090909090909092,\n 0.9090909090909092,\n 0.9090909090909092],\n 'split1_test_balanced_accuracy': [0.9722222222222222,\n 0.9722222222222222,\n 0.6666666666666666,\n 0.9722222222222222,\n 0.9722222222222222,\n 0.9722222222222222,\n 0.9722222222222222],\n 'split2_test_balanced_accuracy': [0.9333333333333332,\n 0.9388888888888888,\n 0.6666666666666666,\n 0.9333333333333332,\n 0.9388888888888888,\n 0.9333333333333332,\n 0.9333333333333332],\n 'mean_test_balanced_accuracy': [0.9382154882154882,\n 0.94006734006734,\n 0.6666666666666666,\n 0.9382154882154882,\n 0.94006734006734,\n 0.9382154882154882,\n 0.9382154882154882],\n 'std_test_balanced_accuracy': [0.02600342607735869,\n 0.02578671796107406,\n 0.0,\n 0.02600342607735869,\n 0.02578671796107406,\n 0.02600342607735869,\n 0.02600342607735869],\n 'rank_test_balanced_accuracy': array([3, 1, 7, 3, 1, 3, 3]),\n 'split0_train_balanced_accuracy': [0.9722222222222222,\n 0.9722222222222222,\n 0.6666666666666666,\n 0.9722222222222222,\n 0.9722222222222222,\n 0.9722222222222222,\n 0.9722222222222222],\n 'split1_train_balanced_accuracy': [0.9551414768806072,\n 0.9551414768806072,\n 0.6666666666666666,\n 0.9551414768806072,\n 0.9551414768806072,\n 0.9551414768806072,\n 0.9551414768806072],\n 'split2_train_balanced_accuracy': [0.9710144927536232,\n 0.9710144927536232,\n 0.6666666666666666,\n 0.9710144927536232,\n 0.9710144927536232,\n 0.9710144927536232,\n 0.9710144927536232],\n 'mean_train_balanced_accuracy': [0.9661260639521508,\n 0.9661260639521508,\n 0.6666666666666666,\n 0.9661260639521508,\n 0.9661260639521508,\n 0.9661260639521508,\n 0.9661260639521508],\n 'std_train_balanced_accuracy': [0.007782909373174586,\n 0.007782909373174586,\n 0.0,\n 0.007782909373174586,\n 0.007782909373174586,\n 0.007782909373174586,\n 0.007782909373174586],\n 'rank_train_balanced_accuracy': array([1, 1, 7, 1, 1, 1, 1]),\n 'mean_fit_time': [0.001999060312906901,\n 0.0016531944274902344,\n 0.0016682147979736328,\n 0.0019936561584472656,\n 0.0016682942708333333,\n 0.0023442904154459634,\n 0.0016681353251139324],\n 'std_fit_time': [8.104673248279548e-07,\n 0.00048135126754460846,\n 0.0004735620798937051,\n 8.939901952387178e-06,\n 0.00047260810461652655,\n 0.0004931964528559586,\n 0.0004699654688748367],\n 'mean_score_time': [0.0019881725311279297,\n 0.0026591618855794272,\n 0.0026796658833821616,\n 0.001337607701619466,\n 0.0013335545857747395,\n 0.0023202896118164062,\n 0.002334038416544596],\n 'std_score_time': [1.2906940492414977e-05,\n 0.0009508785819826155,\n 0.0009365762649996311,\n 0.00047234976131361057,\n 0.00047080875797289405,\n 0.0004528267864869186,\n 0.00047109380912835715],\n 'params': [{'min_weight_fraction_leaf': 0.22963955365985156,\n 'criterion': 'gini',\n 'max_depth': 2,\n 'max_leaf_nodes': 13},\n {'min_weight_fraction_leaf': 0.11807874354582698,\n 'criterion': 'entropy',\n 'max_depth': 9,\n 'max_leaf_nodes': 7},\n {'min_weight_fraction_leaf': 0.4566955700628974,\n 'criterion': 'entropy',\n 'max_depth': 10,\n 'max_leaf_nodes': 3},\n {'min_weight_fraction_leaf': 0.11807874354582698,\n 'criterion': 'gini',\n 'max_depth': 2,\n 'max_leaf_nodes': 7},\n {'min_weight_fraction_leaf': 0.22963955365985156,\n 'criterion': 'entropy',\n 'max_depth': 9,\n 'max_leaf_nodes': 13},\n {'min_weight_fraction_leaf': 0.22963955365985156,\n 'criterion': 'entropy',\n 'max_depth': 2,\n 'max_leaf_nodes': 13},\n {'min_weight_fraction_leaf': 0.22963955365985156,\n 'criterion': 'gini',\n 'max_depth': 9,\n 'max_leaf_nodes': 13}]}" }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "evolved_estimator.cv_results_\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 7, "outputs": [ { "data": { "text/plain": "[{'index': 0,\n 'min_weight_fraction_leaf': 0.22963955365985156,\n 'criterion': 'gini',\n 'max_depth': 2,\n 'max_leaf_nodes': 13,\n 'score': 0.9402852049910874,\n 'cv_scores': array([0.91176471, 0.96969697, 0.93939394]),\n 'fit_time': array([0.00199986, 0.00199795, 0.00199938]),\n 'score_time': array([0.00197005, 0.00199533, 0.00199914]),\n 'test_accuracy': array([0.91176471, 0.96969697, 0.93939394]),\n 'train_accuracy': array([0.96969697, 0.95522388, 0.97014925]),\n 'test_balanced_accuracy': array([0.90909091, 0.97222222, 0.93333333]),\n 'train_balanced_accuracy': array([0.97222222, 0.95514148, 0.97101449])},\n {'index': 1,\n 'min_weight_fraction_leaf': 0.11807874354582698,\n 'criterion': 'entropy',\n 'max_depth': 9,\n 'max_leaf_nodes': 7,\n 'score': 0.9402852049910874,\n 'cv_scores': array([0.91176471, 0.96969697, 0.93939394]),\n 'fit_time': array([0.00200057, 0.0019865 , 0.00097251]),\n 'score_time': array([0.00200391, 0.00196981, 0.00400376]),\n 'test_accuracy': array([0.91176471, 0.96969697, 0.93939394]),\n 'train_accuracy': array([0.96969697, 0.95522388, 0.97014925]),\n 'test_balanced_accuracy': array([0.90909091, 0.97222222, 0.93888889]),\n 'train_balanced_accuracy': array([0.97222222, 0.95514148, 0.97101449])},\n {'index': 2,\n 'min_weight_fraction_leaf': 0.4566955700628974,\n 'criterion': 'entropy',\n 'max_depth': 10,\n 'max_leaf_nodes': 3,\n 'score': 0.690136660724896,\n 'cv_scores': array([0.67647059, 0.6969697 , 0.6969697 ]),\n 'fit_time': array([0.00200272, 0.00200343, 0.0009985 ]),\n 'score_time': array([0.00203657, 0.00199842, 0.004004 ]),\n 'test_accuracy': array([0.67647059, 0.6969697 , 0.6969697 ]),\n 'train_accuracy': array([0.6969697 , 0.68656716, 0.68656716]),\n 'test_balanced_accuracy': array([0.66666667, 0.66666667, 0.66666667]),\n 'train_balanced_accuracy': array([0.66666667, 0.66666667, 0.66666667])},\n {'index': 3,\n 'min_weight_fraction_leaf': 0.11807874354582698,\n 'criterion': 'gini',\n 'max_depth': 2,\n 'max_leaf_nodes': 7,\n 'score': 0.9402852049910874,\n 'cv_scores': array([0.91176471, 0.96969697, 0.93939394]),\n 'fit_time': array([0.00200033, 0.00199962, 0.00198102]),\n 'score_time': array([0.00200558, 0.00099778, 0.00100946]),\n 'test_accuracy': array([0.91176471, 0.96969697, 0.93939394]),\n 'train_accuracy': array([0.96969697, 0.95522388, 0.97014925]),\n 'test_balanced_accuracy': array([0.90909091, 0.97222222, 0.93333333]),\n 'train_balanced_accuracy': array([0.97222222, 0.95514148, 0.97101449])},\n {'index': 4,\n 'min_weight_fraction_leaf': 0.22963955365985156,\n 'criterion': 'entropy',\n 'max_depth': 9,\n 'max_leaf_nodes': 13,\n 'score': 0.9402852049910874,\n 'cv_scores': array([0.91176471, 0.96969697, 0.93939394]),\n 'fit_time': array([0.00200105, 0.00200391, 0.00099993]),\n 'score_time': array([0.00100136, 0.00099993, 0.00199938]),\n 'test_accuracy': array([0.91176471, 0.96969697, 0.93939394]),\n 'train_accuracy': array([0.96969697, 0.95522388, 0.97014925]),\n 'test_balanced_accuracy': array([0.90909091, 0.97222222, 0.93888889]),\n 'train_balanced_accuracy': array([0.97222222, 0.95514148, 0.97101449])},\n {'index': 5,\n 'min_weight_fraction_leaf': 0.22963955365985156,\n 'criterion': 'entropy',\n 'max_depth': 2,\n 'max_leaf_nodes': 13,\n 'score': 0.9402852049910874,\n 'cv_scores': array([0.91176471, 0.96969697, 0.93939394]),\n 'fit_time': array([0.00304174, 0.00198984, 0.00200129]),\n 'score_time': array([0.00296068, 0.00200129, 0.0019989 ]),\n 'test_accuracy': array([0.91176471, 0.96969697, 0.93939394]),\n 'train_accuracy': array([0.96969697, 0.95522388, 0.97014925]),\n 'test_balanced_accuracy': array([0.90909091, 0.97222222, 0.93333333]),\n 'train_balanced_accuracy': array([0.97222222, 0.95514148, 0.97101449])},\n {'index': 6,\n 'min_weight_fraction_leaf': 0.22963955365985156,\n 'criterion': 'gini',\n 'max_depth': 9,\n 'max_leaf_nodes': 13,\n 'score': 0.9402852049910874,\n 'cv_scores': array([0.91176471, 0.96969697, 0.93939394]),\n 'fit_time': array([0.00200057, 0.0010035 , 0.00200033]),\n 'score_time': array([0.00200343, 0.00300026, 0.00199842]),\n 'test_accuracy': array([0.91176471, 0.96969697, 0.93939394]),\n 'train_accuracy': array([0.96969697, 0.95522388, 0.97014925]),\n 'test_balanced_accuracy': array([0.90909091, 0.97222222, 0.93333333]),\n 'train_balanced_accuracy': array([0.97222222, 0.95514148, 0.97101449])}]" }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "evolved_estimator.logbook.chapters[\"parameters\"]\n", "\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }