• 703-743-9010
• info@oneoffcoder.com
• 7526 Old Linton Hall Rd, Gainesville VA, 20155

### Conditional Multivariate Gaussian Distribution

Learn how to estimate the expected values of a subset of variables given (or conditioned on) another subset with a conditional multivariate gaussian distribution.

# Intro¶

In this notebook we will learn about the conditional multivariate normal (MVN) distribution. In particular, we want to estimate the expected value (or the mean) of some subset of variables given that another subset has been conditioned on. Though the notation is quasi-dense, it is not terribly difficult to produce a conditional MVN from a marginal MVN distribution.

# Case 1¶

• $X_0 \rightarrow X_1$
In [1]:
import numpy as np
from numpy.random import normal

In [2]:
N = 10000
x0 = normal(0, 1, N)
x1 = normal(1 + 2 * x0, 1, N)

X = np.hstack([x0.reshape(-1, 1), x1.reshape(-1, 1)])
M = np.mean(X, axis=0)
S = np.cov(X.T)

print(X.shape)
print(M.shape)
print(S.shape)
print('mean', M)
print('cov', S)

(10000, 2)
(2,)
(2, 2)
mean [-0.01217851  0.97154671]
cov [[0.99624904 1.98586371]
[1.98586371 4.94146349]]

In [3]:
M[0] + S[0,1] / S[1,1] * (0.5 - M[1])

Out[3]:
-0.20168259547735531
In [4]:
M[1] + S[1,0] / S[0,0] * (0.5 - M[0])

Out[4]:
1.9924929599471097
In [5]:
S[0,0] - S[0,1] / S[1,1] * S[1,0]

Out[5]:
0.19817481229753187
In [6]:
S[1,1] - S[1,0] / S[0,0] * S[1,0]

Out[6]:
0.9829606417418284

# Case 2¶

• $X_0 \rightarrow X_1 \rightarrow X_2$
In [7]:
from collections import namedtuple
from numpy.linalg import inv
import warnings

warnings.filterwarnings('ignore')
COV = namedtuple('COV', 'C11 C12 C21 C22 C22I')

def to_row_indices(indices):
return [[i] for i in indices]

def to_col_indices(indices):
return indices

def get_covariances(i1, i2, S):
r = to_row_indices(i1)
c = to_col_indices(i1)
C11 = S[r,c]

r = to_row_indices(i1)
c = to_col_indices(i2)
C12 = S[r,c]

r = to_row_indices(i2)
c = to_col_indices(i1)
C21 = S[r,c]

r = to_row_indices(i2)
c = to_col_indices(i2)
C22 = S[r,c]

C22I = inv(C22)

return COV(C11, C12, C21, C22, C22I)

def compute_means(a, M, C, i1, i2):
a = np.array([2.0])
return M[i1] + C.C12.dot(C.C22I).dot(a - M[i2])

def compute_covs(C):
return C.C11 - C.C12.dot(C.C22I).dot(C.C21)

def update_mean(m, a, M, i1, i2):
v = np.copy(M)
for i, mu in zip(i1, m):
v[i] = mu
for i, mu in zip(i2, a):
v[i] = mu
return v

def update_cov(c, S, i1, i2):
m = np.copy(S)
rows, cols = c.shape
for row in range(rows):
for col in range(cols):
m[i1[row],i1[col]] = c[row,col]
for i in i2:
m[i,i] = 0.01
return m

def update_mean_cov(v, iv, M, S):
if v is None or iv is None or len(v) == 0 or len(iv) == 0:
return np.copy(M), np.copy(S)
i2 = iv.copy()
i1 = [i for i in range(S.shape[0]) if i not in i2]

C = get_covariances(i1, i2, S)
m = compute_means(v, M, C, i1, i2)
c = compute_covs(C)
M_u = update_mean(m, v, M, i1, i2)
S_u = update_cov(c, S, i1, i2)
return M_u, S_u

In [8]:
N = 10000
x0 = normal(0, 1, N)
x1 = normal(1 + 2 * x0, 1, N)
x2 = normal(1 + 2 * x1, 1, N)

X = np.hstack([x0.reshape(-1, 1), x1.reshape(-1, 1), x2.reshape(-1, 1)])
M = np.mean(X, axis=0)
S = np.cov(X.T)

print('mean', M)
print('>')
print('cov', S)
print('>')
print('corr', np.corrcoef(X.T))

mean [-0.01313569  0.96903015  2.93044396]
>
cov [[ 0.99331094  1.96882311  3.92097634]
[ 1.96882311  4.87809341  9.73351352]
[ 3.92097634  9.73351352 20.43568728]]
>
corr [[1.         0.89441492 0.87027596]
[0.89441492 1.         0.9748773 ]
[0.87027596 0.9748773  1.        ]]

In [9]:
M_u, S_u = update_mean_cov(np.array([2.0]), [1], M, S)

print('mean', M_u)
print('>')
print('cov', S_u)
print('>')
print('corr', np.corrcoef(np.random.multivariate_normal(M_u, S_u, N*10).T))

mean [0.40296894 2.         4.98759174]
>
cov [[ 1.98683996e-01  1.96882311e+00 -7.51882725e-03]
[ 1.96882311e+00  1.00000000e-02  9.73351352e+00]
[-7.51882725e-03  9.73351352e+00  1.01390146e+00]]
>
corr [[ 1.         -0.02477834  0.77462271]
[-0.02477834  1.          0.05241713]
[ 0.77462271  0.05241713  1.        ]]


# Case 2¶

• $X_0 \leftarrow X_1 \rightarrow X_2$
In [10]:
N = 10000

x1 = normal(0, 1, N)
x0 = normal(1 + 4.0 * x1, 1, N)
x2 = normal(1 + 2.0 * x1, 1, N)

X = np.hstack([x0.reshape(-1, 1), x1.reshape(-1, 1), x2.reshape(-1, 1)])
M = np.mean(X, axis=0)
S = np.cov(X.T)

print('mean', M)
print('>')
print('cov', S)
print('>')
print('corr', np.corrcoef(X.T))

mean [ 0.96073372 -0.01011371  0.98507138]
>
cov [[16.71912352  3.92724763  7.87120405]
[ 3.92724763  0.9809372   1.96828134]
[ 7.87120405  1.96828134  4.9446487 ]]
>
corr [[1.         0.96975255 0.86569856]
[0.96975255 1.         0.8937146 ]
[0.86569856 0.8937146  1.        ]]

In [11]:
M_u, S_u = update_mean_cov(np.array([2.0]), [1], M, S)

print('mean', M_u)
print('>')
print('cov', S_u)
print('>')
print('corr', np.corrcoef(np.random.multivariate_normal(M_u, S_u, N*10).T))

mean [9.00835823 2.         5.01842773]
>
cov [[ 0.99612527  3.92724763 -0.00894181]
[ 3.92724763  0.01        1.96828134]
[-0.00894181  1.96828134  0.99523029]]
>
corr [[1.         0.11353019 0.5455407 ]
[0.11353019 1.         0.08049836]
[0.5455407  0.08049836 1.        ]]


# Case 2¶

• $X_0 \rightarrow X_1 \leftarrow X_2$
In [12]:
N = 10000

x0 = normal(0, 1, N)
x2 = normal(0, 1, N)
x1 = normal(1 + 2 * x0 + 3 * x2, 1, N)

X = np.hstack([x0.reshape(-1, 1), x1.reshape(-1, 1), x2.reshape(-1, 1)])
M = np.mean(X, axis=0)
S = np.cov(X.T)

print('mean', M)
print('>')
print('cov', S)
print('>')
print('corr', np.corrcoef(X.T))

mean [3.11610448e-03 1.00024244e+00 4.90651340e-05]
>
cov [[ 9.88652679e-01  1.94567309e+00 -7.66954037e-03]
[ 1.94567309e+00  1.36958157e+01  2.92099886e+00]
[-7.66954037e-03  2.92099886e+00  9.73747382e-01]]
>
corr [[ 1.          0.52875446 -0.00781672]
[ 0.52875446  1.          0.79986056]
[-0.00781672  0.79986056  1.        ]]

In [13]:
M_u, S_u = update_mean_cov(np.array([2.0]), [1], M, S)

print('mean', M_u)
print('>')
print('cov', S_u)
print('>')
print('corr', np.corrcoef(np.random.multivariate_normal(M_u, S_u, N*10).T))

mean [0.14514499 2.         0.21327409]
>
cov [[ 0.71224389  1.94567309 -0.42263635]
[ 1.94567309  0.01        2.92099886]
[-0.42263635  2.92099886  0.35076628]]
>
corr [[1.         0.01068818 0.53011777]
[0.01068818 1.         0.00570366]
[0.53011777 0.00570366 1.        ]]

In [17]:
for m in M_u:
print(m)

0.00311610448033705
1.0002424415942128
4.906513396201674e-05

In [ ]: