device_handling_m.f90 Source File


Source Code

module device_handling_m
    !! Functionality to choose accelerator device (GPU)
    !! When PARALLAX is compiled stand-alone, `get_device_count` always returns
    !! 1 (i.e. the CPU), and `set_device_id` is a sanity check confirming that
    !! it's called with id 0.
    !!
    !! When PARALLAX is linked to the PAccX library, these
    !! subroutines are simple wrappers around
    !! `paccx_solver_get_device_count` and
    !! `paccx_solver_set_device`.
    !! In turn, PAccX may be linked to different backends.
    !! As of 2025-03-14 a CXX (i.e. CPU) backend is always present, and an
    !! optional GPU backend is also available (CUDA or HIP, chosen at CMake
    !! configure time).
    !! **Note**: these subroutines control the global device settings of the
    !! calling process/thread.
    !! If your code links against PARALLAX and another accelerator-capable
    !! library, please test thoroughly and ensure the code behaves as desired.
    use MPI, only : MPI_comm_rank
    use comm_handling_m, only: get_communicator
    use screen_io_m, only : get_stderr, get_stdout
    use status_codes_m, only : PARALLAX_ERR_CCALL, PARALLAX_ERR_PARAMETERS
    use error_handling_m, only: handle_error
    use, intrinsic :: iso_c_binding
    implicit none

    enum, bind(C)
        !! Enumerator defining the backends
        !! These values MUST agree with the corresponding enum in
        !! `paccx.hxx` from PAccX
        enumerator :: BACKEND_CPU = 0
        enumerator :: BACKEND_GPU = 1
        enumerator :: BACKEND_ROCALUTION_CPU = 2
        enumerator :: BACKEND_ROCALUTION_GPU = 3
    end enum

#ifdef ENABLE_PACCX
    interface cxx_get_device_count
        integer(c_int32_t) function PAccX_CXX_get_device_count( &
                backend,                                                      &
                count)                                                        &
                bind (c, name='paccx_solver_get_device_count')
            use, intrinsic :: iso_c_binding
            implicit none
            integer(c_int32_t), intent(in)  :: backend
            integer(c_int32_t), intent(out) :: count
        end function PAccX_CXX_get_device_count
    end interface

    interface cxx_set_device
        integer(c_int32_t) function PAccX_CXX_set_device( &
                backend,                                                &
                device_id)                                              &
                bind (c, name='paccx_solver_set_device')
            use, intrinsic :: iso_c_binding
            implicit none
            integer(c_int32_t), intent(in) :: backend
            integer(c_int32_t), intent(in) :: device_id
        end function PAccX_CXX_set_device
    end interface

    interface cxx_sync_device
        integer(c_int32_t) function PAccX_CXX_sync_device( &
                backend)                                              &
                bind (c, name='paccx_solver_sync_device')
            use, intrinsic :: iso_c_binding
            implicit none
            integer(c_int32_t), intent(in) :: backend
        end function PAccX_CXX_sync_device
    end interface
#endif

contains

    subroutine get_device_count(backend, count)
        !! Get the number of available devices for the backend specified.
        integer, intent(in)  :: backend
        !! Backend to use. Should be one of
        !!     * BACKEND_CPU
        !!     * BACKEND_GPU
        integer, intent(out) :: count
        !! Number of available devices
        integer(c_int32_t)   :: ccall_result

#ifdef ENABLE_PACCX
        ! if PAccX is used, we ask it for device info
        ccall_result = cxx_get_device_count(backend, count)
        if (ccall_result /= 0) then
            call handle_error(                            &
                "Error retrieving the number of devices", &
                PARALLAX_ERR_CCALL,                       &
                __LINE__,                                 &
                __FILE__)
        endif
#else
        ! if we are in stand-alone fortran installation:
        ! backend must be CPU
        if (backend /= BACKEND_CPU) then
            call handle_error(                                                               &
                "Must use BACKEND_CPU in device handling if ENABLE_PACCX=OFF", &
                PARALLAX_ERR_PARAMETERS,                                                     &
                __LINE__,                                                                    &
                __FILE__)
        endif
        ! number of devices is 1
        count = 1
#endif
    end subroutine

    subroutine set_device_id(backend, device_id)
        !! Given a specific backend, set the device to use
        integer, intent(in) :: backend
        !! Backend to use. Should be one of
        !!     * BACKEND_CPU
        !!     * BACKEND_GPU
        integer, intent(in) :: device_id
        !! Desired device id to use.
        !! If there are `count` devices, `device_id` ranges from 0 to `count-1`.
        integer(c_int32_t)  :: ccall_result

#ifdef ENABLE_PACCX
        ccall_result = cxx_set_device(backend, device_id)
        if (ccall_result /= 0) then
            call handle_error(                  &
                "Error setting the device", &
                PARALLAX_ERR_CCALL,             &
                __LINE__,                       &
                __FILE__)
        endif
#else
        ! if we are in stand-alone fortran installation:
        ! backend must be CPU
        if (backend /= BACKEND_CPU) then
            call handle_error(                                                               &
                "Must use BACKEND_CPU in device handling if ENABLE_PACCX=OFF", &
                PARALLAX_ERR_PARAMETERS,                                                     &
                __LINE__,                                                                    &
                __FILE__)
        endif
#endif
    end subroutine

    subroutine impose_default_device_affinity(object_type)
        !! Assign devices to MPI ranks in order.
        !! Only works if number of ranks is greater or equal to number of devices
        !! Device assigned to rank rr is (rr modulo total_number_of_devices).
        character(len=*), intent(in) :: object_type
        !! Object type. For now this means "solver_type".
        integer :: ierr, my_mpi_rank
        integer :: number_of_devices

        integer :: backend

        select case(object_type)
        case('DIRECT')
            backend = BACKEND_CPU

        case('MGMRES')
            backend = BACKEND_CPU

#ifdef ENABLE_PACCX
        case('MGMRES_CXX')
            backend = BACKEND_CPU

        case('MGMRES_GPU')
            backend = BACKEND_GPU

        case('ROCALUTION_CPU')
            backend = BACKEND_ROCALUTION_CPU

        case('ROCALUTION_GPU')
            backend = BACKEND_ROCALUTION_GPU
#endif

#ifdef ENABLE_PETSC
        case('PETSC_PCMG')
            backend = BACKEND_CPU

        case('PETSC_PCRC')
            backend = BACKEND_CPU
#endif
        case default
            call handle_error(                    &
                "Selected object type not valid", &
                PARALLAX_ERR_PARAMETERS,          &
                __LINE__,                         &
                __FILE__)

        end select

        call get_device_count(backend, number_of_devices)
        call MPI_comm_rank(get_communicator(), my_mpi_rank, ierr)
        call set_device_id(backend, mod(my_mpi_rank, number_of_devices))
    end subroutine

    subroutine sync_device(backend)
        !! Get the number of available devices for the backend specified.
        integer, intent(in)  :: backend
        !! Backend to use. Should be one of
        !!     * BACKEND_CPU
        !!     * BACKEND_GPU
        integer(c_int32_t)   :: ccall_result

#ifdef ENABLE_PACCX
        ! if PAccX is used, we ask it for device info
        ccall_result = cxx_sync_device(backend)
        if (ccall_result /= 0) then
            call handle_error(                 &
                "Error synchronizing device.", &
                PARALLAX_ERR_CCALL,            &
                __LINE__,                      &
                __FILE__)
        endif
#else
        ! if we are in stand-alone fortran installation:
        ! backend must be CPU
        if (backend /= BACKEND_CPU) then
            call handle_error(                                                               &
                "Must use BACKEND_CPU in device handling if ENABLE_PACCX=OFF", &
                PARALLAX_ERR_PARAMETERS,                                                     &
                __LINE__,                                                                    &
                __FILE__)
        endif
#endif
    end subroutine

end module