module m_SpinOrbit_ForceTheorem
  use m_Const_Parameters,  only : DP, CMPLDP, PARABOLIC, TETRAHEDRON, COLD, &
       &                          FERMI_DIRAC, GAMMA, &
       &                          Neglected, ByPawPot, ByProjector, ZeffApprox, &
       &                          ReadFromPP, OFF, ON
  use m_Control_Parameters,  only : neg, nspin, m_CtrlP_way_of_smearing, &
       &                            ndim_spinor, ndim_chgpot, kimg, ekmode, &
       &                            SpinOrbit_mode
  use m_Files,              only : nfout
  use m_Kpoints,          only : kv3, k_symmetry
  use m_Electronic_Structure,  only : eko_l, fsr_l, fsi_l, occup_l
  use m_Ionic_System,       only : ityp, natm
  use m_PseudoPotential,   only : nlmt, ltp, taup, nlmta, ilmt, lmta

  use m_Parallelization,  only : map_k, map_e, map_z, myrank_k, myrank_e, mpi_k_world, &
       &                         ista_e, iend_e, istep_e, np_e, npes, ierr, &
       &                         mpi_comm_group, mype, ista_k, iend_k

  use m_ES_Occup,            only : m_ESoc_fermi_parabolic, m_ESoc_fermi_tetrahedron, &
       &                            m_ESoc_fermi_ColdSmearing, m_ESoc_fermi_Dirac

  use m_SpinOrbit_Potential,     only : m_SO_alloc_dsoc, dsoc, &
       &                                m_SO_set_Dsoc_potential1, &
       &                                m_SO_calc_MatLS_orb_s_to_f, &
       &                                m_SO_diagonalize_MatLS, &
       &                                m_SO_set_MatU_ylm_RC, &
       &                                m_SO_alloc_Mat_SOC_strenth, &
       &                                m_SO_set_Dsoc_potential2, &
       &                                m_SO_dealloc_Mat_SOC_strenth, &
       &                                m_SO_dealloc_dsoc
  use m_SpinOrbit_RadInt,       only : m_SO_calc_SOC_strength_pawpot, &
       &                               m_SO_calc_SOC_strength_zeff, &
       &                               m_SO_check_mode_Builtin, &
       &                               m_SO_check_mode_Pawpot, &
       &                               m_SO_check_mode_Zeff
  use m_SpinOrbit_FromFile,     only : m_SO_set_SOC_strength_from_PP

  implicit none
  include 'mpif.h'
!
  integer :: neg_doubled
  integer :: ndim_spinor_prev, ndim_chgpot_prev
!
contains
  
  subroutine m_SO_init_Force_Theorem
    ndim_spinor_prev = ndim_spinor;   ndim_chgpot_prev = ndim_chgpot
    ndim_spinor = 2;                  ndim_chgpot = ndim_spinor **2

    call m_SO_check_mode_Pawpot
    call m_SO_check_mode_Zeff
    
    if ( SpinOrbit_Mode /= Neglected ) then
       call m_SO_set_MatU_ylm_RC
       call m_SO_calc_MatLS_orb_s_to_f
       call m_SO_diagonalize_MatLS
    endif
    if ( SpinOrbit_Mode == ByPawPot ) then
       call m_SO_alloc_Mat_SOC_strenth
       call m_SO_calc_SOC_strength_pawpot
       call m_SO_alloc_Dsoc
       call m_SO_set_Dsoc_potential2
    endif
    if ( SpinOrbit_Mode == ZeffApprox ) then
       call m_SO_alloc_Mat_SOC_strenth
       call m_SO_calc_SOC_strength_zeff
       call m_SO_alloc_Dsoc
       call m_SO_set_Dsoc_potential2
    endif
    if ( SpinOrbit_Mode == ReadFromPP ) then
       call m_SO_alloc_Mat_SOC_strenth
       call m_SO_set_SOC_strength_from_PP
       call m_SO_alloc_Dsoc
       call m_SO_set_Dsoc_potential2
    endif
    if ( SpinOrbit_Mode == ByProjector ) then
       call m_SO_alloc_Dsoc
       call m_SO_set_Dsoc_potential1
    endif
    
#ifdef USE_ASMS_SPINORBIT
    call ASMS_SpinOrbit_setup( nfout, nspin, neg, kv3, &
         &                     mype, npes, mpi_comm_group, &
         &                     np_e, ista_e, iend_e, istep_e, &
         &                     ista_k, iend_k, myrank_k )
#endif

  end subroutine m_SO_init_Force_Theorem

  subroutine m_SO_finalize_Force_Theorem
    ndim_spinor = ndim_spinor_prev;       ndim_chgpot = ndim_chgpot_prev
    
    if ( SpinOrbit_Mode == ByPawPot .or. SpinOrbit_Mode == ZeffApprox &
         &                          .or. SpinOrbit_Mode == ReadFromPP ) then
       call m_SO_dealloc_Mat_SOC_strenth
       call m_SO_dealloc_Dsoc
    endif
    if ( SpinOrbit_Mode == ByProjector ) then
       call m_SO_dealloc_Dsoc
    endif
  end subroutine m_SO_finalize_Force_Theorem
   
  subroutine m_SO_calc_band_energy_fth
    integer :: neg_doubled
    integer :: ik
    complex(kind=CMPLDP), allocatable :: MatH(:,:)

    if ( SpinOrbit_mode == Neglected ) return

    neg_doubled = neg *ndim_spinor
    allocate( MatH( neg_doubled, neg_doubled ) ); MatH = 0.0d0
    
    Do ik=1, kv3, nspin
       if ( map_k(ik) /= myrank_k ) cycle
       
       MatH = 0.0d0
       call set_contrib_from_dsoc
       call add_diagonal_elements
#ifdef USE_ASMS_SPINORBIT
       call ASMS_SpinOrbit_renew_level_FTH( ik, ndim_spinor, neg_doubled, MatH, &
            &                               map_z, eko_l )
#endif
    End Do
    deallocate( MatH )
!
    if ( ekmode == OFF ) then
       call FermiEnergyLevel
#ifdef USE_ASMS_SPINORBIT
       call ASMS_SpinOrbit_bandenergy_FTH( map_k, occup_l, eko_l )
#endif
    endif

  contains

    subroutine add_diagonal_elements
      integer :: is, ib1, itmp
      real(kind=DP), allocatable :: eko_wk(:,:), eko_mpi(:,:)

      allocate( eko_wk( neg, ndim_spinor ) ); eko_wk = 0.0d0
      Do is=1, ndim_spinor
         itmp = 1
         if ( nspin == 2 ) itmp = is
         
         Do ib1=ista_e, iend_e, istep_e
            eko_wk( ib1, is ) = eko_l( map_z(ib1),ik +itmp-1 )
         End do
      End Do
      
      if ( npes > 1 ) then
         allocate( eko_mpi( neg, ndim_spinor ) ); eko_mpi = 0.0d0
         call mpi_allreduce( eko_wk, eko_mpi, neg*ndim_spinor, &
              &              mpi_double_precision, mpi_sum, &
              &              mpi_k_world(myrank_k), ierr )
         eko_wk = eko_mpi
         deallocate( eko_mpi )
      endif
!
      Do is=1, ndim_spinor
         Do ib1=1, neg
            itmp = ( is -1 )*neg +ib1
            MatH( itmp,itmp ) = MatH( itmp,itmp ) +eko_wk( ib1,is )
         End do
      End Do
      deallocate( eko_wk )

    end subroutine add_diagonal_elements

    subroutine set_contrib_from_dsoc
      integer :: ib1, ib2, is1, is2, itmp1, itmp2
      integer :: ia, it, istmp, jb1, jb2
      integer :: lmt1, lmt2, il1, il2, it1, it2, lmta1, lmta2
      complex(kind=CMPLDP) :: z1, wf1, wf2
      complex(kind=CMPLDP), allocatable :: wk_fsri(:,:), MatH_mpi(:,:)

      allocate( wk_fsri( nlmta, ndim_spinor ) ); wk_fsri = 0.0d0

      Do ib2=1, neg
         wk_fsri = 0.0d0
         
         if ( map_e(ib2) == myrank_e ) then
            Do is1=1, ndim_spinor
               itmp1 = 1
               if ( nspin == 2 ) itmp1 = is1
                
               if ( kv3/nspin == 1 .and. k_symmetry(1) == GAMMA .and. kimg == 2 ) then
                  wk_fsri(:,is1) = fsr_l( map_z(ib2),:,ik+itmp1-1 )
               else
                  wk_fsri(:,is1) = dcmplx( fsr_l( map_z(ib2),:,ik+itmp1-1 ), &
                       &                   fsi_l( map_z(ib2),:,ik+itmp1-1 ) )
               endif
            End do
         endif
         call mpi_bcast( wk_fsri, 2*nlmta*ndim_spinor, mpi_double_precision, &
              &          map_e(ib2), mpi_k_world(myrank_k), ierr )
         !
         DO ib1=ista_e, iend_e, istep_e
            
            Do ia=1, natm
               it = ityp(ia)
               
               Do lmt1=1, ilmt(it)
                  il1 = ltp(lmt1,it); it1 = taup(lmt1,it)
                  lmta1 = lmta( lmt1,ia )
                  
                  Do lmt2=1, ilmt(it)
                     il2 = ltp(lmt2,it); it2 = taup(lmt2,it)
                     lmta2 = lmta( lmt2,ia )
                     
                     Do is1=1, ndim_spinor
                        itmp1 = 1
                        if ( nspin == 2 ) itmp1 = is1
                        
                        Do is2=1, ndim_spinor
                           itmp2 = 1
                           if ( nspin == 2 ) itmp2 = is2
                           
                           if ( kv3/nspin == 1 .and. k_symmetry(1) == GAMMA &
                                &              .and. kimg == 2 ) then
                              wf1 = fsr_l( map_z(ib1), lmta1, ik +itmp1 -1 )
                           else
                              wf1 = dcmplx( fsr_l( map_z(ib1), lmta1, ik +itmp1 -1 ), &
                                   &        fsi_l( map_z(ib1), lmta1, ik +itmp1 -1 ) )
                           endif
                           wf2 = wk_fsri( lmta2, is2 )
                           
                           istmp = ( is1 -1 )*ndim_spinor +is2
                           z1 = dconjg(wf1) *dsoc( lmt1, lmt2, ia, istmp ) *wf2
                           
                           jb1 = neg *( is1 -1 )+ib1
                           jb2 = neg *( is2 -1 )+ib2
                           MatH(jb1,jb2) = MatH (jb1,jb2) +z1
                        End Do
                     End do
                  End Do
                  
               End Do
            End do
         End DO
      End Do
!
      if ( npes > 1 ) then
         allocate( MatH_mpi( neg_doubled, neg_doubled ) ); MatH_mpi = 0.0d0
         call mpi_allreduce( MatH, MatH_mpi, 2*neg_doubled**2, &
               &             mpi_double_precision, mpi_sum, &
               &             mpi_k_world(myrank_k), ierr )
         MatH = MatH_mpi
         deallocate( MatH_mpi )
      endif
      deallocate( wk_fsri )

    end subroutine set_contrib_from_dsoc

  end subroutine m_SO_calc_band_energy_fth

  subroutine FermiEnergyLevel()
    integer :: way_of_smearing

    way_of_smearing = m_CtrlP_way_of_smearing()
    if(way_of_smearing == PARABOLIC) then
       call m_ESoc_fermi_parabolic(nfout)
    else if(way_of_smearing == TETRAHEDRON) then
       stop "Not supported"
       call m_ESoc_fermi_tetrahedron(nfout)
    else if(way_of_smearing == COLD) then
       call m_ESoc_fermi_ColdSmearing(nfout)
! ================ KT_add ========================= 13.0E
    else if(way_of_smearing == FERMI_DIRAC) then
       call m_ESoc_fermi_Dirac(nfout)
! ================================================= 13.0E
    end if
  end subroutine FermiEnergyLevel

end module m_SpinOrbit_ForceTheorem
