MultinomialDistribution.java
/*
* Copyright 2013 University of Glasgow.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package broadwick.statistics.distributions;
import broadwick.math.Factorial;
import java.io.Serializable;
import lombok.extern.slf4j.Slf4j;
/**
* Defines the Multinomial distribution with parameters <I>numSamples</I> and <I>probabilities</I><SUB>1</SUB>,
* ...,<I>probabilities</I><SUB>d</SUB>. The probability mass function is
* <DIV ALIGN="CENTER" CLASS="mathdisplay">
* <I>P</I>[<I>X</I> = (<I>x</I><SUB>1</SUB>,..., <I>x</I><SUB>d</SUB>)] =
* <I>numSamples</I>!∏<SUB>i=1</SUB><SUP>d</SUP><I>probabilities</I><SUB>i</SUB><SUP>x<SUB>i</SUB></SUP>/(∏<SUB>i=1</SUB><SUP>d</SUP><I>x</I><SUB>i</SUB>!),
* </DIV><P></P>
* where
* <SPAN CLASS="MATH">∑<SUB>i=1</SUB><SUP>d</SUP><I>x</I><SUB>i</SUB> = <I>numSamples</I></SPAN> and
* <SPAN CLASS="MATH">∑<SUB>i=1</SUB><SUP>d</SUP><I>probabilities</I><SUB>i</SUB> = 1</SPAN>.
* <p/>
*/
@Slf4j
public class MultinomialDistribution implements Serializable {
/**
* Creates a <TT>MultinomialDist</TT> object with parameters <SPAN CLASS="MATH"><I>numSamples</I></SPAN> and (<SPAN
* CLASS="MATH"><I>probabilities</I><SUB>1</SUB></SPAN>,...,<SPAN
* CLASS="MATH"><I>probabilities</I><SUB>d</SUB></SPAN>) such that
* <SPAN CLASS="MATH">∑<SUB>i=1</SUB><SUP>d</SUP><I>probabilities</I><SUB>i</SUB> = 1</SPAN>. We have
* <SPAN CLASS="MATH"><I>probabilities</I><SUB>i</SUB> =</SPAN> <TT>probabilities[i-1]</TT>.
* @param n the number of samples.
* @param p the array of probabilities
*/
public MultinomialDistribution(final int n, final double p[]) {
double sumP = 0.0;
if (n <= 0) {
throw new IllegalArgumentException("n <= 0");
}
if (p.length < 2) {
throw new IllegalArgumentException("p.length < 2");
}
this.numSamples = n;
this.dimension = p.length;
this.probabilities = new double[dimension];
for (int i = 0; i < dimension; i++) {
if ((p[i] < 0) || (p[i] > 1)) {
throw new IllegalArgumentException("p is not a probability vector");
}
this.probabilities[i] = p[i];
sumP += p[i];
}
if (Math.abs(sumP - 1.0) > 1E-15) {
throw new IllegalArgumentException("p is not a probability vector");
}
}
/**
* Returns the parameter <I>N</I> of this object.
* @return the total number of samples.
*/
public final int getN() {
return numSamples;
}
/**
* Returns the parameters (<I>probabilities</I><SUB>1</SUB>,...,<I>probabilities</I><SUB>d</SUB>) of this object.
* @return the array of probabilities.
*/
public final double[] getP() {
return probabilities.clone();
}
/**
* Returns the probability mass function
* <SPAN CLASS="MATH"><I>p</I>(<I>x</I><SUB>1</SUB>, <I>x</I><SUB>2</SUB>,…, <I>x</I><SUB>d</SUB>)</SPAN>,
* which should be a real number in <SPAN CLASS="MATH">[0, 1]</SPAN>.
* @param x value at which the mass function must be evaluated
* @return the mass function evaluated at <TT>x</TT>
*/
public final double prob(final int x[]) {
double sumXFact = 0.0;
int sumX = 0;
double sumPX = 0.0;
if (x.length != probabilities.length) {
throw new IllegalArgumentException("x and p must have the same dimension");
}
for (int i = 0; i < probabilities.length; i++) {
sumX += x[i];
sumXFact += Factorial.lnFactorial(x[i]);
if (probabilities[i] > 1E-15) {
// just in case probabilities[i] == 0.0
sumPX += (x[i] * Math.log(probabilities[i]));
}
}
if (sumX != numSamples) {
throw new IllegalArgumentException(String.format("Multinomial distribution error: Sum_x [%d] != number of samples [%d].", sumX, numSamples));
} else {
return Math.exp(Factorial.lnFactorial(numSamples) - sumXFact + sumPX);
}
}
@Override
public final String toString() {
final StringBuilder sb = new StringBuilder();
sb.append("p(x) = ");
for (int i = 0; i < probabilities.length; i++) {
sb.append(probabilities[i]).append(" ");
}
return sb.toString();
}
private int numSamples;
private double[] probabilities;
private int dimension;
}