! @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ !
! @@                                                                @@ !
! @@       PROGRAM  ASCOT 2014.440 (ver.4.53)                       @@ !
! @@     "Abinitio Simulation Code for Quantum Transport"           @@ !
! @@                                                                @@ !
! @@                                                                @@ !
! @@  AUTHOR(S): Naoki WATANABE, Nobutaka NISHIKAWA (Mizuho I.R.)   @@ !
! @@             Hisashi KONDO (Univ. Tokyo)                        @@ !
! @@                                                09/May/2014     @@ !
! @@                                                                @@ !
! @@  Contact address: Phase System Consortium                      @@ !
! @@                                                                @@ !
! @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ !

module ac_mpi_module
  use ac_parameter
  implicit none
  include "mpif.h"

  type MPI_type
     integer :: rank2, size2
     integer :: rankE, sizeE 
     integer :: rankM, sizeM 
     integer :: rankA, sizeA
     integer :: info  
     logical :: root  

     integer :: npao   
     integer :: isatom 
     integer :: ieatom 

     integer,pointer :: ispao(:) 
     integer,pointer :: iepao(:) 
     integer,pointer :: vnpao(:) 

     integer,pointer :: vsmat(:) 
     integer,pointer :: vemat(:) 
     integer,pointer :: vnmat(:) 

     integer         ::  smat    
     integer         ::  emat    
     integer         ::  nmat    

     integer,pointer :: vmatcount(:) 
     integer,pointer :: vmatdispl(:) 

     integer :: SL_ID         
     integer :: SL_DESC(9)    
     integer :: SL_DESCLS(9)  

     integer :: localsize1
     integer :: localsize2

     integer :: group2, comm2
     integer :: groupM 
     integer :: groupE 
     integer :: commM  
     integer :: commE  
     integer :: groupA, commA

     real(8) :: num_sc_work2space = -1000.d0
     integer :: num_check_sc = 1

  end type MPI_type

  type(MPI_type), public, save :: MPI 
  real(8) :: time_la, time_fft, time_mpi, time_as, time_total
  real(8) :: time_la0, time_fft0, time_mpi0, time_as0, time_total0
  integer :: time_scount

  type MPI_MatDesc
     integer :: nrow 
     integer :: ncol 

     integer :: srow 
     integer :: erow 
     integer :: scol 
     integer :: ecol 
     integer, pointer :: vsele(:)
     integer, pointer :: vnele(:)

     integer :: SL_DESC(9) 
  end type MPI_MatDesc

contains

  subroutine MPI__setupMatDesc( desc, nrow, ncol )
    implicit none

    type(MPI_MatDesc), intent(out) :: desc
    integer, intent(in) :: nrow, ncol

    integer :: nrow_local, ncol_local, p

    desc%nrow = nrow
    desc%ncol = ncol


    nrow_local = nrow
    ncol_local = int(ceiling(dble(ncol) / dble(MPI%sizeM)))


    call DescInit( desc%SL_DESC, nrow, ncol, ncol_local, ncol_local, &
         0, 0, MPI%SL_ID, nrow, MPI%info )

    desc%srow = 1
    desc%erow = nrow

    desc%scol = 1+(MPI%rankM+0)*ncol_local
    desc%ecol = 0+(MPI%rankM+1)*ncol_local

    if( desc%scol <= ncol .and. desc%ecol > ncol ) then
       desc%ecol = ncol
    else if( desc%scol > ncol .and. desc%ecol > ncol ) then
       desc%scol = ncol
       desc%ecol = ncol
    end if

    allocate( desc%vsele(0:MPI%sizeM-1) )
    allocate( desc%vnele(0:MPI%sizeM-1) )

    do p=0, MPI%sizeM-1
       desc%vsele(p) = nrow * ncol_local * p
       desc%vnele(p) = nrow * ncol_local

       if( desc%vsele(p) > nrow*ncol-1 ) then
          desc%vsele(p) = nrow*ncol-1
       else if( desc%vsele(p) + desc%vnele(p) > nrow*ncol-1 ) then
          desc%vnele(p) = nrow*ncol-1 - desc%vsele(p) +1
       end if
    end do

    return
  end subroutine MPI__setupMatDesc

  subroutine MPI__unsetMatDesc( desc )
    implicit none
    type(MPI_MatDesc), intent(out) :: desc

    if( associated(desc%vsele) ) deallocate( desc%vsele )
    if( associated(desc%vnele) ) deallocate( desc%vnele )

    desc%SL_DESC(:) = 0

    return
  end subroutine MPI__unsetMatDesc


  subroutine MPI__Initialize
    implicit none

    call MPI_Init( MPI%info )
    call MPI_Comm_rank( MPI_COMM_WORLD, MPI%rank2, MPI%info )
    call MPI_Comm_size( MPI_COMM_WORLD, MPI%size2, MPI%info )
    call MPI_Comm_group( MPI_COMM_WORLD, MPI%group2, MPI%info )

    if( MPI%rank2 == 0 ) then
       MPI%root = .true.
    else
       MPI%root = .false.
    end if

    MPI%sizeE = 1
    MPI%sizeM = 1
    MPI%sizeA = 1

    MPI%SL_ID = 0

    time_la = 0.0d0
    time_fft = 0.0d0
    time_mpi = 0.0d0
    time_total = 0.0d0
    time_as = 0.0d0
    time_la0 = 0.0d0
    time_fft0 = 0.0d0
    time_mpi0 = 0.0d0
    time_total0 = 0.0d0
    time_as0 = 0.0d0

    return
  end subroutine MPI__Initialize

  subroutine MPI__Finalize
    implicit none

    if( MPI%SL_ID /= 0 ) then
       call BLACS_GRIDEXIT( MPI%SL_ID )
    end if

    call MPI_Finalize ( MPI%info )

    if( associated(MPI%vnmat) ) then
       deallocate( MPI%vnmat, MPI%vsmat, MPI%vemat )
    end if

    if( associated(MPI%ispao) ) then
       deallocate(MPI%ispao,MPI%iepao,MPI%vnpao)
    end if

    return
  end subroutine MPI__Finalize

  subroutine MPI__setup
    implicit none
    integer :: p, a, i, n
    integer, allocatable :: rankM(:), rankE(:)

    if( Param%Data%npao < MPI%size2 ) then
       if( MPI%root ) then
          write(*,'(a)') "# Error : too many processes for MPI"
          write(*,*) "# number of PAO exceeds number of processes ", &
               Param%Data%npao, MPI%size2
       end if
       call MPI__Finalize
       stop
    end if

    if( MPI%sizeE*MPI%sizeM /= MPI%size2 ) then
       if( MPI%root ) then
          write(*,'(a)') "# Error : not appropriate number of parallelization."
          write(*,*) "# -ne", MPI%sizeE,  " -nm", MPI%sizeM
          write(*,*) "# ne*nm is not equal to number of processes np", &
               MPI%sizeE*MPI%sizeM, MPI%size2
       end if
       call MPI__Finalize
       stop
    end if

    if( associated(MPI%vmatcount) ) then
       deallocate( MPI%vmatcount, MPI%vmatdispl )
    end if
    allocate( MPI%vmatcount(0:MPI%sizeA-1), MPI%vmatdispl(0:MPI%sizeA-1) )

    if( associated(MPI%vnmat) ) then
       deallocate( MPI%vnmat, MPI%vsmat, MPI%vemat )
    end if
    allocate( MPI%vnmat(0:MPI%sizeA-1), MPI%vsmat(0:MPI%sizeA-1), MPI%vemat(0:MPI%sizeA-1) )

    if( associated(MPI%vnpao) ) then
       deallocate(MPI%vnpao, MPI%ispao, MPI%iepao)
    end if
    allocate( MPI%vnpao(0:MPI%sizeA-1), MPI%ispao(Param%Data%natom), MPI%iepao(Param%Data%natom) )

    if( MPI%size2 /= MPI%sizeE * MPI%sizeM ) then
       write(*,*) "# Error : MPI%size mismatch ", &
            MPI%size2, MPI%sizeE, MPI%sizeM
       stop
    end if

    MPI%rankE = int(MPI%rank2/MPI%sizeM)
    MPI%rankM = mod(MPI%rank2,MPI%sizeM)

    allocate( rankM(0:MPI%sizeM-1), rankE(0:MPI%sizeE-1) )

    do p=0, MPI%sizeM-1
       rankM(p) = p + MPI%rankE*MPI%sizeM
    end do
    do p=0, MPI%sizeE-1
       rankE(p) = p*MPI%sizeM + MPI%rankM
    end do

    call BLACS_GET( 0, 0, MPI%SL_ID )
    call BLACS_GRIDMAP( MPI%SL_ID, rankM, 1, 1, MPI%sizeM )

    call MPI_Group_incl( MPI%group2, MPI%sizeM, rankM, MPI%groupM, MPI%info )
    call MPI_Comm_create( MPI_COMM_WORLD, MPI%groupM, MPI%commM, MPI%info )

    call MPI_Group_incl( MPI%group2, MPI%sizeE, rankE, MPI%groupE, MPI%info )
    call MPI_Comm_create( MPI_COMM_WORLD, MPI%groupE, MPI%commE, MPI%info )

    deallocate(rankM,rankE)

    MPI%rankA = mod(MPI%rank2,MPI%sizeM)
    call mpi_comm_dup(MPI%commM,MPI%commA,MPI%info)

    MPI%localsize1 = int(ceiling(dble(Param%Data%npao) / dble(MPI%sizeA)))
    MPI%localsize2 = 2*int(ceiling(dble(Param%Data%npao) / dble(MPI%sizeA)))

    if( MPI%localsize1*(MPI%sizeA-1) >= Param%Data%npao ) then
       open(unit=16,file=Param%Option%file_ac_tempout,position='append')
       write(16,999)  MPI%sizeA
999    format("      ++++++ Sorry : This program can not be parallelized by processors of ",i5)
       close(16)
       call MPI__Finalize
       stop
    end if

    call DescInit( MPI%SL_DESC, Param%Data%npao, Param%Data%npao, &
         MPI%localsize1, MPI%localsize1, 0, 0, MPI%SL_ID, Param%Data%npao, MPI%info )
    call DescInit( MPI%SL_DESCLS, 2*Param%Data%npao, 2*Param%Data%npao, &
         MPI%localsize2, MPI%localsize2, 0, 0, MPI%SL_ID, 2*Param%Data%npao, MPI%info )

    n=0
    do p=0, MPI%sizeA-1
       if( n + MPI%localsize1 > Param%Data%npao ) then
          MPI%vnpao(p) = Param%Data%npao - n
       else
          MPI%vnpao(p) = MPI%localsize1
       end if
       n = n + MPI%vnpao(p)
    end do
    MPI%npao = MPI%vnpao(MPI%rankA)

    i=0
    do p=0, MPI%sizeA-1
       MPI%vsmat(p) = i+1
       MPI%vemat(p) = MPI%vsmat(p) + MPI%vnpao(p) - 1
       MPI%vnmat(p) = MPI%vnpao(p)
       i = i + MPI%vnpao(p)
    end do
    MPI%smat = MPI%vsmat(MPI%rankA)
    MPI%emat = MPI%vemat(MPI%rankA)
    MPI%nmat = MPI%vnmat(MPI%rankA)

    do p=0, MPI%sizeA-1
       MPI%vmatcount(p) = MPI%nmat * MPI%vnmat(p)
       MPI%vmatdispl(p) = MPI%nmat * MPI%vsmat(p) - MPI%nmat * MPI%vsmat(0)
    end do

    do a=1, Param%Data%natom
       MPI%ispao(a) = 0+1
       MPI%iepao(a) = Param%Data%vnpao(a)-1+1
    end do

    n=0
    do a=1, Param%Data%natom
       n=n+Param%Data%vnpao(a)
       if( MPI%vsmat(MPI%rankA)-1 < n ) then
          MPI%isatom = a
          MPI%ispao(a) = (MPI%vsmat(MPI%rankA)-1) - (Param%Data%vipao(a)-1) + 1
          exit
       end if
    end do

    n=0
    do a=1, Param%Data%natom
       n=n+Param%Data%vnpao(a)
       if( MPI%vemat(MPI%rankA)-1 < n ) then
          MPI%ieatom = a
          MPI%iepao(a) = (MPI%vemat(MPI%rankA)-1) - (Param%Data%vipao(a)-1) + 1
          exit
       end if
    end do

    if( MPI%num_sc_work2space >= -1.d0 ) then
       if( Param%Option%spin_orbit ) then
          p=(2*Base%npao/sqrt(dfloat(1*MPI%size2))-1.d0)*2*Base%npao
       else
          p=(Base%npao/sqrt(dfloat(1*MPI%size2))-1.d0)*Base%npao
       end if
       if( MPI%num_sc_work2space > 0.d0 ) then
          if( MPI%num_sc_work2space > dfloat(p) ) then
             MPI%num_sc_work2space = dfloat(p)
          end if
       else
          MPI%num_sc_work2space = -dfloat(p) * MPI%num_sc_work2space
          if( MPI%num_sc_work2space > dfloat(p) ) then
             MPI%num_sc_work2space = dfloat(p)
          end if
       end if
    else
       MPI%num_sc_work2space=0.d0
    end if


    return
  end subroutine MPI__setup

  subroutine MPI__Bcast_Inputfile( filename )
    implicit none
    character(64), intent(inout) :: filename

    if( MPI%size2 == 1 ) return

    call MPI_Bcast( filename, 64, MPI_CHARACTER, 0, MPI_COMM_WORLD, MPI%info )

    return
  end subroutine MPI__Bcast_Inputfile

  subroutine MPI__Allreduce_DensityLS( densityLS )
    implicit none
    complex(8), intent(inout) :: densityLS(Param%Option%nspin,Param%Cell%Na,Param%Cell%Nb,Param%Cell%Nc)
    complex(8), allocatable :: densityLS_mpi(:,:,:,:) 

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    if( MPI%sizeA == 1 ) return

    allocate(densityLS_mpi(Param%Option%nspin,Param%Cell%Na,Param%Cell%Nb,Param%Cell%Nc))
    densityLS_mpi = densityLS

    call MPI_Allreduce( &
         densityLS_mpi, densityLS, &
         Param%Option%nspin * Param%Cell%Na * Param%Cell%Nb * Param%Cell%Nc, &
         MPI_DOUBLE_COMPLEX, &
         MPI_SUM, &
         MPI%commA, MPI%info )

    deallocate( densityLS_mpi )

    call system_clock(ecount, rate, count_max )
    time_mpi = time_mpi + real(ecount-scount)/rate

    return
  end subroutine MPI__Allreduce_DensityLS

  subroutine MPI__Allreduce_Density( density )
    implicit none
    real(8), intent(inout) :: density (Param%Option%nspin,Param%Cell%Na,Param%Cell%Nb,Param%Cell%Nc)
    real(8), allocatable :: density_mpi(:,:,:,:) 

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    if( MPI%sizeA == 1 ) return

    allocate(density_mpi(Param%Option%nspin,Param%Cell%Na,Param%Cell%Nb,Param%Cell%Nc))
    density_mpi = density

    call MPI_Allreduce( &
         density_mpi, density, &
         Param%Option%nspin * Param%Cell%Na * Param%Cell%Nb * Param%Cell%Nc, &
         MPI_DOUBLE_PRECISION, &
         MPI_SUM, &
         MPI%commA, MPI%info )

    deallocate( density_mpi )

    call system_clock(ecount, rate, count_max )
    time_mpi = time_mpi + real(ecount-scount)/rate

    return
  end subroutine MPI__Allreduce_Density

  subroutine MPI__AlltoAll_Matrix( matrix )
    implicit none
    complex(8), intent(inout) :: matrix(Param%Data%npao,MPI%smat:MPI%emat)
    complex(8), allocatable :: pack_send(:)
    complex(8), allocatable :: pack_recv(:)
    integer :: i, j, p, index

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    if( MPI%sizeA == 1 ) then
       do i=1, Param%Data%npao
          do j=i, Param%Data%npao
             matrix(i,j) = dconjg(matrix(i,j)) + matrix(j,i)
             matrix(j,i) = dconjg(matrix(i,j))
          end do
       end do

       return
    end if

    allocate(pack_send(MPI%nmat*Param%Data%npao))
    allocate(pack_recv(MPI%nmat*Param%Data%npao))

    index=0
    do p=0, MPI%sizeA-1
       do i=MPI%smat, MPI%emat
          do j=MPI%vsmat(p), MPI%vemat(p)
             index=index+1
             pack_send(index) = matrix(j,i)
          end do
       end do
    end do

    call MPI_AlltoAllv( &
         pack_send, MPI%vmatcount, MPI%vmatdispl, MPI_DOUBLE_COMPLEX, &
         pack_recv, MPI%vmatcount, MPI%vmatdispl, MPI_DOUBLE_COMPLEX, &
         MPI%commA, MPI%info )

    index=0
    do p=0, MPI%sizeA-1
       do j=MPI%vsmat(p), MPI%vemat(p)
          do i=MPI%smat, MPI%emat
             index=index+1
             matrix(j,i) = dconjg(matrix(j,i)) + pack_recv(index)
          end do
       end do
    end do

    deallocate( pack_send, pack_recv )

    call system_clock(ecount, rate, count_max )
    time_mpi = time_mpi + real(ecount-scount)/rate

    return
  end subroutine MPI__AlltoAll_Matrix

  subroutine MPI__AlltoAll_Matrix2( matrix2, matrix3 )
    implicit none
    complex(8), intent(inout) :: matrix2(Param%Data%npao,MPI%smat:MPI%emat)
    complex(8), intent(inout) :: matrix3(Param%Data%npao,MPI%smat:MPI%emat)
    complex(8), allocatable :: pack_send(:)
    complex(8), allocatable :: pack_recv(:)
    integer :: i, j, p, index

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    if( MPI%sizeA == 1 ) then
       do i=1, Param%Data%npao
          do j=1, Param%Data%npao
             matrix2(i,j) = dconjg(matrix2(i,j)) + matrix3(j,i)
             matrix3(j,i) = dconjg(matrix2(i,j))
          end do
       end do

       return
    end if

    allocate(pack_send(MPI%nmat*Param%Data%npao))
    allocate(pack_recv(MPI%nmat*Param%Data%npao))

    index=0
    do p=0, MPI%sizeA-1
       do i=MPI%smat, MPI%emat
          do j=MPI%vsmat(p), MPI%vemat(p)
             index=index+1
             pack_send(index) = matrix3(j,i)
          end do
       end do
    end do

    call MPI_AlltoAllv( &
         pack_send, MPI%vmatcount, MPI%vmatdispl, MPI_DOUBLE_COMPLEX, &
         pack_recv, MPI%vmatcount, MPI%vmatdispl, MPI_DOUBLE_COMPLEX, &
         MPI%commA, MPI%info )

    index=0
    do p=0, MPI%sizeA-1
       do i=MPI%smat, MPI%emat
          do j=MPI%vsmat(p), MPI%vemat(p)
             index=index+1
             pack_send(index) = matrix2(j,i)
          end do
       end do
    end do

    index=0
    do p=0, MPI%sizeA-1
       do j=MPI%vsmat(p), MPI%vemat(p)
          do i=MPI%smat, MPI%emat
             index=index+1
             matrix2(j,i) = dconjg(matrix2(j,i)) + pack_recv(index)
          end do
       end do
    end do

    call MPI_AlltoAllv( &
         pack_send, MPI%vmatcount, MPI%vmatdispl, MPI_DOUBLE_COMPLEX, &
         pack_recv, MPI%vmatcount, MPI%vmatdispl, MPI_DOUBLE_COMPLEX, &
         MPI%commA, MPI%info )


    index=0
    do p=0, MPI%sizeA-1
       do j=MPI%vsmat(p), MPI%vemat(p)
          do i=MPI%smat, MPI%emat
             index=index+1
             matrix3(j,i) = dconjg(matrix3(j,i)) + pack_recv(index)
          end do
       end do
    end do

    deallocate( pack_send, pack_recv )

    call system_clock(ecount, rate, count_max )
    time_mpi = time_mpi + real(ecount-scount)/rate

    return
  end subroutine MPI__AlltoAll_Matrix2

  subroutine MPI__Allreduce_Force( vatom )
    implicit none
    type(Atom_type), intent(inout) :: vatom(Param%Data%natom)
    real(8), allocatable :: force_mpi(:,:), force_sum(:,:)
    integer :: a

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    if( MPI%sizeA == 1 ) return

    allocate( force_mpi(3,Param%Data%natom), force_sum(3,Param%Data%natom) )
    do a=1, Param%Data%natom
       force_mpi(:,a) = vatom(a)%force(:)
    end do

    call MPI_Allreduce( &
         force_mpi, force_sum, 3*Param%Data%natom, MPI_DOUBLE_PRECISION, &
         MPI_SUM, MPI%commA, MPI%info )

    do a=1, Param%Data%natom
       vatom(a)%force(:) = force_sum(:,a)
    end do

    deallocate( force_mpi, force_sum )

    call system_clock(ecount, rate, count_max )
    time_mpi = time_mpi + real(ecount-scount)/rate

    return
  end subroutine MPI__Allreduce_Force

  subroutine MPI__Allreduce_HistoryPulay( A )
    implicit none
    real(8), intent(inout) :: A(Param%SCF%mix_history,Param%SCF%mix_history)
    real(8), allocatable :: A_mpi(:,:) 

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    if( MPI%sizeA == 1 ) return

    allocate( A_mpi(Param%SCF%mix_history,Param%SCF%mix_history) )
    A_mpi = A

    call MPI_Allreduce( &
         A_mpi, A, Param%SCF%mix_history * Param%SCF%mix_history, &
         MPI_DOUBLE_PRECISION, MPI_SUM, MPI%commA, MPI%info )

    deallocate( A_mpi )

    call system_clock(ecount, rate, count_max )
    time_mpi = time_mpi + real(ecount-scount)/rate

    return
  end subroutine MPI__Allreduce_HistoryPulay

  subroutine MPI__Allreduce_HistoryAnderson( A )
    implicit none
    real(8), intent(inout) :: A(2:3,3)
    real(8) :: A_mpi(2:3,3) 

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    if( MPI%sizeA == 1 ) return

    A_mpi = A

    call MPI_Allreduce( &
         A_mpi, A, 2*3, &
         MPI_DOUBLE_PRECISION, MPI_SUM, MPI%commA, MPI%info )

    call system_clock(ecount, rate, count_max )
    time_mpi = time_mpi + real(ecount-scount)/rate

    return
  end subroutine MPI__Allreduce_HistoryAnderson

  subroutine MPI__Allreduce_Error( dEden )
    implicit none
    real(8), intent(inout) :: dEden
    real(8) :: dEden_mpi

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    if( MPI%sizeA == 1 ) return

    dEden_mpi = dEden

    call MPI_Allreduce( &
         dEden_mpi, dEden, 1, &
         MPI_DOUBLE_PRECISION, MPI_SUM, MPI%commA, MPI%info )

    call system_clock(ecount, rate, count_max )
    time_mpi = time_mpi + real(ecount-scount)/rate

    return
  end subroutine MPI__Allreduce_Error

  subroutine MPI__Allreduce_Error_Max( dEden )
    implicit none
    real(8), intent(inout) :: dEden
    real(8) :: dEden_mpi

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    if( MPI%sizeA == 1 ) return

    dEden_mpi = dEden

    call MPI_Allreduce( &
         dEden_mpi, dEden, 1, &
         MPI_DOUBLE_PRECISION, MPI_MAX, MPI%commA, MPI%info )

    call system_clock(ecount, rate, count_max )
    time_mpi = time_mpi + real(ecount-scount)/rate

    return
  end subroutine MPI__Allreduce_Error_Max

  subroutine MPI__Barrier
    implicit none

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    if( MPI%size2 == 1 ) return
    call MPI_Barrier(MPI_COMM_WORLD, MPI%info)

    call system_clock(ecount, rate, count_max )
    time_mpi = time_mpi + real(ecount-scount)/rate

    return
  end subroutine MPI__Barrier

  subroutine MPI__ZHEGV( size, matrixH, matrixS, vectorE, iter )
    implicit none

    integer, intent(in)       :: size
    complex(8), intent(inout) :: matrixH(size,MPI%smat:MPI%emat)
    complex(8), intent(in)    :: matrixS(size,MPI%smat:MPI%emat)
    real(8),    intent(out)   :: vectorE(size)

    integer, intent(in) :: iter

    real(8) :: VL, VU 
    integer :: IL, IU 
    integer :: M,  NZ 
    real(8), parameter :: ABSTOL = -1.0d0
    real(8), parameter :: ORFAC  = -1.0d0
    integer, allocatable :: icluster(:)
    real(8), allocatable :: gap(:)
    integer, allocatable :: ifail(:)

    complex(8), allocatable :: tempH(:,:)
    complex(8), allocatable :: tempS(:,:)
    complex(8), allocatable :: tempQ(:,:)
    complex(8), allocatable :: work1(:)
    real(8),    allocatable :: work2(:)
    integer,    allocatable :: work3(:)
    integer                 :: lwork1
    integer                 :: lwork2
    integer                 :: lwork3

    integer :: i, j

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    allocate( tempH(size,MPI%smat:MPI%smat+MPI%localsize1-1) )
    allocate( tempS(size,MPI%smat:MPI%smat+MPI%localsize1-1) )
    allocate( tempQ(size,MPI%smat:MPI%smat+MPI%localsize1-1) )
    tempH(:,:)=0.d0
    tempS(:,:)=0.d0
    tempQ(:,:)=0.d0
    do i=MPI%smat, MPI%emat
       do j=1, size
          tempH(j,i) = matrixH(j,i)
          tempS(j,i) = matrixS(j,i)
       end do
    end do

    allocate( icluster(2*MPI%sizeA), gap(MPI%sizeA), ifail(size) )

    allocate( work1(1), work2(3), work3(1) )
    lwork1 = -1
    lwork2 = -1
    lwork3 = -1

    call PZHEGVX( 1, 'V', 'A', 'L', size, &
         tempH, 1, 1, MPI%SL_DESC, &
         tempS, 1, 1, MPI%SL_DESC, &
         VL, VU, IL, IU, ABSTOL, M, NZ, &
         vectorE, ORFAC, &
         tempQ, 1, 1, MPI%SL_DESC, &
         work1, lwork1, work2, lwork2, work3, lwork3, &
         ifail, icluster, gap, MPI%info )

    lwork1 = int(work1(1)) 
    lwork2 = int(work2(1)) + int(MPI%num_sc_work2space) 
    lwork3 = int(work3(1)) 

    deallocate( work1, work2, work3 )
    allocate( work1(lwork1), work2(lwork2), work3(lwork3) )

    call PZHEGVX( 1, 'V', 'A', 'L', size, &
         tempH, 1, 1, MPI%SL_DESC, &
         tempS, 1, 1, MPI%SL_DESC, &
         VL, VU, IL, IU, ABSTOL, M, NZ, &
         vectorE, ORFAC, &
         tempQ, 1, 1, MPI%SL_DESC, &
         work1, lwork1, work2, lwork2, work3, lwork3, &
         ifail, icluster, gap, MPI%info )

    do i=MPI%smat, MPI%emat
       do j=1, size
          matrixH(j,i)=tempQ(j,i)
       end do
    end do

    deallocate( work1, work2, work3 )
    deallocate( icluster, gap, ifail )

    if( mod(iter,MPI%num_check_sc) == 0 .or. iter == 1 ) then
       if( iter /= 0 ) then
          open(unit=16,file=Param%Option%file_ac_tempout,position='append')
          write(16,*) '                               +++++ check: ScaLapack(PZHEGVX)'
          close(16)
       end if
       allocate( work1(2) )
       work1(1)=dcmplx(1.0d0,0.d0)
       work1(2)=dcmplx(0.0d0,0.d0)

       do i=MPI%smat, MPI%emat
          do j=1, size
             tempH(j,i)=matrixS(j,i)
          end do
       end do
       call MPI__ZGEMM('C', 'N', size, work1(1), tempQ, tempH, work1(2), tempS )
       call MPI__ZGEMM('N', 'N', size, work1(1), tempS, tempQ, work1(2), tempH )
       do i=MPI%smat, MPI%emat
          do j=1, size
             if( i == j ) then
                if( cdabs(tempH(j,i)-1.d0) > 1.d-10 ) then
                   open(unit=16,file=Param%Option%file_ac_tempout,position='append')
                   write(16,999) i,j,tempH(j,i)
                   write(16,998)
                   close(16)
                end if
             else
                if( cdabs(tempH(j,i)) > 1.d-10 ) then
                   open(unit=16,file=Param%Option%file_ac_tempout,position='append')
                   write(16,999) i,j,tempH(j,i)
                   write(16,998)
                   close(16)
                end if
             end if
          end do
       end do
999    format('error: MPI__ZHEGV',2i5,2e20.10,'  -> stop')
998    format('       ====> increase: //mpi_condition/num_sc_work2space//')

       deallocate( work1 )
    end if

    deallocate( tempH, tempS, tempQ )

    call system_clock(ecount, rate, count_max )
    time_la = time_la + real(ecount-scount)/rate

    return
  end subroutine MPI__ZHEGV

  subroutine MPI__ZHEGVLS( size, matrixH, matrixS, vectorE, iter )
    implicit none

    integer, intent(in)       :: size 
    complex(8), intent(inout) :: matrixH(2*size,2*MPI%smat-1:2*MPI%emat-0) 
    complex(8), intent(in)    :: matrixS(2*size,2*MPI%smat-1:2*MPI%emat-0) 
    real(8),    intent(out)   :: vectorE(2*size) 

    integer, intent(in) :: iter

    real(8) :: VL, VU 
    integer :: IL, IU 
    integer :: M,  NZ 
    real(8), parameter :: ABSTOL = -1.0d0
    real(8), parameter :: ORFAC  = -1.0d0
    integer, allocatable :: icluster(:)
    real(8), allocatable :: gap(:)
    integer, allocatable :: ifail(:)

    complex(8), allocatable :: tempH(:,:)
    complex(8), allocatable :: tempS(:,:)
    complex(8), allocatable :: tempQ(:,:)
    complex(8), allocatable :: work1(:)
    real(8),    allocatable :: work2(:)
    integer,    allocatable :: work3(:)
    integer                 :: lwork1
    integer                 :: lwork2
    integer                 :: lwork3

    integer :: i,j

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    allocate( tempH(2*size,2*MPI%smat-1:2*MPI%smat-1+MPI%localsize2-1) )
    allocate( tempS(2*size,2*MPI%smat-1:2*MPI%smat-1+MPI%localsize2-1) )
    allocate( tempQ(2*size,2*MPI%smat-1:2*MPI%smat-1+MPI%localsize2-1) )
    tempH(:,:)=0.d0
    tempS(:,:)=0.d0
    tempQ(:,:)=0.d0
    do i=2*MPI%smat-1,2*MPI%emat-0
       do j=1, 2*size
          tempH(j,i) = matrixH(j,i)
          tempS(j,i) = matrixS(j,i)
       end do
    end do

    allocate( icluster(2*MPI%sizeA), gap(MPI%sizeA), ifail(2*size) )

    allocate( work1(1), work2(3), work3(1) )
    lwork1 = -1
    lwork2 = -1
    lwork3 = -1

    call PZHEGVX( 1, 'V', 'A', 'L', 2*size, &
         tempH, 1, 1, MPI%SL_DESCLS, &
         tempS, 1, 1, MPI%SL_DESCLS, &
         VL, VU, IL, IU, ABSTOL, M, NZ, &
         vectorE, ORFAC, &
         tempQ, 1, 1, MPI%SL_DESCLS, &
         work1, lwork1, work2, lwork2, work3, lwork3, &
         ifail, icluster, gap, MPI%info )

    lwork1 = int(work1(1)) 
    lwork2 = int(work2(1)) + int(MPI%num_sc_work2space) 
    lwork3 = int(work3(1)) 

    deallocate( work1, work2, work3 )
    allocate( work1(lwork1), work2(lwork2), work3(lwork3) )

    call PZHEGVX( 1, 'V', 'A', 'L', 2*size, &
         tempH, 1, 1, MPI%SL_DESCLS, &
         tempS, 1, 1, MPI%SL_DESCLS, &
         VL, VU, IL, IU, ABSTOL, M, NZ, &
         vectorE, ORFAC, &
         tempQ, 1, 1, MPI%SL_DESCLS, &
         work1, lwork1, work2, lwork2, work3, lwork3, &
         ifail, icluster, gap, MPI%info )

    do i=2*MPI%smat-1,2*MPI%emat-0
       do j=1, 2*size
          matrixH(j,i)=tempQ(j,i)
       end do
    end do

    deallocate( work1, work2, work3 )
    deallocate( icluster, gap, ifail )

    if( mod(iter,MPI%num_check_sc) == 0 .or. iter == 1 ) then
       if( iter /= 0 ) then
          open(unit=16,file=Param%Option%file_ac_tempout,position='append')
          write(16,*) '                               +++++ check: ScaLapack(PZHEGVX)'
          close(16)
       end if
       allocate( work1(2) )
       work1(1)=dcmplx(1.0d0,0.d0)
       work1(2)=dcmplx(0.0d0,0.d0)

       do i=2*MPI%smat-1,2*MPI%emat-0
          do j=1, 2*size
             tempH(j,i)=matrixS(j,i)
          end do
       end do
       call MPI__ZGEMMLS('C', 'N', size, work1(1), tempQ, tempH, work1(2), tempS )
       call MPI__ZGEMMLS('N', 'N', size, work1(1), tempS, tempQ, work1(2), tempH )
       do i=2*MPI%smat-1,2*MPI%emat-0
          do j=1, 2*size
             if( i == j ) then
                if( cdabs(tempH(j,i)-1.d0) > 1.d-10 ) then
                   open(unit=16,file=Param%Option%file_ac_tempout,position='append')
                   write(16,999) i,j,tempH(j,i)
                   write(16,998)
                   close(16)
                   stop
                end if
             else
                if( cdabs(tempH(j,i)) > 1.d-10 ) then
                   open(unit=16,file=Param%Option%file_ac_tempout,position='append')
                   write(16,999) i,j,tempH(j,i)
                   write(16,998)
                   close(16)
                   stop
                end if
             end if
          end do
       end do
999    format('error: MPI__ZHEGV',2i5,2e20.10,'  -> stop')
998    format('       ====> increase: //mpi_condition/num_sc_work2space//')

       deallocate( work1 )
    end if

    deallocate( tempH, tempS, tempQ )

    call system_clock(ecount, rate, count_max )
    time_la = time_la + real(ecount-scount)/rate

    return
  end subroutine MPI__ZHEGVLS

  subroutine MPI__ZGEMM( transa, transb, size, &
       alpha, matrixA, matrixB, beta, matrixC )
    implicit none

    character, intent(in)     :: transa, transb
    integer, intent(in)       :: size
    complex(8), intent(inout) :: matrixA(size,MPI%smat:MPI%smat+MPI%localsize1-1)
    complex(8), intent(inout) :: matrixB(size,MPI%smat:MPI%smat+MPI%localsize1-1)
    complex(8), intent(inout) :: matrixC(size,MPI%smat:MPI%smat+MPI%localsize1-1)
    complex(8), intent(in)    :: alpha, beta

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    call PZGEMM( transa, transb, size, size, size, &
         alpha, matrixA, 1, 1, MPI%SL_DESC, matrixB, 1, 1, MPI%SL_DESC, &
         beta,  matrixC, 1, 1, MPI%SL_DESC )

    call system_clock(ecount, rate, count_max )
    time_la = time_la + real(ecount-scount)/rate

    return
  end subroutine MPI__ZGEMM

  subroutine MPI__ZGEMMLS( transa, transb, size, &
       alpha, matrixA, matrixB, beta, matrixC )
    implicit none

    character, intent(in)     :: transa, transb
    integer, intent(in)       :: size
    complex(8), intent(inout) :: matrixA(2*size,2*MPI%smat-1:2*MPI%smat-1+MPI%localsize2-1)
    complex(8), intent(inout) :: matrixB(2*size,2*MPI%smat-1:2*MPI%smat-1+MPI%localsize2-1)
    complex(8), intent(inout) :: matrixC(2*size,2*MPI%smat-1:2*MPI%smat-1+MPI%localsize2-1)
    complex(8), intent(in)    :: alpha, beta

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    call PZGEMM( transa, transb, 2*size, 2*size, 2*size, &
         alpha, matrixA, 1, 1, MPI%SL_DESCLS, matrixB, 1, 1, MPI%SL_DESCLS, &
         beta,  matrixC, 1, 1, MPI%SL_DESCLS )

    call system_clock(ecount, rate, count_max )
    time_la = time_la + real(ecount-scount)/rate

    return
  end subroutine MPI__ZGEMMLS

  subroutine MPI__ZGEMM_ASCOT( transa, transb, desc, &
       alpha, matrixA, matrixB, beta, matrixC )
    implicit none

    character, intent(in)     :: transa, transb
    type(MPI_MatDesc), intent(in) :: desc
    complex(8), intent(in)  :: matrixA(desc%nrow,desc%scol:desc%ecol)
    complex(8), intent(in)  :: matrixB(desc%nrow,desc%scol:desc%ecol)
    complex(8), intent(out) :: matrixC(desc%nrow,desc%scol:desc%ecol)
    complex(8), intent(in)  :: alpha, beta

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    call PZGEMM( transa, transb, desc%nrow, desc%ncol, desc%ncol, &
         alpha, matrixA, 1, 1, desc%SL_DESC, matrixB, 1, 1, desc%SL_DESC, &
         beta,  matrixC, 1, 1, desc%SL_DESC )

    call system_clock(ecount, rate, count_max )
    time_la = time_la + real(ecount-scount)/rate

    return
  end subroutine MPI__ZGEMM_ASCOT

  subroutine MPI__ZGEMM2_ASCOT( transa, transb, descA, descB, descC, &
       alpha, matrixA, matrixB, beta, matrixC )
    implicit none

    character, intent(in)     :: transa, transb
    type(MPI_MatDesc), intent(in) :: descA, descB, descC
    complex(8), intent(in)  :: matrixA(descA%nrow,descA%scol:descA%ecol)
    complex(8), intent(in)  :: matrixB(descB%nrow,descB%scol:descB%ecol)
    complex(8), intent(out) :: matrixC(descC%nrow,descC%scol:descC%ecol)
    complex(8), intent(in)  :: alpha, beta

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    call PZGEMM( transa, transb, descC%nrow, descC%ncol, descA%ncol, &
         alpha, matrixA, 1, 1, descA%SL_DESC, matrixB, 1, 1, descB%SL_DESC, &
         beta,  matrixC, 1, 1, descC%SL_DESC )

    call system_clock(ecount, rate, count_max )
    time_la = time_la + real(ecount-scount)/rate

    return
  end subroutine MPI__ZGEMM2_ASCOT

  subroutine MPI__ZTRANC_ASCOT( desc, alpha, matrixA, beta, matrixB )
    implicit none

    type(MPI_MatDesc), intent(in) :: desc
    complex(8), intent(inout) :: matrixA(desc%nrow,desc%scol:desc%ecol)
    complex(8), intent(in)    :: matrixB(desc%nrow,desc%scol:desc%ecol)
    complex(8), intent(in)    :: beta, alpha

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    call PZTRANC( desc%nrow, desc%ncol, &
         beta, matrixB, 1, 1, desc%SL_DESC, &
         alpha,  matrixA, 1, 1, desc%SL_DESC )

    call system_clock(ecount, rate, count_max )
    time_la = time_la + real(ecount-scount)/rate

    return
  end subroutine MPI__ZTRANC_ASCOT

  subroutine MPI__ZLATRA_ASCOT( desc, matrixA, trace )
    implicit none

    type(MPI_MatDesc), intent(in) :: desc
    complex(8), intent(in)    :: matrixA(desc%nrow,desc%scol:desc%ecol)
    complex(8), intent(out)   :: trace
    complex(8), external :: PZLATRA

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    trace = PZLATRA( desc%nrow, matrixA, 1, 1, desc%SL_DESC )

    call system_clock(ecount, rate, count_max )
    time_la = time_la + real(ecount-scount)/rate

    return
  end subroutine MPI__ZLATRA_ASCOT

  subroutine MPI__ZGETRI_ASCOT( desc, matrixA )
    implicit none

    type(MPI_MatDesc), intent(in) :: desc
    complex(8), intent(inout) :: matrixA(desc%nrow,desc%scol:desc%ecol)

    integer, allocatable :: ipiv(:)
    complex(8), allocatable :: work1(:)
    integer, allocatable :: work2(:)
    integer :: lwork1, lwork2

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    allocate( ipiv(desc%nrow) )

    call PZGETRF( desc%nrow, desc%ncol, matrixA, 1, 1, desc%SL_DESC, &
         ipiv, MPI%info )


    allocate( work1(1), work2(1) )
    lwork1=-1; lwork2=-1

    call PZGETRI( desc%nrow, matrixA, 1, 1, desc%SL_DESC, &
         ipiv, work1, lwork1, work2, lwork2, MPI%info )

    lwork1 = int(work1(1))
    lwork2 = int(work2(1))
    lwork1 = (lwork1+100)*2 
    lwork2 = (lwork2+100)*2 
    deallocate( work1, work2 )

    allocate( work1(lwork1), work2(lwork2) )

    call PZGETRI( desc%nrow, matrixA, 1, 1, desc%SL_DESC, &
         ipiv, work1, lwork1, work2, lwork2, MPI%info )

    deallocate( ipiv )
    deallocate( work1, work2 )

    call system_clock(ecount, rate, count_max )
    time_la = time_la + real(ecount-scount)/rate

    return
  end subroutine MPI__ZGETRI_ASCOT

  subroutine MPI__Allgather_MatrixM_ASCOT( desc, matrixA )
    implicit none
    type(MPI_MatDesc), intent(in) :: desc
    complex(8), intent(inout) :: matrixA(desc%nrow,desc%ncol)
    integer :: i

    integer irow, icol

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    if( MPI%sizeM == 1 ) return

    irow = mod(desc%vsele(MPI%rankM), desc%nrow) + 1
    icol = desc%vsele(MPI%rankM) / desc%nrow + 1

    call MPI_Allgatherv( &
         matrixA(irow,icol), desc%vnele(MPI%rankM), MPI_DOUBLE_COMPLEX, &
         matrixA, desc%vnele, desc%vsele, &
         MPI_DOUBLE_COMPLEX, MPI%commM, MPI%info )

    call system_clock(ecount, rate, count_max )
    time_mpi = time_mpi + real(ecount-scount)/rate

    return
  end subroutine MPI__Allgather_MatrixM_ASCOT

  subroutine MPI__Allreduce_MatrixE_ASCOT( desc, matrixA )
    implicit none
    type(MPI_MatDesc), intent(in) :: desc
    complex(8), intent(inout) :: matrixA(desc%nrow,desc%scol:desc%ecol)
    complex(8), allocatable :: matrixA_mpi(:,:) 

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    if( MPI%sizeE == 1 ) return

    allocate(matrixA_mpi(desc%nrow,desc%scol:desc%ecol))

    matrixA_mpi(:,:) = matrixA(:,:)

    call MPI_Allreduce( matrixA_mpi, matrixA, desc%nrow*(desc%ecol-desc%scol+1), &
         MPI_DOUBLE_COMPLEX, MPI_SUM, MPI%commE, MPI%info )

    deallocate(matrixA_mpi)

    call system_clock(ecount, rate, count_max )
    time_mpi = time_mpi + real(ecount-scount)/rate

    return
  end subroutine MPI__Allreduce_MatrixE_ASCOT

  subroutine MPI__Allreduce_MatrixT_ASCOT( desc, matrixA )
    implicit none
    type(MPI_MatDesc), intent(in) :: desc
    complex(8), intent(inout) :: matrixA(desc%nrow,desc%ncol)
    complex(8), allocatable   :: matrixA_mpi(:,:) 

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    if( MPI%size2 == 1 ) return

    allocate(matrixA_mpi(desc%nrow,desc%ncol))

    matrixA_mpi(:,:) = matrixA(:,:)

    call MPI_Allreduce( matrixA_mpi, matrixA, desc%nrow*desc%ncol, &
         MPI_DOUBLE_COMPLEX, MPI_SUM, MPI_COMM_WORLD, MPI%info )

    deallocate(matrixA_mpi)

    call system_clock(ecount, rate, count_max )
    time_mpi = time_mpi + real(ecount-scount)/rate

    return
  end subroutine MPI__Allreduce_MatrixT_ASCOT

  subroutine MPI__Allreduce_MatrixM_ASCOT( desc, matrixA )
    implicit none
    type(MPI_MatDesc), intent(in) :: desc
    complex(8), intent(inout) :: matrixA(desc%nrow,desc%scol:desc%ecol)
    complex(8), allocatable :: matrixA_mpi(:,:) 

    integer scount, ecount, rate, count_max

    call system_clock(scount, rate, count_max )

    if( MPI%sizeM == 1 ) return

    allocate(matrixA_mpi(desc%nrow,desc%scol:desc%ecol))

    matrixA_mpi(:,:) = matrixA(:,:)

    call MPI_Allreduce( matrixA_mpi, matrixA, desc%nrow*(desc%ecol-desc%scol+1), &
         MPI_DOUBLE_COMPLEX, MPI_SUM, MPI%commM, MPI%info )

    deallocate(matrixA_mpi)

    call system_clock(ecount, rate, count_max )
    time_mpi = time_mpi + real(ecount-scount)/rate

    return
  end subroutine MPI__Allreduce_MatrixM_ASCOT

end module ac_mpi_module
