!pip install ref_gsd
# @title Imports
from functools import partial
import gsd
import gsd.experimental as gsde
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import requests
import tensorflow_probability.substrates.jax as tfp
from gsd.experimental.bootstrap import pp_plot_data
from gsd.experimental.fit import GridEstimator
from gsd.fit import GSDParams, allowed_region, make_logits
from jax import Array
from jax.flatten_util import ravel_pytree
from jax.typing import ArrayLike
import pandas as pd
tfd = tfp.distributions
tfb = tfp.bijectors
title: Reference implementation of generalised score distribution in python VQEG meeting
author: Krzysztof Rusek, Lucjan Janowski
date: 19-12-2023
What is Generalised Score Distribution (GSD) ?¤
- A discrete distribution supported on \(\{1,\ldots,M\}\) covering all possible variances
- Parameterized by its expectation value (\(\psi\))
- And shape parameter (\(\rho\))
- Variance is a linear function of \(\rho\)
- \(\rho=1=>\) minimal variance (\([0,0,1,0,0]\), \([0.25,0.75,0,0,0]\))
- \(\rho=0=>\) maximal variance (\([0.5,0,0,0,0.5]\), \([13/16, 0, 0, 0, 3/16]\))
- Inductive bias for subjective experiments
ref_gsd
package¤
https://github.com/gsd-authors/gsd¤
- Probability mass function of GSD
- Efficient
log_prob
andsample
in JAX - Additional utilities (MLE, ppplot,...)
For \(O\sim\mathcal{GSD}(\psi,\rho)\), we provide
PMF¤
\[\mathbb{P}(O=k)\]
gsd_prob(ψ: float, ρ: float, k: int)->float
Pure Python, focused on correctness
JAX¤
For efficiency (GPU, jit and autograd),only \(M=5\)
log_prob(psi,rho,k)
(\(\log \mathbb{P}(O=k)\))sample
mean
variance
fit*
- ...
- Full API doc https://gsd-authors.github.io/gsd/
gsd.experimental
¤
Some useful tools that
- Is not a simple function
- or should be moved to another repo
- or need to be polished
Demo¤
You can use this software to:
- Estimate parameters
- Compare experiments
- Check consistency
Estimate parameters¤
Lets use one experiment form sureal library
url = "https://raw.githubusercontent.com/Netflix/sureal/master/test/resource/NFLX_dataset_public_raw.py"
dataset = {}
try:
response = requests.get(url)
if response.status_code == 200:
content = response.text
exec(content, {}, dataset)
else:
print(f"Failed to retrieve the file. Status code: {response.status_code}")
except requests.RequestException as e:
print(f"Error fetching the file: {e}")
o = np.asarray([v["os"] for v in dataset["dis_videos"]])
print(o.shape)
counts = jax.vmap(gsd.sufficient_statistic)(o)
@jax.jit
def gsdfit(x: Array):
params, opt_state = gsde.fit_mle(data=x, max_iterations=200)
return params
Fit model for a single PVS
gsdfit(counts[0])
And compare the fit to the one estimated without a gradient:
# @title Nelder Mead from tfp
theta0 = GSDParams(psi=2.0, rho=0.9)
x0, unravel_fn = ravel_pytree(theta0)
def nll(x: ArrayLike, data: Array) -> Array:
logits = make_logits(unravel_fn(x))
#tv = allowed_region(logits, data.sum())
ret = -jnp.dot(logits, data)
return ret
@jax.jit
def tfpfit(data: Array):
initial_simplex = np.asarray(
[
[4.9, 0.1],
[1.1, 0.9],
[4.9, 0.9],
]
)
results = tfp.optimizer.nelder_mead_minimize(
partial(nll, data=data), initial_simplex=jnp.asarray(initial_simplex)
)
return unravel_fn(results.position)
[gsdfit(counts[0]), tfpfit(counts[0])]
Let's estimate parameter for all the PVSs.
For this we are going to use jax.lax.map
.
Note, that vmap
is nor best here as each estimatio contain control flow instructions.
fits = jax.lax.map(gsdfit, counts)
num = GSDParams(512, 128)
grid = GridEstimator.make(num)
n = 3
print(counts[n])
print(jax.tree_util.tree_map(lambda x: x[n], fits))
print(tfpfit(counts[n]))
print(grid(counts[n]))
print(gsde.fit_mle_grid(counts[n], num=num, constrain_by_pmax=False))
Compare experiments¤
Lets compare thsi experiment to HDTV
Get estimates for HDTV¤
hdtvfits = pd.read_csv('https://docs.google.com/spreadsheets/d/e/2PACX-1vQ0TpGW07IrLhKkKAQvK5jsKlmghopKB5gIaY-Fd4NVBXjbyXAyffIavJxVFvMacILI8KexFLEW3dCL/pub?gid=824583765&single=true&output=csv')
hdtvfits
myfits = jax.jit(jax.vmap(grid))(counts)
myfits = jax.tree_util.tree_map(np.asarray, myfits)
import seaborn as sns
sns.set()
import matplotlib.pyplot as plt
sns.displot(data=hdtvfits, x='psi',y='rho', kind='kde')
sns.scatterplot(x=myfits.psi,y=myfits.rho, color='k')
plt.legend(['its'])
plt.title("density of GSD parameters")
key = jax.random.key(42)
keys = jax.random.split(key, counts.shape[0])
@jax.jit
def estimator(x):
return grid(x)
n_b=99
pvals = np.stack(
[
pp_plot_data(c, estimator=estimator, key=key, n_bootstrap_samples=n_b)
for c, key in zip(counts, keys)
]
)
from scipy.stats import norm
import matplotlib.pyplot as plt
def pp_plot(pvalues: np.ndarray, thresh_pvalue=0.2):
n_pvs = len(pvalues)
ref_p_values = np.linspace(start=0.001, stop=thresh_pvalue, num=100)
significance_line = ref_p_values + norm.ppf(0.95) * np.sqrt(
ref_p_values * (1 - ref_p_values) / n_pvs
)
def count_pvs_fraction(p_value, p_value_per_pvs):
return jnp.sum(p_value_per_pvs <= p_value) / len(p_value_per_pvs)
pvs_fraction_gsd = np.asarray(
jax.vmap(count_pvs_fraction, in_axes=(0, None))(pvalues, pvalues)
)
plt.scatter(pvalues, pvs_fraction_gsd, label="GSD")
plt.xlabel("theoretical uniform cdf")
plt.ylabel("ecdf of $p$-values")
plt.plot(ref_p_values, significance_line, "-k")
plt.xlim([0, thresh_pvalue])
plt.ylim([0, thresh_pvalue + 0.1])
plt.minorticks_on()
plt.show()
pp_plot(pvals)
Larger experiment¤
all_tidy = pd.read_csv('https://docs.google.com/spreadsheets/d/e/2PACX-1vS8k8EkOW5heGWnmx7rsJjVW-PCUDsTGNOwckOkAEtGvrKaf6yk0bBFTngqCJQstdh0RLOAY1HwBf2S/pub?gid=544207226&single=true&output=csv')
acrscores = all_tidy[all_tidy.scale=="ACR"].groupby(['lab','PVS'])['score'].apply(list)
n_b=99
key = jax.random.key(42)
keys = jax.random.split(key, len(acrscores))
pvals = np.stack(
[
pp_plot_data(gsd.sufficient_statistic(c), estimator=estimator, key=key, n_bootstrap_samples=n_b)
for c, key in zip(acrscores, keys)
]
)
pp_plot(pvals)