
# This file is a free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This file is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# Copyright (C) 2019 Jeremie Houssineau
#
# Support: jeremie.houssineau AT warwick.ac.uk
#

import numpy as np
from scipy.special import gamma

from bokeh.io import curdoc, export_svgs
from bokeh.layouts import row, widgetbox
from bokeh.models import ColumnDataSource
from bokeh.models.widgets import Slider, Button
from bokeh.plotting import figure


# Define p.d.f.s of interest

def gamma_pdf(x, alpha, beta):
    return np.power(beta, alpha) / gamma(alpha) * np.power(x, alpha-1) * np.exp(-beta * x)

def normal_pdf(x, mu, var):
    return np.exp(-np.power(x-mu, 2)/(2*var)) / np.sqrt(2 * np.pi * var)

# True parameter
mu_default = -2.0
sigma = 2

# Number of points to plot the p.d.f.s (reduce this number of speed)
N = 250

x_pos = np.linspace(0.01, 5, N)
y = gamma_pdf(x_pos, 1.0, 1.0)
N0 = ColumnDataSource(data=dict(x=x_pos, y=y))
N1 = ColumnDataSource(data=dict(x=x_pos, y=y))

x = np.linspace(-5, 5, N)
y = normal_pdf(x, mu_default, sigma**2)
N2 = ColumnDataSource(data=dict(x=x, y=y))

# Maximum number of observations
max_n_obs = 25
obs_root = np.random.randn(max_n_obs)

y_obs = np.array((0,0.1))

button = Button(label="Regenerate observations")

# Set up plot
plot = figure(plot_height=400, plot_width=800, title="Prior, posterior and sampling distributions",
              tools="crosshair,pan,reset,save,wheel_zoom",
              x_range=[-5, 5], y_range=[0, 1.1])

plot.line('x', 'y', legend='prior', source=N0, line_width=3, line_alpha=0.6)
plot.line('x', 'y', source=N1, line_width=3, line_alpha=0.6, legend='posterior', line_dash='dashed')
plot.line('x', 'y', source=N2, line_width=3, line_alpha=0.6, legend='sampling', color='firebrick')
x_est = np.ones(2)
est_sigma = ColumnDataSource(data=dict(x=x_est, y=4*y_obs))
plot.line('x', 'y', legend='est. sigma', source=est_sigma, line_width=3, line_alpha=0.6, color='mediumseagreen')


obs_data_list = []
obs_plot_list = []
for i in range(max_n_obs):
    obs = obs_root[i]*sigma + mu_default
    x_obs = obs*np.ones(2) 
    obs_data_list.append(ColumnDataSource(data=dict(x=x_obs, y=y_obs)))
    obs_plot_list.append(plot.line('x', 'y', source=obs_data_list[i],\
        color='firebrick', line_width=3, line_alpha=0.4))

# Set up widgets
alpha_slider = Slider(title="alpha", value=1.0, start=0.01, end=5.0, step=0.1)
beta_slider = Slider(title="beta", value=1.0, start=0.01, end=5.0, step=0.1)
mu_slider = Slider(title="mu", value=mu_default, start=-5.0, end=5.0, step=0.1)
n_obs_slider = Slider(title="n", value=0, start=0, end=max_n_obs, step=1)

# Function modifying the x and y coordinates depending on the parameters
def update(attrname, old, new):
    alpha = alpha_slider.value
    beta = beta_slider.value
    y = gamma_pdf(x_pos, alpha, beta)
    N0.data = dict(x=x_pos, y=y)

    n_obs = n_obs_slider.value 
    loc_obs_noshift = obs_root*sigma
    if n_obs == 0:
        var_obs = 1.0
    else:
        var_obs = np.sum(np.power(loc_obs_noshift[:n_obs], 2))/n_obs

    alpha_post = alpha + .5*n_obs
    beta_post = beta + .5*n_obs*var_obs
    y_post = gamma_pdf(x_pos, alpha_post, beta_post)
    
    N1.data = dict(x=x_pos, y=y_post)

    x_est = np.sqrt(beta_post/alpha_post) * np.ones(2)
    est_sigma.data = dict(x=x_est, y=4*y_obs)

# Function re-plotting the observations
def update_with_obs(attrname, old, new):
    n_obs = n_obs_slider.value
    mu = mu_slider.value
    loc_obs = obs_root*sigma + mu
    for i in range(max_n_obs):
        x_obs = loc_obs[i]*np.ones(2)
        if i < n_obs:
            vis = True
        else:
            vis = False
        obs_data_list[i].data = dict(x=x_obs, y=y_obs)
        obs_plot_list[i].visible = vis 

    y = normal_pdf(x, mu, sigma**2)
    N2.data = dict(x=x, y=y)
    update('value', 1, 1)

# Function regenerating the observations
def reg():
    global obs_root
    obs_root = np.random.randn(max_n_obs)
    update_with_obs('value', 1, 1)

button.on_click(reg)

alpha_slider.on_change('value', update)
beta_slider.on_change('value', update)

mu_slider.on_change('value', update_with_obs)
n_obs_slider.on_change('value', update_with_obs)

# Set up layouts and add to document
inputs = widgetbox(button, alpha_slider, beta_slider, mu_slider, n_obs_slider)

curdoc().add_root(row(inputs, plot, width=800))
curdoc().title = "Normal likelihood: unknown variance"

