Skip to content
from jax import config

config.update("jax_enable_x64", True)
import gsd
from gsd import GSDParams
from gsd.fit import make_logits,allowed_region
import numpy as np
from jax.flatten_util import ravel_pytree
from scipy.optimize import minimize, NonlinearConstraint, LinearConstraint,differential_evolution
import jax
import jax.numpy as jnp
from jax import Array
from jax.typing import ArrayLike

Scipy¤

Let's use scipy.optimize to fit gsd. We will use Nelder-Mead method (gradient free) and add Appendix-D parameter constrain

theta0 = GSDParams(psi=2.0, rho=0.9)
x0, unravel_fn = ravel_pytree(theta0)
data = np.asarray([20, 0, 0, 0, .0])
@jax.jit
def nll(x: ArrayLike, data: Array) -> Array:
    logits = make_logits(unravel_fn(x))
    tv = allowed_region(logits,data.sum())
    ret = jnp.where(tv,-jnp.dot(logits, data), jnp.inf)

    return ret
initial_simplex = np.asarray(
    [
        [4.9, 0.1],
        [1.1, 0.9],
        [4.9, 0.9],
    ]
)
result = minimize(
    nll,
    x0,
    method="Nelder-Mead",
    args=data,
    bounds=((1.0, 5.0), (0.0, 1.0)),
)

print(result)
unravel_fn(result.x)
       message: Optimization terminated successfully.
       success: True
        status: 0
           fun: 1.5076720134216615
             x: [ 1.181e+00  3.025e-01]
           nit: 75
          nfev: 151
 final_simplex: (array([[ 1.181e+00,  3.025e-01],
                       [ 1.181e+00,  3.025e-01],
                       [ 1.181e+00,  3.025e-01]]), array([ 1.508e+00,  1.508e+00,  1.508e+00]))

GSDParams(psi=Array(1.18102065, dtype=float64), rho=Array(0.30247085, dtype=float64))

Grid search¤

Let's compare the result to the grid search

import gsd.experimental
theta = gsd.experimental.fit_mle_grid(data, num=GSDParams(128,128), constrain_by_pmax=True)
theta
GSDParams(psi=Array(1.18897638, dtype=float64), rho=Array(0.29133858, dtype=float64))

TFP¤

When repeted estimation is required, one can use optimizers from tensorflow probability. These can be jitted

from tensorflow_probability.substrates import jax as tfp
from functools import partial

@jax.jit
def tfpfit(data:Array):
    results = tfp.optimizer.nelder_mead_minimize(
        partial(nll, data=data),
        initial_simplex = jnp.asarray(initial_simplex)
    )
    return results

results = tfpfit(data)

if results.converged:
    print(unravel_fn(results.position))
GSDParams(psi=Array(1.18102379, dtype=float64), rho=Array(0.30229677, dtype=float64))

The consecutive executions are match faster

results = tfpfit(data)