helmholtz_solver_factory_m.f90 Source File


Source Code

module helmholtz_solver_factory_m
    !! Parameters and factory routine to create Helmholtz solver
    use precision_m, only : FP, FP_EPS
    use screen_io_m, only : get_stdout, get_stderr
    use comm_handling_m, only: is_master

    use multigrid_m,            only : multigrid_t
    use direct_solver_m,        only : direct_solver_t, direct_solver_MKL_t
    use splitting_m,            only : splitting_t, splitting_jacobi_cpu_t, &
                                       splitting_gauss_seidel_cpu_t, &
                                       splitting_gauss_seidel_redblack_cpu_t
    use helmholtz_solver_m,     only : helmholtz_solver_t, &
                                       helmholtz_solver_direct_t, &
                                       helmholtz_solver_mgmres_cpu_t
#ifdef ENABLE_PACCX
    use helmholtz_solver_m,     only : helmholtz_solver_mgmres_cxx_t
    use device_handling_m,      only : BACKEND_CPU, &
                                       BACKEND_GPU, &
                                       BACKEND_ROCALUTION_CPU, &
                                       BACKEND_ROCALUTION_GPU

#endif
#ifdef ENABLE_PETSC
    use helmholtz_solver_m,     only : helmholtz_solver_petsc_t
#endif
    implicit none

    public :: helmholtz_solver_factory
    !! Factory routine for 2D Helmholtz solver
    private :: factory_allocatable, factory_direct, factory_mgmres
#ifdef ENABLE_PACCX
    private :: factory_mgmres_cxx
#endif

    interface helmholtz_solver_factory
        procedure :: factory_allocatable
        procedure :: factory_direct
        procedure :: factory_mgmres
#ifdef ENABLE_PACCX
        procedure :: factory_mgmres_cxx
#endif
#ifdef ENBALE_PETSC
        procedure :: factory_petsc
#endif
    end interface

    private :: get_outi

    type, public :: parameters_helmholtz_solver_factory
        !! Parameters for factory routine, with defaults
        character(len=16), public :: dirsolver_type     = 'MKL'
        !! Library used for direct solver (MKL, CUSPARSE)
        character(len=16), public :: smoother_type      = 'GSRB'
        !! Algorithm for smoother within multigrid (JAC, GS, GSRB)
        real(FP), public :: rtol                        = 1.0E-8_FP
        !! Relative tolerance on residuum
        !! res := ||Ax - b|| / (||b|| + restol_zero)  <  rtol
        real(FP), public :: restol_zero                 = FP_EPS
        !! Tolerance for zero of right hand side in residuum computation
        integer, public :: gmres_maxiter                = 15
        !! Maximum number of GMRES iterations
        integer, public :: gmres_nrestart               = 15
        !! Number of iterations after which restart of GMRES is done
        character(len=1), public :: mgrid_cycletype     = 'V'
        !! V or W cycle
        integer, public :: mgrid_npresmooth             = 5
        !! Number of pre-smoothing steps
        integer, public :: mgrid_npostsmooth            = 5
        !! Number of post-smoothing steps
        integer, public :: dbgout                       = 0
        !! Debug output level
    contains
        procedure, public :: display => &
                             display_parameters_helmholtz_solver_factory
    end type

contains

    function get_outi(dbgout) result(outi)
        !! Set verbosity level of factory routines
        integer, intent(in) :: dbgout
        integer :: outi

        outi = 0
        if (is_master()) then
            outi = dbgout
        endif
        if (dbgout >= 3) then
            ! Every rank writes
            outi = dbgout
        endif
    end function

    subroutine factory_allocatable(multigrid,                    &
                                   bnd_type_core, bnd_type_wall, &
                                   bnd_type_dome, bnd_type_out,  &
                                   co, lambda, xi,               &
                                   par, solver_type, hsolver     )
        !! Factory routine for 2D Helmholtz solver, including type allocation
        !! of the hsolver specified by string solver_type
        type(multigrid_t), intent(inout) :: multigrid
        !! Multigrid on which to solve the Helmholtz problem (not changed)
        integer, intent(in) :: bnd_type_core
        !! Boundary descriptor for core boundary
        integer, intent(in) :: bnd_type_wall
        !! Boundary descriptor for wall boundary
        integer, intent(in) :: bnd_type_dome
        !! Boundary descriptor for dome boundary
        integer, intent(in) :: bnd_type_out
        !! Boundary descriptor for outer(mask) boundary
        real(FP), dimension(multigrid%get_np(1)), intent(in) :: co
        !! Coefficient within Helmholtz operator
        real(FP), dimension(multigrid%get_np_inner(1)), intent(in) :: lambda
        !! Lambda within Helmholtz operator
        real(FP), dimension(multigrid%get_np_inner(1)), intent(in) :: xi
        !! Xi within Helmholtz operator
        type(parameters_helmholtz_solver_factory), intent(in) :: par
        !! Parameters and selection of Helmholtz solver
        character(len=*), intent(in) :: solver_type
        !! Desired type of Helmholtz solver (DIRECT, MGMRES)
        class(helmholtz_solver_t), allocatable, intent(out) :: hsolver
        !! Created Helmholtz solver

        integer :: outi

        outi = get_outi(par%dbgout)

        if (outi >= 1) then
            write(get_stdout(), *) ''
            write(get_stdout(), '(A80)') &
                ' Factory for Helmholtz solver '//repeat('-', 80)
            call par%display()
        endif

        select case(solver_type)
        case('DIRECT')
            allocate(helmholtz_solver_direct_t :: hsolver)

            select type(hsolver)
            type is(helmholtz_solver_direct_t)
                call factory_direct(multigrid,                    &
                                    bnd_type_core, bnd_type_wall, &
                                    bnd_type_dome, bnd_type_out,  &
                                    co, lambda, xi, par, hsolver  )
            end select

        case('MGMRES')
            allocate(helmholtz_solver_mgmres_cpu_t :: hsolver)

            select type(hsolver)
            type is(helmholtz_solver_mgmres_cpu_t)
                call factory_mgmres(multigrid,                    &
                                    bnd_type_core, bnd_type_wall, &
                                    bnd_type_dome, bnd_type_out,  &
                                    co, lambda, xi, par, hsolver  )
            end select

#ifdef ENABLE_PACCX
        case('MGMRES_CXX')
            allocate(helmholtz_solver_mgmres_cxx_t :: hsolver)

            select type(hsolver)
            type is(helmholtz_solver_mgmres_cxx_t)
                call factory_mgmres_cxx(multigrid,                &
                                    BACKEND_CPU,                  &
                                    bnd_type_core, bnd_type_wall, &
                                    bnd_type_dome, bnd_type_out,  &
                                    co, lambda, xi, par, hsolver  )
            end select
        case('MGMRES_GPU')
            allocate(helmholtz_solver_mgmres_cxx_t :: hsolver)

            select type(hsolver)
            type is(helmholtz_solver_mgmres_cxx_t)
                call factory_mgmres_cxx(multigrid,                &
                                    BACKEND_GPU,                  &
                                    bnd_type_core, bnd_type_wall, &
                                    bnd_type_dome, bnd_type_out,  &
                                    co, lambda, xi, par, hsolver  )
            end select
        case('ROCALUTION_GPU')
            allocate(helmholtz_solver_mgmres_cxx_t :: hsolver)

            select type(hsolver)
            type is(helmholtz_solver_mgmres_cxx_t)
                call factory_mgmres_cxx(multigrid,                &
                                    BACKEND_ROCALUTION_GPU,                  &
                                    bnd_type_core, bnd_type_wall, &
                                    bnd_type_dome, bnd_type_out,  &
                                    co, lambda, xi, par, hsolver  )
            end select

#endif

#ifdef ENABLE_PETSC
        case('PETSC_PCMG')
            allocate(helmholtz_solver_petsc_t :: hsolver)

            select type(hsolver)
            type is(helmholtz_solver_petsc_t)
                call factory_petsc(multigrid,                    &
                                   bnd_type_core, bnd_type_wall, &
                                   bnd_type_dome, bnd_type_out,  &
                                   co, lambda, xi, par, hsolver, &
                                   'PCMG')
            end select
        case('PETSC_PCRC')
            allocate(helmholtz_solver_petsc_t :: hsolver)

            select type(hsolver)
            type is(helmholtz_solver_petsc_t)
                call factory_petsc(multigrid,                    &
                                   bnd_type_core, bnd_type_wall, &
                                   bnd_type_dome, bnd_type_out,  &
                                   co, lambda, xi, par, hsolver, &
                                   'PCRC')
            end select
#endif
        case default
            write(get_stderr(), *) &
                'error(helmholtz_solver_factory): &
                &selected solver type not valid', solver_type
            error stop
        end select

    end subroutine

    subroutine factory_direct(multigrid,                    &
                              bnd_type_core, bnd_type_wall, &
                              bnd_type_dome, bnd_type_out,  &
                              co, lambda, xi, par, hsolver  )
        !! Factory routine for 2D Helmholtz solver, type DIRECT. Only the
        !! finest mesh of the multigrid will be used.
        type(multigrid_t), intent(inout) :: multigrid
        integer, intent(in) :: bnd_type_core
        integer, intent(in) :: bnd_type_wall
        integer, intent(in) :: bnd_type_dome
        integer, intent(in) :: bnd_type_out
        real(FP), dimension(multigrid%get_np(1)), intent(in) :: co
        real(FP), dimension(multigrid%get_np_inner(1)), intent(in) :: lambda
        real(FP), dimension(multigrid%get_np_inner(1)), intent(in) :: xi
        type(parameters_helmholtz_solver_factory), intent(in) :: par
        type(helmholtz_solver_direct_t), intent(inout) :: hsolver

        class(direct_solver_t), allocatable :: dirsolver
        integer :: outi

        outi = get_outi(par%dbgout)

        if (outi >= 1) then
            write(get_stdout(), *) 'Creating DIRECT solver'
        endif

        select case(par%dirsolver_type)
        case('MKL')
            if (outi >= 1) then
                write(get_stdout(), '(5X, A)') &
                    'MKL used as direct solver library'
            endif
            allocate(direct_solver_MKL_t :: dirsolver)
        case default
            write(get_stderr(), *) &
                'error(helmholtz_solver_factory): &
                &selected solver not valid', par%dirsolver_type
            error stop
        end select

        call hsolver%create(multigrid, bnd_type_core, bnd_type_wall, &
                            bnd_type_dome, bnd_type_out, co, lambda, xi)
        ! NOTE: Init deallocates dirsolver
        call hsolver%init(dirsolver)

        if (outi >= 1) then
            write(get_stdout(), '(A80)') repeat('-', 80)
            write(get_stdout(), *) ''
        endif
    end subroutine

    subroutine factory_mgmres(multigrid,                    &
                              bnd_type_core, bnd_type_wall, &
                              bnd_type_dome, bnd_type_out,  &
                              co, lambda, xi, par, hsolver  )
        !! Factory routine for 2D Helmholtz solver, type MGMRES
        type(multigrid_t), intent(inout) :: multigrid
        integer, intent(in) :: bnd_type_core
        integer, intent(in) :: bnd_type_wall
        integer, intent(in) :: bnd_type_dome
        integer, intent(in) :: bnd_type_out
        real(FP), dimension(multigrid%get_np(1)), intent(in) :: co
        real(FP), dimension(multigrid%get_np_inner(1)), intent(in) :: lambda
        real(FP), dimension(multigrid%get_np_inner(1)), intent(in) :: xi
        type(parameters_helmholtz_solver_factory), intent(in) :: par
        type(helmholtz_solver_mgmres_cpu_t), intent(inout) :: hsolver

        class(direct_solver_t), allocatable :: dirsolver
        class(splitting_t), allocatable :: smoother
        integer :: icycletype, outi

        outi = get_outi(par%dbgout)

        if (outi >= 1) then
            write(get_stdout(), *) 'Creating MGMRES solver'
        endif

        select case(par%dirsolver_type)
        case('MKL')
            if (outi >= 1) then
                write(get_stdout(), '(5X, A)') &
                    'MKL used as direct solver library'
            endif
            allocate(direct_solver_MKL_t :: dirsolver)
        case default
            write(get_stderr(), *) &
                'error(helmholtz_solver_factory): &
                &selected solver not valid', par%dirsolver_type
            error stop
        end select

        select case(par%smoother_type)
        case('JAC')
            if (outi >= 1) then
                write(get_stdout(), '(5X, A)') 'Jacobi relaxation smoother'
            endif
            allocate(splitting_jacobi_cpu_t :: smoother)
        case('GS')
            if (outi >= 1) then
                write(get_stdout(), '(5X, A)') 'Gauss-Seidel smoother'
            endif
            allocate(splitting_gauss_seidel_cpu_t :: smoother)
        case('GSRB')
            if (outi >= 1) then
                write(get_stdout(), '(5X, A)') 'Gauss-Seidel red-black smoother'
            endif
            allocate(splitting_gauss_seidel_redblack_cpu_t :: smoother)
        case default
            write(get_stderr(), *) &
                'error(helmholtz_solver_factory): &
                &selected smoother not valid', par%smoother_type
            error stop
        end select

        select case(par%mgrid_cycletype)
        case('V')
            if (outi >= 1) then
                write(get_stdout(), '(5X, A)') 'V-cycle'
            endif
            icycletype = 1
        case('W')
            if (outi >= 1) then
                write(get_stdout(), '(5X, A)') 'W-cycle'
            endif
            icycletype = 2
        case default
            write(get_stdout(), *) &
                'error(helmholtz_solver_factory): &
                &selected cycletype not valid', par%mgrid_cycletype
            error stop
        end select

        call hsolver%create(multigrid, bnd_type_core, bnd_type_wall, &
                            bnd_type_dome, bnd_type_out, co, lambda, xi)
        ! NOTE: Init deallocates dirsolver and smoother
        call hsolver%init(par%rtol, par%restol_zero, &
                          par%gmres_nrestart,        &
                          par%gmres_maxiter,         &
                          icycletype,                &
                          smoother,                  &
                          par%mgrid_npresmooth,      &
                          par%mgrid_npostsmooth,     &
                          dirsolver,                 &
                          par%dbgout                 )

        if (outi >= 1) then
            write(get_stdout(), '(A80)') repeat('-', 80)
            write(get_stdout(), *) ''
        endif
    end subroutine


    subroutine display_parameters_helmholtz_solver_factory(self)
        !! Displays parameters
        class(parameters_helmholtz_solver_factory) :: self

        if (.not.is_master()) then
            return
        endif
        write(get_stdout(), *) ''
        write(get_stdout(), *) 'parameters_helmholtz_solver_factory:'
        write(get_stdout(), 206) &
                         self%dirsolver_type,              &
                         self%smoother_type,               &
                         self%rtol,                        &
                         self%restol_zero,                 &
                         self%gmres_maxiter,               &
                         self%gmres_nrestart,              &
                         self%mgrid_cycletype,             &
                         self%mgrid_npresmooth,            &
                         self%mgrid_npostsmooth,           &
                         self%dbgout

 206        FORMAT(3X,'dirsolver_type       = ',A16         /, &
                   3X,'smoother_type        = ',A16         /, &
                   3X,'rtol                 = ',ES14.6E3    /, &
                   3X,'restol_zero          = ',ES14.6E3    /, &
                   3X,'gmres_maxiter        = ',I8          /, &
                   3X,'gmres_nrestart       = ',I8          /, &
                   3X,'mgrid_cycletype      = ',A1          /, &
                   3X,'mgrid_npresmooth     = ',I8          /, &
                   3X,'mgrid_npostsmooth    = ',I8          /, &
                   3X,'dbgout               = ',I8             )

        write(get_stdout(), *) ''

    end subroutine

#ifdef ENABLE_PETSC
    subroutine factory_petsc(multigrid, &
                             bnd_type_core, bnd_type_wall, &
                             bnd_type_dome, bnd_type_out,  &
                             co, lambda, xi, par, hsolver, pc_sel)
        !! Factory routine for 2D Helmholtz solver, type PETSC. Only the
        !! finest mesh of the multigrid will be used.
        type(multigrid_t), intent(inout) :: multigrid
        integer, intent(in) :: bnd_type_core
        integer, intent(in) :: bnd_type_wall
        integer, intent(in) :: bnd_type_dome
        integer, intent(in) :: bnd_type_out
        real(FP), dimension(multigrid%get_np(1)), intent(in) :: co
        real(FP), dimension(multigrid%get_np_inner(1)), intent(in) :: lambda
        real(FP), dimension(multigrid%get_np_inner(1)), intent(in) :: xi
        type(parameters_helmholtz_solver_factory), intent(in) :: par
        type(helmholtz_solver_petsc_t), intent(inout) :: hsolver
        character(len=*), intent(in) :: pc_sel
        !! Type of Preconditioner
        !! PGMG: Multigrid with operators from parallax-multigrid
        !! PCRC: As specified in file petscrc
        integer :: outi

        outi = get_outi(par%dbgout)

        if (outi >= 1) then
            write(get_stdout(), *) 'Creating PETSc solver'
        endif

        call hsolver%create(multigrid, bnd_type_core, bnd_type_wall, &
                            bnd_type_dome, bnd_type_out, co, lambda, xi)

        call hsolver%init(par%rtol, par%restol_zero, par%gmres_maxiter, pc_sel)

        if (outi >= 1) then
            write(get_stdout(), '(A80)') repeat('-', 80)
            write(get_stdout(), *) ''
        endif

      end subroutine
#endif

#ifdef ENABLE_PACCX
    subroutine factory_mgmres_cxx(multigrid,                &
                              compute_backend_to_use,       &
                              bnd_type_core, bnd_type_wall, &
                              bnd_type_dome, bnd_type_out,  &
                              co, lambda, xi, par, hsolver, &
                              data_backend_to_use)
        !! Factory routine for 2D Helmholtz solver, type MGMRES
        type(multigrid_t), intent(inout) :: multigrid
        integer, intent(in) :: compute_backend_to_use
        integer, intent(in) :: bnd_type_core
        integer, intent(in) :: bnd_type_wall
        integer, intent(in) :: bnd_type_dome
        integer, intent(in) :: bnd_type_out
        real(FP), dimension(multigrid%get_np(1)), intent(in) :: co
        real(FP), dimension(multigrid%get_np_inner(1)), intent(in) :: lambda
        real(FP), dimension(multigrid%get_np_inner(1)), intent(in) :: xi
        type(parameters_helmholtz_solver_factory), intent(in) :: par
        type(helmholtz_solver_mgmres_cxx_t), intent(inout) :: hsolver
        integer, optional, intent(in) :: data_backend_to_use

        class(direct_solver_t), allocatable :: dirsolver
        class(splitting_t), allocatable :: smoother
        integer :: icycletype, outi

        outi = get_outi(par%dbgout)

        if (outi >= 1) then
            write(get_stdout(), *) 'Creating mgmres_cxx solver with backend ', compute_backend_to_use
        endif

        select case(par%dirsolver_type)
        case('MKL')
            if (outi >= 1) then
                write(get_stdout(), '(5X, A)') &
                    'MKL used as direct solver library'
            endif
        case('SLU')
            if (outi >= 1) then
                write(get_stdout(), '(5X, A)') &
                    'SuperLU used as direct solver library'
            endif
        case('CUSPARSE')
            if (outi >= 1) then
                write(get_stdout(), '(5X, A)') &
                    'CUSPARSE used as direct solver library'
            endif
        case default
            write(get_stderr(), *) &
                'error(helmholtz_solver_factory): &
                &selected solver not valid', par%dirsolver_type
            error stop
        end select

        select case(par%smoother_type)
        case('JAC')
            if (outi >= 1) then
                write(get_stdout(), '(5X, A)') 'Jacobi relaxation smoother'
            endif
            allocate(splitting_jacobi_cpu_t :: smoother)
        case('GS')
            if (outi >= 1) then
                write(get_stdout(), '(5X, A)') 'Gauss-Seidel smoother'
            endif
            allocate(splitting_gauss_seidel_cpu_t :: smoother)
        case('GSRB')
            if (outi >= 1) then
                write(get_stdout(), '(5X, A)') 'Gauss-Seidel red-black smoother'
            endif
            allocate(splitting_gauss_seidel_redblack_cpu_t :: smoother)
        case default
            write(get_stderr(), *) &
                'error(factory_mgmres_cxx): &
                &selected smoother not valid', par%smoother_type
            error stop
        end select

        ! PAccX offers more cycle types.
        ! Please see paccx/src/multigrid_solver/paccx_multigrid_cycle_type.hxx
        select case(par%mgrid_cycletype)
        case('V')
            if (outi >= 1) then
                write(get_stdout(), '(5X, A)') 'V-cycle'
            endif
            icycletype = 1 ! PACCX_MULTIGRID_VCYCLE_RECURSIVE
        case('W')
            if (outi >= 1) then
                write(get_stdout(), '(5X, A)') 'W-cycle'
            endif
            ! "2" corresponds to "PACCX_MULTIGRID_WCYCLE0",
            ! which numerically should agree with the
            ! fortran implementation.
            icycletype = 2 ! PACCX_MULTIGRID_WCYCLE0
        case('F')
            if (outi >= 1) then
                write(get_stdout(), '(5X, A)') 'F-cycle'
            endif
            icycletype = 4 ! PACCX_MULTIGRID_FCYCLE
        case('0')
            if (outi >= 1) then
                write(get_stdout(), '(5X, A)') 'V-cycle nonrecursive'
            endif
            icycletype = 0 ! PACCX_MULTIGRID_VCYCLE_NONRECURSIVE
        case('3')
            if (outi >= 1) then
                write(get_stdout(), '(5X, A)') 'W-cycle alternate'
            endif
            icycletype = 3 ! PACCX_MULTIGRID_WCYCLE1
        case default
            write(get_stdout(), *) &
                'error(factory_mgmres_cxx): &
                &selected cycletype not valid', par%mgrid_cycletype
            error stop
        end select

        ! Device must be chosen BEFORE allocating solver
        call hsolver%create(multigrid, bnd_type_core, bnd_type_wall, &
                            bnd_type_dome, bnd_type_out, co, lambda, xi)
        ! NOTE: Init deallocates smoother
        call hsolver%init(par%rtol, par%restol_zero, &
                          compute_backend_to_use,    &
                          par%gmres_nrestart,        &
                          par%gmres_maxiter,         &
                          icycletype,                &
                          smoother,                  &
                          par%mgrid_npresmooth,      &
                          par%mgrid_npostsmooth,     &
                          par%dbgout                 )
        if (present(data_backend_to_use)) then
            call hsolver%set_backend(compute_backend_to_use, &
                                     data_backend_to_use)
        endif

        if (outi >= 1) then
            write(get_stdout(), '(A80)') repeat('-', 80)
            write(get_stdout(), *) ''
        endif
    end subroutine
#endif

end module