TruncatedMultivariateNormalDistribution.java

/*
 * Copyright 2015 University of Glasgow.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package broadwick.statistics.distributions;

import broadwick.math.Matrix;
import broadwick.math.Vector;

/**
 * Sample from a truncated Multivariate Normal (Gaussian) Distribution, i.e. a normal distribution whose value is
 * bounded either above, below or both.
 * <p>
 * References:</p><p>
 * <ul>
 * <li><a href="http://en.wikipedia.org/wiki/Multivariate_normal_distribution"> Multivariatenormal Distribution</a></li>
 * </ul>
 * </p>
 */
public class TruncatedMultivariateNormalDistribution implements ContinuousMultivariateDistribution {

    /**
     * Create an instance of a multivariate distribution that is bounded.
     * @param means       the [mathematical] vector of the means of each variable.
     * @param covariances the [mathematical] matrix of the covariances of each variable.
     * @param lb          the [mathematical] vector of the lower bounds of each variable.
     * @param ub          the [mathematical] vector of the upper bounds of each variable.
     */
    public TruncatedMultivariateNormalDistribution(final Vector means, final Matrix covariances,
                                                   final Vector lb, final Vector ub) {
        
        if (means.length() != lb.length() || lb.length() != ub.length() 
            || ub.length() != covariances.rows() || covariances.rows() != covariances.columns()) {
                throw new IllegalArgumentException("The lengths of the input vectors must be equal and the covariances matirx must be square with the same size as the input vectors.");
        }
        
        for (int i=0; i< means.length(); i++) {
            if (means.element(i) < lb.element(i) || means.element(i) > ub.element(i)) {
                throw new IllegalArgumentException("The means of the distribution must lie between the lower and upper bounds");
            }
            
            if (lb.element(i) > ub.element(i)) {
                throw new IllegalArgumentException("The lower bound of the distribution must be less than the upper bound");
            }
        }
        
        this.means = means;
        this.covariances = covariances;
        this.upperBounds = ub;
        this.lowerBounds = lb;
        this.n = means.toArray().length;
    }

    @Override
    public Vector sample() {

        Vector proposal = new Vector(n);

        // Gibbs sampler of Christian Robert (arxiv:0907.4010v1 [stat.CO])
        // Robert, C.P, "Simulation of truncated normal variables",
        //   Statistics and Computing, pp. 121-125 (1995).
        Matrix covInv = covariances.inverse();
        for (int i = 0; i < n; i++) {
            // get the (n-1) vector from the i-th column of the covariances matrix, removing the i-th row.
            Matrix sigmaI = new Matrix(n - 1, 1);
            for (int j = 0; j < n - 1; j++) {
                if (j != i) {
                    sigmaI.setEntry(j, 0, covariances.element(j, i));
                }
            }

            // Get the inverse of the (n-1)(n-1) matrix obtained from the covariance matrix removing the
            // ith row and column.
            Matrix sigmaIinv = new Matrix(n - 1, n - 1);
            for (int j = 0; j < n - 1; j++) {
                if (j != i) {
                    for (int k = 0; k < n - 1; k++) {
                        if (k != i) {
                            sigmaIinv.setEntry(j, k, covInv.element(j, k));
                        }
                    }
                }
            }

            // x_i is the (n-1) vector of components not being updated at this iteration.
            Matrix xI = new Matrix(1, n - 1);
            Matrix muI = new Matrix(1, n - 1);
            for (int j = 0; j < n - 1; j++) {
                if (j != i) {
                    xI.setEntry(0, j, means.element(j));
                    muI.setEntry(0, j, means.element(j));
                }
            }

            // mui is E(xi|x_i)
            //  mui = mu(i) + sigmai_i * sigma_i_iInv * (x_i - mu_i);
            Matrix diff = xI.transpose().subtract(muI.transpose());
            double mui = means.element(i) + sigmaI.transpose().multiply(sigmaIinv).multiply(diff).element(0, 0);
            double s2i = covariances.element(i, i) - sigmaI.transpose().multiply(sigmaIinv).multiply(sigmaI).element(0, 0);

            // now draw from the 1-d normal truncated to [lb, ub]
            TruncatedNormalDistribution dist = new TruncatedNormalDistribution(mui, Math.sqrt(s2i),
                                                                               lowerBounds.element(i),
                                                                               upperBounds.element(i));
            proposal.setEntry(i, dist.sample());

        }

        /*
         // Use rejection sampling to find the mvn variate.
         boolean proposedStepOutOfBounds;
         int attempts = 0;
         Vector proposal = new Vector(n);
         do {
         proposedStepOutOfBounds = false;
         proposal = mvn.sample();

         for (int i = 0; i < n; i++) {
         if (proposal.element(i) < lowerBounds.element(i)
         || proposal.element(i) > upperBounds.element(i)
         || Math.abs(proposal.element(i)) < 0.00001) {
         proposedStepOutOfBounds = true;
         System.out.println(String.format("Failed to pick proposal [attempt %d]. %f<%f<%f", attempts,
         lowerBounds.element(i), proposal.element(i), upperBounds.element(i)));
         break;
         }
         }
         attempts++;
         } while (proposedStepOutOfBounds);
         */
        return proposal;
    }

    private final int n;
    private final Vector lowerBounds;
    private final Vector upperBounds;
    private final Vector means;
    private final Matrix covariances;
}