program simple_solve
 
  use, intrinsic :: iso_c_binding
  use amgx_fortran
  use mpi
  
  implicit none

  type(c_ptr) :: cfg, res, solver, mode
  type(c_ptr) :: A, b, x


  ! Fichiers d'entree/sortie
  character(len=64), target :: cfg_file

  ! Ressources (communications MPI et GPU)
  integer :: rang, group_world, AMGX_comm, nb_processus, code
  type(c_ptr) :: AMGX_c_comm
  integer, dimension(1) :: devices_list = (/ 0 /) 

  ! Parametre du systeme matriciel
  integer, parameter :: nx=5000, ny=5000
  integer :: size_x, size_y, nrings

  integer :: i, j, n, nnz, start_idx, nmtx
  integer, dimension(:), allocatable :: row_ptrs, col_indices
  double precision, dimension(:), allocatable :: values, rhs, sol


  call MPI_init(code)
  
  ! Communicateur AMGX
  call MPI_Comm_dup(MPI_comm_world, AMGX_comm, code)

  call MPI_comm_rank(AMGX_comm, rang, code)
  call MPI_comm_size(AMGX_comm, nb_processus, code)
  print *,'Coucou, je suis le processus ', rang, '/', nb_processus

  AMGX_c_comm = MPI_comm_f2c(AMGX_comm) 
  call AMGX_register_print_callback(AMGX_c_comm)

  ! devices_list = (/ modulo(rang,2) /)    !! Si 2 GPU reserves
  ! devices_list = (/ modulo(rang,3) /)    !! Si 3 GPU reserves
  devices_list = (/ modulo(rang,4) /)    !! Si 4 GPU reserves

  ! Initialize
  call AMGX_initialize()
  call AMGX_initialize_plugins()

  ! Read config file
  cfg_file = "JACOBI.json"//C_NULL_CHAR
  call AMGX_config_create_from_file(cfg, c_loc(cfg_file))

  call AMGX_config_get_default_number_of_rings(cfg, nrings)

  ! Create resources based on config
  call MPI_Barrier(AMGX_comm, code)
  call AMGX_resources_create(res, cfg, AMGX_c_comm, 1, c_loc(devices_list))

  ! Create solver object, A,x,b, set precision
  call MPI_Barrier(AMGX_comm, code)
  call AMGX_solver_create(solver, res, 8193, cfg)
  call AMGX_matrix_create(A,res,8193)
  call AMGX_vector_create(b,res,8193)
  call AMGX_vector_create(x,res,8193)

  ! Construct the linear system

  n = nx*ny
  allocate(row_ptrs(0:n))
  allocate(col_indices(0:6*n-1))
  allocate(values(0:6*n-1))
  nnz = 0
  start_idx = rang*n

  do i = 0, n-1
    row_ptrs(i) = nnz
    if ((rang .gt. 0) .or. (i .gt. ny)) then
      col_indices(nnz) = (i + start_idx - ny)
      values(nnz) = -1.
      nnz = nnz + 1
    end if
    if (modulo(i,ny) .ne. 0) then
      col_indices(nnz) = (i + start_idx - 1)
      values(nnz) = -1.
      nnz = nnz + 1
    end if
    col_indices(nnz) = (i + start_idx)
    values(nnz) = 4.
    nnz = nnz + 1    
    if (modulo(i+1,ny) .eq. 0) then
      col_indices(nnz) = (i + start_idx + 1)
      values(nnz) = -1.
      nnz = nnz + 1
    end if
    if ((rang .ne. nb_processus-1) .or. (i/ny .ne. nx-1)) then
      col_indices(nnz) = (i + start_idx + ny)
      values(nnz) = -1.
      nnz = nnz + 1
    end if
  end do

  row_ptrs(n) = nnz

  call AMGX_matrix_upload_all_global_32(A, n*nb_processus, n, nnz, 1, 1, c_loc(row_ptrs), c_loc(col_indices), c_loc(values), c_null_ptr, nrings, nrings, c_null_ptr);
  call MPI_Barrier(AMGX_comm, code)

  deallocate(row_ptrs)
  deallocate(col_indices)
  deallocate(values)

  allocate(rhs(n), sol(n))
  rhs(:) = 1.
  sol(:) = 0.

  call AMGX_vector_bind(x, A);
  call AMGX_vector_bind(b, A);
  call AMGX_vector_upload(x, n, 1, c_loc(sol));
  call AMGX_vector_upload(b, n, 1, c_loc(rhs));

  call AMGX_matrix_get_size(A, nmtx, size_x, size_y)
  print *,'Taille du systeme ', nmtx, size_x, size_y


  ! Setup and Solve
  call MPI_Barrier(AMGX_comm, code)
  call AMGX_solver_setup(solver, A)
  call AMGX_solver_solve(solver, b, x)


  ! Recuperation du vecteur dans un tableau de data
  call AMGX_vector_download(x,c_loc(sol))

  ! Liberation de l'allocation memoire et des references
  call MPI_Barrier(AMGX_comm, code)
  call AMGX_matrix_destroy(A)
  call AMGX_vector_destroy(b)
  call AMGX_vector_destroy(x)
  call AMGX_solver_destroy(solver)
  call AMGX_resources_destroy(res)
  call AMGX_config_destroy(cfg)


  call MPI_finalize(code)

end program simple_solve
