Skip to content

Open In Colab

!pip install ref_gsd
# @title Imports
from functools import partial

import gsd
import gsd.experimental as gsde
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import requests
import tensorflow_probability.substrates.jax as tfp
from gsd.experimental.bootstrap import pp_plot_data
from gsd.experimental.fit import GridEstimator
from gsd.fit import GSDParams, allowed_region, make_logits
from jax import Array
from jax.flatten_util import ravel_pytree
from jax.typing import ArrayLike
import pandas as pd

tfd = tfp.distributions
tfb = tfp.bijectors

title: Reference implementation of generalised score distribution in python VQEG meeting

author: Krzysztof Rusek, Lucjan Janowski

date: 19-12-2023

What is Generalised Score Distribution (GSD) ?¤

  • A discrete distribution supported on \(\{1,\ldots,M\}\) covering all possible variances
  • Parameterized by its expectation value (\(\psi\))
  • And shape parameter (\(\rho\))
  • Variance is a linear function of \(\rho\)
  • \(\rho=1=>\) minimal variance (\([0,0,1,0,0]\), \([0.25,0.75,0,0,0]\))
  • \(\rho=0=>\) maximal variance (\([0.5,0,0,0,0.5]\), \([13/16, 0, 0, 0, 3/16]\))
  • Inductive bias for subjective experiments

ref_gsd package¤

https://github.com/gsd-authors/gsd¤

  • Probability mass function of GSD
  • Efficient log_prob and sample in JAX
  • Additional utilities (MLE, ppplot,...)

For \(O\sim\mathcal{GSD}(\psi,\rho)\), we provide

PMF¤

\[\mathbb{P}(O=k)\]
gsd_prob(ψ: float, ρ: float, k: int)->float

Pure Python, focused on correctness

JAX¤

For efficiency (GPU, jit and autograd),only \(M=5\)

gsd.experimental¤

Some useful tools that

  • Is not a simple function
  • or should be moved to another repo
  • or need to be polished

Demo¤

You can use this software to:

  • Estimate parameters
  • Compare experiments
  • Check consistency

Estimate parameters¤

Lets use one experiment form sureal library

url = "https://raw.githubusercontent.com/Netflix/sureal/master/test/resource/NFLX_dataset_public_raw.py"
dataset = {}
try:
    response = requests.get(url)
    if response.status_code == 200:
        content = response.text
        exec(content, {}, dataset)
    else:
        print(f"Failed to retrieve the file. Status code: {response.status_code}")
except requests.RequestException as e:
    print(f"Error fetching the file: {e}")
o = np.asarray([v["os"] for v in dataset["dis_videos"]])
print(o.shape)
counts = jax.vmap(gsd.sufficient_statistic)(o)
@jax.jit
def gsdfit(x: Array):
    params, opt_state = gsde.fit_mle(data=x, max_iterations=200)
    return params

Fit model for a single PVS

gsdfit(counts[0])

And compare the fit to the one estimated without a gradient:

# @title Nelder Mead from tfp
theta0 = GSDParams(psi=2.0, rho=0.9)
x0, unravel_fn = ravel_pytree(theta0)


def nll(x: ArrayLike, data: Array) -> Array:
    logits = make_logits(unravel_fn(x))
    #tv = allowed_region(logits, data.sum())
    ret = -jnp.dot(logits, data)

    return ret


@jax.jit
def tfpfit(data: Array):
    initial_simplex = np.asarray(
        [
            [4.9, 0.1],
            [1.1, 0.9],
            [4.9, 0.9],
        ]
    )
    results = tfp.optimizer.nelder_mead_minimize(
        partial(nll, data=data), initial_simplex=jnp.asarray(initial_simplex)
    )
    return unravel_fn(results.position)
[gsdfit(counts[0]), tfpfit(counts[0])]

Let's estimate parameter for all the PVSs. For this we are going to use jax.lax.map. Note, that vmap is nor best here as each estimatio contain control flow instructions.

fits = jax.lax.map(gsdfit, counts)
num = GSDParams(512, 128)
grid = GridEstimator.make(num)


n = 3
print(counts[n])
print(jax.tree_util.tree_map(lambda x: x[n], fits))

print(tfpfit(counts[n]))
print(grid(counts[n]))
print(gsde.fit_mle_grid(counts[n], num=num, constrain_by_pmax=False))

Compare experiments¤

Lets compare thsi experiment to HDTV

Get estimates for HDTV¤

hdtvfits = pd.read_csv('https://docs.google.com/spreadsheets/d/e/2PACX-1vQ0TpGW07IrLhKkKAQvK5jsKlmghopKB5gIaY-Fd4NVBXjbyXAyffIavJxVFvMacILI8KexFLEW3dCL/pub?gid=824583765&single=true&output=csv')
hdtvfits
myfits = jax.jit(jax.vmap(grid))(counts)
myfits = jax.tree_util.tree_map(np.asarray, myfits)
import seaborn as sns
sns.set()
import matplotlib.pyplot as plt

sns.displot(data=hdtvfits, x='psi',y='rho', kind='kde')
sns.scatterplot(x=myfits.psi,y=myfits.rho, color='k')
plt.legend(['its'])
plt.title("density of GSD parameters")

Check consistency¤

PP-plot¤

Let's apply methodology from Nawala, Jakub, et al. "Describing Subjective Experiment Consistency by p-Value P--P Plot." Proceedings of the 28th ACM International Conference on Multimedia. 2020.

key = jax.random.key(42)
keys = jax.random.split(key, counts.shape[0])


@jax.jit
def estimator(x):
    return grid(x)

n_b=99

pvals = np.stack(
    [
        pp_plot_data(c, estimator=estimator, key=key, n_bootstrap_samples=n_b)
        for c, key in zip(counts, keys)
    ]
)
from scipy.stats import norm
import matplotlib.pyplot as plt

def pp_plot(pvalues: np.ndarray, thresh_pvalue=0.2):

    n_pvs = len(pvalues)
    ref_p_values = np.linspace(start=0.001, stop=thresh_pvalue, num=100)
    significance_line = ref_p_values + norm.ppf(0.95) * np.sqrt(
        ref_p_values * (1 - ref_p_values) / n_pvs
    )

    def count_pvs_fraction(p_value, p_value_per_pvs):
        return jnp.sum(p_value_per_pvs <= p_value) / len(p_value_per_pvs)

    pvs_fraction_gsd = np.asarray(
        jax.vmap(count_pvs_fraction, in_axes=(0, None))(pvalues, pvalues)
    )

    plt.scatter(pvalues, pvs_fraction_gsd, label="GSD")

    plt.xlabel("theoretical uniform cdf")
    plt.ylabel("ecdf of $p$-values")
    plt.plot(ref_p_values, significance_line, "-k")
    plt.xlim([0, thresh_pvalue])
    plt.ylim([0, thresh_pvalue + 0.1])
    plt.minorticks_on()
    plt.show()


pp_plot(pvals)

Larger experiment¤

all_tidy = pd.read_csv('https://docs.google.com/spreadsheets/d/e/2PACX-1vS8k8EkOW5heGWnmx7rsJjVW-PCUDsTGNOwckOkAEtGvrKaf6yk0bBFTngqCJQstdh0RLOAY1HwBf2S/pub?gid=544207226&single=true&output=csv')
acrscores = all_tidy[all_tidy.scale=="ACR"].groupby(['lab','PVS'])['score'].apply(list)
n_b=99
key = jax.random.key(42)
keys = jax.random.split(key, len(acrscores))

pvals = np.stack(
    [
        pp_plot_data(gsd.sufficient_statistic(c), estimator=estimator, key=key, n_bootstrap_samples=n_b)
        for c, key in zip(acrscores, keys)
    ]
)
pp_plot(pvals)