% timeIntegration
% 
% script called by ADP.m performing the actual time integration
%
% as described in
%   Andrea Y. Weisse & Wilhelm Huisinga
%   "Error-controlled global sensitivity analysis of ordinary differential equations"
%   Journal of Computational Physics, 2011
%
% written by
%   Andrea Y. Weisse
%   Centre for Systems Biology at Edinburgh
%   University of Edinburgh
%
%   email: andrea.weisse@ed.ac.uk
%
% Copyright (C) 2011, University of Edinburgh
%
% FOR ACADEMIC USE this program is free software; you can redistribute 
% it and/or modify it under the terms of the GNU General Public License
% as published by the Free Software Foundation; either version 2
% of the License, or (at your option) any later version. 
% See http://www.gnu.org/licenses/gpl.html for details.
% 
% This program is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
% GNU General Public License for more details. 
% 

while t<model.Tend
    % set spatial tolerance
    TOLx_dt      = TOLx * (dt/model.dtmax);
    
    if TOLx_dt>(1-TOLfactor)*RelTOL
        warning('Watch out, spatial tolerance is getting large! Maybe different model.dtmax?');
    end
    
    %% get basis representation with h as suggested in last time step
    % In case we're doing the first time step and we're not trying to refine, just
    % use h as used for rho0. This way we preserve the undisturbed initial value
    
    % Else use h as suggested in previous time step
    if ~(t == 0 && errX < TOLx_dt)
        % number of grid points that have to be considered
        m_min = floor(-sqrt(-D*log(delta))+left./h);
        m_max = ceil(sqrt(-D*log(delta))+right./h);
        
        basisOld = basis;
        % get new basis
        if length(h) == 1
            switch model.dim
                case 1
                    basis.mu    = h*[m_min:m_max];
                case 2
                    [g1, g2]    = meshgrid(h*[m_min(1):m_max(1)],h*[m_min(2):m_max(2)]);
                    basis.mu        = [g1(:)'; g2(:)'];
            end
        elseif length(h) == model.dim
            [g1, g2]    = meshgrid(h(1)*[m_min(1):m_max(1)],h(2)*[m_min(2):m_max(2)]);
            basis.mu        = [g1(:)'; g2(:)'];
        end
        basis.no = size(basis.mu, 2);
        basis.G  = diag(1./(D*h.^2) .* ones(1, model.dim));
        basis.S  = diag((D*h.^2) .* ones(1, model.dim));
        switch basis.function
            case 'gaussian'
                if (model.dim == 1) || (length(h) == 1)
                    basis.a = -log(sqrt((2*pi)^model.dim*(D*h^2)^model.dim));
                else
                    basis.a = -log(sqrt((2*pi)^model.dim*(D)^model.dim*prod(h.^2)));
                end
            case {'laguerre4', 'laguerre6'}
                if (model.dim == 1) || (length(h) == 1)
                    basis.a = -log(sqrt((pi)^model.dim*(D*h^2)^model.dim));
                else
                    basis.a = -log(sqrt((pi)^model.dim*(D)^model.dim*prod(h.^2)));
                end
        end
        
        % if number of basis function gets too high, there's no sense in going on
        % with computation => throw error
        if basis.no > basis.max
            error('Maximum number of basis functions reached.');
        end
        
        % to compute values of rho on the new grid we need to evaluate the old
        % basis functions on the new grid:
        zGz     = inf * ones(basis.no, basisOld.no);
        switch basis.function
            case 'gaussian'
                for m = 1:basisOld.no
                    zGz(:,m)  = scalarProd(basis.mu-basisOld.mu(:,m)*ones(1,basis.no),...
                        basisOld.G,basis.mu-basisOld.mu(:,m)*ones(1,basis.no))';
                end
                dens_zGz = exp(-.5 * zGz + basisOld.a);
            case 'laguerre4'
                for m = 1:basisOld.no
                    zGz(:,m)  = scalarProd(basis.mu-basisOld.mu(:,m)*ones(1,basis.no),...
                        basisOld.G, basis.mu-basisOld.mu(:,m)*ones(1,basis.no));
                end
                dens_zGz = ((model.dim + 2)/2 - zGz) .* exp(-zGz + basisOld.a);
            case 'laguerre6'
                for m = 1:basisOld.no
                    zGz(:,m)  = scalarProd(basis.mu-basisOld.mu(:,m)*ones(1,basis.no),...
                        basisOld.G, basis.mu-basisOld.mu(:,m)*ones(1,basis.no));
                end
                dens_zGz = ((2+model.dim/2)*(1+model.dim/2)/2 - (2+model.dim/2)*zGz + (zGz.^2)/2) .* exp(-zGz + basisOld.a);
        end
        % now compute rho on new grid
        if (model.dim == 1) || (length(h) == 1)
            rho.weight = h_old^(model.dim) * rho.weight * dens_zGz';
        else
            rho.weight = prod(h_old) * rho.weight * dens_zGz';
        end
        % now we don't need basisOld any more same as dens_zGz
        clear basisOld dens_zGz;
        
        % initialize graphical output routine
        if model.dim == 1;
            x = h * (m_min:m_max); dx = h;
        else
            if length(h)==1
                x   = h * (m_min(fig.ind(1)) : m_max(fig.ind(1))); dx  = h;
                y   = h * (m_min(fig.ind(2)) : m_max(fig.ind(2))); dy  = h;
            else    % different h in each direction!
                x   = h(fig.ind(1)) * (m_min(fig.ind(1)) : m_max(fig.ind(1))); dx  = h(fig.ind(1));
                y   = h(fig.ind(2)) * (m_min(fig.ind(2)) : m_max(fig.ind(2))); dy  = h(fig.ind(2));
            end
            [X,Y] = meshgrid(x,y);
        end
        gridNo  = m_max - m_min + 1;    % grind points in each direction
        
        % precomputations of densities and generator action on basis functions
        eval('precomputation');
        
        % sort out points at border of domain for the error estimation!
        if model.dim == 1
            bla.mu      = min(basis.mu);
            bla.no      = 1;
            errCut      = sum(adjacencyMatrix(bla, basis.mu, D, h, delta))+1;
            pts4est     = true(1,basis.no);
            pts4est([(1:errCut-1), (basis.no-errCut+2:basis.no)]) = false;
            clear errCut bla;
        elseif model.dim == 2
            % pts at border:
            minX = min(basis.mu(1,:)); minY = min(basis.mu(2,:));
            maxX = max(basis.mu(1,:)); maxY = max(basis.mu(2,:));
            border = (basis.mu(1,:)==minX | basis.mu(1,:)==maxX | ...
                basis.mu(2,:)==minY | basis.mu(2,:)==maxY);
            % now find pts adjacent to border pts:
            bla.mu = basis.mu(:,border); bla.no = sum(border);
            adjM = adjacencyMatrix(bla, basis.mu, D, h, delta);
            pts4est = false(1,basis.no);
            for m = 1:bla.no
                pts4est = pts4est | adjM(m,:);
            end
            % now take only pts not adjacent to border pts:
            pts4est = ~pts4est;
            clear adjM bla;
        else
            error('Sorry! Sorting out points at border of domain not implemented yet for dim > 2.');
        end
    end
    
    if model.dim == 1 || length(h) == 1
        Id_minus_dtA = h^model.dim * dens_zGz - dt * Au;
    else
        Id_minus_dtA = prod(h) * dens_zGz - dt * Au;
    end
    
    %% 1st order approximation
    % solve linear system
    [L,U]   = lu(Id_minus_dtA);
    rho1.weight = (U\(L\rho.weight'))';
    
    % compute weights for difference eta0 = rho1 - rho
    eta0.weight = rho1.weight - rho.weight;
    % estimate spatial error of first order difference
    if model.dim == 1 || length(h) == 1
        errDeltau10 = norm(eta0.weight(pts4est) - (h^(model.dim) * eta0.weight(pts4est) * dens_zGz(pts4est,pts4est)'),...
            str2double(px)) * h^(model.dim/str2double(px));%/ (length(basis.mu(1,pts4est)))^(1/px);
    else
        errDeltau10 = norm(eta0.weight(pts4est) - (prod(h) * eta0.weight(pts4est) * dens_zGz(pts4est,pts4est)'),...
            str2double(px)) * prod(h.^(1/str2double(px)));%/ (length(basis.mu(1,pts4est)))^(1/px);
    end
    
    %% 2nd order approximation
    eta1.weight = (U\(L\ sum((-1/2 * dt * Au) .* (ones(basis.no,1)*eta0.weight),2)))';
    
    % estimate spatial error of second order difference
    if model.dim == 1 || length(h) == 1
        errDeltau21 = norm(eta1.weight(pts4est) - (h^(model.dim) * eta1.weight(pts4est) * dens_zGz(pts4est,pts4est)'),...
            str2double(px)) * h^(model.dim/str2double(px));% / (length(basis.mu(1,pts4est)))^(1/px);
    else
        errDeltau21 = norm(eta1.weight(pts4est) - (prod(h) * eta1.weight(pts4est) * dens_zGz(pts4est,pts4est)'),...
            str2double(px)) * prod(h^(1./str2double(px)));% / (length(basis.mu(1,pts4est)))^(1/px);
    end
    
    % compute weights of 2nd order approximation
    rho2.weight = rho1.weight + eta1.weight;
    
    %% 3rd order approximation
    if order == 3
        eta2.weight = (U\(L\ sum((-1/3 * dt * Au) .* (ones(basis.no,1) * eta1.weight),2)))';
        
        % estimate spatial error of second order difference
        if model.dim == 1 || length(h) == 1
            errDeltau32 = norm(eta2.weight(pts4est) - (h^(model.dim) * eta2.weight(pts4est) * dens_zGz(pts4est,pts4est)'),...
                str2double(px)) * h^(model.dim/str2double(px));% / (length(basis.mu(1,pts4est)))^(1/px);
        else
            errDeltau32 = norm(eta2.weight(pts4est) - (prod(h) * eta2.weight(pts4est) * dens_zGz(pts4est,pts4est)'), ...
                str2double(px)) * prod(h.^(1/str2double(px)));% / (length(basis.mu(1,pts4est)))^(1/px);
        end
        
        % compute weights of 3rd order approximation
        rho3.weight = rho2.weight + eta2.weight;
    end
    
    %% estimate spatial error
    if order == 2
        errX    = 3/2 * errDeltau10 + errDeltau21;
        theta   = [errDeltau10, 0.5*errDeltau10 + errDeltau21];
    else % order == 3
        errX    = 5/3 * errDeltau10 + 4/3 * errDeltau21 + errDeltau32;
        theta   = [errDeltau10, 0.5*errDeltau10 + errDeltau21, 1/6*errDeltau10 + 1/3*errDeltau21 + errDeltau32];
    end
    
    %% estimate time error
    if order == 2
        if (model.dim == 1) || (length(h) == 1)
            eta1z  = sum((ones(basis.no,1) * eta1.weight * h^model.dim) .* dens_zGz,2);
            errT   = max(norm(eta1z,str2double(pt)) * h^(model.dim/str2double(pt)), 1e-13);
        else
            eta1z  = sum((ones(basis.no,1) * eta1.weight * prod(h)) .* dens_zGz,2);
            errT   = max(norm(eta1z,str2double(pt)) * prod(h^(1/str2double(pt))),1e-13);
        end
    elseif  order == 3
        if (model.dim == 1) || (length(h) == 1)
            eta2z  = sum((ones(basis.no,1) * eta2.weight * h^model.dim) .* dens_zGz,2);
            errT   = max(norm(eta2z,str2double(pt)) * h^(model.dim/str2double(pt)),1e-13);
        else
            eta2z  = sum((ones(basis.no,1) * eta2.weight * prod(h)) .* dens_zGz,2);
            errT   = max(norm(eta2z,str2double(pt)) * prod(h^(1/str2double(pt))),1e-13);
        end
    end
    
    % check theta criterion
    switch order
        case 2
            TOLtheta = norm(eta1z,str2double(pt))*h^(model.dim/str2double(pt))/4;
        case 3
            if (model.dim == 1) || (length(h) == 1)
                eta1z  = sum((ones(basis.no,1) * eta2.weight * h^model.dim) .* dens_zGz,2);
            else
                eta1z  = sum((ones(basis.no,1) * eta2.weight * prod(h)) .* dens_zGz,2);
            end
            TOLtheta = [norm(eta1z,str2double(pt))*h^(model.dim/str2double(pt))/4,...
                norm(eta2z,str2double(pt))*h^(model.dim/str2double(pt))/4];
    end
    
    % get new suggestion for grid size h according to errX
    h2 = min([safetyFactor*[(TOLx_dt/errX)^(1/orderAA), (TOLtheta./theta(2:end)).^(1/orderAA)]*h, 1.5*h, hmax]);
    dt2    = min([safetyFactor*(TOLt/errT)^(1/(order))*dt,2*dt, model.Tend-(t), model.dtmax]);
    
    fprintf('\n t\t= %1.2f\t\t|\tbasis\t= %i\n dt \t= %1.3e\t|\th \t= %1.3e\t|\n', t, basis.no, dt, h);
    fprintf(' errT \t= %1.3e\t|\terrX \t= %1.3e\t|\n', errT, errX);
    fprintf(' TOLt \t= %1.3e\t|\tTOLx \t= %1.3e\t|\n', TOLt, TOLx_dt);
    fprintf('\t-----------------------------------\n');
    %% theta >= eta/4
    if t+dt<model.Tend && (sum(theta(2:end) >= TOLtheta))
        if (TOLt/errT) > 5
            dt = dt2;
            fprintf('Time step too small! Redoing it with dt = %1.3e.\n', dt);
        else
            h_old   = h;
            h = h2;
            fprintf('Redoing time step with h = %1.3e and dt = %1.3e due to theta criterion.\n', h, dt);
        end
        %% errX > TOLx
    elseif errX > TOLx_dt
        % redo time step with new grid size h!
        h_old = h;
        h = h2;
        fprintf('Redoing time step with h = %1.3e.\n', h);
        
    else
        %% spatial error fine, go on with time error
        
        if errT > TOLt
            %% redo time step
            dt = dt2;
            fprintf('Redoing time step with dt = %1.3e.\n', dt);
        else
            %% everything ok, update and output!
            if model.dim ==1
                weightevolution(end+1) = sum(rho.weight*h);
            elseif model.dim ==2
                weightevolution(end+1) = sum(sum((reshape(rho.weight, gridNo(2), gridNo(1))*h)',2)'*h);
            end
            tvec(end+1) = dt;
            hvec(end+1) = h;
            errvec.T(end+1) = errT;
            errvec.X(end+1) = errX;
            errvec.theta(end+1) = max(theta);
            
            t = t+dt;
            rhoOld = rho.weight;
            if order == 2
                if model.dim == 1 || length(h) == 1
                    rho.weight = sum((ones(basis.no,1) * rho2.weight * h^model.dim) .* dens_zGz,2)';
                else
                    rho.weight = sum((ones(basis.no,1) * rho2.weight * prod(h)) .* dens_zGz,2)';
                end
            elseif order == 3
                if model.dim == 1 || length(h) == 1
                    rho.weight = sum((ones(basis.no,1) * rho3.weight * h^model.dim) .* dens_zGz,2)';
                else
                    rho.weight = sum((ones(basis.no,1) * rho3.weight * prod(h)) .* dens_zGz,2)';
                end
            end
            
            %% update graphic
            if model.dim == 1
                % ===================== 1d graphical output ===================== %
                figure(1);
                rhoX = sum((ones(basis.no,1) * rho.weight * h) .* dens_zGz,2);
                set(ph, 'xdata', x, 'ydata', rhoX);
                axis([fig.L(1) fig.R(1) -0.1 fig.T]);
                title(ah, sprintf('t = %d, w = %d',t, sum(rho.weight) * h));
                clear rhoX;
            else
                % ===================== 2d graphical output ===================== %
                if mod(step,1)==0
                    figure(1);
                    
                    % distribution rho
                    rhoXY  = zeros(1,basis.no);
                    rhoX   = zeros(1,gridNo(fig.ind(1))); rhoY = zeros(1,gridNo(fig.ind(2)));
                    if length(h) == 1
                        weightX     = sum((reshape(rho.weight, gridNo(2), gridNo(1))*h)',2)';
                        weightY     = sum(reshape(rho.weight, gridNo(2), gridNo(1))*h,2)';
                    else
                        weightX     = sum((reshape(rho.weight, gridNo(2), gridNo(1))*h(fig.ind(2)))',2)';
                        weightY     = sum(reshape(rho.weight, gridNo(2), gridNo(1))*h(fig.ind(1)),2)';
                    end
                    
                    % compute densities on 1- and 2d grids
                    if model.dim == 1 || length(h) == 1
                        rhoXY = sum((ones(basis.no,1) * rho.weight * h^model.dim) .* dens_zGz,2);
                        rhoX  = sum((ones(gridNo(fig.ind(1)),1) * weightX * h) .* dens_xGx,2);
                        rhoY  = sum((ones(gridNo(fig.ind(2)),1) * weightY * h) .* dens_yGy,2);
                    else
                        rhoXY = sum((ones(basis.no,1) * rho.weight * prod(h)) .* dens_zGz,2);
                        rhoX  = sum((ones(gridNo(fig.ind(1)),1) * weightX * h(fig.ind(1))) .* dens_xGx,2);
                        rhoY  = sum((ones(gridNo(fig.ind(2)),1) * weightY * h(fig.ind(2))) .* dens_yGy,2);
                    end
                    rhoXY = reshape(rhoXY,gridNo(fig.ind(2)),gridNo(fig.ind(1)));
                    
                    figure(1);
                    levels = linspace(min(min(rhoXY)), max(max(rhoXY)), fig.levels+1);
                    levels = levels(2:end);
                    set(ch, 'xdata', X, 'ydata', Y, 'zdata', rhoXY, 'LevelList', levels);
                    axis(ah, [fig.L(1), fig.R(1), fig.L(2), fig.R(2)]);
                    if length(h)==1
                        title(ah, sprintf('t = %2.4f,  w = %2.4f',t,sum(rho.weight)*h^2));
                        title(xh, sprintf('w = %2.4f',sum(weightX)*h));
                        title(yh, sprintf('w = %2.4f',sum(weightY)*h));
                    else
                        title(ah, sprintf('t = %2.4f,  w = %2.4f',t,sum(rho.weight)*prod(h([fig.ind(1), fig.ind(2)]))));
                        title(xh, sprintf('w = %2.4f',sum(weightX)*h(fig.ind(1))));
                        title(yh, sprintf('w = %2.4f',sum(weightY)*h(fig.ind(2))));
                    end
                    
                    set(ph_marg1, 'xdata', x, 'ydata', rhoX);
                    axis(xh, [fig.L(1), fig.R(1), min([rhoX; 0]), max(rhoX)]);
                    
                    set(ph_marg2, 'xdata', rhoY, 'ydata', y);
                    axis(yh, [min([rhoY; 0]), max(rhoY), fig.L(2), fig.R(2)]);
                    
                    clear Cmatrix rhoXY rhoX rhoY zGz XYmu;
                end;
            end;% graphical output            
            pause(0.1);
            
            %% update time stepping and discretization
            step = step + 1;
            
            h_old = h;
            h = h2;
            dt= dt2;
            
            %% move grid
            if moveGrid
                % find points at which density value smaller than TOLx
                lessTOLidx      = find(abs(sum((ones(basis.no,1) * rho.weight * h) .* dens_zGz,2)) < TOLx/move.scale);
                greaterTOLidx   = setdiff((1:basis.no), lessTOLidx);
                
                left    = min(basis.mu(:,greaterTOLidx), [], 2) - move.shift;
                right   = max(basis.mu(:,greaterTOLidx), [], 2) + move.shift;
                fig.L   = left;     fig.R   = right;
            end
        end
    end
end