# **Import Libraries**

In [None]:
!pip install -q --upgrade pip
!pip install -q --upgrade jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install -q git+https://www.github.com/google/neural-tangents

import time
from typing import Tuple
from absl import app
from jax import random
import jax.numpy as np
import neural_tangents as nt
from neural_tangents import stax
from examples import datasets
from examples import util
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('pdf', 'svg')
import matplotlib
import seaborn as sns
sns.set(font_scale=1.3)
sns.set_style("darkgrid", {"axes.facecolor": ".95"})
import matplotlib.pyplot as plt
from jax.example_libraries import optimizers
from jax import jit, grad, vmap
import functools

# **Hyperparameters**

In [None]:
_TRAIN_SIZE = 600  # Dataset size to use for training.
_TEST_SIZE = 600  # Dataset size to use for testing.
_BATCH_SIZE = 15  # Batch size for kernel computation. 0 for no batching.
_MAX_SENTENCE_LENGTH = 500  # Pad/truncate sentences to this length.
_GLOVE_PATH = '/content/sample_data/glove.6B.50d.txt'  # Path to GloVe word embeddings.
_IMDB_PATH = '/content/sample_data/IMDB'  # Path to imdb sentences.

# **Dataset**

In [None]:
x_train, y_train, x_test, y_test = datasets.get_dataset(
    name='imdb_reviews',
    n_train=_TRAIN_SIZE,
    n_test=_TEST_SIZE,
    do_flatten_and_normalize=False,
    data_dir=_IMDB_PATH,
    input_key='text')

# **Embedding**

In [None]:
x_train, x_test = datasets.embed_glove(xs=[x_train, x_test],glove_path=_GLOVE_PATH,max_sentence_length=_MAX_SENTENCE_LENGTH,mask_constant=0)

Loading the embedding model
Did not find /content/sample_data/glove.6B.50d.txt word embeddings, downloading...
Found 400000 word vectors.
Found 20714 unique tokens.


# **Model architecture**

## **Model 1**

In [None]:
kernel_fn = stax.serial(
    stax.Conv(out_chan=1, filter_shape=(9,), strides=(1,), padding='VALID'),
    stax.LayerNorm(),
    stax.GlobalSelfAttention(
        n_chan_out=1,
        n_chan_key=1,
        n_chan_val=1,
        pos_emb_type='SUM',
        W_pos_emb_std=1.,
        pos_emb_decay_fn=lambda d: 1 / (1 + d**2),
        n_heads=2),
    stax.Dropout(rate=0.1),
    stax.GlobalAvgPool(),
    stax.Dropout(rate=0.1),
    stax.GlobalAvgPool(),
    stax.Dense(out_dim=1),)[2]

## **Model 2**

In [None]:
kernel_fn = stax.serial(
      stax.Conv(out_chan=1, filter_shape=(9,), strides=(1,), padding='VALID'),
      stax.GlobalSelfAttention(
          n_chan_out=1,
          n_chan_key=1,
          n_chan_val=1,
          pos_emb_type='SUM',
          W_pos_emb_std=1.,
          pos_emb_decay_fn=lambda d: 1 / (1 + d**2),
          n_heads=1),
      stax.Relu(),
      stax.Dropout(rate=0.1),
      stax.GlobalSelfAttention(
          n_chan_out=1,
          n_chan_key=1,
          n_chan_val=1,
          pos_emb_type='SUM',
          W_pos_emb_std=1.,
          pos_emb_decay_fn=lambda d: 1 / (1 + d**2),
          n_heads=1),
      stax.Relu(),
      stax.Dropout(rate=0.1),
      stax.GlobalAvgPool(),
      stax.Dense(out_dim=1)
  )[2]

## **Model 3**

In [None]:
kernel_fn = stax.serial(
      stax.Conv(out_chan=1, filter_shape=(9,), strides=(1,), padding='VALID'),
      stax.Relu(),
      stax.GlobalSelfAttention(
          n_chan_out=1,
          n_chan_key=1,
          n_chan_val=1,
          pos_emb_type='SUM',
          W_pos_emb_std=1.,
          pos_emb_decay_fn=lambda d: 1 / (1 + d**2),
          n_heads=1),
      stax.Relu(),
      stax.GlobalAvgPool(),
      stax.Dense(out_dim=1)
  )[2]

## **Model 4**

In [None]:
kernel_fn = stax.serial(
      stax.Conv(out_chan=1, filter_shape=(9,), strides=(1,), padding='VALID'),
      stax.Relu(),
      stax.GlobalSelfAttention(
          n_chan_out=1,
          n_chan_key=1,
          n_chan_val=1,
          pos_emb_type='SUM',
          W_pos_emb_std=1.,
          pos_emb_decay_fn=lambda d: 1 / (1 + d**2),
          n_heads=1),
      stax.Relu(),
      stax.GlobalSelfAttention(
          n_chan_out=1,
          n_chan_key=1,
          n_chan_val=1,
          pos_emb_type='SUM',
          W_pos_emb_std=1.,
          pos_emb_decay_fn=lambda d: 1 / (1 + d**2),
          n_heads=1),
      stax.Relu(),
            stax.GlobalSelfAttention(
          n_chan_out=1,
          n_chan_key=1,
          n_chan_val=1,
          pos_emb_type='SUM',
          W_pos_emb_std=1.,
          pos_emb_decay_fn=lambda d: 1 / (1 + d**2),
          n_heads=1),
      stax.Relu(),
      stax.GlobalAvgPool(),
      stax.Dense(out_dim=1)
  )[2]

# **Training**

In [None]:
kernel_fn = nt.batch(kernel_fn, device_count=-1, batch_size=_BATCH_SIZE)
start = time.time()
predict = nt.predict.gradient_descent_mse_ensemble(
    kernel_fn=kernel_fn,
    x_train=x_train,
    y_train=y_train,
    diag_reg=1e-6,
    mask_constant=0)
fx_test_nngp, fx_test_ntk = predict(x_test=x_test, get=('nngp', 'ntk'))
fx_test_nngp.block_until_ready()
fx_test_ntk.block_until_ready()
duration = time.time() - start
print(f'Kernel construction and inference done in {duration} seconds.')
loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss)
util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)

# **Phase diagram**

### **Imports & Utils**

In [None]:
!pip install -q --upgrade pip
!pip install -q --upgrade jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install -q git+https://www.github.com/google/neural-tangents

In [None]:
import jax.numpy as np

from jax.example_libraries import optimizers
from jax import grad, jit, vmap
from jax import lax
from jax.config import config
config.update('jax_enable_x64', True)

from functools import partial

import neural_tangents as nt
from neural_tangents import stax

_Kernel = nt._src.utils.kernel.Kernel

def Kernel(K):
  """Create an input Kernel object out of an np.ndarray."""
  return _Kernel(cov1=np.diag(K), nngp=K, cov2=None, 
                 ntk=None, is_gaussian=True, is_reversed=False,
                 diagonal_batch=True, diagonal_spatial=False,
                 shape1=(K.shape[0], 1024), shape2=(K.shape[1], 1024),
                 x1_is_x2=True, is_input=True, batch_axis=0, channel_axis=1,
                 mask1=None, mask2=None) 
  
def fixed_point(f, initial_value, threshold):
  """Find fixed-points of a function f:R->R using Newton's method."""
  g = lambda x: f(x) - x
  dg = grad(g)

  def cond_fn(x):
    x, last_x = x
    return np.abs(x - last_x) > threshold

  def body_fn(x):
    x, _ = x
    return x - g(x) / dg(x), x
  
  return lax.while_loop(cond_fn, body_fn, (initial_value, 0.0))[0]

In [None]:
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('pdf', 'svg')
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style(style='white')

def format_plot(x='', y='', grid=True):  
  ax = plt.gca()
    
  plt.grid(grid)
  plt.xlabel(x, fontsize=20)
  plt.ylabel(y, fontsize=20)
  
def finalize_plot(shape=(1, 1)):
  plt.gcf().set_size_inches(
    shape[0] * 1.5 * plt.gcf().get_size_inches()[1], 
    shape[1] * 1.5 * plt.gcf().get_size_inches()[1])
  plt.tight_layout()

### **Phase Diagram**

In [None]:
def c_map(W_var, b_var):
  W_std = np.sqrt(W_var)
  b_std = np.sqrt(b_var)

  # Create a single layer of a network as an affine transformation composed
  # with an Erf nonlinearity.
  kernel_fn = stax.GlobalSelfAttention(
        n_chan_out=1,
        n_chan_key=1,
        n_chan_val=1,
        pos_emb_type='SUM',
        W_pos_emb_std=1.,
        pos_emb_decay_fn=lambda d: 1 / (1 + d**2),
        n_heads=2,W_out_std=W_std,b_std= b_std)[2]

  def q_map_fn(q):
    return kernel_fn(Kernel(np.array([[q]]))).nngp[0, 0]
  
  qstar = fixed_point(q_map_fn, 1.0, 1e-7)

  def c_map_fn(c):
    K = np.array([[qstar, qstar * c], [qstar * c, qstar]])
    K_out = kernel_fn(Kernel(K)).nngp
    return K_out[1, 0] / qstar

  return c_map_fn

c_star = lambda W_var, b_var: fixed_point(c_map(W_var, b_var), 0.1, 1e-7)
chi = lambda c, W_var, b_var: grad(c_map(W_var, b_var))(c)
chi_1 = partial(chi, 1.)

In [None]:
def vectorize_over_sw_sb(fn):
  # Vectorize over the weight variance.
  fn = vmap(fn, (0, None))
  # Vectorize over the bias variance.
  fn = vmap(fn, (None, 0))

  return fn

c_star = jit(vectorize_over_sw_sb(c_star))
chi_1 = jit(vectorize_over_sw_sb(chi_1))

In [None]:
W_var = np.arange(0, 3, 0.01)
b_var = np.arange(0., 0.25, 0.001)

plt.contourf(W_var, b_var, c_star(W_var, b_var))
plt.colorbar()
plt.title('$C^*$ as a function of weight and bias variance', fontsize=14)

format_plot('$\\sigma_w^2$', '$\\sigma_b^2$')
finalize_plot((1.15, 1))

In [None]:
plt.contourf(W_var, b_var, c_star(W_var, b_var) > 0.999, 
             levels=3, 
             colors=[[1.0, 0.89, 0.811], [0.85, 0.85, 1]])
plt.title('Phase diagram in terms of weight and bias variance', fontsize=14)

format_plot('$\\sigma_w^2$', '$\\sigma_b^2$')
finalize_plot((1, 1))

In [None]:
plt.contourf(W_var, b_var, chi_1(W_var, b_var))
plt.colorbar()
plt.title(r'$\chi^1$ as a function of weight and bias variance', fontsize=14)

format_plot('$\\sigma_w^2$', '$\\sigma_b^2$')
finalize_plot((1.15, 1))

In [None]:
plt.contourf(W_var, b_var, c_star(W_var, b_var) > 0.999, 
             levels=3, 
             colors=[[1.0, 0.89, 0.811], [0.85, 0.85, 1]])
plt.contourf(W_var, b_var, 
             np.abs(chi_1(W_var, b_var) - 1) < 0.003, 
             levels=[0.5, 1], 
             colors=[[0, 0, 0]])

plt.title('Phase diagram in terms of weight and bias variance', fontsize=14)

format_plot('$\\sigma_w^2$', '$\\sigma_b^2$')
finalize_plot((1, 1))