rejection_maxheap_composition.f90 Source File


This file depends on

sourcefile~~rejection_maxheap_composition.f90~~EfferentGraph sourcefile~rejection_maxheap_composition.f90 rejection_maxheap_composition.f90 sourcefile~base.f90 base.f90 sourcefile~rejection_maxheap_composition.f90->sourcefile~base.f90 sourcefile~btree.f90 btree.f90 sourcefile~rejection_maxheap_composition.f90->sourcefile~btree.f90 sourcefile~kinds.f90 kinds.f90 sourcefile~rejection_maxheap_composition.f90->sourcefile~kinds.f90 sourcefile~lists.f90 lists.f90 sourcefile~rejection_maxheap_composition.f90->sourcefile~lists.f90 sourcefile~rejection_maxheap.f90 rejection_maxheap.f90 sourcefile~rejection_maxheap_composition.f90->sourcefile~rejection_maxheap.f90 sourcefile~base.f90->sourcefile~kinds.f90 sourcefile~btree.f90->sourcefile~base.f90 sourcefile~btree.f90->sourcefile~kinds.f90 sourcefile~lists.f90->sourcefile~kinds.f90 sourcefile~dynamical_list.f90 dynamical_list.f90 sourcefile~lists.f90->sourcefile~dynamical_list.f90 sourcefile~fixed_list.f90 fixed_list.f90 sourcefile~lists.f90->sourcefile~fixed_list.f90 sourcefile~maxheap.f90 maxheap.f90 sourcefile~lists.f90->sourcefile~maxheap.f90 sourcefile~rejection_maxheap.f90->sourcefile~base.f90 sourcefile~rejection_maxheap.f90->sourcefile~kinds.f90 sourcefile~rejection_maxheap.f90->sourcefile~lists.f90 sourcefile~dynamical_list.f90->sourcefile~kinds.f90 sourcefile~fixed_list.f90->sourcefile~kinds.f90 sourcefile~maxheap.f90->sourcefile~kinds.f90

Files dependent on this one

sourcefile~~rejection_maxheap_composition.f90~~AfferentGraph sourcefile~rejection_maxheap_composition.f90 rejection_maxheap_composition.f90 sourcefile~samplers.f90 samplers.f90 sourcefile~samplers.f90->sourcefile~rejection_maxheap_composition.f90 sourcefile~datastructs_mod.f90 datastructs_mod.f90 sourcefile~datastructs_mod.f90->sourcefile~samplers.f90

Source Code

!> This module implements a rejection composition sampler
!> It is based on "St-Onge, G., Young, J. G., Hébert-Dufresne, L., & Dubé, L. J. (2019).
!> Efficient sampling of spreading processes on complex networks using a composition and rejection algorithm.
!> Computer physics communications, 240, 30-37"
!> <https://doi.org/10.1016/j.cpc.2019.02.008>
module datastructs_samplers_rejection_maxheap_composition_mod
    use datastructs_kinds_mod
    use datastructs_lists_mod
    use datastructs_samplers_rejection_maxheap_mod, only : rejection_maxheap_t => weighted_sampler_t
    use datastructs_samplers_btree_mod, only : btree_t => weighted_sampler_t
    use datastructs_samplers_base_mod
    implicit none
    private

    real(kind=dp), parameter :: EPSILON = 0.0_dp

    !> Constructor
    interface weighted_sampler
        module procedure weighted_sampler_new
    end interface weighted_sampler

    !> Derived type, extending from the base
    type, extends(sampler_base_t) :: weighted_sampler_t
        type(rejection_maxheap_t), allocatable :: samplers(:)
        integer(i4) :: q = 0 ! number of groups
        type(btree_t) :: btree ! btree for the samplers selection (size q)
        integer(i4), allocatable :: sampler_of_index(:) ! maps index to sampler
        real(dp) :: wmin, wmax
    contains
        procedure :: init_n => sampler_init
        procedure :: init_w => sampler_init_w
        procedure :: init_w2 => sampler_init_w2
        procedure :: reset        => sampler_reset
        procedure :: set_weight   => sampler_set_weight
        procedure :: set_weight_array => sampler_set_weight_array
        procedure :: add_weight => sampler_add_weight
        procedure :: sample       => sampler_sample
        procedure :: remove => sampler_remove
        procedure :: sum => sampler_sum
        procedure :: sampler_pos => sampler_sampler_pos

        final :: sampler_finalize
    end type

    public :: weighted_sampler, weighted_sampler_t

contains

    !> Create a new sampler with N weights
    !> Input: n - number of weights
    function weighted_sampler_new(n) result(this)
        type(weighted_sampler_t) :: this
        integer(i4), intent(in) :: n

        call this%init(n)

    end function weighted_sampler_new

    !> Initializes the structure with N weights
    !> Input: n - number of weights
    subroutine sampler_init(this, n)
        class(weighted_sampler_t), intent(inout) :: this
        integer(i4), intent(in) :: n
        integer(i4) :: i

        this%n = n ! Set the number of weights

        ! DEBUG
        if (this%q == 0) then
            error stop 'Error: q must be set before initializing the sampler.'
        end if

        allocate(this%samplers(this%q))

        call this%btree%init(this%q) ! Initialize the btree for sampler selection

        ! init each sampler with size n
        do i = 1, this%q
            call this%samplers(i)%init(n) ! Initialize each sampler with n weights
        end do

        allocate(this%sampler_of_index(n))
        allocate(this%weights(n))
        this%sampler_of_index = 0 ! Initialize to zero (no sampler assigned yet)

        this%weights = 0.0_dp ! Initialize weights to zero
    end subroutine sampler_init

    !> Placeholder for 1D initialization (compat mode), maps to original init
    !> Input: n - number of weights
    !>        w - any real number, not used
    subroutine sampler_init_w(this, n, w)
        class(weighted_sampler_t), intent(inout) :: this
        integer(i4), intent(in) :: n
        real(dp), intent(in) :: w

        call this%init(n)  ! call original init
    end subroutine sampler_init_w

    !> Initializes the sampler giving its minimum and maximum weights
    !> This calculates the number of rejection samplers to use
    !> Input: n - number of weights
    !>        w1 - minimum weight
    !>        w2 - maximum weight
    subroutine sampler_init_w2(this, n, w1,w2)
        class(weighted_sampler_t), intent(inout) :: this
        integer(i4), intent(in) :: n
        real(dp), intent(in) :: w1, w2 ! omega min and max

        ! calculate the number q (see https://arxiv.org/pdf/1808.05859)

        ! ensure that w1 is larger than zero
        if (w1 <= 0.0_dp) then
            error stop 'Error: min weight must be larger than zero.'
        end if

        this%wmin = w1 ! Set the minimum weight
        this%wmax = w2 ! Set the maximum weight

        this%q = max(1, ceiling(log(this%wmax / this%wmin) / log(2.0_dp))) ! number of groups

        !write(*, fmt_general) 'Number of groups (q): ', this%q

        call this%init(n)  ! call original init
    end subroutine sampler_init_w2

    !> Resets the sampler: clears the list
    subroutine sampler_reset(this)
        class(weighted_sampler_t), intent(inout) :: this
        integer(kind=i4) :: sampler_pos

        ! reset all samplers
        do sampler_pos = 1, this%q
            call this%samplers(sampler_pos)%reset()
        end do
        this%sampler_of_index = 0 ! Reset the index mapping
        this%weights = 0.0_dp ! Reset the weights
        call this%btree%reset() ! Reset the btree
    end subroutine sampler_reset

    !> Sets the weight for a given index
    !> Input: index - index of the element with a given weight
    !>        weight - weight of the element
    subroutine sampler_set_weight(this, index, weight)
        class(weighted_sampler_t), intent(inout) :: this
        integer(i4), intent(in) :: index
        real(dp), intent(in) :: weight
        integer(i4) :: sampler_pos
        integer(i4) :: old_sampler_pos

        ! if the weight is zero, remove the index from the list
        if (weight <= EPSILON) then
            call this%remove(index)
            return
        end if

        ! get the sampler pos
        old_sampler_pos = this%sampler_of_index(index)

        ! if the weight is the same and is already in the list, just return
        if ((old_sampler_pos /= 0) .and. (weight == this%weights(index))) return

        sampler_pos = this%sampler_pos(weight) ! Get the sampler position based on the weight

        call this%samplers(sampler_pos)%set_weight(index, weight) ! Set the weight in the first sampler
        this%sampler_of_index(index) = sampler_pos ! Map index to the corresponding sampler

        ! remove from the old sampler if it was assigned
        if ((old_sampler_pos /= 0) .and. (old_sampler_pos /= sampler_pos)) then
            call this%samplers(old_sampler_pos)%remove(index)
        end if

        ! set weight of the btree
        call this%btree%set_weight(sampler_pos, this%samplers(sampler_pos)%sum())

        ! update the weights array
        this%weights(index) = weight
    end subroutine sampler_set_weight

    !> Sets the weights from an array (full), using its indexes
    !> Weights should be larger than zero
    !> It assumes it was initialized before, and has the same size
    !> Input: weights - array with the weights
    subroutine sampler_set_weight_array(this, weights)
        class(weighted_sampler_t), intent(inout) :: this
        real(dp), intent(in) :: weights(:)
        integer(i4) :: i

        call this%reset() ! Reset the sampler before setting weights

        ! loop over all elements and set the weights
        do i = 1, size(weights)
            call this%set_weight(i, weights(i)) ! Set the weight for each index
        end do

    end subroutine sampler_set_weight_array

    !> Adds a weight to the sampler at a given index
    !> Input: index - index of the element
    !>        delta_weight - difference to add to its weight
    subroutine sampler_add_weight(this, index, delta_weight)
        class(weighted_sampler_t), intent(inout) :: this
        integer(i4), intent(in) :: index
        real(dp), intent(in) :: delta_weight
        real(dp) :: weight

        weight = this%weights(index) + delta_weight ! Calculate the new weight

        ! Now we can add the weight again
        call this%set_weight(index, weight) ! Set the new weight

    end subroutine sampler_add_weight

    !> Remove an index from the sampler
    !> Important: the index is the original one, not the index used internally by the sampler
    !> Input: index - index of the element to be removed
    subroutine sampler_remove(this, index)
        class(weighted_sampler_t), intent(inout) :: this
        integer(i4), intent(in) :: index
        integer(i4) :: sampler_pos
        real(dp) :: weight

        sampler_pos = this%sampler_of_index(index) ! Get the sampler position for the index

        if (sampler_pos == 0) return ! It is not mapped to any sampler, just ignore

        call this%samplers(sampler_pos)%remove(index) ! Remove weight in the first sampler
        this%sampler_of_index(index) = 0 ! Unmap index from first sampler

        weight = this%weights(index) ! Get the weight for the index

        ! update the weights array
        this%weights(index) = 0.0_dp ! Set the weight to zero

        ! update the btree
        call this%btree%set_weight(sampler_pos, this%samplers(sampler_pos)%sum())

    end subroutine sampler_remove

    !> Samples an index from the sampler
    !> Input: gen - random number generator (rndgen-fortran module)
    !> Output: index - sampled index
    function sampler_sample(this, gen) result(index)
        use rndgen_mod
        class(weighted_sampler_t), intent(in) :: this
        class(rndgen), intent(inout) :: gen
        integer(i4) :: index
        integer(i4) :: sampler_pos

        sampler_pos = this%btree%sample(gen) ! Get a sampler position from the btree

        ! samples from the corresponding rejection sampler
        index = this%samplers(sampler_pos)%sample(gen)

    end function

    !> Get the sum of all weights
    !> Output: total - sum of all weights
    function sampler_sum(this) result(total_weight)
        class(weighted_sampler_t), intent(in) :: this
        real(dp) :: total_weight

        total_weight = this%btree%sum()

    end function sampler_sum

    !> Finalize the sampler, deallocate resources
    subroutine sampler_finalize(this)
        type(weighted_sampler_t), intent(inout) :: this

        if (allocated(this%sampler_of_index)) then
            deallocate(this%sampler_of_index)
        end if
        if (allocated(this%weights)) then
            deallocate(this%weights)
        end if

    end subroutine sampler_finalize

    !> Auxiliary function (private) to get the sampler position for a given weight
    !> Input: weight - the weight to find the sampler position for
    function sampler_sampler_pos(this, weight) result(pos)
        class(weighted_sampler_t), intent(in) :: this
        real(dp), intent(in) :: weight
        integer(i4) :: pos

        if (weight < this%wmin) then
            pos = 1
        else
            ! Calculate the sampler position based on the weight
            pos = min(this%q, ceiling(log(weight / this%wmin) / log(2.0_dp)) + 1)
        end if

    end function sampler_sampler_pos

end module datastructs_samplers_rejection_maxheap_composition_mod