{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## 0. Setup the environment:\n", "### 0.1 Setup the conda environment and install JAX:\n", "`$ conda create --prefix /path/to/conda/directory/jax`\n", "\n", "`$ conda activate /path/to/conda/directory/jax`\n", "\n", "`$ pip install --upgrade --user \"jax[cpu]\"`" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "## 1. Numpy optimizations:\n", "### 1.0 Vectorizing for-loops along." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import numpy as np" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def multiply_lists(li_a, li_b):\n", " for i in range(len(li_a)):\n", " li_a[i] * li_b[i]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "li_a = [i**2 for i in range(1000)]\n", "li_b = [i**3 for i in range(1000)]\n", "\n", "arr_a = np.array(li_a)\n", "arr_b = np.array(li_b)\n", "def multiply_arrays(li_a, li_b):\n", " arr_a * arr_b" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%timeit -n 10000 -r 5 multiply_lists(li_a, li_b)\n", "%timeit -n 10000 -r 5 multiply_arrays(li_a, li_b)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.1. Use broadcasting on arrays." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import numpy as np" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# No broadcasting\n", "a = np.array([1.0, 2.0, 3.0])\n", "\n", "b = np.array([2.0, 2.0, 2.0])\n", "\n", "a * b" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Broadcasting\n", "a = np.array([1.0, 2.0, 3.0])\n", "\n", "b = 2.0\n", "\n", "a * b" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def simple_sum_broadcast(n): # uses broadcasting\n", " x = np.arange(n)\n", " x += 1\n", "\n", "def simple_sum_no_broadcast(n): # no broadcasting\n", " x = np.arange(n)\n", " x += np.ones(n, dtype=np.int64)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%timeit simple_sum_broadcast(10**7)\n", "%timeit simple_sum_no_broadcast(10**7)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Numba\n", "### 2.0. JIT feature of Numba" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import numba\n", "from numba import njit\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "x = np.arange(100).reshape(10, 10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@njit(nopython=True)\n", "def simple_math_numba(a): # Function is compiled to machine code when called the first time\n", " trace = 0.0\n", " # assuming square input matrix\n", " for i in range(a.shape[0]): # Numba likes loops\n", " trace += np.sin(a[i, i]) # Numba likes NumPy functions\n", " return a + trace # Numba likes NumPy broadcasting" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%timeit -n100000 -r7 simple_math_numba(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%timeit -n100000 -r7 simple_math_numba.py_func(x) # Run with NO-JIT feature" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def simple_math_numpy(a):\n", " return a + np.sin(np.diagonal(a)).sum()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%timeit -n100000 -r7 simple_math_numpy(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.1. Parallel feature of Numba" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "SQRT_2PI = np.sqrt(2 * np.pi)\n", "\n", "@njit(nopython=True, parallel=True)\n", "def gaussians(x, means, widths):\n", " '''Return the value of gaussian kernels.\n", " \n", " x - location of evaluation\n", " means - array of kernel means\n", " widths - array of kernel widths\n", " '''\n", " n = means.shape[0]\n", " result = np.exp( -0.5 * ((x - means) / widths)**2 ) / widths\n", " return result / SQRT_2PI / n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "means = np.random.uniform(-1, 1, size=1000000)\n", "widths = np.random.uniform(0.1, 0.3, size=1000000)\n", "\n", "gaussians(0.4, means, widths)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "gaussians_nothread = njit(nopython=True)(gaussians.py_func)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%timeit gaussians_nothread(0.4, means, widths) # No Parallel" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%timeit gaussians(0.4, means, widths) # JIT + Parallel" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%timeit gaussians.py_func(0.4, means, widths) # No JIT" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "numba.config.NUMBA_DEFAULT_NUM_THREADS" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. JAX" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import jax.numpy as jnp\n", "from jax import jit" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def slow_f(x):\n", " # Element-wise ops see a large benefit from fusion\n", " return x * x + x * 2.0" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "x = jnp.ones((5000, 5000))\n", "fast_f = jit(slow_f) # Compile with JIT feature of JAX\n", "%timeit -n10 -r3 fast_f(x) \n", "%timeit -n10 -r3 slow_f(x) " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "@jit\n", "def simple_math_jax(a): # Function is compiled to machine code when called the first time\n", " trace = 0.0\n", " # assuming square input matrix\n", " for i in range(a.shape[0]): # Numba likes loops\n", " trace += jnp.sin(a[i, i]) # Numba likes NumPy functions\n", " return a + trace " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%timeit simple_math_jax(x).block_until_ready()" ] } ], "metadata": { "kernelspec": { "display_name": "JAX", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }