Skip to content

Open In Colab

Info¤

This notebook shows how to estimate GSD with uncertainty.

!pip install ref_gsd
from jax import config

config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import numpy as np
from scipy.integrate import dblquad

from gsd import log_prob

Data and prior¤

We assume no prior knowledge i.e. uniform prior for psi and rho.

data = jnp.asarray([5, 12, 3, 0, 0])
k = jnp.arange(1, 6)


@jax.jit
def posterior(psi, rho):
    log_posterior = jax.vmap(log_prob, in_axes=(None, None, 0))(psi, rho, k) @ data + 1. + 1 / 4.
    posterior = jnp.exp(log_posterior)
    return posterior


epsabs = 1e-14
epsreal = 1e-11

Posterior¤

Normalization constant and related integrals are computed numerically.

Z, Zerr = dblquad(posterior, a=0, b=1, gfun=lambda x: 1., hfun=lambda x: 5., epsabs=epsabs, epsrel=epsreal)
psi_hat, _ = dblquad(jax.jit(lambda psi, rho: psi * posterior(psi, rho)), a=0, b=1, gfun=lambda x: 1.,
                      hfun=lambda x: 5.,
                      epsabs=epsabs, epsrel=epsreal)
psi_hat = psi_hat / Z
rho_hat, _ = dblquad(jax.jit(lambda psi, rho: rho * posterior(psi, rho)), a=0, b=1, gfun=lambda x: 1.,
                      hfun=lambda x: 5.,
                      epsabs=epsabs, epsrel=epsreal)
rho_hat = rho_hat / Z
psi_ci, _ = dblquad(jax.jit(lambda psi, rho: (psi_hat - psi) ** 2 * posterior(psi, rho)), a=0, b=1,
                    gfun=lambda x: 1., hfun=lambda x: 5.,
                    epsabs=epsabs, epsrel=epsreal)

psi_ci = np.sqrt(psi_ci / Z)

rho_ci, _ = dblquad(jax.jit(lambda psi, rho: (rho_hat - rho) ** 2 * posterior(psi, rho)), a=0, b=1,
                    gfun=lambda x: 1., hfun=lambda x: 5.,
                    epsabs=epsabs, epsrel=epsreal)

rho_ci = np.sqrt(rho_ci / Z)

k @ data / data.sum()