helmholtz_solver_mgmres_cxx_s.f90 Source File


Source Code

module PAccX_m
    !! Module for interfacing with PAccX.
    !! It's unclear whether this should remain here, or whether a Fortran
    !! interface should be provided by the PAccX package itself.
    !! TODO: If the module stays here, then the two parameters should be moved
    !! to the `helmholtz_solver_m` module, and the interfaces should be moved to
    !! the `helmholtz_solver_gmgmres_cxx_s` submodule.
    use, intrinsic :: iso_c_binding
    use precision_m, only : FP
    use device_handling_m, only : BACKEND_CPU, BACKEND_GPU, BACKEND_ROCALUTION_CPU, BACKEND_ROCALUTION_GPU
    implicit none

    type, bind(C) :: solver_numerical_parameters
        !! Holds solver numerical parameter information
        integer(c_int32_t) :: cycle_type
        !! Type of multigrid cycle to use.
        !! Possible values are 0, 1, 2, 3, 4; these correspond to:
        !! V cycle non recursive, V cycle recursive, V cycle0, W cycle1
        !! and F cycle.
        !! Please see
        !! PAccX/src/csr_solver/multigrid_solver/paccx_multigrid_cycle_type.hxx
        !! for further information
        integer(c_int32_t) :: presmooth
        !! Number of presmooth GSRB passes
        integer(c_int32_t) :: postsmooth
        !! Number of postsmooth GSRB passes
        integer(c_int32_t) :: Krylov_dimension
        !! Krylov dimension
        integer(c_int32_t) :: max_restart
        !! Maximum number of FGMRES restarts
        real(FP)           :: tolerance
        !! Maximum accepted residuum
    end type solver_numerical_parameters

    type, bind(C) :: helmholtz_matrix_data
        !! Holds information needed to build Helmholtz matrices for 2D problem.
        integer(c_int32_t) :: bnd_type_core
        integer(c_int32_t) :: bnd_type_wall
        integer(c_int32_t) :: bnd_type_dome
        integer(c_int32_t) :: bnd_type_out
        integer(c_int32_t) :: backend
        type(c_ptr)        :: co
        type(c_ptr)        :: xi
        type(c_ptr)        :: lambda
    end type helmholtz_matrix_data

    type, bind(C) :: helmholtz_vector_data
        !! Holds right-hand-side of Helmholtz equation and initial guess for
        !! FGMRES algorithm
        integer(c_int32_t) :: backend
        type(c_ptr)        :: rhs
        type(c_ptr)        :: sol
    end type helmholtz_vector_data

    interface cxx_init
        integer(c_int32_t) function PAccX_CXX_init( &
                so,                                               &
                backend,                                          &
                snp,                                              &
                mgrid_data,                                       &
                hmd)                                              &
                bind (c, name='paccx_solver_init')
            use, intrinsic :: iso_c_binding
            use :: precision_m, only : FP
            use :: multigrid_m, only : multigrid_data_t
            use helmholtz_solver_m, only: cxx_solver_objects
            import solver_numerical_parameters
            import helmholtz_matrix_data
            implicit none
            type(cxx_solver_objects), intent(out)                :: so
            integer(c_int32_t), intent(in), value                :: backend
            type(solver_numerical_parameters), intent(in), value :: snp
            type(multigrid_data_t), intent(in)                   :: mgrid_data
            type(helmholtz_matrix_data), intent(in)              :: hmd
        end function PAccX_CXX_init
    end interface

    interface cxx_update
        integer(c_int32_t) function PAccX_CXX_update( &
                so,                                                 &
                rhd)                                                &
                bind (c, name='paccx_solver_update')
            use, intrinsic :: iso_c_binding
            use :: precision_m, only : FP
            use helmholtz_solver_m, only: cxx_solver_objects
            import helmholtz_matrix_data
            implicit none
            type(cxx_solver_objects), intent(inout) :: so
            type(helmholtz_matrix_data), intent(in) :: rhd
        end function PAccX_CXX_update
    end interface

    interface cxx_solve
        integer(c_int32_t) function PAccX_CXX_solve( &
                so,                                                &
                hvd,                                               &
                Arnoldi_iterations,                                &
                nrestarts,                                         &
                residuum)                                          &
                bind (c, name='paccx_solver_solve')
            use, intrinsic :: iso_c_binding
            use :: precision_m, only : FP
            use helmholtz_solver_m, only: cxx_solver_objects
            import helmholtz_vector_data
            implicit none
            type(cxx_solver_objects), intent(in) :: so
            type(helmholtz_vector_data)          :: hvd
            integer(c_int32_t), intent(out)      :: Arnoldi_iterations
            integer(c_int32_t), intent(out)      :: nrestarts
            real(c_double), intent(out)          :: residuum
        end function PAccX_CXX_solve
    end interface

    interface cxx_destroy
        integer(c_int32_t) function PAccX_CXX_destroy( &
                so)                                                  &
                bind (c, name='paccx_solver_destroy')
            use, intrinsic :: iso_c_binding
            use helmholtz_solver_m, only: cxx_solver_objects
            implicit none
            type(cxx_solver_objects), intent(inout)  :: so
        end function PAccX_CXX_destroy
    end interface

    interface cxx_debug_info
        integer(c_int32_t) function PAccX_CXX_debug_info( &
                so)                                                  &
                bind (c, name='paccx_solver_debug_info')
            use, intrinsic :: iso_c_binding
            use helmholtz_solver_m, only: cxx_solver_objects
            implicit none
            type(cxx_solver_objects), intent(in)  :: so
        end function PAccX_CXX_debug_info
    end interface

end module

submodule(helmholtz_solver_m) helmholtz_solver_mgmres_cxx_s
    !! Helmholtz solver that uses the PAccX library
    !! The library uses FGMRES with a multigrid preconditioner
    use screen_io_m, only : get_stdout
    use multigrid_m, only : multigrid_data_t
    use PAccX_m,  only : cxx_init, &
                                       cxx_solve, &
                                       cxx_destroy, &
                                       cxx_debug_info, &
                                       solver_numerical_parameters, &
                                       helmholtz_vector_data, &
                                       helmholtz_matrix_data, &
                                       BACKEND_CPU, &
                                       BACKEND_GPU, &
                                       BACKEND_ROCALUTION_CPU, &
                                       BACKEND_ROCALUTION_GPU

    use, intrinsic :: iso_c_binding
    implicit none

contains

    module subroutine set_backend(self, compute_backend_in, data_backend_in)
        class(helmholtz_solver_mgmres_cxx_t), intent(inout) :: self
        integer, intent(in) :: compute_backend_in
        ! Which compute device performes the work:
        ! BACKEND_CPU: computation is performed on the host (with
        !              fgmres_solver_cpu from PAccX)
        ! BACKEND_GPU: computation is performed on a GPU (with
        !              fgmres_solver_gpu from PAccX)
        integer, optional, intent(in) :: data_backend_in
        ! Where data passed to `init`, `update` and `solve` resides
        ! BACKEND_CPU: Data resides on host.
        !              If compute backend is different, PAccX
        !              itself will take care of the transfers.
        ! BACKEND_GPU: Data resides on device.
        !              Compute backend must also be BACKEND_GPU, otherwise
        !              segmentation faults will occur when PAccX
        !              attempts to access the data.

        ! Check and set the compute backend for the kernels
        call is_backend_supported(compute_backend_in)
        self%compute_backend = compute_backend_in

        ! Check and set the data backend
        if (present(data_backend_in)) then
            call is_backend_supported(data_backend_in)

            if (compute_backend_in == BACKEND_CPU .and. &
                data_backend_in /= BACKEND_CPU) then
                call handle_error('compute_backend has to be set to one of &
                                  &the GPU backends if a GPU backend is used &
                                  &for data_backend!', &
                                   PARALLAX_ERR_HELMHOLTZ, __LINE__, __FILE__)
            endif

            self%data_backend = data_backend_in
        else
            self%data_backend = BACKEND_CPU
        endif

        contains

            subroutine is_backend_supported(backend)
                !! Check if the backend is supported
                integer, intent(in) :: backend

                integer :: i
                logical :: err
                integer, dimension(4) :: supported_backends

                supported_backends = [BACKEND_CPU, BACKEND_GPU, BACKEND_ROCALUTION_CPU, BACKEND_ROCALUTION_GPU]
                err = .true.
                do i = 1, 4
                    if (backend == supported_backends(i)) then
                        err = .false.
                        exit
                    endif
                enddo

                if (err) then
                    call handle_error('Backend is not supported!', &
                             PARALLAX_ERR_HELMHOLTZ, __LINE__, __FILE__, &
                             additional_info=error_info_t('Backend=',[backend]))
                endif
            end subroutine
    end subroutine

    module subroutine create_mgmres_cxx(self, multigrid,              &
                                        bnd_type_core, bnd_type_wall, &
                                        bnd_type_dome, bnd_type_out,  &
                                        co, lambda, xi)
        class(helmholtz_solver_mgmres_cxx_t), intent(inout) :: self
        type(multigrid_t), intent(inout) :: multigrid
        integer, intent(in) :: bnd_type_core
        integer, intent(in) :: bnd_type_wall
        integer, intent(in) :: bnd_type_dome
        integer, intent(in) :: bnd_type_out
        real(FP), dimension(multigrid%get_np(1)), &
            intent(in) :: co
        real(FP), dimension(multigrid%get_np_inner(1)), &
            intent(in) :: lambda
        real(FP), dimension(multigrid%get_np_inner(1)), &
            intent(in) :: xi

        ! Set backend to CPU by default
        call self%set_backend(BACKEND_CPU)

        self%ndim = multigrid%get_np(1)
        self%np_inner = multigrid%get_np_inner(1)

        self%bnd_val_core = bnd_type_core
        self%bnd_val_wall = bnd_type_wall
        self%bnd_val_dome = bnd_type_dome
        self%bnd_val_out  = bnd_type_out

        call multigrid%expose_data( &
            self%multigrid_intermediate_object, self%multigrid_data_object)

    end subroutine

    module subroutine update_mgmres_cxx(self, co, lambda, xi, &
                                        bnd_type_core, bnd_type_wall, &
                                        bnd_type_dome, bnd_type_out)
        use PAccX_m
        implicit none
        class(helmholtz_solver_mgmres_cxx_t), intent(inout) :: self
        real(FP), dimension(self%ndim),       intent(in), target :: co
        real(FP), dimension(self%np_inner),   intent(in), target :: lambda
        real(FP), dimension(self%np_inner),   intent(in), target :: xi
        integer, intent(in), optional :: bnd_type_core
        integer, intent(in), optional :: bnd_type_wall
        integer, intent(in), optional :: bnd_type_dome
        integer, intent(in), optional :: bnd_type_out

        logical :: update_with_bnds

        integer(c_int32_t)      :: ccall_result
        type(helmholtz_matrix_data)  :: hmd_tmp

        update_with_bnds = present(bnd_type_core) &
                            .and. present(bnd_type_wall) &
                            .and. present(bnd_type_dome) &
                            .and. present(bnd_type_out)
        if (update_with_bnds) then
            self%bnd_val_core = bnd_type_core
            self%bnd_val_wall = bnd_type_wall
            self%bnd_val_dome = bnd_type_dome
            self%bnd_val_out  = bnd_type_out
        endif

        hmd_tmp%bnd_type_core = self%bnd_val_core
        hmd_tmp%bnd_type_wall = self%bnd_val_wall
        hmd_tmp%bnd_type_dome = self%bnd_val_dome
        hmd_tmp%bnd_type_out  = self%bnd_val_out

        hmd_tmp%backend = self%data_backend

        hmd_tmp%co     = c_loc(co)
        hmd_tmp%xi     = c_loc(xi)
        hmd_tmp%lambda = c_loc(lambda)

        if (self%dbgout > 0) then
            ccall_result = cxx_debug_info(self%sobjects)
        endif
        ccall_result = cxx_update(self%sobjects, hmd_tmp)
        if (self%dbgout > 0) then
            write(get_stdout(),*) "cxx_update result is", ccall_result
            ccall_result = cxx_debug_info(self%sobjects)
        endif

    end subroutine

    module subroutine init_mgmres_cxx(self, rtol, restol_zero,      &
                                      backend_to_use,               &
                                      nrestart, niter_max,          &
                                      mgr_ncycle,                   &
                                      mgr_smoother, mgr_npresmooth, &
                                      mgr_npostsmooth,              &
                                      dbgout)
        use PAccX_m
        class(helmholtz_solver_mgmres_cxx_t), intent(inout) :: self
        real(FP), intent(in) :: rtol
        real(FP), intent(in) :: restol_zero
        integer, intent(in) :: backend_to_use
        integer, intent(in) :: nrestart
        integer, intent(in) :: niter_max
        integer, intent(in) :: mgr_ncycle
        class(splitting_t), intent(inout), allocatable :: mgr_smoother
        integer, intent(in) :: mgr_npresmooth
        integer, intent(in) :: mgr_npostsmooth
        integer, intent(in) :: dbgout

        integer :: lvl

        integer(c_int32_t)                :: ccall_result
        type(solver_numerical_parameters) :: ccall_snp
        type(helmholtz_matrix_data)       :: hmd_tmp

        ! Set parameters
        self%rtol             = rtol
        self%restol_zero      = restol_zero
        self%nrestart         = nrestart
        self%niter_max        = niter_max
        self%dbgout           = dbgout
        self%mgr_npresmooth   = mgr_npresmooth
        self%mgr_npostsmooth  = mgr_npostsmooth

        call self%set_backend(backend_to_use)

        if ((mgr_ncycle >= 0) .and. (mgr_ncycle <= 4)) then
            ccall_snp%cycle_type = mgr_ncycle
        else
            ccall_snp%cycle_type = 0 !! nonrecursive V cycle
        endif
        ccall_snp%presmooth        = self%mgr_npresmooth
        ccall_snp%postsmooth       = self%mgr_npostsmooth
        ccall_snp%Krylov_dimension = self%niter_max
        ccall_snp%max_restart      = self%nrestart
        ccall_snp%tolerance        = self%rtol

        hmd_tmp%bnd_type_core = self%bnd_val_core
        hmd_tmp%bnd_type_wall = self%bnd_val_wall
        hmd_tmp%bnd_type_dome = self%bnd_val_dome
        hmd_tmp%bnd_type_out  = self%bnd_val_out

        hmd_tmp%backend = BACKEND_CPU !! data not used anyway
        hmd_tmp%co     = c_null_ptr
        hmd_tmp%xi     = c_null_ptr
        hmd_tmp%lambda = c_null_ptr

        ccall_result = cxx_init(        &
            self%sobjects,              &
            self%compute_backend,       &
            ccall_snp,                  &
            self%multigrid_data_object, &
            hmd_tmp)
        if (self%dbgout > 0) then
            write(get_stdout(),*) "cxx_init result is", ccall_result
        endif
    end subroutine

    module subroutine solve_mgmres_cxx(self, rhs, sol, res, info)
        use perf_m
        use PAccX_m
        use, intrinsic :: iso_c_binding
        implicit none

        class(helmholtz_solver_mgmres_cxx_t), intent(inout) :: self
        real(FP), dimension(self%ndim), intent(inout), target :: rhs
        real(FP), dimension(self%ndim), intent(inout), target :: sol
        real(FP), intent(out) :: res
        integer, intent(out) :: info

        integer(c_int32_t) :: ccall_result
        integer(c_int32_t) :: ccall_Arnoldi_iterations
        integer(c_int32_t) :: ccall_nrestarts
        type(helmholtz_vector_data) :: ccall_hvd

        ccall_hvd%rhs = c_loc(rhs)
        ccall_hvd%sol = c_loc(sol)
        ccall_hvd%backend = self%data_backend

        call perf_start('../../libcsr_solve')
        ccall_result = cxx_solve(     &
            self%sobjects,            &
            ccall_hvd,                &
            ccall_Arnoldi_iterations, &
            ccall_nrestarts,          &
            res)
        call perf_stop('../../libcsr_solve')

        !! PAccX returns "nrestarts" and "Arnoldi iterations"
        !! Please see PAccX documentation for exact details
        !! (abstract_fgmres_solver::apply_abstract).
        !! In brief:
        !! * "nrestarts" refers to the number of times that the Arnoldi process
        !!   is restarted.
        !! * "Arnoldi iterations" is the number of iterations within the
        !!   last-performed Arnoldi process (always smaller than the Krylov dimension).
        info = ccall_nrestarts*self%niter_max + ccall_Arnoldi_iterations

        if (self%dbgout > 0) then
            write(get_stdout(), *) "Information on this call to `solve_mgmres_cxx`:"
            write(get_stdout(), *) "Arnoldi_iterations = ", ccall_Arnoldi_iterations
            write(get_stdout(), *) "nrestarts = ", ccall_nrestarts
            write(get_stdout(), *) "residuum = ", res
        endif

    end subroutine

    module subroutine destructor_mgmres_cxx(self)
        type(helmholtz_solver_mgmres_cxx_t), intent(inout) :: self
        integer(c_int32_t) :: ccall_result

        ccall_result = cxx_destroy(self%sobjects)
        if (self%dbgout > 0) then
            write(get_stdout(),*) "cxx_destroy result is", ccall_result
        endif

        self%ndim = 0
        self%rtol = 0.0_FP
        self%restol_zero = 0.0_FP
        self%nrestart = 0
        self%niter_max = 0
        self%dbgout = 0
    end subroutine

end submodule