solver3d_PIM_s.f90 Source File


Source Code

submodule(solver3d_m) solver3d_PIM_s
    !! Iterative MPI solver based on PIM library
    !! [Rudnei Dias da Cunha and Tim Hopkins, 
    !!  Applied NumericalMathematics 16:33-50 (1995),
    !!  https://doi.org/10.1016/0168-9274(95)00017-O]
    implicit none

    integer, save :: comm_pim
    ! Communicator for PIM solver,
    ! used in the routines pdsum and pdnrm2 below.
    ! Since these routine must follow a prescribed interface from PIM, 
    ! the communicator can only be passed as module variable

    real(FP), save :: rhs_nrm
    !! Norm of right hand side stored for monitoring routine
    
    abstract interface
        subroutine progress_interface(loclen, itno, normres, x, res, trueres)
            !! Interface for PIM progress routine
            import FP
            integer :: loclen
            integer :: itno
            real(FP) :: normres
            real(FP), dimension(*) :: x
            real(FP), dimension(*) :: res
            real(FP), dimension(*) :: trueres
        end subroutine
    end interface
        
contains

    module subroutine create_PIM(self, comm, krylov_method, &
                                 ndim_loc, resmax, &
                                 matvec, precondl, precondr, &
                                 maxiter, nrestart, &
                                 dbgout)
        class(solver3d_PIM_t), intent(inout) :: self
        integer, intent(in) :: comm
        character(len=*), intent(in) :: krylov_method
        integer, intent(in) :: ndim_loc
        real(FP), intent(in) :: resmax
        procedure(matvec_interface) :: matvec
        procedure(precond_interface), optional :: precondl
        procedure(precond_interface), optional :: precondr
        integer, intent(in), optional :: maxiter
        integer, intent(in), optional :: nrestart
        integer, intent(in), optional :: dbgout
        
        integer :: rank, nprocs, ierr, ndim_glob, blksize
        integer :: pim_maxiter, pim_nrestart, pim_prectype, pim_stoptype
        
        ! Set debug output level
        self%outi = 0
        if (present(dbgout)) then
            if (is_master()) then
                self%outi = dbgout
            endif
            if (dbgout >= 3) then
                self%outi = dbgout ! every rank writes
            endif
        endif
        
        if (self%outi >= 1) then
            write(get_stdout(),*)''
            write(get_stdout(),*) 'Initialising 3d solver (PIM)'
        endif

        ! Krylow method
        self%krylov_method = krylov_method
        select case(self%krylov_method)
            case('CG','RGMRES','BICGSTAB')
                if (self%outi >=1) then
                    write(get_stdout(),*) 'Krylov method: ', self%krylov_method
                endif
            case default
                call handle_error('Krylov method not valid for PIM solver, &
                                   available methods are &
                                   [CG, RGMRES, BICGSTAB]', &
                                   PARALLAX_ERR_SOLVER3D, __LINE__, __FILE__, &                                                     
                                   additional_info=&
                                       error_info_t(self%krylov_method))
        end select

        ! Communicator used for PIM solver
        comm_pim = comm
        
        ! Dimension of problem
        call MPI_comm_size(comm_pim, nprocs, ierr)
        call MPI_comm_rank(comm_pim, rank, ierr)
        
        self%ndim_loc = ndim_loc 
        call MPI_Allreduce(self%ndim_loc, ndim_glob, 1, MPI_INTEGER, &
                           MPI_SUM, comm_pim, ierr)

        ! Set parameters       
        if (present(maxiter)) then
            pim_maxiter = maxiter
        else
            pim_maxiter = 10
        endif
        if (present(nrestart)) then
            pim_nrestart = nrestart
        else
            pim_nrestart = 10
        endif
    
        ! Set matrix-vector multiplication routine
        self%matvec => matvec

        ! Set preconditioner
        if ( (present(precondl)) .and. (present(precondr)) )  then
            pim_prectype = 3
            self%precondl => precondl
            self%precondr => precondr
        elseif ( (present(precondl)) .and. (.not.present(precondr)) ) then
            pim_prectype = 1
            self%precondl => precondl
        elseif ( (.not.present(precondl)) .and. (present(precondr)) ) then
            pim_prectype = 2
            self%precondr => precondr
        else
            pim_prectype = 0
        endif

        ! Stopping criteria set on relative pseudo-residual (see PIM manual)
        ! (for GMRES it is always on pseudo-residal, see footnote below table 2)
        pim_stoptype = 5
        self%resmax = resmax

        ! Size of block of data: used when data is partitioned using 
        ! cyclic mode (not the case here)
        blksize = -1
         
        ! Pass parameters to PIM
#ifndef DOUBLE_PREC
        call PIMssetpar(self%ipar, self%dpar, &
                        ndim_glob, ndim_glob, blksize, ndim_loc, pim_nrestart, &
                        nprocs, rank, &
                        pim_prectype, pim_stoptype, pim_maxiter, resmax)
#else
        call PIMdsetpar(self%ipar, self%dpar, &
                        ndim_glob, ndim_glob, blksize, ndim_loc, pim_nrestart, &
                        nprocs, rank, &
                        pim_prectype, pim_stoptype, pim_maxiter, resmax)
#endif

        ! Print parameters
        if (self%outi >= 1) then
#ifndef DOUBLE_PREC
            call PIMsprtpar(self%ipar, self%dpar)
#else
            call PIMdprtpar(self%ipar, self%dpar)
#endif
        endif
        
    end subroutine

    module subroutine solve_PIM(self, comm, rhs, sol, res, info, res_true) 
        class(solver3d_PIM_t), intent(inout) :: self
        integer, intent(in) :: comm
        real(FP), dimension(self%ndim_loc), intent(in) :: rhs
        real(FP), dimension(self%ndim_loc), intent(inout) :: sol
        real(FP), intent(out) :: res
        integer, intent(out) :: info
        real(FP), intent(out), optional :: res_true
        
        integer :: i
        real(FP), allocatable, dimension(:) :: wrk
        real(FP), dimension(self%ndim_loc) :: rst
        procedure(progress_interface), pointer :: progress => progress_noout
        procedure(matvec_interface), pointer :: matvec => err_no_matvec
        procedure(precond_interface), pointer :: precondl => err_no_prec
        procedure(precond_interface), pointer :: precondr => err_no_prec

        ! Debug output
        if (self%outi > 0) then
            progress => progress_out
        endif
        
        info = -1

        rhs_nrm = pdnrm2(self%ipar(4), rhs, self%ipar)

        ! Set maximum tolerance of pseudo-residual (stopping criteria)
        self%dpar(1) = self%resmax

        ! A little hacky here needed for gnu compiler
        matvec => self%matvec
        precondl => self%precondl
        precondr => self%precondr
        
#ifndef DOUBLE_PREC
        select case(self%krylov_method)
            case('CG')
                allocate(wrk(6*self%ipar(4)))
                call PIMSCG(sol, rhs, wrk, self%ipar, self%dpar, &
                                matvec, precondl, precondr, &
                                pdsum, pdnrm2, progress)
            case('RGMRES')
                allocate(wrk((4+self%ipar(5)) * self%ipar(4)))
                call PIMSRGMRES(sol, rhs, wrk, self%ipar, self%dpar, &
                                matvec, precondl, precondr, &
                                pdsum, pdnrm2, progress)
            case('BICGSTAB')
                allocate(wrk(10*self%ipar(4)))
                call PIMSBICGSTAB(sol, rhs, wrk, self%ipar, self%dpar, &
                                matvec, precondl, precondr, &
                                pdsum, pdnrm2, progress)
            case default
                call handle_error('Krylov method not valid', &
                                PARALLAX_ERR_SOLVER3D, __LINE__, __FILE__)
        end select
#else
        select case(self%krylov_method)
            case('CG')
                allocate(wrk(6*self%ipar(4)))
                call PIMDCG(sol, rhs, wrk, self%ipar, self%dpar, &
                                matvec, precondl, precondr, &
                                pdsum, pdnrm2, progress)
            case('RGMRES')
                allocate(wrk((4+self%ipar(5)) * self%ipar(4)))
                call PIMDRGMRES(sol, rhs, wrk, self%ipar, self%dpar, &
                                matvec, precondl, precondr, &
                                pdsum, pdnrm2, progress)
            case('BICGSTAB')
                allocate(wrk(10*self%ipar(4)))
                call PIMDBICGSTAB(sol, rhs, wrk, self%ipar, self%dpar, &
                                  matvec, precondl, precondr, &
                                  pdsum, pdnrm2, progress)
            case default
                call handle_error('Krylov method not valid', &
                                   PARALLAX_ERR_SOLVER3D, __LINE__, __FILE__)
        end select       
#endif
        deallocate(wrk)
        
        ! Evaluate residua
        res = self%dpar(2) / rhs_nrm

        if (present(res_true)) then
            call matvec(sol, rst, self%ipar)
            !$omp parallel default(none) private(i) &
            !$omp shared(self, rst, rhs)
            !$omp do
            do i = 1, self%ndim_loc
                rst(i) = rst(i) - rhs(i)
            enddo
            !$omp end do
            !$omp end parallel
            res_true = pdnrm2(self%ndim_loc, rst, self%ipar) / rhs_nrm
        endif
        
        if (self%ipar(12) == 0) then
            info = self%ipar(11)
            if (self%outi > 0) then
                write(get_stdout(),*)'3d solver (PIM) successful'
                write(get_stdout(), 303) info, self%dpar(2), res
                303 FORMAT(3X,'nrestart =', I6, & 
                              '; res(unnormalised) = ',ES14.6E3, &
                              '; res (norm) = ', ES14.6E3)
            endif
        else
            ! Negative number: indicates error message, see PIM manual
            info = self%ipar(12) 
            write(get_stdout(),*)'3d solver (PIM) failed, info = ', info
        endif
        
    end subroutine
    
    module subroutine destructor_PIM(self)
        type(solver3d_PIM_t), intent(inout) :: self
        
        self%matvec => null()
        self%precondl => null()
        self%precondr => null()
        
    end subroutine
    
    subroutine pdsum(loclen, x, ipar)
        !! Computes a global sum of a vector as required by PIM, 
        !! the global sum is computed by PIM internally.
        integer :: loclen
        !! Local dimension of vector
        real(FP), dimension(*) :: x
        !! Vector
        integer, dimension(*) :: ipar
        !! PIM integer parameter array

        integer :: ierr

        call MPI_allreduce(MPI_IN_PLACE, x, loclen, MPI_FP, &
                           MPI_SUM, comm_pim, ierr)

    end subroutine pdsum

    real(FP) function pdnrm2(loclen, x, ipar)
        !! Computes a global sum of a vector as required by PIM, 
        !! the global sum is computed by PIM internally.
        integer :: loclen
        !! Local dimension of vector
        real(FP), dimension(*) :: x
        !! Vector
        integer, dimension(*) :: ipar
        !! PIM integer parameter array

        integer :: i, ierr

        pdnrm2 = 0.0_FP
        !$omp parallel default(none) &
        !$omp private(i) &
        !$omp shared(loclen) reduction(+:pdnrm2)
        !$omp do
        do i = 1, loclen
            pdnrm2 = pdnrm2 + x(i)**2
        end do
        !$omp end do
        !$omp end parallel
        call MPI_allreduce(MPI_IN_PLACE, pdnrm2, 1, MPI_FP, &
                           MPI_SUM, comm_pim, ierr)
        pdnrm2 = sqrt(pdnrm2)

    end function pdnrm2

    subroutine progress_out(loclen, itno, normres, x, res, trueres)
        integer :: loclen
        !! Local dimension
        integer :: itno
        !! Iteration number
        real(FP) :: normres
        !! Norm of residual
        real(FP), dimension(*) :: x
        !! Current vector
        real(FP), dimension(*) :: res
        !! Current residual
        real(FP), dimension(*) :: trueres
        !! Current true residual

        write(get_stdout(), 501) itno, normres, normres / rhs_nrm
        501 FORMAT(3X,'#Iteration:', I6, &
                      '; absolute / relative pseudo-residual:',&
                      ES14.6E3,X'/',X,ES14.6E3)
        
    end subroutine
    
    subroutine progress_noout(loclen, itno, normres, x, res, trueres)
        integer :: loclen
        !! Local dimension
        integer :: itno
        !! Iteration number
        real(FP) :: normres
        !! Norm of (pseudo)residual
        real(FP), dimension(*) :: x
        !! Current vector
        real(FP), dimension(*) :: res
        !! Current residual
        real(FP), dimension(*) :: trueres
        !! Current true residual
    end subroutine
    
end submodule