! @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ !
! @@                                                                @@ !
! @@       PROGRAM  ASCOT 2014.440 (ver.4.53)                       @@ !
! @@     "Abinitio Simulation Code for Quantum Transport"           @@ !
! @@                                                                @@ !
! @@                                                                @@ !
! @@  AUTHOR(S): Naoki WATANABE, Nobutaka NISHIKAWA (Mizuho I.R.)   @@ !
! @@             Hisashi KONDO (Univ. Tokyo)                        @@ !
! @@                                                09/May/2014     @@ !
! @@                                                                @@ !
! @@  Contact address: Phase System Consortium                      @@ !
! @@                                                                @@ !
! @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ !

module ac_mpi_module
  use ac_misc_module
  implicit none
  include "mpif.h"

  type MPI_type
     integer :: size  
     integer :: rank  
     integer :: sizeE, sizeM 
     integer :: rankE, rankM 
     integer :: info  
     logical :: root  

     integer :: npao   
     integer :: isatom 
     integer :: ieatom 

     integer,pointer :: ispao(:) 
     integer,pointer :: iepao(:) 
     integer,pointer :: vnpao(:) 

     integer,pointer :: vsmat(:) 
     integer,pointer :: vemat(:) 
     integer,pointer :: vnmat(:) 

     integer         ::  smat    
     integer         ::  emat    
     integer         ::  nmat    

     integer,pointer :: vmatcount(:) 
     integer,pointer :: vmatdispl(:) 

     integer :: SL_ID         
     integer :: SL_DESC(9)    
     integer :: SL_DESCLS(9)  

     integer :: localsize1
     integer :: localsize2

     integer :: group  
     integer :: groupM 
     integer :: groupE 
     integer :: commM  
     integer :: commE  

     real(8) :: num_sc_work2space = -1000.d0
     integer :: num_check_sc = 1


  end type MPI_type

  type(MPI_type), public, save :: MPI 

  type MPI_MatDesc
     integer :: nrow 
     integer :: ncol 

     integer :: srow 
     integer :: erow 
     integer :: scol 
     integer :: ecol 
     integer, allocatable :: vsele(:)
     integer, allocatable :: vnele(:)

     integer :: SL_DESC(9) 
  end type MPI_MatDesc

contains
  subroutine DescInit( DESC, M, N, MB, NB, IRSRC, ICSRC, ICTXT, LLD, INFO )
    implicit none

    integer, intent(out) :: DESC(9)
    integer, intent(in)  :: M, N, MB, NB, IRSRC, ICSRC, ICTXT, LLD, INFO

    DESC(:) = 0

    return
  end subroutine DescInit


  subroutine MPI__setupMatDesc( desc, nrow, ncol )
    implicit none

    type(MPI_MatDesc), intent(out) :: desc
    integer, intent(in) :: nrow, ncol

    integer :: nrow_local, ncol_local, p

    desc%nrow = nrow
    desc%ncol = ncol



    nrow_local = nrow
    ncol_local = int(ceiling(dble(ncol) / dble(MPI%sizeM)))


    call DescInit( desc%SL_DESC, nrow, ncol, ncol_local, ncol_local, &
         0, 0, MPI%SL_ID, nrow, MPI%info )

    desc%srow = 1
    desc%erow = nrow

    desc%scol = 1+(MPI%rankM+0)*ncol_local
    desc%ecol = 0+(MPI%rankM+1)*ncol_local

    if( desc%scol <= ncol .and. desc%ecol > ncol ) then
       desc%ecol = ncol
    else if( desc%scol > ncol .and. desc%ecol > ncol ) then
       desc%scol = ncol
       desc%ecol = ncol
    end if


    allocate( desc%vsele(0:MPI%sizeM-1) )
    allocate( desc%vnele(0:MPI%sizeM-1) )

    do p=0, MPI%sizeM-1
       desc%vsele(p) = nrow * ncol_local * p
       desc%vnele(p) = nrow * ncol_local

       if( desc%vsele(p) > nrow*ncol-1 ) then
          desc%vsele(p) = nrow*ncol-1
       else if( desc%vsele(p) + desc%vnele(p) > nrow*ncol-1 ) then
          desc%vnele(p) = nrow*ncol-1 - desc%vsele(p) + 1
       end if
    end do


    return
  end subroutine MPI__setupMatDesc

  subroutine MPI__unsetMatDesc( desc )
    implicit none
    type(MPI_MatDesc), intent(out) :: desc

    if( allocated(desc%vsele) ) deallocate( desc%vsele )
    if( allocated(desc%vnele) ) deallocate( desc%vnele )

    desc%SL_DESC(:) = 0

    return
  end subroutine MPI__unsetMatDesc

  subroutine MPI__Initialize
    implicit none

    call MPI_Init( MPI%info )
    call MPI_Comm_rank( MPI_COMM_WORLD, MPI%rank, MPI%info )
    call MPI_Comm_size( MPI_COMM_WORLD, MPI%size, MPI%info )
    call MPI_Comm_group( MPI_COMM_WORLD, MPI%group, MPI%info )

    if( MPI%rank == 0 ) then
       MPI%root = .true.
    else
       MPI%root = .false.
    end if

    return
  end subroutine MPI__Initialize

  subroutine MPI__Finalize
    implicit none

    call MPI_Finalize ( MPI%info )

    if( associated(MPI%vnmat) ) then
       deallocate( MPI%vnmat, MPI%vsmat, MPI%vemat )
    end if

    if( associated(MPI%ispao) ) then
       deallocate(MPI%ispao,MPI%iepao,MPI%vnpao)
    end if

    return
  end subroutine MPI__Finalize

  subroutine MPI__setup( sizeE, sizeM )
    implicit none
    integer, intent(in) :: sizeE, sizeM 
    integer :: p, a, i, n
    integer, allocatable :: rankM(:), rankE(:)

    if( Param%Data%npao < MPI%sizeM ) then
       write(*,'(a)') "# warning! : too many processors for MPI"
    end if

    MPI%sizeE = sizeE
    MPI%sizeM = sizeM

    if( associated(MPI%vmatcount) ) then
       deallocate( MPI%vmatcount, MPI%vmatdispl )
    end if
    allocate( MPI%vmatcount(0:MPI%sizeM-1), MPI%vmatdispl(0:MPI%sizeM-1) )

    if( associated(MPI%vnmat) ) then
       deallocate( MPI%vnmat, MPI%vsmat, MPI%vemat )
    end if
    allocate( MPI%vnmat(0:MPI%sizeM-1), MPI%vsmat(0:MPI%sizeM-1), MPI%vemat(0:MPI%sizeM-1) )

    if( associated(MPI%vnpao) ) then
       deallocate(MPI%vnpao, MPI%ispao, MPI%iepao)
    end if
    allocate( MPI%vnpao(0:MPI%sizeM-1), MPI%ispao(Param%Data%natom), MPI%iepao(Param%Data%natom) )

    if( MPI%size /= MPI%sizeE * MPI%sizeM ) then
       write(*,*) "# Error : MPI%size mismatch ", &
            MPI%size, MPI%sizeE, MPI%sizeM
       stop
    end if

    MPI%rankE = int(MPI%rank/MPI%sizeM)
    MPI%rankM = mod(MPI%rank,MPI%sizeM)

    allocate( rankM(0:MPI%sizeM-1), rankE(0:MPI%sizeE-1) )

    do p=0, MPI%sizeM-1
       rankM(p) = p + MPI%rankE*MPI%sizeM
    end do
    do p=0, MPI%sizeE-1
       rankE(p) = p*MPI%sizeM + MPI%rankM
    end do

    call MPI_Group_incl( MPI%group, MPI%sizeM, rankM, MPI%groupM, MPI%info )
    call MPI_Comm_create( MPI_COMM_WORLD, MPI%groupM, MPI%commM, MPI%info )

    call MPI_Group_incl( MPI%group, MPI%sizeE, rankE, MPI%groupE, MPI%info )
    call MPI_Comm_create( MPI_COMM_WORLD, MPI%groupE, MPI%commE, MPI%info )

    deallocate(rankM,rankE)

    MPI%localsize1 = int(ceiling(dble(Param%Data%npao) / dble(MPI%sizeM)))
    MPI%localsize2 = 2*int(ceiling(dble(Param%Data%npao) / dble(MPI%sizeM)))

    if( MPI%localsize1*(MPI%sizeM-1) >= Param%Data%npao ) then
       open(unit=16,file=Param%Option%file_ac_tempout,position='append')
       write(16,999)  MPI%sizeM
999    format("      ++++++ Sorry : This program can not be parallelized by processors of ",i5)
       close(16)
       call MPI__Finalize
       stop
    end if

    call DescInit( MPI%SL_DESC, Param%Data%npao, Param%Data%npao, &
         MPI%localsize1, MPI%localsize1, 0, 0, MPI%SL_ID, Param%Data%npao, MPI%info )
    call DescInit( MPI%SL_DESCLS, 2*Param%Data%npao, 2*Param%Data%npao, &
         MPI%localsize2, MPI%localsize2, 0, 0, MPI%SL_ID, 2*Param%Data%npao, MPI%info )

    n=0
    do p=0, MPI%sizeM-1
       if( n + MPI%localsize1 > Param%Data%npao ) then
          MPI%vnpao(p) = Param%Data%npao - n
       else
          MPI%vnpao(p) = MPI%localsize1
       end if
       n = n + MPI%vnpao(p)
    end do
    MPI%npao = MPI%vnpao(MPI%rankM)

    i=0
    do p=0, MPI%sizeM-1
       MPI%vsmat(p) = i+1
       MPI%vemat(p) = MPI%vsmat(p) + MPI%vnpao(p) - 1
       MPI%vnmat(p) = MPI%vnpao(p)
       i = i + MPI%vnpao(p)
    end do
    MPI%smat = MPI%vsmat(MPI%rankM)
    MPI%emat = MPI%vemat(MPI%rankM)
    MPI%nmat = MPI%vnmat(MPI%rankM)

    do p=0, MPI%sizeM-1
       MPI%vmatcount(p) = MPI%vnmat(p) * MPI%nmat
       MPI%vmatdispl(p) = MPI%vsmat(p) * MPI%nmat - MPI%vsmat(0) * MPI%nmat
    end do

    do a=1, Param%Data%natom
       MPI%ispao(a) = 0+1
       MPI%iepao(a) = Param%Data%vnpao(a)-1+1
    end do

    n=0
    do a=1, Param%Data%natom
       n=n+Param%Data%vnpao(a)
       if( MPI%vsmat(MPI%rankM)-1 < n ) then
          MPI%isatom = a
          MPI%ispao(a) = (MPI%vsmat(MPI%rankM)-1) - (Param%Data%vipao(a)-1) + 1
          exit
       end if
    end do

    n=0
    do a=1, Param%Data%natom
       n=n+Param%Data%vnpao(a)
       if( MPI%vemat(MPI%rankM)-1 < n ) then
          MPI%ieatom = a
          MPI%iepao(a) = (MPI%vemat(MPI%rankM)-1) - (Param%Data%vipao(a)-1) + 1
          exit
       end if
    end do

    if( MPI%num_sc_work2space >= -1.d0 ) then
       if( Param%Option%spin_orbit ) then
          p=(2*Base%npao/sqrt(dfloat(1*MPI%sizeM))-1.d0)*2*Base%npao
       else
          p=(Base%npao/sqrt(dfloat(1*MPI%sizeM))-1.d0)*Base%npao
       end if
       if( MPI%num_sc_work2space > 0.d0 ) then
          if( MPI%num_sc_work2space > dfloat(p) ) then
             MPI%num_sc_work2space = dfloat(p)
          end if
       else
          MPI%num_sc_work2space = -dfloat(p) * MPI%num_sc_work2space
          if( MPI%num_sc_work2space > dfloat(p) ) then
             MPI%num_sc_work2space = dfloat(p)
          end if
       end if
    else
       MPI%num_sc_work2space=0.d0
    end if


    return
  end subroutine MPI__setup


  subroutine MPI__Bcast_Inputfile( filename )
    implicit none
    character(64), intent(inout) :: filename

    if( MPI%size == 1 ) return

    call MPI_Bcast( filename, 64, MPI_CHARACTER, 0, MPI_COMM_WORLD, MPI%info )

    return
  end subroutine MPI__Bcast_Inputfile

  subroutine MPI__Allreduce_DensityLS( densityLS )
    implicit none
    complex(8), intent(inout) :: densityLS(Param%Option%nspin,Param%Cell%Na,Param%Cell%Nb,Param%Cell%Nc)
    complex(8), allocatable :: densityLS_mpi(:,:,:,:) 

    if( MPI%sizeM == 1 ) return

    allocate(densityLS_mpi(Param%Option%nspin,Param%Cell%Na,Param%Cell%Nb,Param%Cell%Nc))
    densityLS_mpi = densityLS

    call MPI_Allreduce( &
         densityLS_mpi, densityLS, &
         Param%Option%nspin * Param%Cell%Na * Param%Cell%Nb * Param%Cell%Nc, &
         MPI_DOUBLE_COMPLEX, &
         MPI_SUM, &
         MPI_COMM_WORLD, MPI%info )

    deallocate( densityLS_mpi )

    return
  end subroutine MPI__Allreduce_DensityLS

  subroutine MPI__Allreduce_Density( density )
    implicit none
    real(8), intent(inout) :: density (Param%Option%nspin,Param%Cell%Na,Param%Cell%Nb,Param%Cell%Nc)
    real(8), allocatable :: density_mpi(:,:,:,:) 

    if( MPI%sizeM == 1 ) return

    allocate(density_mpi(Param%Option%nspin,Param%Cell%Na,Param%Cell%Nb,Param%Cell%Nc))
    density_mpi = density

    call MPI_Allreduce( &
         density_mpi, density, &
         Param%Option%nspin * Param%Cell%Na * Param%Cell%Nb * Param%Cell%Nc, &
         MPI_DOUBLE_PRECISION, &
         MPI_SUM, &
         MPI_COMM_WORLD, MPI%info )

    deallocate( density_mpi )

    return
  end subroutine MPI__Allreduce_Density

  subroutine MPI__AlltoAll_Matrix( matrix )
    implicit none
    complex(8), intent(inout) :: matrix(Param%Data%npao,MPI%smat:MPI%emat)
    complex(8), allocatable :: pack_send(:)
    complex(8), allocatable :: pack_recv(:)
    integer :: i, j, p, index

    if( MPI%sizeM == 1 ) then
       do i=1, Param%Data%npao
          do j=i, Param%Data%npao
             matrix(i,j) = dconjg(matrix(i,j)) + matrix(j,i)
             matrix(j,i) = dconjg(matrix(i,j))
          end do
       end do

       return
    end if

    allocate(pack_send(MPI%nmat*Param%Data%npao))
    allocate(pack_recv(MPI%nmat*Param%Data%npao))

    index=0
    do p=0, MPI%sizeM-1
       do i=MPI%smat, MPI%emat
          do j=MPI%vsmat(p), MPI%vemat(p)
             index=index+1
             pack_send(index) = matrix(j,i)
          end do
       end do
    end do

    call MPI_AlltoAllv( &
         pack_send, MPI%vmatcount, MPI%vmatdispl, MPI_DOUBLE_COMPLEX, &
         pack_recv, MPI%vmatcount, MPI%vmatdispl, MPI_DOUBLE_COMPLEX, &
         MPI_COMM_WORLD, MPI%info )

    index=0
    do p=0, MPI%sizeM-1
       do j=MPI%vsmat(p), MPI%vemat(p)
          do i=MPI%smat, MPI%emat
             index=index+1
             matrix(j,i) = dconjg(matrix(j,i)) + pack_recv(index)
          end do
       end do
    end do

    deallocate( pack_send, pack_recv )

    return
  end subroutine MPI__AlltoAll_Matrix

  subroutine MPI__AlltoAll_Matrix2( matrix2, matrix3 )
    implicit none
    complex(8), intent(inout) :: matrix2(Param%Data%npao,MPI%smat:MPI%emat)
    complex(8), intent(inout) :: matrix3(Param%Data%npao,MPI%smat:MPI%emat)
    complex(8), allocatable :: pack_send(:)
    complex(8), allocatable :: pack_recv(:)
    integer :: i, j, p, index

    if( MPI%sizeM == 1 ) then
       do i=1, Param%Data%npao
          do j=1, Param%Data%npao
             matrix2(i,j) = dconjg(matrix2(i,j)) + matrix3(j,i)
             matrix3(j,i) = dconjg(matrix2(i,j))
          end do
       end do

       return
    end if

    allocate(pack_send(MPI%nmat*Param%Data%npao))
    allocate(pack_recv(MPI%nmat*Param%Data%npao))

    index=0
    do p=0, MPI%sizeM-1
       do i=MPI%smat, MPI%emat
          do j=MPI%vsmat(p), MPI%vemat(p)
             index=index+1
             pack_send(index) = matrix3(j,i)
          end do
       end do
    end do

    call MPI_AlltoAllv( &
         pack_send, MPI%vmatcount, MPI%vmatdispl, MPI_DOUBLE_COMPLEX, &
         pack_recv, MPI%vmatcount, MPI%vmatdispl, MPI_DOUBLE_COMPLEX, &
         MPI_COMM_WORLD, MPI%info )

    index=0
    do p=0, MPI%sizeM-1
       do i=MPI%smat, MPI%emat
          do j=MPI%vsmat(p), MPI%vemat(p)
             index=index+1
             pack_send(index) = matrix2(j,i)
          end do
       end do
    end do

    index=0
    do p=0, MPI%sizeM-1
       do j=MPI%vsmat(p), MPI%vemat(p)
          do i=MPI%smat, MPI%emat
             index=index+1
             matrix2(j,i) = dconjg(matrix2(j,i)) + pack_recv(index)
          end do
       end do
    end do

    call MPI_AlltoAllv( &
         pack_send, MPI%vmatcount, MPI%vmatdispl, MPI_DOUBLE_COMPLEX, &
         pack_recv, MPI%vmatcount, MPI%vmatdispl, MPI_DOUBLE_COMPLEX, &
         MPI_COMM_WORLD, MPI%info )


    index=0
    do p=0, MPI%sizeM-1
       do j=MPI%vsmat(p), MPI%vemat(p)
          do i=MPI%smat, MPI%emat
             index=index+1
             matrix3(j,i) = dconjg(matrix3(j,i)) + pack_recv(index)
          end do
       end do
    end do

    deallocate( pack_send, pack_recv )

    return
  end subroutine MPI__AlltoAll_Matrix2

  subroutine MPI__Allreduce_Force( vatom )
    implicit none
    type(Atom_type), intent(inout) :: vatom(Param%Data%natom)
    real(8), allocatable :: force_mpi(:,:), force_sum(:,:)
    integer :: a

    if( MPI%size == 1 ) return

    allocate( force_mpi(3,Param%Data%natom), force_sum(3,Param%Data%natom) )
    do a=1, Param%Data%natom
       force_mpi(:,a) = vatom(a)%force(:)
    end do

    call MPI_Allreduce( &
         force_mpi, force_sum, 3*Param%Data%natom, MPI_DOUBLE_PRECISION, &
         MPI_SUM, MPI_COMM_WORLD, MPI%info )

    do a=1, Param%Data%natom
       vatom(a)%force(:) = force_sum(:,a)
    end do

    deallocate( force_mpi, force_sum )

    return
  end subroutine MPI__Allreduce_Force


  subroutine MPI__Allreduce_HistoryPulay( A )
    implicit none
    real(8), intent(inout) :: A(Param%SCF%mix_history,Param%SCF%mix_history)
    real(8), allocatable :: A_mpi(:,:) 

    if( MPI%sizeM == 1 ) return

    allocate( A_mpi(Param%SCF%mix_history,Param%SCF%mix_history) )
    A_mpi = A

    call MPI_Allreduce( &
         A_mpi, A, Param%SCF%mix_history * Param%SCF%mix_history, &
         MPI_DOUBLE_PRECISION, MPI_SUM, MPI_COMM_WORLD, MPI%info )

    deallocate( A_mpi )

    return
  end subroutine MPI__Allreduce_HistoryPulay

  subroutine MPI__Allreduce_HistoryAnderson( A )
    implicit none
    real(8), intent(inout) :: A(2:3,3)
    real(8) :: A_mpi(2:3,3) 

    if( MPI%sizeM == 1 ) return

    A_mpi = A

    call MPI_Allreduce( &
         A_mpi, A, 2*3, &
         MPI_DOUBLE_PRECISION, MPI_SUM, MPI_COMM_WORLD, MPI%info )

    return
  end subroutine MPI__Allreduce_HistoryAnderson

  subroutine MPI__Allreduce_Error( dEden )
    implicit none
    real(8), intent(inout) :: dEden
    real(8) :: dEden_mpi

    if( MPI%sizeM == 1 ) return

    dEden_mpi = dEden

    call MPI_Allreduce( &
         dEden_mpi, dEden, 1, &
         MPI_DOUBLE_PRECISION, MPI_SUM, MPI_COMM_WORLD, MPI%info )

    return
  end subroutine MPI__Allreduce_Error

  subroutine MPI__Barrier
    implicit none

    if( MPI%sizeM == 1 ) return
    call MPI_Barrier(MPI_COMM_WORLD, MPI%info)

    return
  end subroutine MPI__Barrier

  subroutine MPI__ZHEGV( size, matrixH, matrixS, vectorE, iter )
    implicit none

    integer, intent(in)       :: size 
    complex(8), intent(inout) :: matrixH(size,size) 
    complex(8), intent(in)    :: matrixS(size,size) 
    real(8),    intent(out)   :: vectorE(size) 

    complex(8), allocatable   :: localS(:,:)
    complex(8), allocatable   :: work1(:)
    real(8),    allocatable   :: work2(:)

    integer, intent(in) :: iter

    allocate( work1(4*size), work2(7*size) )
    allocate( localS(size,size) )
    localS = matrixS 

    call ZHEGV( 1, 'V', 'L', size, matrixH, size, localS, size, &
         vectorE, work1, 4*size, work2, MPI%info )

    deallocate( localS )
    deallocate( work1, work2 )

    return
  end subroutine MPI__ZHEGV

  subroutine MPI__ZHEGVLS( size, matrixH, matrixS, vectorE, iter )
    implicit none

    integer, intent(in)       :: size 
    complex(8), intent(inout) :: matrixH(2*size,2*size) 
    complex(8), intent(in)    :: matrixS(2*size,2*size) 
    real(8),    intent(out)   :: vectorE(2*size) 

    integer, intent(in) :: iter

    complex(8), allocatable :: localS(:,:)
    complex(8), allocatable :: work1(:)
    real(8),    allocatable :: work2(:)

    allocate( work1(4*2*size), work2(7*2*size) )
    allocate( localS(2*size,2*size) )
    localS = matrixS 

    call ZHEGV( 1, 'V', 'L', 2*size, matrixH, 2*size, localS, 2*size, &
         vectorE, work1, 4*2*size, work2, MPI%info )

    deallocate( localS )
    deallocate( work1, work2 )

    return
  end subroutine MPI__ZHEGVLS

  subroutine MPI__ZGEMM( transa, transb, size, &
       alpha, matrixA, matrixB, beta, matrixC )
    implicit none

    character, intent(in)     :: transa, transb
    integer, intent(in)       :: size 
    complex(8), intent(in)    :: matrixA(size,size) 
    complex(8), intent(in)    :: matrixB(size,size) 
    complex(8), intent(out)   :: matrixC(size,size) 
    complex(8), intent(in)    :: alpha, beta 

    call ZGEMM( transa, transb, size, size, size, &
         alpha, matrixA, size, matrixB, size, &
         beta,  matrixC, size )

    return
  end subroutine MPI__ZGEMM



  subroutine MPI__ZGEMM_ASCOT( transa, transb, desc, &
       alpha, matrixA, matrixB, beta, matrixC )
    implicit none

    character, intent(in)     :: transa, transb
    type(MPI_MatDesc), intent(in) :: desc
    complex(8), intent(inout) :: matrixA(desc%nrow,desc%scol:desc%ecol)
    complex(8), intent(inout) :: matrixB(desc%nrow,desc%scol:desc%ecol)
    complex(8), intent(inout) :: matrixC(desc%nrow,desc%scol:desc%ecol)
    complex(8), intent(in)    :: alpha, beta

    call ZGEMM( transa, transb, desc%nrow, desc%ncol, desc%ncol, &
         alpha, matrixA, desc%nrow, matrixB, desc%nrow, &
         beta,  matrixC, desc%nrow )

    return
  end subroutine MPI__ZGEMM_ASCOT

  subroutine MPI__ZGEMM2_ASCOT( transa, transb, descA, descB, descC, &
       alpha, matrixA, matrixB, beta, matrixC )
    implicit none

    character, intent(in)     :: transa, transb
    type(MPI_MatDesc), intent(in) :: descA, descB, descC
    complex(8), intent(inout) :: matrixA(descA%nrow,descA%scol:descA%ecol)
    complex(8), intent(inout) :: matrixB(descB%nrow,descB%scol:descB%ecol)
    complex(8), intent(inout) :: matrixC(descC%nrow,descC%scol:descC%ecol)
    complex(8), intent(in)    :: alpha, beta

    call ZGEMM( transa, transb, descA%nrow, descB%ncol, descA%ncol, &
         alpha, matrixA, descA%nrow, matrixB, descB%nrow, &
         beta,  matrixC, descC%nrow )

    return
  end subroutine MPI__ZGEMM2_ASCOT

  subroutine MPI__ZTRANC_ASCOT( desc, alpha, matrixA, beta, matrixB )
    implicit none

    type(MPI_MatDesc), intent(in) :: desc
    complex(8), intent(inout) :: matrixA(desc%nrow,desc%scol:desc%ecol)
    complex(8), intent(in)    :: matrixB(desc%nrow,desc%scol:desc%ecol)
    complex(8), intent(in)    :: beta, alpha
    integer :: i, j

    do j=desc%scol, desc%ecol
       do i=1, desc%nrow
          matrixA(i,j) = alpha*matrixA(i,j) + beta*dconjg(matrixB(j,i))
       end do
    end do

    return
  end subroutine MPI__ZTRANC_ASCOT

  subroutine MPI__ZLATRA_ASCOT( desc, matrixA, trace )
    implicit none

    type(MPI_MatDesc), intent(in) :: desc
    complex(8), intent(in)    :: matrixA(desc%nrow,desc%scol:desc%ecol)
    complex(8), intent(out)   :: trace
    integer :: i
    complex(8) :: trace_mpi

    trace_mpi = 0.0d0
    do i=desc%scol, desc%ecol
       trace_mpi = trace_mpi + matrixA(i,i)
    end do

    call MPI_Allreduce( trace_mpi, trace, 1, &
         MPI_DOUBLE_COMPLEX, MPI_SUM, MPI%commM, MPI%info )

    return
  end subroutine MPI__ZLATRA_ASCOT

  subroutine MPI__ZGETRI_ASCOT( desc, matrixA )
    implicit none

    type(MPI_MatDesc), intent(in) :: desc
    complex(8), intent(inout) :: matrixA(desc%nrow,desc%scol:desc%ecol)

    integer, allocatable :: ipiv(:)
    complex(8), allocatable :: work1(:)
    integer, allocatable :: work2(:)
    integer :: lwork1, lwork2
    complex(8), allocatable :: tempA(:,:)

    allocate( tempA(desc%nrow,desc%ncol) )
    tempA(:,desc%scol:desc%ecol ) = matrixA(:,desc%scol:desc%ecol )

    call MPI__Allgather_MatrixM_ASCOT( desc, tempA )

    allocate( ipiv(desc%nrow) )

    call ZGETRF( desc%nrow, desc%ncol, tempA, desc%nrow, &
         ipiv, MPI%info )

    allocate( work1(1), work2(1) )
    lwork1=-1; lwork2=-1

    call ZGETRI( desc%nrow, tempA, desc%nrow, &
         ipiv, work1, lwork1, work2, lwork2, MPI%info )

    lwork1 = int(work1(1))
    lwork2 = int(work2(1))
    lwork1 = (lwork1+100)*2 
    lwork2 = (lwork2+100)*2 
    deallocate( work1, work2 )

    allocate( work1(lwork1), work2(lwork2) )

    call ZGETRI( desc%nrow, tempA, desc%nrow, &
         ipiv, work1, lwork1, work2, lwork2, MPI%info )

    deallocate( ipiv )
    deallocate( work1, work2 )

    matrixA(:,desc%scol:desc%ecol ) = tempA(:,desc%scol:desc%ecol )

    return
  end subroutine MPI__ZGETRI_ASCOT

  subroutine MPI__Allgather_MatrixM_ASCOT( desc, matrixA )
    implicit none
    type(MPI_MatDesc), intent(in) :: desc
    complex(8), intent(inout) :: matrixA(desc%nrow,desc%ncol)
    integer :: i

    if( MPI%sizeM == 1 ) return

    call MPI_Allgatherv( &
         matrixA(desc%vsele(MPI%rankM)+1,1), desc%vnele(MPI%rankM), MPI_DOUBLE_COMPLEX, &
         matrixA, desc%vnele, desc%vsele, &
         MPI_DOUBLE_COMPLEX, MPI_COMM_WORLD, MPI%info )

    return
  end subroutine MPI__Allgather_MatrixM_ASCOT

  subroutine MPI__Allreduce_MatrixE_ASCOT( desc, matrixA )
    implicit none
    type(MPI_MatDesc), intent(in) :: desc
    complex(8), intent(inout) :: matrixA(desc%nrow,desc%scol:desc%ecol)
    complex(8), allocatable :: matrixA_mpi(:,:) 

    if( MPI%sizeE == 1 ) return

    allocate(matrixA_mpi(desc%nrow,desc%scol:desc%ecol))

    matrixA_mpi(:,:) = matrixA(:,:)

    call MPI_Allreduce( matrixA_mpi, matrixA, desc%nrow*(desc%ecol-desc%scol+1), &
         MPI_DOUBLE_COMPLEX, MPI_SUM, MPI%commE, MPI%info )

    deallocate(matrixA_mpi)

    return
  end subroutine MPI__Allreduce_MatrixE_ASCOT

  subroutine MPI__Allreduce_MatrixT_ASCOT( desc, matrixA )
    implicit none
    type(MPI_MatDesc), intent(in) :: desc
    complex(8), intent(inout) :: matrixA(desc%nrow,desc%ncol)
    complex(8), allocatable   :: matrixA_mpi(:,:) 

    if( MPI%size == 1 ) return

    allocate(matrixA_mpi(desc%nrow,desc%ncol))

    matrixA_mpi(:,:) = matrixA(:,:)

    call MPI_Allreduce( matrixA_mpi, matrixA, desc%nrow*desc%ncol, &
         MPI_DOUBLE_COMPLEX, MPI_SUM, MPI_COMM_WORLD, MPI%info )

    deallocate(matrixA_mpi)

    return
  end subroutine MPI__Allreduce_MatrixT_ASCOT


end module ac_mpi_module
