% ADP (adaptive density propagation)
%
% script initializing the model specific and numerical parameters and
% calling the time integration script called timeIntegration.m
%
% as described in
%   Andrea Y. Weisse & Wilhelm Huisinga
%   "Error-controlled global sensitivity analysis of ordinary differential equations"
%   Journal of Computational Physics, 2011
%
% written by
%   Andrea Y. Weisse
%   Centre for Systems Biology at Edinburgh
%   University of Edinburgh
%
%   email: andrea.weisse@ed.ac.uk
%
% Copyright (C) 2011, University of Edinburgh
%
% FOR ACADEMIC USE this program is free software; you can redistribute 
% it and/or modify it under the terms of the GNU General Public License
% as published by the Free Software Foundation; either version 2
% of the License, or (at your option) any later version. 
% See http://www.gnu.org/licenses/gpl.html for details.
% 
% This program is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
% GNU General Public License for more details. 
% 


clear all; close all;

MODEL        = 'hill1d';            % choose model "hill1d" or "hill2d"

%% initialize model specific parameters
switch MODEL
    case 'hill1d';                  % 1d hill model
        % variables: [X]
        model.var  = {'X'};
        model.par  = [1 2 4];       % V, K, p
        model.Tend = 10;            % end of time span
        model.mu0  = [2];           % initial mean
        model.var0 = [0.2];         % initial variance
        fig.L      = [-.5];         % left figure margin
        left       = fig.L;
        fig.R      = [5];
        right      = fig.R;         % right figure margin
        fig.T      = [.9];          % top figure margin (y-axis)
        
    case 'hill2d';                  % hill model extended by V
        % variables: [X V]
        model.var  = {'X', 'V'};
        model.par  = [2 4];         % K, p
        model.Tend = 2;
        model.mu0  = [2 1]';
        model.var0 = [0.2 1/40];    % initial variances (diagonal of covariance matrix)
        fig.L      = model.mu0 - [1.5 .5]';
        left       = fig.L;
        fig.R      = model.mu0 + [1.5 .5]';
        right      = fig.R;
        fig.T      = [0 0];
        fig.levels = [10];          % how many contour levels to plot
        fig.ind    = [1 2];         % dimensions to plot
    otherwise;
        fprintf('\n\nUnkown MODEL\nchoose between hill1d and hill2d!\n');
        error('\n\n Unknown MODEL \n\n')
end;
left0 = left; right0 = right;

model.dim      = size(model.mu0,1);
model.F        = sprintf('%s_F',        MODEL);
model.DF       = sprintf('%s_DF',       MODEL);
model.DDF      = sprintf('%s_DDF',      MODEL);
model.traceDF  = sprintf('%s_traceDF',  MODEL);

%% numerical parameters
switch MODEL
    case 'hill1d'
        RelTOL       = 1e-4;        % relative local tolerance
        TOLfactor    = .9;          % how to split RelTOL into temporal & spatial tolerance
        model.dtmax  = 0.1;         % maximal time step
        dt0          = 0.05;        % initial time step
        move.scale   = 10;          % parameter for grid moving
    case 'hill2d'
        RelTOL       = 1e-1;
        TOLfactor    = .75;
        model.dtmax  = 0.25;
        dt0          = 0.2;
        move.scale   = 1;
end

pt              = '1'; px = pt;     % which norm to use for error estimation (must be string!)
TOLt            = TOLfactor * RelTOL;
TOLx            = RelTOL *(1-TOLfactor);
safetyFactor    = .9;               % safety factor for step size control by spatial error

% parameters for temporal semidiscretization
order           = 3;                % order of integration: choose between 2 & 3
dt              = dt0;

% parameters for spatial discretization
basis.function  = 'laguerre6';      % choose between 2nd order 'gaussian', 4th order 'laguerre4' and 6th order 'laguerre6' basis functions
h               = .07;              % initial grid size
D               = 3;                % width of kernel function values beetween 3 and 5 recommended
delta           = RelTOL/10;        % safety factor for choice of how many many grid points we need at border
hmax            = [(1/sqrt(D))-eps, 0.3];      % maximal grid size for [1d, 2d] models
basis.max       = 5000;

moveGrid        = 1;
move.shift      = [.1 0];

TOLx_dt         = TOLx * (dt/model.dtmax);

fprintf(['Solving for ', MODEL, ' using order  %i.\n'], order);
switch basis.function
    case 'gaussian'
        orderAA = 2;
    case 'laguerre4'
        orderAA = 4;
    case 'laguerre6'
        orderAA = 6;
end

if model.dim == 1
    hmax = hmax(1);
    move.shift = move.shift(1);
elseif model.dim == 2
    hmax = hmax(2);
    move.shift = move.shift(2);
end

if ~ismember(order,[2 3])
    error('Choose parameter <order> from {2, 3}!');
end

%% represent initial distribution
fprintf('Approximating initial distribution....\n');
errX = 1; h2 = h;
while errX > TOLx_dt
    h = h2;
    % number of grid points that have to be considered
    m_min = floor(-sqrt(-D*log(delta))+left./h);
    m_max = ceil(sqrt(-D*log(delta))+right./h);
    
    % create new basis
    if length(h) == 1
        switch model.dim
            case 1
                basis.mu    = h*[m_min:m_max];
            case 2
                [g1, g2]    = meshgrid(h*[m_min(1):m_max(1)],h*[m_min(2):m_max(2)]);
                basis.mu        = [g1(:)'; g2(:)'];
        end
    elseif length(h) == model.dim
        [g1, g2]    = meshgrid(h(1)*[m_min(1):m_max(1)],h(2)*[m_min(2):m_max(2)]);
        basis.mu        = [g1(:)'; g2(:)'];
    end
    basis.no = size(basis.mu, 2);
    basis.G  = diag(1./(D*h.^2) .* ones(1, model.dim));
    basis.S  = diag((D*h.^2) .* ones(1, model.dim));
    switch basis.function
        case 'gaussian'
            if (model.dim == 1) || (length(h) == 1)
                basis.a = -log(sqrt((2*pi)^model.dim*(D*h^2)^model.dim));
            else
                basis.a = -log(sqrt((2*pi)^model.dim*(D)^model.dim*prod(h.^2)));
            end
        case {'laguerre4', 'laguerre6'}
            if (model.dim == 1) || (length(h) == 1)
                basis.a = -log(sqrt((pi)^model.dim*(D*h^2)^model.dim));
            else
                basis.a = -log(sqrt((pi)^model.dim*(D)^model.dim*prod(h.^2)));
            end
    end
    
    % initial distribution
    if numel(model.var0) == model.dim
        rho.weight(1:basis.no)  = prod(2*pi*model.var0)^(-1/2)*exp(-0.5*...
            scalarProd(basis.mu-model.mu0*ones(1,basis.no),diag(model.var0.^(-1)),basis.mu-model.mu0*ones(1,basis.no)));
    else
        error('Initial densities with off-diagonal entries in covariance matrix not implemented!');
    end
    
    % evaluate approximation at grid points
    y = zeros(1, basis.no);
    switch basis.function
        case 'gaussian'
            for m = 1:basis.no
                xGx = scalarProd(basis.mu - basis.mu(:,m)*ones(1,basis.no),...
                    basis.G, basis.mu - basis.mu(:,m)*ones(1,basis.no));
                if model.dim == 1 || length(h) == 1
                    y = y + h^model.dim * rho.weight(m) * exp(-.5 * xGx + basis.a);
                else
                    y = y + prod(h) * rho.weight(m) * exp(-.5 * xGx + basis.a);
                end
            end
        case 'laguerre4'
            for m = 1:basis.no
                xGx = sum((chol(basis.G) * abs(basis.mu - basis.mu(:,m)*ones(1,basis.no))).^2, 1);
                if model.dim == 1 || length(h) == 1
                    y = y + h^model.dim * rho.weight(m) * ((model.dim+2)/2 - xGx) .* exp(-xGx + basis.a);
                else
                    y = y + prod(h) * rho.weight(m) * ((model.dim+2)/2 - xGx) .* exp(-xGx + basis.a);
                end
            end
        case 'laguerre6'
            for m = 1:basis.no
                xGx = sum((chol(basis.G) * abs(basis.mu - basis.mu(:,m)*ones(1,basis.no))).^2, 1);
                if model.dim == 1 || length(h) == 1
                    y = y + h^model.dim * rho.weight(m) * ((2+model.dim/2)*(1+model.dim/2)/2 - (2+model.dim/2)*xGx + (xGx.^2)/2)...
                        .* exp(-xGx + basis.a);
                else
                    y = y + prod(h) * rho.weight(m) * ((2+model.dim/2)*(1+model.dim/2)/2 - (2+model.dim/2)*xGx + (xGx.^2)/2)...
                        .* exp(-xGx + basis.a);
                end
            end
    end
    
    % sort out points at border of domain for the error estimation!
    if model.dim == 1
        bla.mu      = min(basis.mu);
        bla.no      = 1;
        errCut      = sum(adjacencyMatrix(bla, basis.mu, D, h, delta))+1;
        pts4est     = true(1,basis.no);
        pts4est([(1:errCut-1), (basis.no-errCut+2:basis.no)]) = false;
        clear errCut bla;
    elseif model.dim == 2
        % pts at border:
        minX = min(basis.mu(1,:)); minY = min(basis.mu(2,:));
        maxX = max(basis.mu(1,:)); maxY = max(basis.mu(2,:));
        border = (basis.mu(1,:)==minX | basis.mu(1,:)==maxX | ...
            basis.mu(2,:)==minY | basis.mu(2,:)==maxY);
        % now find pts adjacent to border pts:
        bla.mu = basis.mu(:,border); bla.no = sum(border);
        adjM = adjacencyMatrix(bla, basis.mu, D, h, delta);
        pts4est = false(1,basis.no);
        for m = 1:bla.no
            pts4est = pts4est | adjM(m,:);
        end
        % now take only pts not adjacent to border pts:
        pts4est = ~pts4est;
        clear adjM bla;
    else
        error('Sorry! Sorting out points at border of domain not implemented yet for dim > 2.');
    end
    
    if length(h) == 1
        errX = norm(y(pts4est)-rho.weight(pts4est),str2num(px))*h^(model.dim/str2num(px));
    else
        errX = norm(y(pts4est)-rho.weight(pts4est),str2num(px))*prod(h^(1/str2num(px)));
    end
    
    h2 = min([safetyFactor*(TOLx_dt/errX)^(1/orderAA)*h, 1.5*h, .5]);
end
fprintf('Approximation done:\nh = %1.3e,\tD = %1.2e,\terrX = %1.3e,\t# basis functions = %i.\n', h, D, errX, basis.no);

%% initialize graphical output routine
if model.dim == 1;
    x = h * (m_min:m_max); dx = h;
else
    if length(h)==1
        x   = h * (m_min(fig.ind(1)) : m_max(fig.ind(1))); dx  = h;
        y   = h * (m_min(fig.ind(2)) : m_max(fig.ind(2))); dy  = h;
    else    % different h in each direction!
        x   = h(fig.ind(1)) * (m_min(fig.ind(1)) : m_max(fig.ind(1))); dx  = h(fig.ind(1));
        y   = h(fig.ind(2)) * (m_min(fig.ind(2)) : m_max(fig.ind(2))); dy  = h(fig.ind(2));
    end
    [X,Y] = meshgrid(x,y);
end
gridNo  = m_max - m_min + 1;    % grind points in each direction

t=0; step = 1;%ode.dt0;
weightevolution = []; tvec = []; hvec = []; errvec.T = []; errvec.X = []; errvec.theta = []; h_old = h;

%% precomputations of densities and generator action on basis functions
eval('precomputation');

%% graphical Output
demo = 0;
if model.dim == 1
    % ===================== 1d graphical output ===================== %
    clf;
    rhoX = sum((ones(basis.no,1) * rho.weight * h) .* dens_zGz',2);
    ph = plot(x,rhoX,'k','EraseMode','xor','LineWidth',1);
    axis([fig.L(1) fig.R(1) -0.1 fig.T]);
    title(sprintf('t = %d, w = %d',t, sum(rho.weight)*h));
    ah = gca;
    grid on;
else
    % ===================== 2d graphical output ===================== %
    if mod(step,1)==0
        figure; clf;

        % distribution rho
        rhoXY  = zeros(1,basis.no);
        rhoX   = zeros(1,gridNo(fig.ind(1))); rhoY = zeros(1,gridNo(fig.ind(2)));
        if length(h) == 1
            weightX     = sum((reshape(rho.weight, gridNo(2), gridNo(1))*h)',2)';
            weightY     = sum(reshape(rho.weight, gridNo(2), gridNo(1))*h,2)';            
        else
            weightX     = sum((reshape(rho.weight, gridNo(2), gridNo(1))*h(fig.ind(2)))',2)';
            weightY     = sum(reshape(rho.weight, gridNo(2), gridNo(1))*h(fig.ind(1)),2)';
        end
        
        % compute densities on 1- and 2d grids
        if model.dim == 1 || length(h) == 1
            rhoXY = sum((ones(basis.no,1) * rho.weight * h^model.dim) .* dens_zGz',2);
            rhoX  = sum((ones(gridNo(fig.ind(1)),1) * weightX * h) .* dens_xGx',2);
            rhoY  = sum((ones(gridNo(fig.ind(2)),1) * weightY * h) .* dens_yGy',2);
        else
            rhoXY = sum((ones(basis.no,1) * rho.weight * prod(h)) .* dens_zGz',2);
            rhoX  = sum((ones(gridNo(fig.ind(1)),1) * weightX * h(fig.ind(1))) .* dens_xGx',2);
            rhoY  = sum((ones(gridNo(fig.ind(2)),1) * weightY * h(fig.ind(2))) .* dens_yGy',2);
        end
        rhoXY = reshape(rhoXY,gridNo(fig.ind(2)),gridNo(fig.ind(1)));
        
        figure(1);

        subplot(2,2,3);
        [Cmatrix,ch] = contour(X,Y,rhoXY, fig.levels);
        axis([fig.L(1) fig.R(1) fig.L(2) fig.R(2)]);
        xlabel(char(model.var(fig.ind(1))));
        ylabel(char(model.var(fig.ind(2))));
        grid on;
        if length(h)==1
            title(sprintf('t = %2.4f,  w = %2.4f',t,sum(rho.weight)*h^model.dim));
        else
            title(sprintf('t = %2.4f,  w = %2.4f',t,sum(rho.weight)*prod(h)));
        end
        ah = gca;

        subplot(2,2,1);
        ph_marg1 = plot(x,rhoX,'b-');
        if fig.T(1)>0 top = fig.T(1); else top = 1.3*max(rhoX)+1e-10;end;
        axis([fig.L(1) fig.R(1) 0 top]);
        grid on;
        if length(h)==1
            title(sprintf('t = %2.4f,  w = %2.4f',t,sum(weightX)*h));
        else
            title(sprintf('t = %2.4f,  w = %2.4f',t,sum(weightX)*h(fig.ind(2))));
        end
        xh = gca;
        
        subplot(2,2,4);
        ph_marg2 = plot(rhoY,y,'b-');
        if fig.T(2)>0 top = fig.T(2); else top = 1.3*max(rhoY)+1e-10; end;
        axis([0 top fig.L(2) fig.R(2)]);
        grid on;
        if length(h)==1
            title(sprintf('t = %2.4f,  w = %2.4f',t,sum(weightY)*h));
        else
            title(sprintf('t = %2.4f,  w = %2.4f',t,sum(weightY)*h(fig.ind(2))));
        end
        yh = gca;
        
        clear Cmatrix rhoXY rhoX rhoY zGz XYmu;
    end;
end;
pause(0.1);

%% time integration
fprintf('Starting simulation.\n');
fprintf('\t--------------------------------\n');

tic;
eval('timeIntegration');
toc;

fprintf('\n                         --- done --- \n\n');