Probability of $X$ taking values in interval $\;\;[ x - \epsilon \;\;,\;\; x + \epsilon ]$ where $\epsilon \gt 0$ is half of the bin width
\begin{align} P_{\text{mid}} & = P[\;\;(x - \epsilon)\;\; \le \;\;X\;\; \le \;\;(x + \epsilon)\;\; ] \\ & = P[\;\;X\;\;\le\;\;(x + \epsilon)\;\;]\;\; - \;\;P[\;\;X\;\;\le \;\;(x - \epsilon)\;\; ] \\ & = \sigma( \frac{x + \epsilon-\mu}{s} ) - \sigma(\frac{x - \epsilon-\mu}{s} ) \end{align}
We are assuming there are N components in logistic mixture. If $\mu_i, s_i$ are parameters of $i^{th}$ component, Probability of $X$ taking values in the interval $[ x - \epsilon \;\;,\;\; x + \epsilon ]$ under $i_{th}$ mixture component is given by;
$$ P_{x}^{i} \;=\; P[\;\;(x - \epsilon)\;\; \le \;\;X\;\; \le \;\;(x + \epsilon)\;\; ]_{i^{th}\; \text{mixture component}} \;=\; \sigma( \frac{x + \epsilon - \mu_i}{s_i}) \;\;-\;\; \sigma( \frac{x - \epsilon - \mu_i}{s_i})$$
Collectively, N mixture components would assign following probability to interval/bin $[ x - \epsilon \;\;,\;\; x + \epsilon ]$ ;
$$ P_x \;=\; P[\;\;(x - \epsilon)\;\; \le \;\;X\;\; \le \;\;(x + \epsilon)\;\; ]_{\text{all mixture components}} \;=\; \sum_{i=1}^{N} w_{i}P_{x}^{i} $$
$$ loss \;=\; - \sum_{k=1}^{K} log( P_{x_k} ) $$
\begin{align*} P_x & = \sum_{i=1}^{N} w_{i}P_{x}^{i} \\ & = \sum_{i=1}^{N} \; e^{ log\{\;w_{i}\;.\;P_{x}^{i}\;\} } \\ & = \sum_{i=1}^{N} \; e^{ log\{\;w_{i}\;\}\;+\;log\{\;P_{x}^{i}\;\} } \\ & = \sum_{i=1}^{N} \; e^{\alpha_i}\\\\ log(P_x) & = log(\;\sum_{i=1}^{N} \; e^{\alpha_i}\;)\\ & = log(\;e^{\alpha_1}+e^{\alpha_2}+\ldots+e^{\alpha_N}\;)\\ & = \text{log sum exponent}\;( \alpha_1 , \alpha_2, \ldots, \alpha_N)\\ & = \text{LSE}\;( \alpha_1 , \alpha_2, \ldots, \alpha_N)\\\\\\\\ loss & = - \sum_{k=1}^{K} log( P_{x_{k}} ) \end{align*}
To handle extreme case when 'prob_mid_case' << 1e-5
This scenario can happen when network provides mean $\mu$ that is way off data point $x$
e.g network outputs $\mu$=1000 ans s=1 but $max$(train_data) = 5 and $min$(train_data) = -5
Under these distribution parameters ($\mu$=1000, s=1), extremely low probability would be assigned to interval/bin $[ x - \epsilon \;\;,\;\; x + \epsilon ]$.
When we take log of this extremely low probability, NaNs or -Infinity will occur.
To solve this problem, we can approximate the integral (area under the PDF) by taking the centered PDF of
the logistic distribution and multiply it by one bin-width interval on the support of train_data.
\begin{align} f(x) & = \frac{ e^{-(x-\mu)/s} }{ s(1 + e^{-(x-\mu)/s} )^2 } \\ log\{f(x)\} & = -(x-\mu)/s - log(s) - log\{ \;(1 + e^{-(x-\mu)/s} )^2\; \} \\ log\{f(x)\} & = -(x-\mu)/s - log(s) - 2log\{ ( 1 + e^{-(x-\mu)/s} ) \} \\ log\{f(x)\} & = -(x-\mu)/s - log(s) - 2softplus\{ -(x-\mu)/s \} \end{align}
approximate area under one-bin wide PDF curve = $ f(x) \times \text{binWidth} $
$\large log\{f(x) \times \text{binWidth} \} = log\{f(x)\} + log\{ \text{binWidth} \} $
Probability of value $x$ lying in interval $\;\;[ x - \epsilon \;\;,\;\; +\infty )$ where $\epsilon \gt 0$ is half of the bin width
$\large\;\; P_{\text{right_edge}} = P[\;\;(x - \epsilon)\;\; \le \;\;X\;\; \lt \;\; +\infty \;\; ]$
$\large = P[\;\;X\;\;\lt\;\;+\infty\;\;]\;\; - \;\;P[\;\;X\;\;\le \;\;(x - \epsilon)\;\; ]$
$\large = \sigma( \frac{+\infty -\mu}{s} ) - \sigma(\frac{x - \epsilon-\mu}{s} )$
$\large = 1 - \sigma(\frac{x - \epsilon-\mu}{s} )$
$\large\;\; log\{P_{\text{right_edge}}\} $
$\large\;\; = log\{1 - \sigma(\frac{x - \epsilon-\mu}{s}) \} $
$\large\;\; = log\{1 - \frac{e^{\frac{x - \epsilon-\mu}{s}}}{1+e^{\frac{x - \epsilon-\mu}{s}}} \} $
$\large\;\; = log\{\frac{1}{1+e^{\frac{x - \epsilon-\mu}{s}}} \} $
$\large\;\; = -log\{1+e^{\frac{x - \epsilon-\mu}{s}} \} $
$\large\;\; = -softplus\{\frac{x - \epsilon-\mu}{s} \} $
Probability of value $x$ lying in interval $\;\;( -\infty\;\;,\;\; x + \epsilon ]$ where $\epsilon \gt 0$ is half of the bin width
$\large\;\; P_{\text{left_edge}} = P[\;\;-\infty\;\; \lt \;\;X\;\; \le \;\; (x + \epsilon) \;\; ]$
$\large = \;\;P[\;\;X\;\;\le \;\;(x + \epsilon)\;\; ] - P[\;\;X\;\;\lt\;\;-\infty\;\;]\;\; $
$\large = \sigma(\frac{x + \epsilon-\mu}{s}) - 0 $
$\large = \sigma(\frac{x + \epsilon-\mu}{s}) $
$\large\;\; log\{P_{\text{left_edge}}\} $
$\large\;\; = log\{\sigma(\frac{x + \epsilon-\mu}{s}) \} $
$\large\;\; = log\{\frac{e^{\frac{x + \epsilon-\mu}{s}}}{1+e^{\frac{x + \epsilon-\mu}{s}}} \} $
$\large\;\; = log\{e^{\frac{x + \epsilon-\mu}{s}}\} - log\{1+e^{\frac{x + \epsilon-\mu}{s}} \} $
$\large\;\; = \frac{x + \epsilon-\mu}{s}-log\{1+e^{\frac{x + \epsilon-\mu}{s}} \} $
$\large\;\; = \frac{x + \epsilon-\mu}{s} -softplus\{\frac{x + \epsilon-\mu}{s} \} $
def log_sum_exp(x, dim):
''' Numerically stable log_sum_exp(x)
LSE = log_sum_exponent
LSE(x1, x2, ... , xn) = log( exp(x1) + exp(x2) + ... + exp(xn) )
LSE(x1, x2, ... , xn) = x* + log( exp(x1-x*) + exp(x2-x*) + ... + exp(xn-x*) )
where x* = max{x1, x2, ... , xn}
'''
x_star , idx = torch.max(x, dim=dim, keepdim=True)
centered_x = x - x_star
exp_centered_x = torch.exp(centered_x)
log_sum_exp_centered_x = torch.log( torch.sum( exp_centered_x , dim=dim, keepdim=True) )
lse = x_star + log_sum_exp_centered_x
return lse
def log_softmax(x, dim):
''' Numerically stable log(softmax(x))
x = [x1, x2, ... , xn ]
softmax(x) = [ exp(x_1)/sum(exp(x)) , exp(x_2)/sum(exp(x)) , ... , exp(x_n)/sum(exp(x)) ]
Note: Element-wise take natural log
log( softmax(x) ) = log( [ exp(x_1)/sum(exp(x)) , exp(x_2)/sum(exp(x)) , ... , exp(x_n)/sum(exp(x)) ] )
= [ x1 - LSE(x) , x2 - LSE(x) , ... , xn - LSE(x)]
'''
x_star , idx = torch.max(x, dim=dim, keepdims=True)
lse_x = log_sum_exp(x, dim=dim)
log_softmax_x = x - lse_x
return log_softmax_x
def logistic_mixture_loss(model_output, target_data, target_data_range, n_bins):
'''
Inputs Arguments:
-------------------------------------------------------------------------------------
1) model_output: parameters of PDFs i.e. mean, log_scales, logit_weights (output of model after forward pass of input value) : shape ( batch_size, n_target_nodes , n_mixture , 3 )
2) target_data: target data to be used at the end of the Network for loss calculation : shape ( batch_size , 1 , n_target_nodes )
3) target_data_range: max(target_data) - min(target_data)
4) n_bins: number of bins
Returns:
-------------------------------------------------------------------------------------
1) log-loss
'''
batch_size , n_target_nodes, n_mixtures , _ = model_output.shape
#print("Batch Size : " , batch_size)
#print("N_mixtures : " , n_mixtures)
# Extract out each of the mixture parameters
# model_output[ BatchSize , N_TargetNodes , N_Mixtures , 3]
m = model_output[ :, :, :, 0] # mixture means
log_scales = model_output[ :, :, :, 1] # mixture log_scales
w_logits = model_output[ :, :, :, 2] # mixture raw weights, or logit_weights
log_s = torch.clamp(log_scales , min=-7) # imposing constraint on log_scales values, corresponding inv_s = e^(-log(s))
# = e^(-(-7))
# = 1096.63
# Hence, minimum value of s that can be obtained is (1/1096.63) = 0.0009118
inv_s = torch.exp(-log_s) # e^(-log(s)) = 1/s
log_w = log_softmax(w_logits, dim=2) # w_logits needs to be converted to valid weights that sum to one using softmax. also take log or normalized weights
x = target_data.reshape((batch_size , n_target_nodes , -1)) # shape ( batch_size , n_target_nodes , 1 )
#print("model_output shape :" , model_output.shape)
#print("Means shape :" , m.shape)
#print("Log_Scale shape :" , log_scales.shape)
#print("Logit_weight shape :" , w_logits.shape)
#print("Target Data Shape :" , x.shape)
#print("Target Data :" , x)
# There are total 'n_bins' number of bins. Bin index ranges from '0' to 'n_bins-1'
#
# range_of_train_data = max(train_data) - min(train_data)
# When support of train_data is divided into equally spaced 'n_bins', then
# width of one bin on the support of train_data is as follows ;
bin_width = target_data_range/(n_bins)
half_bin_width = bin_width/2.0
# right and left points of bin around data point x
bin_arg_plus = ( x + half_bin_width - m )*inv_s # shape ( batch_size , n_target_nodes, n_mixture )
bin_arg_minus = ( x - half_bin_width - m )*inv_s # shape ( batch_size , n_target_nodes, n_mixture )
# see documentation in markdown cell for derivation
Px_i_mid_case = torch.sigmoid(bin_arg_plus) - torch.sigmoid(bin_arg_minus)
# approximation of Px_i_mid_case when ( prob_mid_case << 1e-5 )
# see documentation in markdown cell for derivation and reason
log_pdf = -(x-m)*inv_s - log_s - 2*F.softplus( -(x-m)*inv_s )
# approximate area under one-bin wide PDF curve = pdf*bin_width
# log(pdf * bin_width ) = log(pdf) + log( bin_width )
log_Px_i_mid_case_approximate = log_pdf + np.log(bin_width)
# see documentation in markdown cells for follwing lines
log_Px_i = torch.where( Px_i_mid_case > 1e-5, torch.log(torch.clamp(Px_i_mid_case, min=1e-12)),
log_Px_i_mid_case_approximate
)
alpha_i = log_w + log_Px_i
log_Px_k = log_sum_exp(alpha_i, dim=2)
loss = -torch.sum(log_Px_k)
return loss
import os
cwd = os.getcwd()
os.chdir(cwd)
!jupyter nbconvert LogisticMixtureLossFunction.ipynb