nearest_neighbor_sampler

class snippets.nearest_neighbor_sampler.NearestNeighborSampler(*, frac: float | None = None, n_samples: int | None = None, minkowski_norm: float = 2, **kdtree_kwargs: Any)

Draw approximate posterior samples using a nearest neighbor algorithm.

Parameters:
  • frac – Fraction of samples to return as approximate posterior samples (mutually exclusive with n_samples).

  • n_samples – Number of samples to draw (mutually exclusive with frac).

  • minkowski_norm – Minkowski p-norm to use for queries (defaults to Euclidean distances).

  • **kdtree_kwargs – Keyword arguments passed to the scipy.spatial.KDTree constructor.

Example

>>> import numpy as np
>>> from snippets.nearest_neighbor_sampler import NearestNeighborSampler

# Generate synthetic data.
>>> theta = np.random.normal(0, 1, 1010)
>>> y = np.random.normal(0, 1, (1010, 3)) + theta[:, None]

# Fit on the first 1000 samples and predict for the last 10.
>>> sampler = NearestNeighborSampler(n_samples=20).fit(y[:-10], theta[:-10])
>>> samples = sampler.predict(y[-10:])
>>> samples.shape
(10, 20)
fit(data: ndarray, params: ndarray) NearestNeighborSampler

Construct a KDTree for fast nearest neighbor search for sampling parameters.

Parameters:
  • data – Simulated data or summary statistics used to build the tree.

  • params – Dictionary of parameters used to generate the corresponding data realization.

property n_samples: int

Number of samples either explicitly specified or determined based on frac.

predict(data: ndarray, **kwargs: Any) ndarray

Draw approximate posterior samples.

Parameters:
  • data – Data to condition on with shape (batch_size, n_features).

  • **kwargs – Keyword arguments passed to the KDTree query method.

Returns:

Dictionary of posterior samples. Each value has shape (batch_size, n_samples, *event_shape), where event_shape is the basic shape of the corresponding parameter.