!=======================================================================
!
!  PROGRAM  PHASE/0 2014.03 ($Rev: 409 $)
!
!  MODULE: m_Wannier90
!
!  AUTHOR(S): T. Yamamoto   May/05/2008
!
!  Contact address :  Phase System Consortium
!                     E-mail: phase_system@nims.go.jp URL https://azuma.nims.go.jp
!
!
!
!=======================================================================
!
!     The original version of this set of the computer programs "PHASE"
!  was developed by the members of the Theory Group of Joint Research
!  Center for Atom Technology (JRCAT), based in Tsukuba, in the period
!  1993-2001.
!
!     Since 2002, this set has been tuned and new functions have been
!  added to it as a part of the national project "Frontier Simulation
!  Software for Industrial Science (FSIS)",  which is supported by
!  the IT program of the Ministry of Education, Culture, Sports,
!  Science and Technology (MEXT) of Japan.
!     Since 2006, this program set has been developed as a part of the
!  national project "Revolutionary Simulation Software (RSS21)", which
!  is supported by the next-generation IT program of MEXT of Japan.
!   Since 2013, this program set has been further developed centering on PHASE System
!  Consortium.
!   The activity of development of this program set has been supervised by Takahisa Ohno.
!
!
module m_Wannier90
  use m_Const_Parameters,   only : DP,PAI2,PAI4,SKIP,EXECUT,GAMMA,BUCS,CRDTYP
  use m_Control_Parameters, only : printable, kimg, nspin, neg, ipri &
    &                            , wan90_seedname, nb_wan90, LEN_TITLE
  use m_Files,              only : nfwannier 
  use m_Timing,             only : tstatc0_begin, tstatc0_end
  use m_Kpoints,            only : kv3,vkxyz,k_symmetry
  use m_PlaneWaveBasisSet,  only : ngabc,kg1,kgp,nbase,iba,nbase_gamma
  use m_Electronic_Structure,only: neordr,zaj_l,eko_l,m_ES_WF_in_Rspace
  use m_FFT,                only : fft_box_size_WF, nfft, m_FFT_alloc_WF_work &
       &                         , m_FFT_dealloc_WF_work
  use m_Crystal_Structure,  only : univol,rltv
  use m_Parallelization,    only : mpi_comm_group,npes,mype &
       &                         , nrank_e,myrank_e,map_e,ista_e,iend_e,istep_e,idisp_e &
       &                         , map_z,np_e,mpi_k_world,myrank_k,map_k,ista_k,iend_k &
       &                         , ista_snl, iend_snl, ierr

  character(len=LEN_TITLE) :: comment
  logical :: calc_only_A
  real(kind=DP) :: real_lattice(3,3), recip_lattice(3,3)
  integer :: num_kp, n_proj, n_exclude, nntot, num_bands
  integer, parameter :: lmax1_pj = 3

  real(kind=DP), allocatable :: kvec(:,:) ! d(3,num_kp)
  real(kind=DP), allocatable :: centre(:,:), zaxis(:,:), xaxis(:,:) ! d(3,n_proj)
  real(kind=DP), allocatable :: zona(:) ! d(n_proj)
  integer, allocatable :: lang(:), mr(:), irf(:) ! d(n_proj)
  integer, allocatable :: exclude_bands(:) ! d(n_exclude)
  integer, allocatable :: nnlist(:,:) ! d(num_kp,nntot)
  integer, allocatable :: nncell(:,:,:) ! d(3,num_kp,nntot)
  integer, allocatable :: ib_inc(:) ! d(num_bands)

  real(kind=DP), allocatable :: projfunc(:,:,:,:) ! d(kg1,n_proj,ista_snl:iend_snl,kimg)

  include 'mpif.h'                                      ! MPI
  integer istatus(mpi_status_size)                      ! MPI

contains

  subroutine m_Wan90_init(nfout)
    implicit none  
    integer, intent(in) :: nfout

    character(len=15) :: dummy
 
    integer :: i, j, n
    integer :: kp, nnkp, ng1,ng2,ng3 

    open(nfwannier,file=trim(wan90_seedname)//".nnkp",status="old",form="formatted")
    read(nfwannier,'(a80)') comment

    read(nfwannier,*)

    read(nfwannier,'(a15,l)') dummy,calc_only_A

    read(nfwannier,*)

    read(nfwannier,*)
    do i=1,3
       read(nfwannier,*) real_lattice(1:3,i)
    end do
    read(nfwannier,*)

    read(nfwannier,*)

    read(nfwannier,*)
    do i=1,3
       read(nfwannier,*) recip_lattice(1:3,i)
    end do
    read(nfwannier,*)

    read(nfwannier,*)

    read(nfwannier,*)
    read(nfwannier,*) num_kp
    if(ipri>=2) write(*,'("num_kp=",i5)') num_kp
    allocate(kvec(3,num_kp))
    do i=1,num_kp
       read(nfwannier,*) kvec(1:3,i)
    end do
    read(nfwannier,*)

    read(nfwannier,*)

    read(nfwannier,*)
    read(nfwannier,*) n_proj
    if(ipri>=2) write(*,'("n_proj=",i5)') n_proj
    allocate(centre(3,n_proj))
    allocate(lang(n_proj))
    allocate(mr(n_proj))
    allocate(irf(n_proj))
    allocate(zaxis(3,n_proj))
    allocate(xaxis(3,n_proj))
    allocate(zona(n_proj))
    do i=1,n_proj
       read(nfwannier,*) centre(1:3,i), lang(i), mr(i), irf(i)
       read(nfwannier,*) zaxis(1:3,i), xaxis(1:3,i), zona(i)
    end do
    read(nfwannier,*)

    read(nfwannier,*)

    read(nfwannier,*)
    read(nfwannier,*) nntot
    if(ipri>=2) write(*,'("nntot=",i5)') nntot
    allocate(nnlist(num_kp,nntot))
    allocate(nncell(3,num_kp,nntot))
    do j=1,num_kp
       do i=1,nntot
          read(nfwannier,*) kp, nnkp, ng1,ng2,ng3
          nnlist(kp,i) = nnkp
          nncell(1,kp,i) = ng1
          nncell(2,kp,i) = ng2
          nncell(3,kp,i) = ng3
      end do
    end do
    read(nfwannier,*)

    read(nfwannier,*)

    read(nfwannier,*)
    read(nfwannier,*) n_exclude
    if(ipri>=2) write(*,'("n_exclude=",i5)') n_exclude
    allocate(exclude_bands(n_exclude))
    do i=1,n_exclude
       read(nfwannier,*) exclude_bands(i)
    end do
    read(nfwannier,*)

    close(nfwannier)

    num_bands = nb_wan90 - n_exclude
    allocate(ib_inc(num_bands))
    n = 0
    do i=1,nb_wan90
       if(excluding_band(i)) cycle
       n = n + 1
       ib_inc(n) = i
    end do

    call write_down_input(nfout)
  end subroutine m_Wan90_init

  subroutine write_down_input(nfout)
    implicit none
    integer, intent(in) :: nfout

    integer :: i,j
   
    write(nfout,'("Wannier90")')
    write(nfout,'("comment: ",a)') trim(comment)
    write(nfout,'("calc_only_A: ",l)') calc_only_A
    write(nfout,'("real_lattice")')
    do i=1,3
       write(nfout,'(3(1x,f20.8))') real_lattice(1:3,i)
    end do
    write(nfout,'("recip_lattice")')
    do i=1,3
       write(nfout,'(3(1x,f20.8))') recip_lattice(1:3,i)
    end do
    write(nfout,'("k vectors:")')
    write(nfout,'("num_kp: ",i5)') num_kp
    do i=1,num_kp
       write(nfout,'("kvec(",i5,") =",3(1x,f20.8))') i,kvec(1:3,i)
    end do
    write(nfout,'("n_proj: ",i5)') n_proj
    do i=1,n_proj
       write(nfout,'(3(1x,f20.8),3(1x,i5))') centre(1:3,i), lang(i), mr(i), irf(i)
       write(nfout,'(7(1x,f20.8))') zaxis(1:3,i), xaxis(1:3,i), zona(i)
    end do
    write(nfout,'("kp nnlist nncell(1:3)")')
    do i=1,num_kp
       do j=1,nntot
          write(nfout,'(i5,1x,i5,3(1x,i5))') i,nnlist(i,j),nncell(1:3,i,j)
       end do
    end do
    write(nfout,'("excluded bands ")')
    write(nfout,'("n_exclude: ",i5)') n_exclude
    do i=1,n_exclude
       write(nfout,'(i5)') exclude_bands(i)
    end do
    write(nfout,'("included bands ")')
    do i=1,num_bands
       write(nfout,'(i5)') ib_inc(i)
    end do

  end subroutine write_down_input

  subroutine m_Wan90_gen_amat(nfout)
!
!    Amn(k) = <Psi_mk|g_n>
!
    implicit none  
    integer, intent(in) :: nfout

    integer :: ik,ikpj,n,m,ib,i,mi
    real(kind=DP) :: pr,pi,zr,zi
    real(kind=DP), allocatable :: a_mat(:,:,:,:) ! d(num_bands,n_proj,kv3,2)
    real(kind=DP), allocatable :: a_mat_mpi(:,:,:,:) ! d(num_bands,n_proj,kv3,2)

    allocate(projfunc(kg1,n_proj,ista_snl:iend_snl,kimg))
    call m_Wan90_Projectors(nfout,kv3,vkxyz)

    allocate(a_mat(num_bands,n_proj,kv3,2))
    a_mat = 0.d0
    do ik = 1, kv3
       if(map_k(ik) /= myrank_k) cycle ! MPI
       ikpj = (ik-1)/nspin + 1
       do n=1,n_proj
          do m=1,num_bands
             mi = ib_inc(m)
             ib = neordr(mi,ik)
             if(map_e(ib) /= myrank_e) cycle ! MPI
             if(ipri>=2) write(*,'("m=",i5," n=",i5," ik=",i5)') m,n,ik
             if(kimg==1) then
                do i=1,iba(ik)
                   pr = projfunc(i,n,ikpj,1)
                   zr = zaj_l(i,map_z(ib),ik,1)
                   a_mat(m,n,ik,1) = a_mat(m,n,ik,1) + zr*pr
                end do
             else
                if(k_symmetry(ik) == GAMMA) then
                   do i=2,iba(ik)
                      pr = projfunc(i,n,ikpj,1)
                      pi = projfunc(i,n,ikpj,2)
                      zr = zaj_l(i,map_z(ib),ik,1)
                      zi = zaj_l(i,map_z(ib),ik,2)
                      a_mat(m,n,ik,1) = a_mat(m,n,ik,1) + zr*pr + zi*pi
                   end do
                   pr = projfunc(1,n,ikpj,1)
                   zr = zaj_l(1,map_z(ib),ik,1)
                   a_mat(m,n,ik,1) = a_mat(m,n,ik,1)*2.d0 + zr*pr
                else
                   do i=1,iba(ik)
                      pr = projfunc(i,n,ikpj,1)
                      pi = projfunc(i,n,ikpj,2)
                      zr = zaj_l(i,map_z(ib),ik,1)
                      zi = zaj_l(i,map_z(ib),ik,2)
                      a_mat(m,n,ik,1) = a_mat(m,n,ik,1) + zr*pr + zi*pi
                      a_mat(m,n,ik,2) = a_mat(m,n,ik,2) + zr*pi - zi*pr
                   end do
                end if
             end if
          end do
       end do
    end do
    deallocate(projfunc)

    if(npes>1) then
       allocate(a_mat_mpi(num_bands,n_proj,kv3,2))
       a_mat_mpi = a_mat
       a_mat = 0.d0
       call mpi_allreduce(a_mat_mpi,a_mat,num_bands*n_proj*kv3*2,mpi_double_precision,mpi_sum,mpi_comm_group,ierr)
       deallocate(a_mat_mpi)
    end if

    if(mype==0) then
       open(nfwannier,file=trim(wan90_seedname)//".amn",form="formatted")
       write(nfwannier,*) 'Generated by PHASE'
       write(nfwannier,*) num_bands, kv3, n_proj
       do ik = 1, kv3
          do n=1,n_proj
             do m=1,num_bands
                write(nfwannier,'(3(1x,i5),2(1x,f18.12))') m, n, ik, a_mat(m,n,ik,1:2)
             end do
          end do
       end do
       close(nfwannier)
    end if

    deallocate(a_mat)
  end subroutine m_Wan90_gen_amat

  subroutine m_Wan90_Projectors(nfout,kv3,vkxyz)
    implicit none  
    integer, intent(in) :: nfout,kv3
    real(kind=DP), intent(in) :: vkxyz(kv3,3,CRDTYP)

    integer :: ik,ikpj,il1,ip,n,i
    integer, parameter :: nmesh = 1501
    real(kind=DP)       :: fac, facr, dnorm
    real(kind=DP), allocatable, dimension(:) :: qx,qy,qz,vlength,wka,wkb
    real(kind=DP), allocatable, dimension(:) :: zcos,zsin
    real(kind=DP), allocatable, dimension(:,:) :: pf
    real(kind=DP), allocatable, dimension(:,:) :: angf
    real(kind=DP), allocatable, dimension(:) :: radr,wos,radfunc
    integer, allocatable :: lmin_pj(:) ! d(n_proj)
    integer, allocatable :: lmax_pj(:) ! d(n_proj)

    integer             :: id_sname = -1
    call tstatc0_begin('m_Wan90_Projectors ',id_sname,1)

    allocate(lmin_pj(n_proj))
    allocate(lmax_pj(n_proj))
    allocate(qx(kg1),qy(kg1),qz(kg1),vlength(kg1),wka(kg1),wkb(kg1))
    allocate(pf(kg1,lmax1_pj),angf(kg1,lmax1_pj))
    allocate(radr(nmesh),wos(nmesh),radfunc(nmesh))
    allocate(zcos(kg1),zsin(kg1))

    do ip=1,n_proj
       call set_lmin_lmax(lang(ip),lmin_pj(ip),lmax_pj(ip))
       if(ipri>=2) write(*,'("ip=",i5," lmin=",i5," lmax=",i5)') ip, lmin_pj(ip), lmax_pj(ip) 
    end do

    call radr_and_wos(nmesh,radr,wos) ! --> radr, wos
    fac = PAI4/dsqrt(univol)
    do ik = 1, kv3, nspin
       if(map_k(ik) /= myrank_k) cycle                     ! MPI
       call k_plus_G_vectors(ik,kgp,kg1,kv3,iba,nbase,vkxyz,ngabc,rltv&
                           &,qx,qy,qz,vlength) ! ->(bottom_Subr.)
!!       call G_vectors(kgp,iba(ik),nbase(1,ik),ngabc,rltv,qx,qy,qz,vlength)
       ikpj = (ik-1)/nspin + 1
       do ip=1,n_proj
          if(ipri>=2) write(*,'("ikpj=",i5," ip=",i5)') ikpj,ip

          call radial_function(irf(ip),zona(ip),nmesh,radr,radfunc)
! Debug
         !! if(ik==1) then
         !! write(1000+ip,'("radr radfunc")')
         !! do n=1,nmesh
         !!    write(1000+ip,'(f20.8,1x,f20.8)') radr(n), radfunc(n)
         !! end do
         !! end if
! end Debug
          if(ipri>=2) write(*,'("debug 1")')
          angf = 0.d0
          do il1=lmin_pj(ip)+1,lmax_pj(ip)+1
             call angular_function(lang(ip),mr(ip),iba(ik),qx,qy,qz,angf)
          end do
          if(ipri>=2) write(*,'("debug 2")')
          pf = 0.d0
          do n=1,nmesh
             facr = fac*wos(n)*radr(n)**2*radfunc(n)
             do i=1,iba(ik)
                wka(i) = vlength(i)*radr(n)
             end do
             do il1=lmin_pj(ip)+1,lmax_pj(ip)+1
                call dsjnv(il1-1,iba(ik),wka,wkb)     ! -(bottom_Subr.)
                do i=1,iba(ik)
                   pf(i,il1) = pf(i,il1) + facr*wkb(i)*angf(i,il1)
                end do
             end do
          end do
          if(ipri>=2) write(*,'("debug 3")')

          !! Projector(k+G) = sum_l i^(-l) * exp(-i(k+G)*R) * PF_l(k+G) 
          call exp_KpG_dot_R(centre(1,ip),ik,kgp,iba,nbase,ngabc,vkxyz,zcos,zsin)

          if(ipri>=2) write(*,'("debug 4")')
          projfunc(1:kg1,ip,ikpj,1:kimg) = 0.d0
          do il1=lmin_pj(ip)+1,lmax_pj(ip)+1
             if(il1==1) then
                if(kimg==1) then
                   do i=1,iba(ik)
                      projfunc(i,ip,ikpj,1) = projfunc(i,ip,ikpj,1) + pf(i,il1) * zcos(i)
                   end do
                else
                   do i=1,iba(ik)
                      projfunc(i,ip,ikpj,1) = projfunc(i,ip,ikpj,1) + pf(i,il1) * zcos(i)
                      projfunc(i,ip,ikpj,2) = projfunc(i,ip,ikpj,2) - pf(i,il1) * zsin(i)
                   end do
                end if
             else if(il1==2) then
                if(kimg==1) then
                   do i=1,iba(ik)
                      projfunc(i,ip,ikpj,1) = projfunc(i,ip,ikpj,1) - pf(i,il1) * zsin(i)
                   end do
                else
                   do i=1,iba(ik)
                      projfunc(i,ip,ikpj,1) = projfunc(i,ip,ikpj,1) - pf(i,il1) * zsin(i)
                      projfunc(i,ip,ikpj,2) = projfunc(i,ip,ikpj,2) - pf(i,il1) * zcos(i)
                   end do
                end if
             else if(il1==3) then
                if(kimg==1) then
                   do i=1,iba(ik)
                      projfunc(i,ip,ikpj,1) = projfunc(i,ip,ikpj,1) - pf(i,il1) * zcos(i)
                   end do
                else
                   do i=1,iba(ik)
                      projfunc(i,ip,ikpj,1) = projfunc(i,ip,ikpj,1) - pf(i,il1) * zcos(i)
                      projfunc(i,ip,ikpj,2) = projfunc(i,ip,ikpj,2) + pf(i,il1) * zsin(i)
                   end do
                end if
             else if(il1==4) then
                if(kimg==1) then
                   do i=1,iba(ik)
                      projfunc(i,ip,ikpj,1) = projfunc(i,ip,ikpj,1) + pf(i,il1) * zsin(i)
                   end do
                else
                   do i=1,iba(ik)
                      projfunc(i,ip,ikpj,1) = projfunc(i,ip,ikpj,1) + pf(i,il1) * zsin(i)
                      projfunc(i,ip,ikpj,2) = projfunc(i,ip,ikpj,2) + pf(i,il1) * zcos(i)
                   end do
                end if
             end if
          end do
          if(ipri>=2) write(*,'("debug 5")')
! Normalization
          dnorm = 0.d0
          do i=1,iba(ik)
             dnorm = dnorm + projfunc(i,ip,ikpj,1)**2 + projfunc(i,ip,ikpj,2)**2
          end do
! Debug
!!          write(nfout,'("ip=",i5," ikpj=",i5," dnorm=",f25.10," centre=",3(1x,f10.5))') &
!!           & ip,ikpj,dnorm,centre(1:3,ip)
! End Debug
          dnorm = 1.d0/sqrt(dnorm)
          if(kimg==1) then
             do i=1,iba(ik)
                projfunc(i,ip,ikpj,1) = projfunc(i,ip,ikpj,1) * dnorm
             end do
          else
             do i=1,iba(ik)
                projfunc(i,ip,ikpj,1) = projfunc(i,ip,ikpj,1) * dnorm
                projfunc(i,ip,ikpj,2) = projfunc(i,ip,ikpj,2) * dnorm
             end do
          end if
       end do
    end do

    deallocate(lmin_pj)
    deallocate(lmax_pj)
    deallocate(qx,qy,qz,vlength,wka,wkb)
    deallocate(pf,angf)
    deallocate(zcos,zsin)
    deallocate(radr,wos,radfunc)

    call tstatc0_end(id_sname)
  end subroutine m_Wan90_Projectors

  subroutine radr_and_wos(nm,radr,wos)
    integer, intent(in) :: nm
    real(kind=DP), intent(out) :: radr(nm), wos(nm)

    real(kind=DP), parameter :: xh = 96.d0
    real(kind=DP), parameter :: rmax = 60.d0
    real(kind=DP) :: hn

    call rmeshs(nm,nm,xh,rmax,radr,hn) ! -(b_PP)
    call coef_simpson_integration(nm,nm,xh,radr,wos) ! -(b_PP)
  end subroutine radr_and_wos

  subroutine spherical_function(l,m,nq,qx,qy,qz,sphfunc)
    integer, intent(in) :: l,m,nq
    real(kind=DP), intent(in) :: qx(nq), qy(nq), qz(nq)
    real(kind=DP), intent(out) :: sphfunc(nq)

    if(l==0.and.m==1) then ! s
       call sphr(nq,1,qx,qy,qz,sphfunc)        ! -(bottom_Subr.)
    else if(l==1.and.m==1) then ! pz
       call sphr(nq,4,qx,qy,qz,sphfunc)        ! -(bottom_Subr.)
    else if(l==1.and.m==2) then ! px
       call sphr(nq,2,qx,qy,qz,sphfunc)        ! -(bottom_Subr.)
    else if(l==1.and.m==3) then ! py
       call sphr(nq,3,qx,qy,qz,sphfunc)        ! -(bottom_Subr.)
    else if(l==2.and.m==1) then ! dz2
       call sphr(nq,5,qx,qy,qz,sphfunc)        ! -(bottom_Subr.)
    else if(l==2.and.m==2) then ! dxz
       call sphr(nq,9,qx,qy,qz,sphfunc)        ! -(bottom_Subr.)
    else if(l==2.and.m==3) then ! dyz
       call sphr(nq,8,qx,qy,qz,sphfunc)        ! -(bottom_Subr.)
    else if(l==2.and.m==4) then ! dx2-y2
       call sphr(nq,6,qx,qy,qz,sphfunc)        ! -(bottom_Subr.)
    else if(l==2.and.m==5) then ! dxy
       call sphr(nq,7,qx,qy,qz,sphfunc)        ! -(bottom_Subr.)
    else if(l==3.and.m==1) then ! fz3
       call sphr(nq,10,qx,qy,qz,sphfunc)        ! -(bottom_Subr.)
    else if(l==3.and.m==2) then ! fxz2
       call sphr(nq,11,qx,qy,qz,sphfunc)        ! -(bottom_Subr.)
    else if(l==3.and.m==3) then ! fyz2
       call sphr(nq,12,qx,qy,qz,sphfunc)        ! -(bottom_Subr.)
    else if(l==3.and.m==4) then ! fz(x2-y2)
       call sphr(nq,13,qx,qy,qz,sphfunc)        ! -(bottom_Subr.)
    else if(l==3.and.m==5) then ! fxyz
       call sphr(nq,14,qx,qy,qz,sphfunc)        ! -(bottom_Subr.)
    else if(l==3.and.m==6) then ! fx(x2-3y2)
       call sphr(nq,15,qx,qy,qz,sphfunc)        ! -(bottom_Subr.)
    else if(l==3.and.m==7) then ! fy(3x2-y2)
       call sphr(nq,16,qx,qy,qz,sphfunc)        ! -(bottom_Subr.)
    end if
  end subroutine spherical_function

  subroutine angular_function(l,m,nq,qx,qy,qz,angfunc)
    integer, intent(in) :: l,m,nq
    real(kind=DP), intent(in) :: qx(nq), qy(nq), qz(nq)
    real(kind=DP), intent(out) :: angfunc(kg1,lmax1_pj)

    real(kind=DP), allocatable :: sphfunc(:)

    allocate(sphfunc(nq))
  
    if(l>=0) then
      call spherical_function(l,m,nq,qx,qy,qz,sphfunc)
      angfunc(1:nq,1) = sphfunc(1:nq)
    else if(l==-1.and.m==1) then
      call spherical_function(0,1,nq,qx,qy,qz,sphfunc) ! s
      angfunc(1:nq,1) = sphfunc(1:nq)/sqrt(2.d0)
      call spherical_function(1,2,nq,qx,qy,qz,sphfunc) ! px
      angfunc(1:nq,2) = sphfunc(1:nq)/sqrt(2.d0)
    else if(l==-1.and.m==2) then
      call spherical_function(0,1,nq,qx,qy,qz,sphfunc) ! s
      angfunc(1:nq,1) = sphfunc(1:nq)/sqrt(2.d0)
      call spherical_function(1,2,nq,qx,qy,qz,sphfunc) ! px
      angfunc(1:nq,2) = -sphfunc(1:nq)/sqrt(2.d0)
    else if(l==-2.and.m==1) then
      call spherical_function(0,1,nq,qx,qy,qz,sphfunc) ! s
      angfunc(1:nq,1) = sphfunc(1:nq)/sqrt(3.d0)
      call spherical_function(1,2,nq,qx,qy,qz,sphfunc) ! px
      angfunc(1:nq,2) = -sphfunc(1:nq)/sqrt(6.d0)
      call spherical_function(1,3,nq,qx,qy,qz,sphfunc) ! py
      angfunc(1:nq,2) = angfunc(1:nq,2) + sphfunc(1:nq)/sqrt(2.d0)
    else if(l==-2.and.m==2) then
      call spherical_function(0,1,nq,qx,qy,qz,sphfunc) ! s
      angfunc(1:nq,1) = sphfunc(1:nq)/sqrt(3.d0)
      call spherical_function(1,2,nq,qx,qy,qz,sphfunc) ! px
      angfunc(1:nq,2) = -sphfunc(1:nq)/sqrt(6.d0)
      call spherical_function(1,3,nq,qx,qy,qz,sphfunc) ! py
      angfunc(1:nq,2) = angfunc(1:nq,2) - sphfunc(1:nq)/sqrt(2.d0)
    else if(l==-2.and.m==3) then
      call spherical_function(0,1,nq,qx,qy,qz,sphfunc) ! s
      angfunc(1:nq,1) = sphfunc(1:nq)/sqrt(3.d0)
      call spherical_function(1,2,nq,qx,qy,qz,sphfunc) ! px
      angfunc(1:nq,2) = sphfunc(1:nq)*(2.d0/sqrt(6.d0))
    else if(l==-3.and.m==1) then
      call spherical_function(0,1,nq,qx,qy,qz,sphfunc) ! s
      angfunc(1:nq,1) = sphfunc(1:nq)*0.5d0
      call spherical_function(1,2,nq,qx,qy,qz,sphfunc) ! px
      angfunc(1:nq,2) = sphfunc(1:nq)*0.5d0
      call spherical_function(1,3,nq,qx,qy,qz,sphfunc) ! py
      angfunc(1:nq,2) = angfunc(1:nq,2) + sphfunc(1:nq)*0.5d0
      call spherical_function(1,1,nq,qx,qy,qz,sphfunc) ! pz
      angfunc(1:nq,2) = angfunc(1:nq,2) + sphfunc(1:nq)*0.5d0
    else if(l==-3.and.m==2) then
      call spherical_function(0,1,nq,qx,qy,qz,sphfunc) ! s
      angfunc(1:nq,1) = sphfunc(1:nq)*0.5d0
      call spherical_function(1,2,nq,qx,qy,qz,sphfunc) ! px
      angfunc(1:nq,2) = sphfunc(1:nq)*0.5d0
      call spherical_function(1,3,nq,qx,qy,qz,sphfunc) ! py
      angfunc(1:nq,2) = angfunc(1:nq,2) - sphfunc(1:nq)*0.5d0
      call spherical_function(1,1,nq,qx,qy,qz,sphfunc) ! pz
      angfunc(1:nq,2) = angfunc(1:nq,2) - sphfunc(1:nq)*0.5d0
    else if(l==-3.and.m==3) then
      call spherical_function(0,1,nq,qx,qy,qz,sphfunc) ! s
      angfunc(1:nq,1) = sphfunc(1:nq)*0.5d0
      call spherical_function(1,2,nq,qx,qy,qz,sphfunc) ! px
      angfunc(1:nq,2) = -sphfunc(1:nq)*0.5d0
      call spherical_function(1,3,nq,qx,qy,qz,sphfunc) ! py
      angfunc(1:nq,2) = angfunc(1:nq,2) + sphfunc(1:nq)*0.5d0
      call spherical_function(1,1,nq,qx,qy,qz,sphfunc) ! pz
      angfunc(1:nq,2) = angfunc(1:nq,2) - sphfunc(1:nq)*0.5d0
    else if(l==-3.and.m==4) then
      call spherical_function(0,1,nq,qx,qy,qz,sphfunc) ! s
      angfunc(1:nq,1) = sphfunc(1:nq)*0.5d0
      call spherical_function(1,2,nq,qx,qy,qz,sphfunc) ! px
      angfunc(1:nq,2) = -sphfunc(1:nq)*0.5d0
      call spherical_function(1,3,nq,qx,qy,qz,sphfunc) ! py
      angfunc(1:nq,2) = angfunc(1:nq,2) - sphfunc(1:nq)*0.5d0
      call spherical_function(1,1,nq,qx,qy,qz,sphfunc) ! pz
      angfunc(1:nq,2) = angfunc(1:nq,2) + sphfunc(1:nq)*0.5d0
    else if(l<=-4) then
      stop 'Not supported for l<=-4'
    end if
    
    deallocate(sphfunc)
  end subroutine angular_function

  subroutine radial_function(irf,zona,nm,radr,radfunc)
    implicit none
    integer, intent(in) :: irf, nm
    real(kind=DP) :: zona
    real(kind=DP) :: radr(nm), radfunc(nm)

    integer :: i
    real(kind=DP), parameter :: bohr = 0.5291772480d0 ! Angstrom
    real(kind=DP) :: fac, ar, alph

    alph = zona * bohr
   
    if(irf==1) then
      fac = 2.d0 * alph**(3.d0/2.d0)
      do i=1,nm
         radfunc(i) = fac * exp(-alph*radr(i))
      end do
    else if(irf==2) then
      fac = alph**(3.d0/2.d0) / sqrt(8.d0)
      do i=1,nm
         ar = alph * radr(i)
         radfunc(i) = fac * ( 2.d0 - ar ) * exp(-0.5d0*ar)
      end do
    else if(irf==3) then
      fac = alph**(3.d0/2.d0) * sqrt(4.d0/27.d0)
      do i=1,nm
         ar = alph * radr(i)
         radfunc(i) = fac * ( 1.d0 - (2.d0/3.d0)*ar + (2.d0/27.d0)*ar*ar ) * exp(-ar/3.d0)
      end do
    end if
  end subroutine radial_function

  subroutine set_lmin_lmax(l,lmin_pj,lmax_pj)
    implicit none
    integer, intent(in) :: l
    integer, intent(out) :: lmin_pj, lmax_pj

    if(l>=0) then
       lmin_pj = 0 
       lmax_pj = 0 
    else if(l==-1.or.l==-2.or.l==-3) then
       lmin_pj = 0 
       lmax_pj = 1
    else if(l<=-4) then
       stop 'Not supported for l<=-4'
    end if
  end subroutine set_lmin_lmax

  subroutine exp_KpG_dot_R(centre,ik,kgp,iba,nbase,ngabc,vkxyz,zcos,zsin)
    implicit none
    real(kind=DP), intent(in) :: centre(3)
    integer, intent(in) :: ik, kgp
    integer, intent(in) :: iba(kv3), nbase(kg1,kv3), ngabc(kgp,3)
    real(kind=DP), intent(in) :: vkxyz(kv3,3,CRDTYP)
    real(kind=DP), intent(out) :: zcos(kg1), zsin(kg1)

    integer :: i
    real(kind=DP) :: ph

    do i=1,iba(ik)
       ph = PAI2 * dot_product(centre,vkxyz(ik,1:3,BUCS)+ngabc(nbase(i,ik),1:3))
       zcos(i) = cos(ph)
       zsin(i) = sin(ph)
    end do
  end subroutine exp_KpG_dot_R

  subroutine m_Wan90_wd_eig(nfout)
    implicit none
    integer, intent(in) :: nfout

    integer :: ik,m,ib,mi
    real(kind=DP) :: ek
    real(kind=DP), parameter :: hartree = 27.21139615d0 ! eV

    !!$if(npes /= 1) return
    if(mype==0) then
      open(nfwannier,file=trim(wan90_seedname)//".eig",form="formatted")
      do ik = 1, kv3
         do m=1,num_bands
            mi = ib_inc(m)
            ib = neordr(mi,ik)
            ek = eko_l(map_z(ib),ik) * hartree
            write(nfwannier,'(2(1x,i5),1x,e25.16)') m, ik, ek
         end do
      end do
      close(nfwannier)
   end if
  end subroutine m_Wan90_wd_eig

  subroutine m_Wan90_gen_mmat(nfout)
    implicit none
    integer, intent(in) :: nfout

    integer :: i,ik1,ik2,nn,n,m,ni,mi,ib1,ib2,nnc(3),ir,im
    integer :: nffth,ip,r1,r2,r3,ngrid
    integer :: id1,id2,id12
    real(kind=DP) :: rgrid(3),da(3),wr,wi,dvol,ph,sumr,sumi
    real(kind=DP), allocatable :: wf1(:), wf2(:), zcos(:), zsin(:)
    real(kind=DP), allocatable :: m_mat(:,:,:,:,:) ! d(neg,neg,kv3,nntot,2)
    real(kind=DP), allocatable :: m_mat_mpi(:,:,:,:,:) ! d(neg,neg,kv3,nntot,2)
    logical :: shift_k

    nffth = nfft/2
    ngrid = product(fft_box_size_WF(1:3,1))
    dvol = 1.d0/dble(ngrid)

    allocate(m_mat(neg,neg,kv3,nntot,2)); m_mat = 0.d0
    allocate(wf1(nfft),wf2(nfft))
    allocate(zcos(nffth),zsin(nffth))

    call m_FFT_alloc_WF_work()

    do nn=1,nntot
       do ik1=1,kv3
          ik2 = nnlist(ik1,nn)
          nnc(1:3) = nncell(1:3,ik1,nn)
          if(all(nnc(1:3) == 0)) then
             shift_k = .false.
          else
             shift_k = .true.
             ! Phase factors, exp(iG*r)
             id1 = fft_box_size_WF(1,0)
             id2 = fft_box_size_WF(2,0)
             id12 = id1*id2
             da(1:3) = 1.d0/fft_box_size_WF(1:3,1)
             do i=1,nffth
                ip = i-1
                r3 = ip/id12
                r2 = (ip-r3*id12)/id1
                r1 = ip-r2*id1-r3*id12
                rgrid(1) = dble(r1)*da(1)
                rgrid(2) = dble(r2)*da(2)
                rgrid(3) = dble(r3)*da(3)
                ph = PAI2*dot_product(nnc,rgrid)
                zcos(i) = cos(ph)
                zsin(i) = sin(ph)
             end do
          end if
          do n=1,num_bands
             ni = ib_inc(n)
             ib2 = neordr(ni,ik2)
             call m_ES_WF_in_Rspace(ik2,ib2,wf2) !-> u_n(k+b) or u_n(k')
             if(shift_k) then
                ! u_n(k+b) = exp(-iG*r) * u_n(k')
                ! k+b = k' + G
                do i=1,nffth
                   ir = 2*i-1
                   im = ir+1
                   wr = wf2(ir)
                   wi = wf2(im)
                   wf2(ir) = zcos(i)*wr + zsin(i)*wi
                   wf2(im) = zcos(i)*wi - zsin(i)*wr
                end do
             end if
             do m=1,num_bands
                mi = ib_inc(m)
                ib1 = neordr(mi,ik1)
                call m_ES_WF_in_Rspace(ik1,ib1,wf1) !-> u_m(k)
                ! <u_m(k)|u_n(k+b)>
                sumr = 0.d0
                sumi = 0.d0
                do i=1,nffth
                   ir = 2*i-1
                   im = ir+1
                   sumr = sumr + wf1(ir)*wf2(ir) + wf1(im)*wf2(im)
                   sumi = sumi + wf1(ir)*wf2(im) - wf1(im)*wf2(ir)
                end do
                m_mat(m,n,ik1,nn,1) = sumr * dvol
                m_mat(m,n,ik1,nn,2) = sumi * dvol
             end do
          end do
       end do
    end do

    call m_FFT_dealloc_WF_work()

    deallocate(wf1,wf2)
    deallocate(zcos,zsin)

    if(npes>1) then
       allocate(m_mat_mpi(num_bands,num_bands,kv3,nntot,2))
       m_mat_mpi = m_mat
       m_mat = 0.d0
       call mpi_allreduce(m_mat_mpi,m_mat,num_bands*num_bands*kv3*nntot*2,mpi_double_precision,mpi_sum,mpi_comm_group,ierr)
       deallocate(m_mat_mpi)
    end if

    if(mype==0) then
       open(nfwannier,file=trim(wan90_seedname)//".mmn",form="formatted")
       write(nfwannier,*) 'Generated by PHASE'
       write(nfwannier,*) num_bands, kv3, nntot
       do ik1 = 1, kv3
          do nn=1,nntot
             ik2 = nnlist(ik1,nn)
             nnc(1:3) = nncell(1:3,ik1,nn)
             write(nfwannier,'(5(1x,i5))') ik1,ik2,nnc(1:3)
             do n=1,num_bands
                do m=1,num_bands
                   write(nfwannier,'(2(1x,f16.12))') m_mat(m,n,ik1,nn,1:2)
                end do
             end do
          end do
       end do
       close(nfwannier)
    end if

    deallocate(m_mat)
  end subroutine m_Wan90_gen_mmat

  subroutine m_Wan90_write_unk(nfout)
    implicit none
    integer, intent(in) :: nfout

    integer :: ispin, iksnl, ik, n, nffth, ngrid, ni, ib
    integer :: id1, id2, id3, id12, ip
    integer :: nd1, nd2, nd3
    integer :: i1, i2, i3
    character(len=10) :: wfnname
    real(kind=DP), allocatable :: wf(:)
    complex(kind=DP), allocatable :: zwf(:,:,:)

    id1 = fft_box_size_WF(1,0)
    id2 = fft_box_size_WF(2,0)
    id3 = fft_box_size_WF(3,0)
    id12 = id1 * id2

    nd1 = fft_box_size_WF(1,1)
    nd2 = fft_box_size_WF(2,1)
    nd3 = fft_box_size_WF(3,1)

    allocate(wf(nfft))
    allocate(zwf(id1,id2,id3))

    if(mype==0) then
       do ispin = 1, nspin
          do iksnl = 1, kv3/nspin
             ik = nspin*(iksnl-1) + ispin
             write(wfnname,'("UNK",i5.5,".",i1)') iksnl,ispin
             open(nfwannier,file=trim(wfnname),form="unformatted")
             write(nfwannier) fft_box_size_WF(1:3,1), iksnl, num_bands
             do n=1,num_bands
                ni = ib_inc(n)
                ib = neordr(ni,ik)
                call m_ES_WF_in_Rspace(ik,ib,wf)
                do i3=1,nd3 
                   do i2=1,nd2
                      do i1=1,nd1
                         ip = i1 + (i2-1)*id1 + (i3-1) * id12
                         zwf(i1,i2,i3) = cmplx(wf(2*ip-1),wf(2*ip))
                      end do
                   end do
                end do
                write(nfwannier) zwf
             end do
             close(nfwannier)
          end do
       end do
    end if

    deallocate(wf)
    deallocate(zwf)

  end subroutine m_Wan90_write_unk

  logical function excluding_band(m)
    implicit none
    integer, intent(in) :: m
    
    integer :: i

    do i=1,n_exclude
       if(exclude_bands(i)==m) then
          excluding_band = .true.
          return
       end if
    end do
    excluding_band = .false.
  end function excluding_band

end module m_Wannier90
