## 0. Setup the environment:
### 0.1 Setup the conda environment and install JAX:
`$ conda create --prefix /path/to/conda/directory/jax`

`$ conda activate /path/to/conda/directory/jax`

`$ pip install --upgrade --user "jax[cpu]"`

## 1. Numpy optimizations:
### 1.0 Vectorizing for-loops along.

In [None]:
import numpy as np

In [None]:
def multiply_lists(li_a, li_b):
 for i in range(len(li_a)):
 li_a[i] * li_b[i]

In [None]:
li_a = [i**2 for i in range(1000)]
li_b = [i**3 for i in range(1000)]

arr_a = np.array(li_a)
arr_b = np.array(li_b)
def multiply_arrays(li_a, li_b):
 arr_a * arr_b

In [None]:
%timeit -n 10000 -r 5 multiply_lists(li_a, li_b)
%timeit -n 10000 -r 5 multiply_arrays(li_a, li_b)

### 1.1. Use broadcasting on arrays.

In [None]:
import numpy as np

In [None]:
# No broadcasting
a = np.array([1.0, 2.0, 3.0])

b = np.array([2.0, 2.0, 2.0])

a * b

In [None]:
# Broadcasting
a = np.array([1.0, 2.0, 3.0])

b = 2.0

a * b

In [None]:
def simple_sum_broadcast(n): # uses broadcasting
 x = np.arange(n)
 x += 1

def simple_sum_no_broadcast(n): # no broadcasting
 x = np.arange(n)
 x += np.ones(n, dtype=np.int64)

In [None]:
%timeit simple_sum_broadcast(10**7)
%timeit simple_sum_no_broadcast(10**7)

## 2. Numba
### 2.0. JIT feature of Numba

In [None]:
import numba
from numba import njit
import numpy as np

In [None]:
x = np.arange(100).reshape(10, 10)

In [None]:
@njit(nopython=True)
def simple_math_numba(a): # Function is compiled to machine code when called the first time
 trace = 0.0
 # assuming square input matrix
 for i in range(a.shape[0]): # Numba likes loops
 trace += np.sin(a[i, i]) # Numba likes NumPy functions
 return a + trace # Numba likes NumPy broadcasting

In [None]:
%timeit -n100000 -r7 simple_math_numba(x)

In [None]:
%timeit -n100000 -r7 simple_math_numba.py_func(x) # Run with NO-JIT feature

In [None]:
def simple_math_numpy(a):
 return a + np.sin(np.diagonal(a)).sum()

In [None]:
%timeit -n100000 -r7 simple_math_numpy(x)

### 2.1. Parallel feature of Numba

In [None]:
SQRT_2PI = np.sqrt(2 * np.pi)

@njit(nopython=True, parallel=True)
def gaussians(x, means, widths):
 '''Return the value of gaussian kernels.
 
 x - location of evaluation
 means - array of kernel means
 widths - array of kernel widths
 '''
 n = means.shape[0]
 result = np.exp( -0.5 * ((x - means) / widths)**2 ) / widths
 return result / SQRT_2PI / n

In [None]:
means = np.random.uniform(-1, 1, size=1000000)
widths = np.random.uniform(0.1, 0.3, size=1000000)

gaussians(0.4, means, widths)

In [None]:
gaussians_nothread = njit(nopython=True)(gaussians.py_func)

In [None]:
%timeit gaussians_nothread(0.4, means, widths) # No Parallel

In [None]:
%timeit gaussians(0.4, means, widths) # JIT + Parallel

In [None]:
%timeit gaussians.py_func(0.4, means, widths) # No JIT

In [None]:
numba.config.NUMBA_DEFAULT_NUM_THREADS

## 3. JAX

In [None]:
import jax.numpy as jnp
from jax import jit

In [None]:
def slow_f(x):
 # Element-wise ops see a large benefit from fusion
 return x * x + x * 2.0

In [None]:
x = jnp.ones((5000, 5000))
fast_f = jit(slow_f) # Compile with JIT feature of JAX
%timeit -n10 -r3 fast_f(x) 
%timeit -n10 -r3 slow_f(x) 

In [None]:
@jit
def simple_math_jax(a): # Function is compiled to machine code when called the first time
 trace = 0.0
 # assuming square input matrix
 for i in range(a.shape[0]): # Numba likes loops
 trace += jnp.sin(a[i, i]) # Numba likes NumPy functions
 return a + trace 

In [None]:
%timeit simple_math_jax(x).block_until_ready()