rejection_two_classes.f90 Source File


This file depends on

sourcefile~~rejection_two_classes.f90~~EfferentGraph sourcefile~rejection_two_classes.f90 rejection_two_classes.f90 sourcefile~base.f90 base.f90 sourcefile~rejection_two_classes.f90->sourcefile~base.f90 sourcefile~kinds.f90 kinds.f90 sourcefile~rejection_two_classes.f90->sourcefile~kinds.f90 sourcefile~lists.f90 lists.f90 sourcefile~rejection_two_classes.f90->sourcefile~lists.f90 sourcefile~rejection.f90 rejection.f90 sourcefile~rejection_two_classes.f90->sourcefile~rejection.f90 sourcefile~base.f90->sourcefile~kinds.f90 sourcefile~lists.f90->sourcefile~kinds.f90 sourcefile~dynamical_list.f90 dynamical_list.f90 sourcefile~lists.f90->sourcefile~dynamical_list.f90 sourcefile~fixed_list.f90 fixed_list.f90 sourcefile~lists.f90->sourcefile~fixed_list.f90 sourcefile~maxheap.f90 maxheap.f90 sourcefile~lists.f90->sourcefile~maxheap.f90 sourcefile~rejection.f90->sourcefile~base.f90 sourcefile~rejection.f90->sourcefile~kinds.f90 sourcefile~rejection.f90->sourcefile~lists.f90 sourcefile~dynamical_list.f90->sourcefile~kinds.f90 sourcefile~fixed_list.f90->sourcefile~kinds.f90 sourcefile~maxheap.f90->sourcefile~kinds.f90

Files dependent on this one

sourcefile~~rejection_two_classes.f90~~AfferentGraph sourcefile~rejection_two_classes.f90 rejection_two_classes.f90 sourcefile~samplers.f90 samplers.f90 sourcefile~samplers.f90->sourcefile~rejection_two_classes.f90 sourcefile~datastructs_mod.f90 datastructs_mod.f90 sourcefile~datastructs_mod.f90->sourcefile~samplers.f90

Source Code

!> This module implements rejection sampling with two classes
!> We split the weights into two classes based on a threshold
module datastructs_samplers_rejection_two_classes_mod
    use datastructs_kinds_mod
    use datastructs_lists_mod
    use datastructs_samplers_rejection_mod, only : rejection_t => weighted_sampler_t
    use datastructs_samplers_base_mod
    implicit none
    private

    real(kind=dp), parameter :: EPSILON = 0.0_dp

    !> Constructor
    interface weighted_sampler
        module procedure weighted_sampler_new
    end interface weighted_sampler

    !> Derived type, extending from the base
    type, extends(sampler_base_t) :: weighted_sampler_t
        type(rejection_t) :: samplers(2)
        real(dp) :: threshold = huge(dp) ! threshold for the weights
        integer(i4), allocatable :: sampler_of_index(:) ! maps index to sampler (1 or 2)
    contains
        procedure :: init_n => sampler_init
        procedure :: init_w => sampler_init_w
        procedure :: init_w2 => sampler_init_w2
        procedure :: reset        => sampler_reset
        procedure :: set_weight   => sampler_set_weight
        procedure :: set_weight_array => sampler_set_weight_array
        procedure :: add_weight => sampler_add_weight
        procedure :: sample       => sampler_sample
        procedure :: remove => sampler_remove
        procedure :: sum => sampler_sum

        final :: sampler_finalize
    end type

    public :: weighted_sampler, weighted_sampler_t

contains

    !> Create a new rejection sampler with N weights
    !> Input: n - number of weights
    function weighted_sampler_new(n) result(this)
        type(weighted_sampler_t) :: this
        integer(i4), intent(in) :: n

        call this%init(n)

    end function weighted_sampler_new

    !> Initializes the structure with N weights
    !> Input: n - number of weights
    subroutine sampler_init(this, n)
        class(weighted_sampler_t), intent(inout) :: this
        integer(i4), intent(in) :: n

        this%n = n ! Set the number of weights

        ! init each sampler with size n
        call this%samplers(1)%init(n)
        call this%samplers(2)%init(n)

        allocate(this%sampler_of_index(n))
        allocate(this%weights(n))
        this%sampler_of_index = 0 ! Initialize to zero (no sampler assigned yet)

        this%weights = 0.0_dp ! Initialize weights to zero
    end subroutine sampler_init

    !> Initializes with size n and threshold w
    !> Input: n - number of weights
    !>        w - threshold for the weights
    subroutine sampler_init_w(this, n, w)
        class(weighted_sampler_t), intent(inout) :: this
        integer(i4), intent(in) :: n
        real(dp), intent(in) :: w

        call this%init(n)  ! call original init
        this%threshold = w ! set the threshold for the weights
    end subroutine sampler_init_w

    !> Placeholder for 1D initialization (compat mode), maps to original init
    !> Input: n - number of weights
    !>        w1 - any real number, not used
    !>        w2 - any real number, not used
    subroutine sampler_init_w2(this, n, w1,w2)
        class(weighted_sampler_t), intent(inout) :: this
        integer(i4), intent(in) :: n
        real(dp), intent(in) :: w1,w2

        call this%init(n,w1)  ! call original init
    end subroutine sampler_init_w2

    !> Resets the sampler: clears the list
    subroutine sampler_reset(this)
        class(weighted_sampler_t), intent(inout) :: this

        call this%samplers(1)%reset() ! Reset the first sampler
        call this%samplers(2)%reset() ! Reset the second sampler
        this%sampler_of_index = 0 ! Reset the index mapping
        this%weights = 0.0_dp ! Reset the weights
    end subroutine sampler_reset

    !> Sets the weight for a given index
    !> Input: index - index of the element with a given weight
    !>        weight - weight of the element
    subroutine sampler_set_weight(this, index, weight)
        class(weighted_sampler_t), intent(inout) :: this
        integer(i4), intent(in) :: index
        real(dp), intent(in) :: weight
        integer(i4) :: old_sampler_pos

        ! if the weight is zero, remove the index from the list
        if (weight <= EPSILON) then
            call this%remove(index)
            return
        end if

        ! get the sampler pos
        old_sampler_pos = this%sampler_of_index(index)

        ! if the weight is the same and is already in the list, just return
        if ((old_sampler_pos /= 0) .and. (weight == this%weights(index))) return

        ! We select which sampler will receive the weight based on the threshold
        ! The first sampler will receive weights below the threshold
        ! The second sampler will receive weights above the threshold
        if (weight < this%threshold) then
            call this%samplers(1)%set_weight(index, weight) ! Set the weight in the first sampler
            this%sampler_of_index(index) = 1 ! Map index to first sampler

            ! If it was in the second sampler, we need to remove it
            if (old_sampler_pos == 2) call this%samplers(2)%remove(index)
        else
            call this%samplers(2)%set_weight(index, weight) ! Set the weight in the second sampler
            this%sampler_of_index(index) = 2 ! Map index to second sampler

            ! If it was in the first sampler, we need to remove it
            if (old_sampler_pos == 1) call this%samplers(1)%remove(index)
        end if

        ! update the weights array
        this%weights(index) = weight
    end subroutine sampler_set_weight

    !> Sets the weights from an array (full), using its indexes
    !> Weights should be larger than zero
    !> It assumes it was initialized before, and has the same size
    !> Input: weights - array with the weights
    subroutine sampler_set_weight_array(this, weights)
        class(weighted_sampler_t), intent(inout) :: this
        real(dp), intent(in) :: weights(:)
        integer(i4) :: i

        ! We select which sampler will receive the weights based on the threshold
        ! The first sampler will receive weights below the threshold
        ! The second sampler will receive weights above the threshold
        do i = 1, size(weights)
            call this%set_weight(i, weights(i))
        end do

    end subroutine sampler_set_weight_array

    !> Adds a weight to the sampler at a given index
    !> Input: index - index of the element
    !>        delta_weight - difference to add to its weight
    subroutine sampler_add_weight(this, index, delta_weight)
        class(weighted_sampler_t), intent(inout) :: this
        integer(i4), intent(in) :: index
        real(dp), intent(in) :: delta_weight
        real(dp) :: weight

        weight = this%weights(index) + delta_weight ! Calculate the new weight

        ! Now we can add the weight again
        call this%set_weight(index, weight) ! Set the new weight

    end subroutine sampler_add_weight

    !> Remove an index from the sampler
    !> Important: the index is the original one, not the index used internally by the sampler
    !> Input: index - index of the element to be removed
    subroutine sampler_remove(this, index)
        class(weighted_sampler_t), intent(inout) :: this
        integer(i4), intent(in) :: index
        integer(i4) :: sampler_pos

        sampler_pos = this%sampler_of_index(index) ! Get the sampler position for the index

        if (sampler_pos == 0) return ! It is not mapped to any sampler, just ignore

        call this%samplers(sampler_pos)%remove(index) ! Remove weight in the first sampler
        this%sampler_of_index(index) = 0 ! Unmap index

        ! update the weights array
        this%weights(index) = 0.0_dp ! Set the weight to zero

    end subroutine sampler_remove

    !> Samples an index from the sampler
    !> Input: gen - random number generator (rndgen-fortran module)
    !> Output: index - sampled index
    function sampler_sample(this, gen) result(index)
        use rndgen_mod
        class(weighted_sampler_t), intent(in) :: this
        class(rndgen), intent(inout) :: gen
        integer(i4) :: index
        real(dp) :: total_weight

        total_weight = this%sum()

        ! We will use rejection sampling to select an index
        ! First, we see which sampler to use based on the total weight
        if (gen%rnd() < this%samplers(1)%sum() / total_weight) then
            ! Use the first sampler
            index = this%samplers(1)%sample(gen)
        else
            ! Use the second sampler
            index = this%samplers(2)%sample(gen)
        end if

    end function

    !> Get the sum of all weights
    !> Output: total - sum of all weights
    function sampler_sum(this) result(total_weight)
        class(weighted_sampler_t), intent(in) :: this
        real(dp) :: total_weight

        total_weight = this%samplers(1)%sum() + this%samplers(2)%sum()

    end function sampler_sum

    !> Finalize the sampler, deallocate resources
    subroutine sampler_finalize(this)
        type(weighted_sampler_t), intent(inout) :: this

        if (allocated(this%sampler_of_index)) then
            deallocate(this%sampler_of_index)
        end if
        if (allocated(this%weights)) then
            deallocate(this%weights)
        end if

    end subroutine sampler_finalize

end module datastructs_samplers_rejection_two_classes_mod