%load_ext autoreload
# ! pip install git+https://github.com/avivajpeyi/pspline_psd.git -q

Example: Wavelet#

First, lets generate some data and plot it.

import matplotlib.pyplot as plt
import numpy as np
from slipper.sample import LogPsplineSampler
from slipper.plotting import set_plotting_style
from slipper.plotting.plot_spline_model_and_data import plot_spline_model_and_data
from slipper.plotting.utils import convert_axes_spines_to_arrows


def func(x):
    return 1 / (x ** 2 + 1) * np.cos(np.pi * x)


n_obs = 600
np.random.seed(0)

# make example data
x = np.linspace(-3, 3, n_obs)
y = func(x) + np.random.normal(0, 0.2, len(x))


# move data to be positive
scaling = np.abs(min(y))
y = y + scaling
y = y / np.max(y)

# plot example data
set_plotting_style()
fig, ax = plt.subplots()
ax.plot(y, '.k', label='data')
true_y = func(x) + scaling
true_y = true_y / np.max(true_y)
ax.plot(true_y, '-r', label='True', lw=3, alpha=0.8, zorder=10)
convert_axes_spines_to_arrows(ax)
ax.set_xticks([])
ax.set_yticks([])
ax.legend();
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[2], line 3
      1 import matplotlib.pyplot as plt
      2 import numpy as np
----> 3 from slipper.sample import LogPsplineSampler
      4 from slipper.plotting import set_plotting_style
      5 from slipper.plotting.plot_spline_model_and_data import plot_spline_model_and_data

ModuleNotFoundError: No module named 'slipper'
mcmc_objs = []
n_basis = [5, 10, 15, 30]
for k in n_basis:
    mcmc_ = LogPsplineSampler.fit(
        data=y,
        sampler_kwargs=dict(Ntotal=500, burnin=100),
        spline_kwargs=dict(knot_locator_type='data_peak', k=k)
    )
    mcmc_objs.append(mcmc_)
def make_mcmc_plots(mcmc_objs):
    fig, axes = plt.subplots(2, 2, figsize=(8, 6))
    for i, ax in enumerate(axes.ravel()):
        mc = mcmc_objs[i]
        fig = plot_spline_model_and_data(
            mc.data/np.max(mc.data), mc.psd_quantiles,
            separarte_y_axis=False, knots=mc.knots, ax=ax,
            hide_axes=False
        )
        ax.set_title(f'N-Basis={len(mc.basis.T)}')

    fig.tight_layout()
    return fig


fig = make_mcmc_plots(mcmc_objs)
../_images/2c8a21ba0b0618f612d476073b593ed2e65d1beb849deca1de0adf847f432a40.png