Bhattacharya Distance: Dirichlet and Multinomial distribution

Derivation and Code
distribution
dirichlet
statistics
Author

Deebul Nair

Published

February 28, 2023

Open In Colab

Logarthmic form of Bhattachrya distance between Dirichlet densities

  • [1] T. W. Rauber, A. Conci, T. Braun and K. Berns, “Bhattacharyya probabilistic distance of the Dirichlet density and its application to Split-and-Merge image segmentation,” 2008 15th International Conference on Systems, Signals and Image Processing, Bratislava, Slovakia, 2008, pp. 145-148, doi: 10.1109/IWSSIP.2008.4604388.

  • [2] Nielsen, Frank and Richard Nock. “Cumulant-free closed-form formulas for some common (dis)similarities between densities of an exponential family.” ArXiv abs/2003.02469 (2020): n. page 12

Equation from [1]

\[ \eqalignno{ & \qquad B(\alpha_{a},\alpha_{b}) = -ln\rho(\alpha_{a},\alpha_{b})\cr & \quad=- ln\Gamma( \sum_{k=1}^{d}{\alpha_{ak}+\alpha_{bk}\over 2}) + {1\over 2}\{\sum_{k=1}^{d}ln \Gamma(\alpha_{ak})+ \sum_{k=1}^{d}ln \Gamma(\alpha_{bk})\} - \sum_{k=1}^{d}ln \Gamma({\alpha_{ak}+\alpha_{bk}\over 2}) - {1\over 2} \{ ln \Gamma(\vert \alpha_{a}\vert)+ln\Gamma(\vert\alpha_{b}\vert \} } \]

Same Equation from [2]

\[ B(\alpha,\beta) = B(\alpha) + B(\beta) - B(\alpha+\beta) \] where \[ B(\alpha) = \sum\log\Gamma{\alpha} - \log \Gamma{\sum(\alpha)} \]

Equation from [2] is easy to code and debung than [1] so we implemented it

import numpy as np
from scipy.special import gammaln
from scipy.special import digamma

def logBalpha(alpha: np.ndarray):
  return gammaln(alpha).sum() - gammaln(alpha.sum())

def bhattacharya_distance_dirichlet(alpha: np.ndarray, beta: np.ndarray):
  #todo check if values are positive and above 1
  return logBalpha(alpha) + logBalpha(beta) - logBalpha(alpha+beta)
for i in range(10): print (logBalpha(np.ones(i))) 
-inf
0.0
0.0
-0.6931471805599453
-1.791759469228055
-3.1780538303479458
-4.787491742782046
-6.579251212010101
-8.525161361065415
-10.60460290274525

Test 1: Minimum distance for different number of alphas

The minimum distance starts from 0 and reduces as number of classes increases.

If one class is confident then the distance between them will get lower than the no information class distance . See test 4 where for 5 classes the distance goes lower than 6.45 as seen below.

for i in range(1, 10): 
  print (i, np.ones(i), bhattacharya_distance_dirichlet(np.ones(i), np.ones(i))) 
1 [1.] 0.0
2 [1. 1.] 1.791759469228055
3 [1. 1. 1.] 3.401197381662155
4 [1. 1. 1. 1.] 4.941642422609305
5 [1. 1. 1. 1. 1.] 6.445719819385578
6 [1. 1. 1. 1. 1. 1.] 7.927324360309795
7 [1. 1. 1. 1. 1. 1. 1.] 9.393661429103219
8 [1. 1. 1. 1. 1. 1. 1. 1.] 10.848948661710065
9 [1. 1. 1. 1. 1. 1. 1. 1. 1.] 12.295867644646389

Test 2: With high values of alpha effect on distance

Here we observe that the distance between same alphas keeps on decreasing as the value of alphas keep increasing.

Higher the alpha lower the distance . Also negative distance

for i in range(1,100, 10):
  print (i, np.ones(5)*i, bhattacharya_distance_dirichlet(np.ones(5)*i, np.ones(5)*i))
1 [1. 1. 1. 1. 1.] 6.445719819385578
11 [11. 11. 11. 11. 11.] 1.1255028731251286
21 [21. 21. 21. 21. 21.] -0.19370880228223086
31 [31. 31. 31. 31. 31.] -0.9818529925794337
41 [41. 41. 41. 41. 41.] -1.5457429140878958
51 [51. 51. 51. 51. 51.] -1.9851193053912084
61 [61. 61. 61. 61. 61.] -2.3451443271551398
71 [71. 71. 71. 71. 71.] -2.650141672341306
81 [81. 81. 81. 81. 81.] -2.914723500702621
91 [91. 91. 91. 91. 91.] -3.1483581907859843

Test 3: More confident value distance increases

When 1 valu of alpha is increases to show confidence increase on 1 class the distance with respect to no information keeps increasing

for i in range(1,100, 10):
  p1 = np.ones(5)
  p2 = p1.copy()
  p2[0] = i
  print (i, p1, p2, bhattacharya_distance_dirichlet(p1, p2))
1 [1. 1. 1. 1. 1.] [1. 1. 1. 1. 1.] 6.445719819385578
11 [1. 1. 1. 1. 1.] [11.  1.  1.  1.  1.] 8.572713901314495
21 [1. 1. 1. 1. 1.] [21.  1.  1.  1.  1.] 10.249733300984314
31 [1. 1. 1. 1. 1.] [31.  1.  1.  1.  1.] 11.438891683612685
41 [1. 1. 1. 1. 1.] [41.  1.  1.  1.  1.] 12.356846899935736
51 [1. 1. 1. 1. 1.] [51.  1.  1.  1.  1.] 13.10383713518198
61 [1. 1. 1. 1. 1.] [61.  1.  1.  1.  1.] 13.733421146565181
71 [1. 1. 1. 1. 1.] [71.  1.  1.  1.  1.] 14.277449847443624
81 [1. 1. 1. 1. 1.] [81.  1.  1.  1.  1.] 14.75637687058833
91 [1. 1. 1. 1. 1.] [91.  1.  1.  1.  1.] 15.18411005351209

Test 4: Distance between single confident class

As both the classes converge to same alpha values the distance reduces.

for i in range(1,100, 10):
  p1 = np.ones(5)
  p2 = p1.copy()
  p2[0] = i
  p1[0] = 91
  print (i, p1, p2, bhattacharya_distance_dirichlet(p1, p2))
1 [91.  1.  1.  1.  1.] [1. 1. 1. 1. 1.] 15.18411005351209
11 [91.  1.  1.  1.  1.] [11.  1.  1.  1.  1.] 9.072449120087896
21 [91.  1.  1.  1.  1.] [21.  1.  1.  1.  1.] 7.434934213898558
31 [91.  1.  1.  1.  1.] [31.  1.  1.  1.  1.] 6.625978812347682
41 [91.  1.  1.  1.  1.] [41.  1.  1.  1.  1.] 6.165455265460906
51 [91.  1.  1.  1.  1.] [51.  1.  1.  1.  1.] 5.889528847642993
61 [91.  1.  1.  1.  1.] [61.  1.  1.  1.  1.] 5.7237287914577735
71 [91.  1.  1.  1.  1.] [71.  1.  1.  1.  1.] 5.628589843464567
81 [91.  1.  1.  1.  1.] [81.  1.  1.  1.  1.] 5.581062260628471
91 [91.  1.  1.  1.  1.] [91.  1.  1.  1.  1.] 5.566743973084726

Bhattacharya Distance Multinomial Distirbution

def bhattacharya_distiance_multinomial(p1, p2):
  return -np.log(sum(np.sqrt(p1*p2)))
for i in range(1,10):
  print (i, np.ones(i)/i, bhattacharya_distiance_multinomial(np.ones(i)/i, np.ones(i)/i))
1 [1.] -0.0
2 [0.5 0.5] -0.0
3 [0.33333333 0.33333333 0.33333333] -0.0
4 [0.25 0.25 0.25 0.25] -0.0
5 [0.2 0.2 0.2 0.2 0.2] -0.0
6 [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] 1.1102230246251565e-16
7 [0.14285714 0.14285714 0.14285714 0.14285714 0.14285714 0.14285714
 0.14285714] 2.2204460492503136e-16
8 [0.125 0.125 0.125 0.125 0.125 0.125 0.125 0.125] -0.0
9 [0.11111111 0.11111111 0.11111111 0.11111111 0.11111111 0.11111111
 0.11111111 0.11111111 0.11111111] -2.2204460492503128e-16
for i in range(1,100, 10):
  p = (np.ones(5)*i )/ sum(np.ones(5)*i)
  print (i, p, bhattacharya_distiance_multinomial(p, p))
1 [0.2 0.2 0.2 0.2 0.2] -0.0
11 [0.2 0.2 0.2 0.2 0.2] -0.0
21 [0.2 0.2 0.2 0.2 0.2] -0.0
31 [0.2 0.2 0.2 0.2 0.2] -0.0
41 [0.2 0.2 0.2 0.2 0.2] -0.0
51 [0.2 0.2 0.2 0.2 0.2] -0.0
61 [0.2 0.2 0.2 0.2 0.2] -0.0
71 [0.2 0.2 0.2 0.2 0.2] -0.0
81 [0.2 0.2 0.2 0.2 0.2] -0.0
91 [0.2 0.2 0.2 0.2 0.2] -0.0
for i in range(1,100, 10):
  p1 = np.ones(5)
  p2 = p1.copy()
  p2[0] = i
  p1 = p1/5
  p2 = p2/sum(p2)
  print (i, p1, p2, bhattacharya_distiance_multinomial(p1, p2))
1 [0.2 0.2 0.2 0.2 0.2] [0.2 0.2 0.2 0.2 0.2] -0.0
11 [0.2 0.2 0.2 0.2 0.2] [0.73333333 0.06666667 0.06666667 0.06666667 0.06666667] 0.1685949293453037
21 [0.2 0.2 0.2 0.2 0.2] [0.84 0.04 0.04 0.04 0.04] 0.26442280265424684
31 [0.2 0.2 0.2 0.2 0.2] [0.88571429 0.02857143 0.02857143 0.02857143 0.02857143] 0.32399341768738754
41 [0.2 0.2 0.2 0.2 0.2] [0.91111111 0.02222222 0.02222222 0.02222222 0.02222222] 0.36594403262269887
51 [0.2 0.2 0.2 0.2 0.2] [0.92727273 0.01818182 0.01818182 0.01818182 0.01818182] 0.3977150973709143
61 [0.2 0.2 0.2 0.2 0.2] [0.93846154 0.01538462 0.01538462 0.01538462 0.01538462] 0.4229448201200781
71 [0.2 0.2 0.2 0.2 0.2] [0.94666667 0.01333333 0.01333333 0.01333333 0.01333333] 0.44365990820723616
81 [0.2 0.2 0.2 0.2 0.2] [0.95294118 0.01176471 0.01176471 0.01176471 0.01176471] 0.46109522700067157
91 [0.2 0.2 0.2 0.2 0.2] [0.95789474 0.01052632 0.01052632 0.01052632 0.01052632] 0.47605403848305355

End of Blog


Trial codes not working

a = np.array([1, 1, 1])
b = np.array([100,1,1])
print (-gammaln(((a + b )/2 ).sum()))


(gammaln(a).sum() + gammaln(b).sum())/2

-gammaln((a + b )/2).sum()

-(gammaln(p1).sum() + gammaln(p2).sum() )/2
-154.38281063467164
-12.918839126312893
a = np.array([1, 1, 1])
b = np.array([100,1,1])
print (-gammaln(sum((a + b )/2 )))


print ((gammaln(a).sum() + gammaln(b).sum())/2)

print (-gammaln((a + b )/2).sum())

print (-(gammaln(a) + gammaln(b) )/2)
-154.38281063467164
179.5671026847877
-146.51925549072064
[-179.56710268   -0.           -0.        ]
import numpy as np
from scipy.special import gammaln
from scipy.special import digamma

def bhattacharya_distance_dirichlet(alpha: np.ndarray, beta: np.ndarray,):

  #return -np.log2(sum(np.add(alpha, beta)/2)) + \
  #        0.5*sum(np.log2(alpha)) + 0.5*sum(np.log2(beta)) - \
  #        sum(np.log2(np.add(alpha, beta) / 2)) - \
  #        0.5 * np.log2(alpha).sum() + \
  #       0.5 * np.log2(beta).sum()

  #return  - gammaln(sum(np.add(alpha, beta)/2)) + \
  #  0.5*(sum(gammaln(alpha)) + sum(gammaln(beta))) - \
  #  sum(gammaln(np.add(alpha, beta) / 2)) -  \
  #  0.5*(gammaln(alpha).sum() + gammaln(beta).sum())

  return (gammaln(alpha).sum() + gammaln(beta).sum())/2 - \
         gammaln((alpha + beta )/2).sum() - \
         (gammaln(alpha) + gammaln(beta) )/2
a1 = [1, 1, 1, 1, 1]
a2 = [1, 1, 1, 1, 1]

bhattacharya_distance_dirichlet(a1, a2)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-124-8efb0820f155> in <module>
      2 a2 = [1, 1, 1, 1, 1]
      3 
----> 4 bhattacharya_distance_dirichlet(a1, a2)

<ipython-input-123-aa14d3c37007> in bhattacharya_distance_dirichlet(alpha, beta)
     17 
     18   return (gammaln(alpha).sum() + gammaln(beta).sum())/2 - \
---> 19          gammaln((alpha + beta )/2).sum() - \
     20          (gammaln(alpha) + gammaln(beta) )/2

TypeError: unsupported operand type(s) for /: 'list' and 'int'
a3 = [100, 1, 1, 1, 1]
bhattacharya_distance_dirichlet(a1, a3)
0.0
a3 = [100, 100, 100, 100, 100]
bhattacharya_distance_dirichlet(a1, a3)
0.0
for i in range(1,10):
  print (i, np.ones(i), bhattacharya_distance_dirichlet(np.ones(i), np.ones(i)))
1 [1.] [0.]
2 [1. 1.] [0. 0.]
3 [1. 1. 1.] [0. 0. 0.]
4 [1. 1. 1. 1.] [0. 0. 0. 0.]
5 [1. 1. 1. 1. 1.] [0. 0. 0. 0. 0.]
6 [1. 1. 1. 1. 1. 1.] [0. 0. 0. 0. 0. 0.]
7 [1. 1. 1. 1. 1. 1. 1.] [0. 0. 0. 0. 0. 0. 0.]
8 [1. 1. 1. 1. 1. 1. 1. 1.] [0. 0. 0. 0. 0. 0. 0. 0.]
9 [1. 1. 1. 1. 1. 1. 1. 1. 1.] [0. 0. 0. 0. 0. 0. 0. 0. 0.]
for i in range(1,100, 10):
  print (i, np.ones(5)*i, bhattacharya_distance_dirichlet(np.ones(5)*i, np.ones(5)*i))
1 [1. 1. 1. 1. 1.] [0. 0. 0. 0. 0.]
11 [11. 11. 11. 11. 11.] [-15.10441257 -15.10441257 -15.10441257 -15.10441257 -15.10441257]
21 [21. 21. 21. 21. 21.] [-42.33561646 -42.33561646 -42.33561646 -42.33561646 -42.33561646]
31 [31. 31. 31. 31. 31.] [-74.65823635 -74.65823635 -74.65823635 -74.65823635 -74.65823635]
41 [41. 41. 41. 41. 41.] [-110.32063971 -110.32063971 -110.32063971 -110.32063971 -110.32063971]
51 [51. 51. 51. 51. 51.] [-148.47776695 -148.47776695 -148.47776695 -148.47776695 -148.47776695]
61 [61. 61. 61. 61. 61.] [-188.62817342 -188.62817342 -188.62817342 -188.62817342 -188.62817342]
71 [71. 71. 71. 71. 71.] [-230.43904357 -230.43904357 -230.43904357 -230.43904357 -230.43904357]
81 [81. 81. 81. 81. 81.] [-273.67312429 -273.67312429 -273.67312429 -273.67312429 -273.67312429]
91 [91. 91. 91. 91. 91.] [-318.15263962 -318.15263962 -318.15263962 -318.15263962 -318.15263962]
import scipy 
from scipy.special import gammaln
from scipy.special import digamma

def KL_divergence(alpha, beta):
  return gammaln(sum(alpha)) - \
        gammaln(sum(beta)) - \
        sum(gammaln(alpha)) + \
        sum(gammaln(beta)) +  \
        (alpha - beta) * (digamma(alpha) - digamma(sum(alpha)))
for i in range(1,10):
  print (i, np.ones(i), KL_divergence(np.ones(i), np.ones(i)))
1 [1.] [0.]
2 [1. 1.] [0. 0.]
3 [1. 1. 1.] [0. 0. 0.]
4 [1. 1. 1. 1.] [0. 0. 0. 0.]
5 [1. 1. 1. 1. 1.] [0. 0. 0. 0. 0.]
6 [1. 1. 1. 1. 1. 1.] [0. 0. 0. 0. 0. 0.]
7 [1. 1. 1. 1. 1. 1. 1.] [0. 0. 0. 0. 0. 0. 0.]
8 [1. 1. 1. 1. 1. 1. 1. 1.] [0. 0. 0. 0. 0. 0. 0. 0.]
9 [1. 1. 1. 1. 1. 1. 1. 1. 1.] [0. 0. 0. 0. 0. 0. 0. 0. 0.]
for i in range(1,100, 10):
  print (i, np.ones(5)*i, KL_divergence(np.ones(5)*i, np.ones(5)*i))
1 [1. 1. 1. 1. 1.] [0. 0. 0. 0. 0.]
11 [11. 11. 11. 11. 11.] [0. 0. 0. 0. 0.]
21 [21. 21. 21. 21. 21.] [0. 0. 0. 0. 0.]
31 [31. 31. 31. 31. 31.] [0. 0. 0. 0. 0.]
41 [41. 41. 41. 41. 41.] [0. 0. 0. 0. 0.]
51 [51. 51. 51. 51. 51.] [0. 0. 0. 0. 0.]
61 [61. 61. 61. 61. 61.] [0. 0. 0. 0. 0.]
71 [71. 71. 71. 71. 71.] [0. 0. 0. 0. 0.]
81 [81. 81. 81. 81. 81.] [0. 0. 0. 0. 0.]
91 [91. 91. 91. 91. 91.] [0. 0. 0. 0. 0.]
scipy.d
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-11-4addd521ddf1> in <module>
----> 1 scipy.special.digamma

AttributeError: module 'scipy' has no attribute 'special'
p1 = np.ones(5)/5

bhattacharya_distiance_multinomial(p1,p1)
-0.0
def bhattacharya_distiance_multinomial_log(p1, p2):

  return -(gammaln(p1).sum() + gammaln(p2).sum() )/2
for i in range(1,10):
  print (i, np.ones(i), bhattacharya_distiance_multinomial_log(np.ones(i), np.ones(i)))
1 [1.] -0.0
2 [1. 1.] -0.0
3 [1. 1. 1.] -0.0
4 [1. 1. 1. 1.] -0.0
5 [1. 1. 1. 1. 1.] -0.0
6 [1. 1. 1. 1. 1. 1.] -0.0
7 [1. 1. 1. 1. 1. 1. 1.] -0.0
8 [1. 1. 1. 1. 1. 1. 1. 1.] -0.0
9 [1. 1. 1. 1. 1. 1. 1. 1. 1.] -0.0
for i in range(1,100, 10):
  p1 = np.ones(5)
  p2 = p1.copy()
  p2[0] = i
  p1 = p1/5
  p2 = p2/sum(p2)
  print (i, p1, p2, bhattacharya_distiance_multinomial_log(p1, p2))
1 [0.2 0.2 0.2 0.2 0.2] [0.2 0.2 0.2 0.2 0.2] -7.620319112153923
11 [0.2 0.2 0.2 0.2 0.2] [0.73333333 0.06666667 0.06666667 0.06666667 0.06666667] -9.26724932985341
21 [0.2 0.2 0.2 0.2 0.2] [0.84 0.04 0.04 0.04 0.04] -10.261942523017305
31 [0.2 0.2 0.2 0.2 0.2] [0.88571429 0.02857143 0.02857143 0.02857143 0.02857143] -10.927875722571079
41 [0.2 0.2 0.2 0.2 0.2] [0.91111111 0.02222222 0.02222222 0.02222222 0.02222222] -11.427687210720562
51 [0.2 0.2 0.2 0.2 0.2] [0.92727273 0.01818182 0.01818182 0.01818182 0.01818182] -11.827621143269788
61 [0.2 0.2 0.2 0.2 0.2] [0.93846154 0.01538462 0.01538462 0.01538462 0.01538462] -12.160926600816564
71 [0.2 0.2 0.2 0.2 0.2] [0.94666667 0.01333333 0.01333333 0.01333333 0.01333333] -12.44662759830056
81 [0.2 0.2 0.2 0.2 0.2] [0.95294118 0.01176471 0.01176471 0.01176471 0.01176471] -12.6966207038721
91 [0.2 0.2 0.2 0.2 0.2] [0.95789474 0.01052632 0.01052632 0.01052632 0.01052632] -12.918839126312893
import numpy as np


def bhattacharyya_distance(repr1: np.ndarray, repr2: np.ndarray) -> float:
    """Calculates Bhattacharyya distance (https://en.wikipedia.org/wiki/Bhattacharyya_distance)."""
    temp = np.sum([np.sqrt(p*q) for (p, q) in zip(repr1, repr2)])
    print (temp)
    value = - np.log(temp)
    if np.isinf(value):
        return 0
    return value
for i in range(1,10):
  print (i, np.ones(i)/i, bhattacharyya_distance(np.ones(i)/i, np.ones(i)/i))
1.0
1 [1.] -0.0
1.0
2 [0.5 0.5] -0.0
1.0
3 [0.33333333 0.33333333 0.33333333] -0.0
1.0
4 [0.25 0.25 0.25 0.25] -0.0
1.0
5 [0.2 0.2 0.2 0.2 0.2] -0.0
0.9999999999999999
6 [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667] 1.1102230246251565e-16
0.9999999999999998
7 [0.14285714 0.14285714 0.14285714 0.14285714 0.14285714 0.14285714
 0.14285714] 2.2204460492503136e-16
1.0
8 [0.125 0.125 0.125 0.125 0.125 0.125 0.125 0.125] -0.0
1.0
9 [0.11111111 0.11111111 0.11111111 0.11111111 0.11111111 0.11111111
 0.11111111 0.11111111 0.11111111] -0.0
p1 = np.ones(5)/5
np.sqrt(p1*p1)
array([0.2, 0.2, 0.2, 0.2, 0.2])
np.sum([np.sqrt(p*q) for (p, q) in zip(p1, p1)])
1.0