TruncatedNormalDistribution.java

/*
 * Copyright 2015 University of Glasgow.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package broadwick.statistics.distributions;

import broadwick.rng.RNG;
import lombok.Getter;

/**
 * An implementation of the truncated normal distribution, i.e. a normal distribution whose value is bounded either
 * above, below or both.
 * <p>
 * References:</p><p>
 * <ul>
 * <li><a href="http://mathworld.wolfram.com/NormalDistribution.html"> Normal Distribution</a></li>
 * </ul>
 */
public class TruncatedNormalDistribution implements ContinuousDistribution {

    /**
     * Create an instance of the normal distribution distribution truncated to given limits.
     * @param mean the mean of the normal distribution.
     * @param sd   the standard deviation of the distribution.
     * @param lb   the lower bound of the distribution, no values lower than this will be returned.
     * @param ub   the upper bound of the distribution, no values higher1 than this will be returned.
     */
    public TruncatedNormalDistribution(final double mean, final double sd,
                                       final double lb, final double ub) {

        if (mean == Double.NaN || sd == Double.NaN
            || lb == Double.NaN || ub == Double.NaN) {
            throw new IllegalArgumentException("Invalid argument: TruncatedNormalDistribution cannot take NaN as argument");
        }

        if (lb > ub) {
            throw new IllegalArgumentException("Invalid argument: lower bound greater than upper bound");
        }

        this.mean = mean;
        this.sd = sd;
        this.lower = lb;
        this.upper = ub;
        this.generator = new RNG(RNG.Generator.Well19937c);
    }

    @Override
    public double sample() {

        // Sample using the method of C.P. Robert (doi: 10.1007/BF00143942, arXiv:0907.4010 [stat.CO])
        double x = Double.NaN;
        double rho = 0.0;
        double u = generator.getDouble();

        while (u > rho) {
            double z = generator.getDouble(lower, upper);
            if (0 > lower && 0 < upper) {
                rho = Math.exp(-z * z / 2.0);
            } else if (upper < 0) {
                rho = Math.exp(((upper * upper) - (z * z)) / 2.0);
            } else if (0 < lower) {
                rho = Math.exp(((lower * lower) - (z * z)) / 2.0);
            }

            u = generator.getDouble();
            x = z;
        }

        return x;
    }

    public double rejectionSample() {
        // use a simple rejection sampling
        final Normal dist = new Normal(mean, sd);
        double val = dist.sample();

        while (val < lower || val > upper) {
            val = dist.sample();
        }
        return val;
    }

    @Getter
    private double mean;
    @Getter
    private double sd;
    private double lower;
    private double upper;
    private RNG generator;
}