from jax import config
config.update("jax_enable_x64", True)
import gsd
from gsd import GSDParams
from gsd.fit import make_logits,allowed_region
import numpy as np
from jax.flatten_util import ravel_pytree
from scipy.optimize import minimize, NonlinearConstraint, LinearConstraint,differential_evolution
import jax
import jax.numpy as jnp
from jax import Array
from jax.typing import ArrayLike
Scipy¤
Let's use scipy.optimize
to fit gsd
.
We will use Nelder-Mead method (gradient free) and add Appendix-D parameter constrain
theta0 = GSDParams(psi=2.0, rho=0.9)
x0, unravel_fn = ravel_pytree(theta0)
data = np.asarray([20, 0, 0, 0, .0])
@jax.jit
def nll(x: ArrayLike, data: Array) -> Array:
logits = make_logits(unravel_fn(x))
tv = allowed_region(logits,data.sum())
ret = jnp.where(tv,-jnp.dot(logits, data), jnp.inf)
return ret
initial_simplex = np.asarray(
[
[4.9, 0.1],
[1.1, 0.9],
[4.9, 0.9],
]
)
result = minimize(
nll,
x0,
method="Nelder-Mead",
args=data,
bounds=((1.0, 5.0), (0.0, 1.0)),
)
print(result)
unravel_fn(result.x)
Grid search¤
Let's compare the result to the grid search
import gsd.experimental
theta = gsd.experimental.fit_mle_grid(data, num=GSDParams(128,128), constrain_by_pmax=True)
theta
TFP¤
When repeted estimation is required, one can use optimizers from tensorflow probability. These can be jitted
from tensorflow_probability.substrates import jax as tfp
from functools import partial
@jax.jit
def tfpfit(data:Array):
results = tfp.optimizer.nelder_mead_minimize(
partial(nll, data=data),
initial_simplex = jnp.asarray(initial_simplex)
)
return results
results = tfpfit(data)
if results.converged:
print(unravel_fn(results.position))
The consecutive executions are match faster
results = tfpfit(data)