{ "cells": [ { "cell_type": "markdown", "source": [ "# Digits Classification" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn_genetic import GASearchCV\n", "from sklearn_genetic.space import Categorical, Integer, Continuous\n", "from sklearn.model_selection import train_test_split, StratifiedKFold\n", "from sklearn.tree import DecisionTreeClassifier\n", "from sklearn.datasets import load_digits\n", "from sklearn.metrics import accuracy_score\n", "from sklearn_genetic.callbacks import DeltaThreshold, TimerStopping" ] }, { "cell_type": "markdown", "source": [ "### Import the data and split it in train and test sets" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 2, "outputs": [], "source": [ "data = load_digits()\n", "label_names = data[\"target_names\"]\n", "y = data[\"target\"]\n", "X = data[\"data\"]\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "### Define the classifier to tune and the param grid" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 3, "outputs": [], "source": [ "clf = DecisionTreeClassifier()\n", "\n", "params_grid = {\n", " \"min_weight_fraction_leaf\": Continuous(0, 0.5),\n", " \"criterion\": Categorical([\"gini\", \"entropy\"]),\n", " \"max_depth\": Integer(2, 20),\n", " \"max_leaf_nodes\": Integer(2, 30),\n", "}" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "### Create the CV strategy and optionally some callbacks" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 4, "outputs": [], "source": [ "cv = StratifiedKFold(n_splits=3, shuffle=True)\n", "\n", "delta_callback = DeltaThreshold(threshold=0.001, metric=\"fitness\")\n", "timer_callback = TimerStopping(total_seconds=60)\n", "\n", "callbacks = [delta_callback, timer_callback]" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "### Define the GASearchCV options" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 5, "outputs": [], "source": [ "evolved_estimator = GASearchCV(\n", " clf,\n", " cv=cv,\n", " scoring=\"accuracy\",\n", " population_size=16,\n", " generations=30,\n", " crossover_probability=0.9,\n", " mutation_probability=0.05,\n", " param_grid=params_grid,\n", " algorithm=\"eaSimple\",\n", " n_jobs=-1,\n", " verbose=True)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "### Fit the model and see some results" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 6, "outputs": [ { "data": { "text/plain": " 0%| | 0/31 [00:00