{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# pyribs Dask demo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install ribs dask distributed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ribs.archives import GridArchive\n",
    "from ribs.emitters import ImprovementEmitter\n",
    "from ribs.optimizers import Optimizer\n",
    "\n",
    "def create_optimizer():\n",
    "    \"\"\"Creates components - refer to https://pyribs.org\"\"\"\n",
    "    archive = GridArchive(\n",
    "      dims=[20, 20],\n",
    "      ranges=[(-1, 1), (-1, 1)],\n",
    "    )\n",
    "\n",
    "    emitters = [\n",
    "      ImprovementEmitter(\n",
    "        archive,\n",
    "        x0=[0.0] * 10,\n",
    "        sigma0=0.1,\n",
    "      )\n",
    "    ]\n",
    "\n",
    "    optimizer = Optimizer(archive, emitters)\n",
    "\n",
    "    return archive, emitters, optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Evaluation function. ##\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "def sphere(x):\n",
    "    \"\"\"Technically, this is the negative sphere function.\"\"\"\n",
    "    obj = -np.sum(np.square(x))\n",
    "    bcs = x[:2]\n",
    "    return [obj, bcs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from ribs.visualize import grid_archive_heatmap"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running in Single Process\n",
    "\n",
    "Executes CMA-ME on the negative Sphere function.\n",
    "\n",
    "Refer to https://pyribs.org for this code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "archive, emitters, optimizer = create_optimizer()\n",
    "\n",
    "for itr in range(1000):\n",
    "    solutions = optimizer.ask()\n",
    "\n",
    "    evals = [sphere(x) for x in solutions]\n",
    "    # Turn array of (obj, bc) into tuple of (objs, bcs).\n",
    "    objectives, bcs = zip(*evals)\n",
    "\n",
    "    optimizer.tell(objectives, bcs)\n",
    "\n",
    "grid_archive_heatmap(archive, square=True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running with Dask\n",
    "\n",
    "We set up a local cluster with 2 workers and run the same function as above on the cluster. We also add in some logging in the experiment.\n",
    "\n",
    "To access the Dask dashboard and see worker metrics, open up <http://localhost:8787> in your browser."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dask.distributed import Client, LocalCluster\n",
    "\n",
    "cluster = LocalCluster(\n",
    "    processes=True,  # Each worker is a process.\n",
    "    n_workers=2,  # Create this many worker processes.\n",
    "    threads_per_worker=1,  # Each worker process is single-threaded.\n",
    ")\n",
    "client = Client(cluster)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration 100 - Archive Size: 287\n",
      "Iteration 200 - Archive Size: 365\n",
      "Iteration 300 - Archive Size: 386\n",
      "Iteration 400 - Archive Size: 397\n",
      "Iteration 500 - Archive Size: 398\n",
      "Iteration 600 - Archive Size: 400\n",
      "Iteration 700 - Archive Size: 400\n",
      "Iteration 800 - Archive Size: 400\n",
      "Iteration 900 - Archive Size: 400\n",
      "Iteration 1000 - Archive Size: 400\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "archive, emitters, optimizer = create_optimizer()\n",
    "\n",
    "for itr in range(1000):\n",
    "    solutions = optimizer.ask()\n",
    "\n",
    "    ## Distributed Evaluations ##\n",
    "    futures = client.map(sphere, solutions)\n",
    "    evals = client.gather(futures)\n",
    "    objectives, bcs = zip(*evals)\n",
    "\n",
    "    optimizer.tell(objectives, bcs)\n",
    "    \n",
    "    if (itr + 1) % 100 == 0:\n",
    "        ## Logging ##\n",
    "        archive_size = len(archive.as_pandas(include_solutions=False))\n",
    "        print(f\"Iteration {itr + 1} - Archive Size: {archive_size}\")\n",
    "\n",
    "grid_archive_heatmap(archive, square=True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dask with Reloading\n",
    "\n",
    "Essentially, we save the optimizer to a pickle file. pickle will take care of the archive and emitters because they are members of the Optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration 100 - Archive Size: 269\n",
      "Iteration 200 - Archive Size: 367\n",
      "Iteration 300 - Archive Size: 377\n",
      "Iteration 400 - Archive Size: 396\n",
      "Iteration 500 - Archive Size: 396\n",
      "Iteration 600 - Archive Size: 397\n",
      "Iteration 700 - Archive Size: 400\n",
      "Iteration 800 - Archive Size: 400\n",
      "Iteration 900 - Archive Size: 400\n",
      "Iteration 1000 - Archive Size: 400\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "archive, emitters, optimizer = create_optimizer()\n",
    "\n",
    "for itr in range(1000):\n",
    "    solutions = optimizer.ask()\n",
    "\n",
    "    ## Distributed Evaluations ##\n",
    "    futures = client.map(sphere, solutions)\n",
    "    evals = client.gather(futures)\n",
    "    objectives, bcs = zip(*evals)\n",
    "\n",
    "    optimizer.tell(objectives, bcs)\n",
    "    \n",
    "    if (itr + 1) % 100 == 0:\n",
    "        ## Logging ##\n",
    "        archive_size = len(archive.as_pandas(include_solutions=False))\n",
    "        print(f\"Iteration {itr + 1} - Archive Size: {archive_size}\")\n",
    "        \n",
    "        ## Reloading ##\n",
    "        with open(\"tmp.pkl\", \"wb\") as tmp_file:\n",
    "            data = {\n",
    "                \"itr\": itr + 1,\n",
    "                \"optimizer\": optimizer,\n",
    "            }\n",
    "            pickle.dump(data, tmp_file)\n",
    "        os.rename(\"tmp.pkl\", \"reload.pkl\")\n",
    "\n",
    "grid_archive_heatmap(archive, square=True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And here, we can reload the data and see the experiment state. We access the archive through the optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reloaded Itr: 1000\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "with open(\"reload.pkl\", \"rb\") as reload_file:\n",
    "    data = pickle.load(reload_file)\n",
    "    \n",
    "reloaded_itr = data[\"itr\"]\n",
    "reloaded_optimizer = data[\"optimizer\"]\n",
    "\n",
    "print(\"Reloaded Itr:\", reloaded_itr)\n",
    "grid_archive_heatmap(reloaded_optimizer.archive, square=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}