!> Radiation coupling driver.
!!
!! This module owns MPI/data plumbing for radiation so model developers do not
!! need to touch MPI.  Physics kernels live under `src/radiation/models/` and
!! operate on `mesh`, `radiation_context_t`, `radiation_state_t`, and
!! `radiation_source_t` only.
module mod_radiation
   use mpi_f08
   use mod_precision, only : rk, zero, tiny_safe, name_len, output_unit, fatal_error, lowercase
   use mod_mesh, only : mesh_t
   use mod_input, only : case_params_t, max_species
   use mod_mpi_flow, only : flow_mpi_t, flow_allgather_owned_scalar, flow_allgather_owned_matrix
   use mod_mpi_radiation, only : radiation_mpi_t
   use mod_energy, only : energy_fields_t
   use mod_profiling, only : profiler_start, profiler_stop
   use mod_species_registry, only : species_index_of
   use mod_bc, only : bc_set_t
   use mod_radiation_types, only : radiation_context_t, radiation_state_t, radiation_source_t, &
                                   allocate_radiation_state, free_radiation_state, &
                                   allocate_radiation_source, free_radiation_source
   use mod_radiation_none, only : radiation_compute_none
   use mod_radiation_spectral_test, only : radiation_compute_spectral_test
   use mod_radiation_external, only : radiation_compute_external
   use mod_radiation_p1, only : radiation_compute_p1
   use mod_radiation_dom, only : radiation_compute_dom
   implicit none

   private
   public :: radiation_context_t
   public :: radiation_initialize, radiation_update_source, radiation_finalize

contains

   subroutine radiation_initialize(mesh, rad, params, context)
      type(mesh_t), intent(in) :: mesh
      type(radiation_mpi_t), intent(in) :: rad
      type(case_params_t), intent(in) :: params
      type(radiation_context_t), intent(inout) :: context
      integer :: j, idx

      call radiation_finalize(context)

      context%enabled = params%enable_radiation
      context%model = trim(lowercase(params%radiation_source_model))
      context%rad_rank = rad%rank
      context%rad_size = rad%nprocs
      context%n_wavenumbers = params%radiation_n_wavenumbers
      context%wn_first = rad%first_wavenumber
      context%wn_last = rad%last_wavenumber
      context%nlocal_wavenumbers = rad%nlocal_wavenumbers
      context%ncells = mesh%ncells
      context%nfaces = mesh%nfaces
      context%pressure_source = trim(lowercase(params%radiation_pressure_source))
      context%mesh_cached = params%enable_radiation
      context%debug = params%radiation_debug
      context%write_diagnostics = params%write_radiation_diagnostics
      context%setup_file = trim(params%output_dir)//'/radiation_setup.txt'
      context%diagnostics_file = trim(params%output_dir)//'/diagnostics/radiation_diagnostics.csv'

      if (.not. context%enabled) return

      context%n_species = params%radiation_n_species
      if (context%n_species > 0) then
         allocate(context%species_name(context%n_species), context%species_index(context%n_species))
         do j = 1, context%n_species
            context%species_name(j) = trim(params%radiation_species_name(j))
            if (len_trim(context%species_name(j)) == 0) then
               call fatal_error('radiation', 'blank radiation_species_name entry')
            end if
            idx = species_index_of(params%species_name, params%nspecies, context%species_name(j))
            if (idx <= 0) then
               call fatal_error('radiation', 'radiation species "'//trim(context%species_name(j))//'" is not in the active species list')
            end if
            context%species_index(j) = idx
         end do
      end if

      context%n_scalars = params%radiation_n_scalars
      if (context%n_scalars > 0) then
         allocate(context%scalar_name(context%n_scalars))
         do j = 1, context%n_scalars
            context%scalar_name(j) = trim(params%radiation_scalar_name(j))
            if (len_trim(context%scalar_name(j)) == 0) then
               call fatal_error('radiation', 'blank radiation_scalar_name entry')
            end if
         end do
      end if

      select case (trim(context%model))
      case ('none', 'spectral_test', 'external', 'p1', 'dom')
         continue
      case default
         call fatal_error('radiation', 'unknown radiation_source_model: '//trim(context%model))
      end select

      call write_radiation_setup(mesh, rad, params, context)
   end subroutine radiation_initialize


   subroutine radiation_finalize(context)
      type(radiation_context_t), intent(inout) :: context
      if (allocated(context%species_name)) deallocate(context%species_name)
      if (allocated(context%species_index)) deallocate(context%species_index)
      if (allocated(context%scalar_name)) deallocate(context%scalar_name)
      context%enabled = .false.
      context%model = 'none'
      context%rad_rank = -1
      context%rad_size = 0
      context%n_wavenumbers = 0
      context%wn_first = 0
      context%wn_last = -1
      context%nlocal_wavenumbers = 0
      context%ncells = 0
      context%nfaces = 0
      context%n_species = 0
      context%n_scalars = 0
      context%mesh_cached = .false.
      context%debug = .false.
      context%write_diagnostics = .true.
      context%setup_written = .false.
      context%diagnostics_initialized = .false.
      context%setup_file = ''
      context%diagnostics_file = ''
   end subroutine radiation_finalize


    subroutine radiation_update_source(mesh, bc, flow, rad, context, params, energy, species_Y, step, time, dt)
      type(mesh_t), intent(in) :: mesh
      type(bc_set_t), intent(in) :: bc
      type(flow_mpi_t), intent(inout) :: flow
      type(radiation_mpi_t), intent(in) :: rad
      type(radiation_context_t), intent(inout) :: context
      type(case_params_t), intent(in) :: params
      type(energy_fields_t), intent(inout) :: energy
      real(rk), intent(in), optional :: species_Y(:,:)
      integer, intent(in) :: step
      real(rk), intent(in) :: time, dt

      type(radiation_state_t) :: state
      type(radiation_source_t) :: partial, total
      real(rk) :: t_start, t_after_gather, t_after_compute, t_after_reduce
      real(rk) :: gather_time, compute_time, reduce_time
      real(rk) :: debug_T_diff, debug_Y_diff
      real(rk) :: absorption
      real(rk), parameter :: sigma = 5.670374419e-8_rk
      integer :: c, ierr

      if (.not. context%enabled) return
      if (.not. params%enable_radiation) return
      if (.not. allocated(energy%qrad)) return
      if (params%radiation_update_interval <= 0) return
      if (mod(step-1, params%radiation_update_interval) /= 0 .and. step /= 1) return

      call profiler_start('Radiation_Source_Update')
      t_start = real(MPI_Wtime(), rk)

      call allocate_radiation_state(state, mesh%ncells, context%n_species, context%n_scalars)
      call allocate_radiation_source(partial, mesh%ncells)
      call allocate_radiation_source(total, mesh%ncells)

      call profiler_start('Radiation_State_Gather')
      call prepare_radiation_state(mesh, flow, context, params, energy, species_Y, step, time, dt, state, &
                                   debug_T_diff, debug_Y_diff)
      call profiler_stop('Radiation_State_Gather')
      t_after_gather = real(MPI_Wtime(), rk)

      call profiler_start('Radiation_Model_Compute')
      select case (trim(context%model))
      case ('none')
         call radiation_compute_none(mesh, context, state, partial)
      case ('spectral_test')
         call radiation_compute_spectral_test(mesh, context, state, partial, params%radiation_source_scale)
      case ('external')
         call radiation_compute_external(mesh, context, state, partial)
      case ('p1')
         call radiation_compute_p1(mesh, bc, context, state, partial, params)
      case ('dom')
         call radiation_compute_dom(mesh, bc, context, state, partial, params)
      case default
         call fatal_error('radiation', 'unknown radiation_source_model: '//trim(context%model))
      end select
      call profiler_stop('Radiation_Model_Compute')
      t_after_compute = real(MPI_Wtime(), rk)

      ! Retrieve the unreduced partial absorption term (kappa*G or kappa*G_local) from 
      ! each wavenumber spectral segment.
      total%qrad = zero
      call profiler_start('Radiation_Source_Reduce')
      call MPI_Allreduce(partial%qrad, total%qrad, mesh%ncells, MPI_DOUBLE_PRECISION, MPI_SUM, rad%comm, ierr)
      if (ierr /= MPI_SUCCESS) call fatal_error('radiation', 'MPI failure reducing radiation qrad')
      call profiler_stop('Radiation_Source_Reduce')

      ! [X-1 Contract]: In radiation model modules, source%qrad stores only the unreduced
      ! partial absorption contribution kappa*G (or kappa*G_local).
      ! The radiation driver is responsible for performing the MPI reduction above and then 
      ! subtracting the isotropic emission contribution 4*kappa*sigma*T^4 here locally on 
      ! every rank to form the final net radiative gas energy source:
      ! q_rad = kappa*G - 4*kappa*sigma*T^4.
      if (trim(context%model) == 'dom' .or. trim(context%model) == 'p1') then
         absorption = max(params%radiation_absorption_coeff, 1.0e-6_rk)
         do c = 1, mesh%ncells
            total%qrad(c) = total%qrad(c) - absorption * 4.0_rk * sigma * (state%temperature(c)**4)
         end do
      end if

      energy%qrad = total%qrad
      t_after_reduce = real(MPI_Wtime(), rk)

      gather_time = t_after_gather - t_start
      compute_time = t_after_compute - t_after_gather
      reduce_time = t_after_reduce - t_after_compute
      call write_radiation_diagnostics(mesh, rad, context, state, total, step, time, dt, &
                                       gather_time, compute_time, reduce_time, debug_T_diff, debug_Y_diff)

      call free_radiation_source(total)
      call free_radiation_source(partial)
      call free_radiation_state(state)
      call profiler_stop('Radiation_Source_Update')
   end subroutine radiation_update_source


   subroutine prepare_radiation_state(mesh, flow, context, params, energy, species_Y, step, time, dt, state, &
                                      debug_T_diff, debug_Y_diff)
      type(mesh_t), intent(in) :: mesh
      type(flow_mpi_t), intent(inout) :: flow
      type(radiation_context_t), intent(in) :: context
      type(case_params_t), intent(in) :: params
      type(energy_fields_t), intent(in) :: energy
      real(rk), intent(in), optional :: species_Y(:,:)
      integer, intent(in) :: step
      real(rk), intent(in) :: time, dt
      type(radiation_state_t), intent(inout) :: state
      real(rk), intent(out) :: debug_T_diff, debug_Y_diff

      real(rk), allocatable :: selected_local(:,:)
      integer :: s, c

      state%step = step
      state%time = time
      state%dt = dt
      if (params%enable_chemistry_load_balancing) then
         state%temperature = energy%T
      else
         call flow_allgather_owned_scalar(flow, energy%T, state%temperature)
      end if

      ! Populate radiation thermodynamic pressure from the selected source.
      ! Note: in this low-Mach solver fields%p is a projection correction
      ! potential, NOT an absolute thermodynamic pressure.  Both 'background'
      ! and 'system' therefore resolve to params%background_press (the uniform
      ! thermodynamic p0 passed to Cantera and the energy module).
      select case (trim(context%pressure_source))
      case ('background', 'system')
         state%pressure = params%background_press
      case default
         call fatal_error('radiation', 'unsupported radiation_pressure_source: ' // &
                          trim(context%pressure_source))
      end select

      if (context%n_species > 0) then
         if (.not. present(species_Y)) call fatal_error('radiation', 'radiation selected species but species_Y was not passed')
         if (params%enable_chemistry_load_balancing) then
            do s = 1, context%n_species
               do c = 1, mesh%ncells
                  state%Y(s, c) = species_Y(context%species_index(s), c)
               end do
            end do
         else
            allocate(selected_local(context%n_species, mesh%ncells))
            selected_local = zero
            do s = 1, context%n_species
               do c = flow%first_cell, flow%last_cell
                  selected_local(s, c) = species_Y(context%species_index(s), c)
               end do
            end do
            call flow_allgather_owned_matrix(flow, selected_local, state%Y)
            deallocate(selected_local)
         end if
      end if

      ! Generic scalar gathering is intentionally reserved for the next scalar
      ! registry patch.  The state allocation keeps the API stable.
      if (allocated(state%scalars)) state%scalars = zero

      call compute_gather_debug(flow, context, state, debug_T_diff, debug_Y_diff)
   end subroutine prepare_radiation_state


   subroutine compute_gather_debug(flow, context, state, debug_T_diff, debug_Y_diff)
      type(flow_mpi_t), intent(in) :: flow
      type(radiation_context_t), intent(in) :: context
      type(radiation_state_t), intent(in) :: state
      real(rk), intent(out) :: debug_T_diff, debug_Y_diff
      real(rk) :: owned_sum, global_owned_sum, gathered_sum, diff, maxdiff
      integer :: ierr, c, s

      debug_T_diff = zero
      debug_Y_diff = zero
      if (.not. context%debug) return

      owned_sum = zero
      do c = flow%first_cell, flow%last_cell
         owned_sum = owned_sum + state%temperature(c)
      end do
      call MPI_Allreduce(owned_sum, global_owned_sum, 1, MPI_DOUBLE_PRECISION, MPI_SUM, flow%comm, ierr)
      if (ierr /= MPI_SUCCESS) call fatal_error('radiation', 'MPI failure reducing T checksum')
      gathered_sum = sum(state%temperature)
      diff = abs(gathered_sum - global_owned_sum)
      call MPI_Allreduce(diff, debug_T_diff, 1, MPI_DOUBLE_PRECISION, MPI_MAX, flow%comm, ierr)
      if (ierr /= MPI_SUCCESS) call fatal_error('radiation', 'MPI failure reducing T checksum diff')

      maxdiff = zero
      if (allocated(state%Y)) then
         do s = 1, state%n_species
            owned_sum = zero
            do c = flow%first_cell, flow%last_cell
               owned_sum = owned_sum + state%Y(s, c)
            end do
            call MPI_Allreduce(owned_sum, global_owned_sum, 1, MPI_DOUBLE_PRECISION, MPI_SUM, flow%comm, ierr)
            if (ierr /= MPI_SUCCESS) call fatal_error('radiation', 'MPI failure reducing Y checksum')
            gathered_sum = sum(state%Y(s, :))
            maxdiff = max(maxdiff, abs(gathered_sum - global_owned_sum))
         end do
      end if
      call MPI_Allreduce(maxdiff, debug_Y_diff, 1, MPI_DOUBLE_PRECISION, MPI_MAX, flow%comm, ierr)
      if (ierr /= MPI_SUCCESS) call fatal_error('radiation', 'MPI failure reducing Y checksum diff')
   end subroutine compute_gather_debug


   subroutine write_radiation_setup(mesh, rad, params, context)
      type(mesh_t), intent(in) :: mesh
      type(radiation_mpi_t), intent(in) :: rad
      type(case_params_t), intent(in) :: params
      type(radiation_context_t), intent(inout) :: context
      integer :: unit_id, r, first, last, nlocal, s

      if (.not. context%enabled) return
      if (.not. context%write_diagnostics) return
      if (context%rad_rank /= 0) return

      open(newunit=unit_id, file=trim(context%setup_file), status='replace', action='write')
      write(unit_id,'(a)') 'Radiation setup'
      write(unit_id,'(a)') '---------------'
      write(unit_id,'(a,l1)') 'enabled: ', context%enabled
      write(unit_id,'(a,a)') 'model: ', trim(context%model)
      write(unit_id,'(a,i0)') 'radiation communicator size: ', rad%nprocs
      write(unit_id,'(a,i0)') 'n_wavenumbers: ', context%n_wavenumbers
      write(unit_id,'(a,i0)') 'ncells: ', mesh%ncells
      write(unit_id,'(a,i0)') 'nfaces: ', mesh%nfaces
      write(unit_id,'(a,a)') 'pressure_source: ', trim(context%pressure_source)
      write(unit_id,'(a,l1)') 'mesh_cached: ', context%mesh_cached
      write(unit_id,'(a,l1)') 'debug_checksums: ', context%debug
      write(unit_id,'(a)') ''
      write(unit_id,'(a)') 'Wavenumber decomposition:'
      do r = 0, max(0, rad%nprocs-1)
         call local_wavenumber_bounds(context%n_wavenumbers, rad%nprocs, r, first, last, nlocal)
         write(unit_id,'(a,i0,a,i0,a,i0,a,i0)') '  rank ', r, ': ', first, '-', last, ' count=', nlocal
      end do
      write(unit_id,'(a)') ''
      write(unit_id,'(a)') 'Selected radiation species:'
      if (context%n_species == 0) then
         write(unit_id,'(a)') '  <none>'
      else
         do s = 1, context%n_species
            write(unit_id,'(a,a,a,i0)') '  ', trim(context%species_name(s)), ' -> solver species index ', context%species_index(s)
         end do
      end if
      write(unit_id,'(a)') ''
      write(unit_id,'(a)') 'Selected radiation scalars:'
      if (context%n_scalars == 0) then
         write(unit_id,'(a)') '  <none currently gathered; scalar registry pending>'
      else
         do s = 1, context%n_scalars
            write(unit_id,'(a,a)') '  ', trim(context%scalar_name(s))
         end do
      end if
      write(unit_id,'(a)') ''
      write(unit_id,'(a)') 'Developer contract:'
      write(unit_id,'(a)') '  Implement physics kernels under src/radiation/models/mod_radiation_<modelname>.f90.'
      write(unit_id,'(a)') '  Models receive full mesh/state and assigned wavenumber range; do not call MPI in model files.'
      write(unit_id,'(a,es16.8)') 'background_press: ', params%background_press
      close(unit_id)
      context%setup_written = .true.
   end subroutine write_radiation_setup


   subroutine write_radiation_diagnostics(mesh, rad, context, state, source, step, time, dt, &
                                          gather_time, compute_time, reduce_time, debug_T_diff, debug_Y_diff)
      type(mesh_t), intent(in) :: mesh
      type(radiation_mpi_t), intent(in) :: rad
      type(radiation_context_t), intent(inout) :: context
      type(radiation_state_t), intent(in) :: state
      type(radiation_source_t), intent(in) :: source
      integer, intent(in) :: step
      real(rk), intent(in) :: time, dt, gather_time, compute_time, reduce_time, debug_T_diff, debug_Y_diff
      integer :: unit_id, s
      real(rk) :: vol, T_avg, P_avg, q_avg, q_sum
      real(rk) :: ymin, ymax, yavg

      if (.not. context%write_diagnostics) return
      if (context%rad_rank /= 0) return

      if (.not. context%diagnostics_initialized) then
         open(newunit=unit_id, file=trim(context%diagnostics_file), status='replace', action='write')
         write(unit_id,'(a)', advance='no') 'step,time,dt,rad_size,n_wavenumbers,wn_first,wn_last,ncells,T_min,T_max,T_avg,P_min,P_max,P_avg,qrad_min,qrad_max,qrad_avg,qrad_sum,gather_time,compute_time,reduce_time,debug_T_sum_diff,debug_Y_sum_maxdiff'
         do s = 1, context%n_species
            write(unit_id,'(a,a,a)', advance='no') ',Y_', trim(context%species_name(s)), '_min'
            write(unit_id,'(a,a,a)', advance='no') ',Y_', trim(context%species_name(s)), '_max'
            write(unit_id,'(a,a,a)', advance='no') ',Y_', trim(context%species_name(s)), '_avg'
         end do
         write(unit_id,*)
         context%diagnostics_initialized = .true.
      else
         open(newunit=unit_id, file=trim(context%diagnostics_file), status='old', position='append', action='write')
      end if

      vol = real(max(1, mesh%ncells), rk)
      T_avg = sum(state%temperature) / vol
      P_avg = sum(state%pressure) / vol
      q_avg = sum(source%qrad) / vol
      q_sum = zero
      do s = 1, mesh%ncells
         q_sum = q_sum + source%qrad(s) * mesh%cells(s)%volume
      end do

      write(unit_id,'(i0,a,es16.8,a,es16.8,a,i0,a,i0,a,i0,a,i0,a,i0,a,es16.8,a,es16.8,a,es16.8,a,es16.8,a,es16.8,a,es16.8,a,es16.8,a,es16.8,a,es16.8,a,es16.8,a,es16.8,a,es16.8,a,es16.8,a,es16.8,a,es16.8)', advance='no') &
         step, ',', time, ',', dt, ',', rad%nprocs, ',', context%n_wavenumbers, ',', context%wn_first, ',', context%wn_last, ',', mesh%ncells, ',', &
         minval(state%temperature), ',', maxval(state%temperature), ',', T_avg, ',', &
         minval(state%pressure), ',', maxval(state%pressure), ',', P_avg, ',', &
         minval(source%qrad), ',', maxval(source%qrad), ',', q_avg, ',', q_sum, ',', &
         gather_time, ',', compute_time, ',', reduce_time, ',', debug_T_diff, ',', debug_Y_diff

      if (allocated(state%Y)) then
         do s = 1, context%n_species
            ymin = minval(state%Y(s, :))
            ymax = maxval(state%Y(s, :))
            yavg = sum(state%Y(s, :)) / vol
            write(unit_id,'(a,es16.8,a,es16.8,a,es16.8)', advance='no') ',', ymin, ',', ymax, ',', yavg
         end do
      end if
      write(unit_id,*)
      close(unit_id)
   end subroutine write_radiation_diagnostics


   subroutine local_wavenumber_bounds(n_wavenumbers, nprocs, rank, first, last, nlocal)
      integer, intent(in) :: n_wavenumbers, nprocs, rank
      integer, intent(out) :: first, last, nlocal
      integer :: base, rem

      if (nprocs <= 0) then
         first = 0; last = -1; nlocal = 0
         return
      end if
      base = n_wavenumbers / nprocs
      rem = mod(n_wavenumbers, nprocs)
      if (rank < rem) then
         nlocal = base + 1
         first = rank * (base + 1) + 1
      else
         nlocal = base
         first = rem * (base + 1) + (rank - rem) * base + 1
      end if
      last = first + nlocal - 1
      if (nlocal <= 0) then
         first = 0
         last = -1
      end if
   end subroutine local_wavenumber_bounds

end module mod_radiation
