conjugate_gradient_solver Subroutine

public subroutine conjugate_gradient_solver(n, rhs, x, eel, matvec, precnd, arg_tol, arg_n_iter)

Uses

  • proc~~conjugate_gradient_solver~~UsesGraph proc~conjugate_gradient_solver conjugate_gradient_solver module~mod_constants mod_constants proc~conjugate_gradient_solver->module~mod_constants module~mod_memory mod_memory proc~conjugate_gradient_solver->module~mod_memory iso_c_binding iso_c_binding module~mod_constants->iso_c_binding module~mod_memory->module~mod_constants module~mod_memory->iso_c_binding module~mod_io mod_io module~mod_memory->module~mod_io module~mod_io->module~mod_constants

Conjugate gradient solver (TODO)

Routine to perform matrix-vector product Preconditioner routine

Arguments

Type IntentOptional Attributes Name
integer(kind=ip), intent(in) :: n

Size of the matrix

real(kind=rp), intent(in), dimension(n) :: rhs

Right hand side of the linear system

real(kind=rp), intent(inout), dimension(n) :: x

In input, initial guess for the solver, in output the solution

type(ommp_electrostatics_type), intent(in) :: eel

Electrostatics data structure

integer :: matvec
real :: precnd
real(kind=rp), intent(in), optional :: arg_tol

Optional convergence criterion in input, if not present OMMP_DEFAULT_SOLVER_TOL is used.

integer(kind=ip), intent(in), optional :: arg_n_iter

Optional maximum number of iterations for the solver, if not present OMMP_DEFAULT_SOLVER_ITER is used.


Calls

proc~~conjugate_gradient_solver~~CallsGraph proc~conjugate_gradient_solver conjugate_gradient_solver proc~ommp_message ommp_message proc~conjugate_gradient_solver->proc~ommp_message interface~mallocate mallocate proc~conjugate_gradient_solver->interface~mallocate interface~mfree mfree proc~conjugate_gradient_solver->interface~mfree proc~fatal_error fatal_error proc~conjugate_gradient_solver->proc~fatal_error proc~r_alloc1 r_alloc1 interface~mallocate->proc~r_alloc1 proc~r_alloc3 r_alloc3 interface~mallocate->proc~r_alloc3 proc~i_alloc2 i_alloc2 interface~mallocate->proc~i_alloc2 proc~i_alloc1 i_alloc1 interface~mallocate->proc~i_alloc1 proc~r_alloc2 r_alloc2 interface~mallocate->proc~r_alloc2 proc~i_alloc3 i_alloc3 interface~mallocate->proc~i_alloc3 proc~l_alloc1 l_alloc1 interface~mallocate->proc~l_alloc1 proc~l_alloc2 l_alloc2 interface~mallocate->proc~l_alloc2 proc~i_free1 i_free1 interface~mfree->proc~i_free1 proc~l_free1 l_free1 interface~mfree->proc~l_free1 proc~r_free3 r_free3 interface~mfree->proc~r_free3 proc~r_free1 r_free1 interface~mfree->proc~r_free1 proc~i_free2 i_free2 interface~mfree->proc~i_free2 proc~r_free2 r_free2 interface~mfree->proc~r_free2 proc~l_free2 l_free2 interface~mfree->proc~l_free2 proc~i_free3 i_free3 interface~mfree->proc~i_free3 proc~fatal_error->proc~ommp_message proc~close_output close_output proc~fatal_error->proc~close_output proc~chk_free chk_free proc~i_free1->proc~chk_free proc~l_free1->proc~chk_free proc~r_free3->proc~chk_free proc~chk_alloc chk_alloc proc~r_alloc1->proc~chk_alloc proc~memory_init memory_init proc~r_alloc1->proc~memory_init proc~r_alloc3->proc~chk_alloc proc~r_alloc3->proc~memory_init proc~i_alloc2->proc~chk_alloc proc~i_alloc2->proc~memory_init proc~i_alloc1->proc~chk_alloc proc~i_alloc1->proc~memory_init proc~r_alloc2->proc~chk_alloc proc~r_alloc2->proc~memory_init proc~i_alloc3->proc~chk_alloc proc~i_alloc3->proc~memory_init proc~l_alloc1->proc~chk_alloc proc~l_alloc1->proc~memory_init proc~l_alloc2->proc~chk_alloc proc~l_alloc2->proc~memory_init proc~r_free1->proc~chk_free proc~i_free2->proc~chk_free proc~r_free2->proc~chk_free proc~l_free2->proc~chk_free proc~close_output->proc~ommp_message proc~i_free3->proc~chk_free proc~chk_free->proc~fatal_error proc~chk_alloc->proc~fatal_error

Called by

proc~~conjugate_gradient_solver~~CalledByGraph proc~conjugate_gradient_solver conjugate_gradient_solver proc~polarization polarization proc~polarization->proc~conjugate_gradient_solver proc~polelec_geomgrad polelec_geomgrad proc~polelec_geomgrad->proc~polarization proc~ommp_get_polelec_energy ommp_get_polelec_energy proc~ommp_get_polelec_energy->proc~polarization proc~ommp_set_external_field ommp_set_external_field proc~ommp_set_external_field->proc~polarization proc~ommp_polelec_geomgrad ommp_polelec_geomgrad proc~ommp_polelec_geomgrad->proc~polelec_geomgrad proc~c_ommp_get_polelec_energy C_ommp_get_polelec_energy proc~c_ommp_get_polelec_energy->proc~ommp_get_polelec_energy proc~ommp_get_full_ele_energy ommp_get_full_ele_energy proc~ommp_get_full_ele_energy->proc~ommp_get_polelec_energy proc~ommp_set_external_field_nomm ommp_set_external_field_nomm proc~ommp_set_external_field_nomm->proc~ommp_set_external_field proc~ommp_full_geomgrad ommp_full_geomgrad proc~ommp_full_geomgrad->proc~polelec_geomgrad proc~c_ommp_set_external_field_nomm C_ommp_set_external_field_nomm proc~c_ommp_set_external_field_nomm->proc~ommp_set_external_field proc~c_ommp_set_external_field C_ommp_set_external_field proc~c_ommp_set_external_field->proc~ommp_set_external_field proc~c_ommp_polelec_geomgrad C_ommp_polelec_geomgrad proc~c_ommp_polelec_geomgrad->proc~ommp_polelec_geomgrad proc~c_ommp_full_geomgrad C_ommp_full_geomgrad proc~c_ommp_full_geomgrad->proc~ommp_full_geomgrad proc~c_ommp_get_full_ele_energy C_ommp_get_full_ele_energy proc~c_ommp_get_full_ele_energy->proc~ommp_get_full_ele_energy proc~ommp_get_full_energy ommp_get_full_energy proc~ommp_get_full_energy->proc~ommp_get_full_ele_energy proc~c_ommp_get_full_energy C_ommp_get_full_energy proc~c_ommp_get_full_energy->proc~ommp_get_full_energy

Contents


Source Code

    subroutine conjugate_gradient_solver(n, rhs, x, eel, matvec, precnd, &
                                         arg_tol, arg_n_iter)
        !! Conjugate gradient solver (TODO)
        ! TODO add more printing
    
        use mod_constants, only: eps_rp
        use mod_memory, only: mallocate, mfree

        implicit none

        integer(ip), intent(in) :: n
        !! Size of the matrix
        real(rp), intent(in), optional :: arg_tol
        !! Optional convergence criterion in input, if not present
        !! OMMP_DEFAULT_SOLVER_TOL is used.
        real(rp) :: tol
        !! Convergence criterion, it is required that RMS norm < tol

        integer(ip), intent(in), optional :: arg_n_iter
        !! Optional maximum number of iterations for the solver, if not present
        !! OMMP_DEFAULT_SOLVER_ITER is used.
        integer(ip) :: n_iter
        !! Maximum number of iterations for the solver 

        real(rp), dimension(n), intent(in) :: rhs
        !! Right hand side of the linear system
        real(rp), dimension(n), intent(inout) :: x
        !! In input, initial guess for the solver, in output the solution
        type(ommp_electrostatics_type), intent(in) :: eel
        !! Electrostatics data structure
        external :: matvec
        !! Routine to perform matrix-vector product
        external :: precnd
        !! Preconditioner routine

        integer(ip) :: it
        real(rp) :: rms_norm, alpha, gnew, gold, gama
        real(rp), allocatable :: r(:), p(:), h(:), z(:)
        character(len=OMMP_STR_CHAR_MAX) :: msg

        ! Optional arguments handling
        if(present(arg_tol)) then
            tol = arg_tol
        else
            tol = OMMP_DEFAULT_SOLVER_TOL
        end if

        if(present(arg_n_iter)) then
            n_iter = arg_n_iter
        else
            n_iter = OMMP_DEFAULT_SOLVER_ITER
        end if

        call ommp_message("Solving linear system with CG solver", OMMP_VERBOSE_LOW)
        write(msg, "(A, I4)") "Max iter:", n_iter
        call ommp_message(msg, OMMP_VERBOSE_LOW)
        write(msg, "(A, E8.1)") "Tolerance: ", tol
        call ommp_message(msg, OMMP_VERBOSE_LOW)

        call mallocate('conjugate_gradient_solver [r]', n, r)
        call mallocate('conjugate_gradient_solver [p]', n, p)
        call mallocate('conjugate_gradient_solver [h]', n, h)
        call mallocate('conjugate_gradient_solver [z]', n, z)

        ! compute a guess, if required:
        rms_norm = dot_product(x,x)
        if(rms_norm < eps_rp) then
            call ommp_message("Input guess has zero norm, generating a guess&
                              & from preconditioner.", OMMP_VERBOSE_HIGH)
            call precnd(eel, x, x)
        else
            call ommp_message("Using input guess as a starting point for&
                              & iterative solver.", OMMP_VERBOSE_HIGH)
        end if

        ! compute the residual:
        call matvec(eel, x, z, .true.)
        r = rhs - z
        ! apply the preconditioner and get the first direction:
        call precnd(eel, r, z)
        p = z
        gold = dot_product(r, z)
        gama = 0.0_rp

        do it = 1, n_iter
            ! compute the step:
            call matvec(eel, p, h, .true.)
            gama = dot_product(h, p)

            ! unlikely quick return:
            if(abs(gama) < eps_rp) then
                call ommp_message("Direction vector with zero norm, exiting &
                                  &iterative solver.", OMMP_VERBOSE_HIGH)
                exit
            end if

            alpha = gold / gama
            x = x + alpha * p
            r = r - alpha * h

            ! apply the preconditioner:
            call precnd(eel, r, z)
            gnew = dot_product(r, z)
            rms_norm = sqrt(gnew/dble(n))

            write(msg, "('iter=',i4,' residual rms norm: ', d14.4)") it, rms_norm
            call ommp_message(msg, OMMP_VERBOSE_HIGH)

            ! Check convergence
            if(rms_norm < tol) then
                call ommp_message("Required convergence threshold reached, &
                                  &exiting iterative solver.", OMMP_VERBOSE_HIGH)
                exit
            end if

            ! compute the next direction:
            gama = gnew/gold
            p    = gama*p + z
            gold = gnew
        end do

        call mfree('conjugate_gradient_solver [r]', r)
        call mfree('conjugate_gradient_solver [p]', p)
        call mfree('conjugate_gradient_solver [h]', h)
        call mfree('conjugate_gradient_solver [z]', z)

        if(rms_norm > tol .and. abs(gama) > eps_rp) then
            call fatal_error("Iterative solver did not converged")
        end if

    end subroutine conjugate_gradient_solver