module device_handling_m !! Functionality to choose accelerator device (GPU) !! When PARALLAX is compiled stand-alone, `get_device_count` always returns !! 1 (i.e. the CPU), and `set_device_id` is a sanity check confirming that !! it's called with id 0. !! !! When PARALLAX is linked to the PAccX library, these !! subroutines are simple wrappers around !! `paccx_solver_get_device_count` and !! `paccx_solver_set_device`. !! In turn, PAccX may be linked to different backends. !! As of 2025-03-14 a CXX (i.e. CPU) backend is always present, and an !! optional GPU backend is also available (CUDA or HIP, chosen at CMake !! configure time). !! **Note**: these subroutines control the global device settings of the !! calling process/thread. !! If your code links against PARALLAX and another accelerator-capable !! library, please test thoroughly and ensure the code behaves as desired. use MPI, only : MPI_comm_rank use comm_handling_m, only: get_communicator use screen_io_m, only : get_stderr, get_stdout use status_codes_m, only : PARALLAX_ERR_CCALL, PARALLAX_ERR_PARAMETERS use error_handling_m, only: handle_error use, intrinsic :: iso_c_binding implicit none enum, bind(C) !! Enumerator defining the backends !! These values MUST agree with the corresponding enum in !! `paccx.hxx` from PAccX enumerator :: BACKEND_CPU = 0 enumerator :: BACKEND_GPU = 1 enumerator :: BACKEND_ROCALUTION_CPU = 2 enumerator :: BACKEND_ROCALUTION_GPU = 3 end enum #ifdef ENABLE_PACCX interface cxx_get_device_count integer(c_int32_t) function PAccX_CXX_get_device_count( & backend, & count) & bind (c, name='paccx_solver_get_device_count') use, intrinsic :: iso_c_binding implicit none integer(c_int32_t), intent(in) :: backend integer(c_int32_t), intent(out) :: count end function PAccX_CXX_get_device_count end interface interface cxx_set_device integer(c_int32_t) function PAccX_CXX_set_device( & backend, & device_id) & bind (c, name='paccx_solver_set_device') use, intrinsic :: iso_c_binding implicit none integer(c_int32_t), intent(in) :: backend integer(c_int32_t), intent(in) :: device_id end function PAccX_CXX_set_device end interface interface cxx_sync_device integer(c_int32_t) function PAccX_CXX_sync_device( & backend) & bind (c, name='paccx_solver_sync_device') use, intrinsic :: iso_c_binding implicit none integer(c_int32_t), intent(in) :: backend end function PAccX_CXX_sync_device end interface #endif contains subroutine get_device_count(backend, count) !! Get the number of available devices for the backend specified. integer, intent(in) :: backend !! Backend to use. Should be one of !! * BACKEND_CPU !! * BACKEND_GPU integer, intent(out) :: count !! Number of available devices integer(c_int32_t) :: ccall_result #ifdef ENABLE_PACCX ! if PAccX is used, we ask it for device info ccall_result = cxx_get_device_count(backend, count) if (ccall_result /= 0) then call handle_error( & "Error retrieving the number of devices", & PARALLAX_ERR_CCALL, & __LINE__, & __FILE__) endif #else ! if we are in stand-alone fortran installation: ! backend must be CPU if (backend /= BACKEND_CPU) then call handle_error( & "Must use BACKEND_CPU in device handling if ENABLE_PACCX=OFF", & PARALLAX_ERR_PARAMETERS, & __LINE__, & __FILE__) endif ! number of devices is 1 count = 1 #endif end subroutine subroutine set_device_id(backend, device_id) !! Given a specific backend, set the device to use integer, intent(in) :: backend !! Backend to use. Should be one of !! * BACKEND_CPU !! * BACKEND_GPU integer, intent(in) :: device_id !! Desired device id to use. !! If there are `count` devices, `device_id` ranges from 0 to `count-1`. integer(c_int32_t) :: ccall_result #ifdef ENABLE_PACCX ccall_result = cxx_set_device(backend, device_id) if (ccall_result /= 0) then call handle_error( & "Error setting the device", & PARALLAX_ERR_CCALL, & __LINE__, & __FILE__) endif #else ! if we are in stand-alone fortran installation: ! backend must be CPU if (backend /= BACKEND_CPU) then call handle_error( & "Must use BACKEND_CPU in device handling if ENABLE_PACCX=OFF", & PARALLAX_ERR_PARAMETERS, & __LINE__, & __FILE__) endif #endif end subroutine subroutine impose_default_device_affinity(object_type) !! Assign devices to MPI ranks in order. !! Only works if number of ranks is greater or equal to number of devices !! Device assigned to rank rr is (rr modulo total_number_of_devices). character(len=*), intent(in) :: object_type !! Object type. For now this means "solver_type". integer :: ierr, my_mpi_rank integer :: number_of_devices integer :: backend select case(object_type) case('DIRECT') backend = BACKEND_CPU case('MGMRES') backend = BACKEND_CPU #ifdef ENABLE_PACCX case('MGMRES_CXX') backend = BACKEND_CPU case('MGMRES_GPU') backend = BACKEND_GPU case('ROCALUTION_CPU') backend = BACKEND_ROCALUTION_CPU case('ROCALUTION_GPU') backend = BACKEND_ROCALUTION_GPU #endif #ifdef ENABLE_PETSC case('PETSC_PCMG') backend = BACKEND_CPU case('PETSC_PCRC') backend = BACKEND_CPU #endif case default call handle_error( & "Selected object type not valid", & PARALLAX_ERR_PARAMETERS, & __LINE__, & __FILE__) end select call get_device_count(backend, number_of_devices) call MPI_comm_rank(get_communicator(), my_mpi_rank, ierr) call set_device_id(backend, mod(my_mpi_rank, number_of_devices)) end subroutine subroutine sync_device(backend) !! Get the number of available devices for the backend specified. integer, intent(in) :: backend !! Backend to use. Should be one of !! * BACKEND_CPU !! * BACKEND_GPU integer(c_int32_t) :: ccall_result #ifdef ENABLE_PACCX ! if PAccX is used, we ask it for device info ccall_result = cxx_sync_device(backend) if (ccall_result /= 0) then call handle_error( & "Error synchronizing device.", & PARALLAX_ERR_CCALL, & __LINE__, & __FILE__) endif #else ! if we are in stand-alone fortran installation: ! backend must be CPU if (backend /= BACKEND_CPU) then call handle_error( & "Must use BACKEND_CPU in device handling if ENABLE_PACCX=OFF", & PARALLAX_ERR_PARAMETERS, & __LINE__, & __FILE__) endif #endif end subroutine end module