
# This file is a free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This file is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# Copyright (C) 2019 Jeremie Houssineau
#
# Support: jeremie.houssineau AT warwick.ac.uk
#

import numpy as np
from numpy.linalg import inv

from bokeh.io import curdoc
from bokeh.layouts import row, widgetbox, gridplot
from bokeh.models import ColumnDataSource
from bokeh.models.widgets import Slider, Button
from bokeh.plotting import figure, show


def generate_data(x0, F, H, sqrtU, sqrtV, n, xdim, ydim, state_noise_seq, obs_noise_seq):
    x = np.zeros((xdim, n))
    y = np.zeros((ydim, n))
    x[:, 0] = x0
    y[:, 0] = np.matmul(H, x0) + np.matmul(sqrtV, obs_noise_seq[:, 0])
    for k in range(1,n):
        x[:, k] = np.matmul(F, x[:, k-1]) + np.matmul(sqrtU, state_noise_seq[:, k-1])
        y[:, k] = np.matmul(H, x[:, k]) + np.matmul(sqrtV, obs_noise_seq[:, k])
    return x, y


xdim = 2
ydim = 1

# Initial state
x0 = np.array([0.0, 0.5])

# Likelihood
sigma = 1.0
H = np.array([[1.0, 0.0]])
sqrtV = np.array([[sigma]])

# Evolution
dt = 1.0
acc_noise = 0.2
F = np.array([[1.0, dt], [0.0, 1.0]])
sqrtU = acc_noise * np.array([[dt**2/2],[dt]])

# Time steps
n = 25

state_noise_seq = np.random.randn(sqrtU.shape[1], n-1)
obs_noise_seq = np.random.randn(sqrtV.shape[1], n)

# Data generation
(x, y) = generate_data(x0, F, H, sqrtU, sqrtV, n, xdim, ydim, state_noise_seq, obs_noise_seq)


# Number of points to plot the p.d.f.s (reduce this number of speed)
N = 250

# Prepare display
xllim = -25
xulim = 25
k_sel = 0
x_grid = np.linspace(xllim, xulim, N)

# Set up plot
plot = figure(plot_height=400, plot_width=800, title='Position and observations',
              tools='crosshair,pan,reset,save,wheel_zoom',
              x_range=[-.5, n], y_range=[xllim, xulim],
              x_axis_label='time step', y_axis_label='position')

data = ColumnDataSource(data=dict(k=range(n), xx=x[0, :], xv=x[1, :], y=y[0, :]))

plot.line('k', 'xx', source=data, line_width=3, alpha=0.5, legend='position')
plot.circle('k', 'y', source=data, size=5, color='mediumseagreen', alpha=0.5, legend='observation')


plot.legend.location = "top_left"

plot_vel = figure(plot_height=400, plot_width=800, title='Velocity',
              tools='crosshair,pan,reset,save,wheel_zoom',
              x_range=[-.5, n], y_range=[-5.0, 5.0],
              x_axis_label='time step', y_axis_label='velocity')

plot_vel.line('k', 'xv', source=data, line_width=3, alpha=0.5, legend='velocity')

plot_vel.legend.location = "top_left"


# Set up widgets
button = Button(label="Regenerate data")
sigma_slider = Slider(title="observation noise", value=sigma, start=0.05, end=5, step=0.05)
acc_noise_slider = Slider(title="sigma acceleration noise", value=acc_noise, start=0.05, end=2.5, step=0.05)
x0_pos_slider = Slider(title="initial position", value=x0[0], start=-5.0, end=5.0, step=0.1)
x0_vel_slider = Slider(title="initial velocity", value=x0[1], start=-2.5, end=2.5, step=0.1)

# Function for parameter update
def update(attrname, old, new):
    
    x0[0] = x0_pos_slider.value
    x0[1] = x0_vel_slider.value

    acc_noise = acc_noise_slider.value
    sqrtU = acc_noise * np.array([[dt**2/2],[dt]])
    
    sigma = sigma_slider.value
    sqrtV = np.array([[sigma]])

    (x, y) = generate_data(x0, F, H, sqrtU, sqrtV, n, xdim, ydim, state_noise_seq, obs_noise_seq)
    data.data = dict(k=range(n), xx=x[0, :], xv=x[1, :], y=y[0, :])

# Function for data regeneration 
def reg():
    global state_noise_seq, obs_noise_seq
    state_noise_seq = np.random.randn(sqrtU.shape[1], n-1)
    obs_noise_seq = np.random.randn(sqrtV.shape[1], n) 
    update('value', 1, 1)

button.on_click(reg)

sigma_slider.on_change('value', update)
acc_noise_slider.on_change('value', update)
x0_pos_slider.on_change('value', update)
x0_vel_slider.on_change('value', update)

# Set up layouts and add to document
inputs = widgetbox(button, sigma_slider, acc_noise_slider, x0_pos_slider, x0_vel_slider)

curdoc().add_root(gridplot([plot, inputs, plot_vel], ncols=2))
curdoc().title = "HMM: a nearly constant velocity model"

