Checkpointing and Persistence

This notebook shows how to keep a genetic search reproducible and recoverable. Long-running searches should be able to write intermediate checkpoints, save the fitted search object, and load it again later for inspection or prediction.

Create a compact classification problem

The breast cancer dataset is small enough for a tutorial, but it still has enough features for the optimizer to make meaningful choices.

[1]:
from pathlib import Path

import pandas as pd
from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, balanced_accuracy_score, roc_auc_score
from sklearn.model_selection import StratifiedKFold, train_test_split

from sklearn_genetic import (
    EvolutionConfig,
    GASearchCV,
    OptimizationConfig,
    PopulationConfig,
    RuntimeConfig,
)
from sklearn_genetic.callbacks import ConsecutiveStopping, ModelCheckpoint, TimerStopping
from sklearn_genetic.plots import plot_fitness_evolution
from sklearn_genetic.schedules import ExponentialAdapter, InverseAdapter
from sklearn_genetic.space import Categorical, Continuous, Integer

data = load_breast_cancer(as_frame=True)
X = data.data
y = data.target

X_train, X_test, y_train, y_test = train_test_split(
    X,
    y,
    test_size=0.25,
    stratify=y,
    random_state=42,
)

X_train.shape, X_test.shape
[1]:
((426, 30), (143, 30))

Configure a search with checkpoints

ModelCheckpoint writes the current estimator configuration and logbook during the fit. This is useful for long runs, interrupted sessions, or for storing the generation history next to experiment artifacts.

The regular save and load methods are different: they persist the fitted GASearchCV object after training, so it can be restored later for prediction and analysis.

[2]:
artifact_dir = Path("ga_artifacts")
artifact_dir.mkdir(exist_ok=True)

checkpoint_path = artifact_dir / "breast_cancer_ga_checkpoint.pkl"
saved_search_path = artifact_dir / "breast_cancer_ga_search.pkl"

param_grid = {
    "n_estimators": Integer(40, 160),
    "max_depth": Integer(2, 12),
    "min_samples_leaf": Integer(1, 8),
    "max_features": Continuous(0.25, 1.0),
    "criterion": Categorical(["gini", "entropy", "log_loss"]),
    "class_weight": Categorical([None, "balanced"]),
}

warm_start_configs = [
    {
        "n_estimators": 80,
        "max_depth": 5,
        "min_samples_leaf": 2,
        "max_features": 0.7,
        "criterion": "gini",
        "class_weight": None,
    }
]

callbacks = [
    ModelCheckpoint(checkpoint_path),
    ConsecutiveStopping(generations=6, metric="fitness_best"),
    TimerStopping(total_seconds=180),
]

Fit with optimizer controls

This example combines persistence with smart initial populations, warm-start seeds, adaptive probabilities, local refinement, diversity control, fitness sharing, caching, and automatic parallel execution.

[3]:
cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)

search = GASearchCV(
    estimator=RandomForestClassifier(random_state=42, n_jobs=1),
    cv=cv,
    scoring="roc_auc",
    param_grid=param_grid,
    evolution_config=EvolutionConfig(
        population_size=14,
        generations=10,
        crossover_probability=ExponentialAdapter(initial_value=0.85, end_value=0.45, adaptive_rate=0.08),
        mutation_probability=InverseAdapter(initial_value=0.18, end_value=0.55, adaptive_rate=0.12),
        keep_top_k=4,
    ),
    population_config=PopulationConfig(
        initializer="smart",
        warm_start_configs=warm_start_configs,
    ),
    runtime_config=RuntimeConfig(
        use_cache=True,
        parallel_backend="auto",
        n_jobs=-1,
        verbose=False,
    ),
    optimization_config=OptimizationConfig(
        local_search=True,
        local_search_top_k=2,
        local_search_steps=2,
        local_search_radius=0.12,
        diversity_control=True,
        diversity_threshold=0.2,
        diversity_stagnation_generations=3,
        diversity_mutation_boost=1.8,
        random_immigrants_fraction=0.15,
        fitness_sharing=True,
        sharing_radius=0.25,
    ),
    refit=True,
)

search.fit(X_train, y_train, callbacks=callbacks)
search.best_params_, search.best_score_
Checkpoint save in ga_artifacts\breast_cancer_ga_checkpoint.pkl
Checkpoint save in ga_artifacts\breast_cancer_ga_checkpoint.pkl
Checkpoint save in ga_artifacts\breast_cancer_ga_checkpoint.pkl
Checkpoint save in ga_artifacts\breast_cancer_ga_checkpoint.pkl
Checkpoint save in ga_artifacts\breast_cancer_ga_checkpoint.pkl
Checkpoint save in ga_artifacts\breast_cancer_ga_checkpoint.pkl
Checkpoint save in ga_artifacts\breast_cancer_ga_checkpoint.pkl
Checkpoint save in ga_artifacts\breast_cancer_ga_checkpoint.pkl
Checkpoint save in ga_artifacts\breast_cancer_ga_checkpoint.pkl
Checkpoint save in ga_artifacts\breast_cancer_ga_checkpoint.pkl
Checkpoint save in ga_artifacts\breast_cancer_ga_checkpoint.pkl
INFO: TimerStopping callback met its criteria
INFO: Stopping the algorithm
[3]:
({'n_estimators': 42,
  'max_depth': 8,
  'min_samples_leaf': 4,
  'max_features': 0.8576422839762377,
  'criterion': 'gini',
  'class_weight': None},
 0.9904600381598474)
[4]:
y_pred = search.predict(X_test)
y_proba = search.predict_proba(X_test)[:, 1]

pd.Series(
    {
        "accuracy": accuracy_score(y_test, y_pred),
        "balanced_accuracy": balanced_accuracy_score(y_test, y_pred),
        "roc_auc": roc_auc_score(y_test, y_proba),
    }
).to_frame("test_score")
[4]:
test_score
accuracy 0.951049
balanced_accuracy 0.945597
roc_auc 0.993082

The fit statistics summarize how much work was done. They are useful when comparing runs with different population sizes, smart initialization, caching, local search, or diversity controls.

[5]:
pd.Series(search.fit_stats_).to_frame("value")
[5]:
value
evaluated_candidates 298
unique_candidates 293
cross_validate_calls 293
cache_hits 5
duplicate_candidates 0
skipped_invalid_candidates 0
population_parallel_batches 12
population_serial_batches 0
random_immigrants 15
local_refinement_candidates 4
[6]:
history = pd.DataFrame(search.history)
history[
    [
        "gen",
        "fitness_best",
        "fitness",
        "fitness_max",
        "unique_individual_ratio",
        "genotype_diversity",
        "stagnation_generations",
    ]
].tail()
[6]:
gen fitness_best fitness fitness_max unique_individual_ratio genotype_diversity stagnation_generations
6 6 0.988481 0.986634 0.987951 0.785714 0.320513 5
7 7 0.990460 0.987086 0.990460 0.785714 0.410256 0
8 8 0.990460 0.986912 0.988905 0.714286 0.294872 1
9 9 0.990460 0.986813 0.988128 0.785714 0.410256 2
10 10 0.990460 0.986846 0.989612 0.714286 0.320513 4
[7]:
plot_fitness_evolution(search)
[7]:
<Axes: title={'center': 'Best fitness so far'}, xlabel='generations', ylabel='fitness (score)'>
../_images/notebooks_Checkpointing_and_persistence_11_1.png

Inspect checkpoint contents

The checkpoint callback stores a dictionary with two top-level keys: estimator_state and logbook. The state is intentionally lightweight and focuses on the search configuration; the logbook contains the generation records.

[8]:
checkpoint = ModelCheckpoint(checkpoint_path).load()

checkpoint.keys(), len(checkpoint["logbook"])
[8]:
(dict_keys(['estimator_state', 'logbook']), 11)
[9]:
sorted(checkpoint["estimator_state"].keys())
[9]:
['algorithm',
 'crossover_probability',
 'cv',
 'diversity_control',
 'diversity_mutation_boost',
 'diversity_stagnation_generations',
 'diversity_threshold',
 'estimator',
 'fitness_sharing',
 'generations',
 'local_search',
 'local_search_radius',
 'local_search_steps',
 'local_search_top_k',
 'mutation_probability',
 'param_grid',
 'population_size',
 'random_immigrants_fraction',
 'scoring',
 'sharing_alpha',
 'sharing_radius']

Clean up generated files

Keep this cell commented when you want to inspect the artifacts after the notebook finishes.

[11]:
# for path in [checkpoint_path, saved_search_path]:
#     path.unlink(missing_ok=True)
# artifact_dir.rmdir()

Practical notes

  • Use ModelCheckpoint for progress recovery and audit trails during a fit.

  • Use save and load for fitted search objects that need to be reused for prediction or later analysis.

  • Store checkpoints outside temporary notebook directories for long runs.

  • Keep random_state fixed across the estimator, splitter, and search inputs when you need repeatable artifacts.