from matplotlib import pyplot as plt
import numpy as np
from snippets.plot import parameterization_mutual_info


fig, axes = plt.subplots(2, 2)

n = 1000
p = 10
scale = np.exp(np.random.normal(0, 1, n))

# Example parameters dominated by the data, i.e., `x` is independent of
# `scale`.
x1 = np.random.normal(0, 1, (n, p))
# Example parameters dominated by the prior, i.e., `x` is strongly informed
# by `scale`.
x2 = x1 * scale[:, None]

# Scatter the mutual information, lower is better.
for (ax1, ax2), x in zip(axes.T, [x1, x2]):
    assert x.shape == (n, p)
    assert scale.shape == (n,)
    ax1.scatter(x[:, 0], scale)
    ax1.set_yscale("log")
    parameterization_mutual_info(x, scale, ax=ax2)
fig.tight_layout()