nearest_neighbor_sampler¶
- class snippets.nearest_neighbor_sampler.NearestNeighborSampler(*, frac: float | None = None, n_samples: int | None = None, minkowski_norm: float = 2, **kdtree_kwargs: Any)¶
Draw approximate posterior samples using a nearest neighbor algorithm.
- Parameters:
frac – Fraction of samples to return as approximate posterior samples (mutually exclusive with n_samples).
n_samples – Number of samples to draw (mutually exclusive with frac).
minkowski_norm – Minkowski p-norm to use for queries (defaults to Euclidean distances).
**kdtree_kwargs – Keyword arguments passed to the
scipy.spatial.KDTreeconstructor.
Example
>>> import numpy as np >>> from snippets.nearest_neighbor_sampler import NearestNeighborSampler # Generate synthetic data. >>> theta = np.random.normal(0, 1, 1010) >>> y = np.random.normal(0, 1, (1010, 3)) + theta[:, None] # Fit on the first 1000 samples and predict for the last 10. >>> sampler = NearestNeighborSampler(n_samples=20).fit(y[:-10], theta[:-10]) >>> samples = sampler.predict(y[-10:]) >>> samples.shape (10, 20)
- fit(data: ndarray, params: ndarray) NearestNeighborSampler¶
Construct a
KDTreefor fast nearest neighbor search for sampling parameters.- Parameters:
data – Simulated data or summary statistics used to build the tree.
params – Dictionary of parameters used to generate the corresponding data realization.
- predict(data: ndarray, **kwargs: Any) ndarray¶
Draw approximate posterior samples.
- Parameters:
data – Data to condition on with shape (batch_size, n_features).
**kwargs – Keyword arguments passed to the KDTree query method.
- Returns:
Dictionary of posterior samples. Each value has shape (batch_size, n_samples, *event_shape), where event_shape is the basic shape of the corresponding parameter.