Logistic Mixture Loss

Maximization of the Area under PDF Curve





$$X \;\;\sim\;\; \text{logistic}(\mu, s) $$$$P[ X \le x ] \;\;=\;\; \text{sigmoid}( \frac{x-\mu}{s} ) \;\;=\;\; \sigma(\frac{x-\mu}{s})$$



Normal/Mid Case { Current value of data point x is somewhere in between the max and min of entire train_data }

Probability of $X$ taking values in interval $\;\;[ x - \epsilon \;\;,\;\; x + \epsilon ]$ where $\epsilon \gt 0$ is half of the bin width



\begin{align} P_{\text{mid}} & = P[\;\;(x - \epsilon)\;\; \le \;\;X\;\; \le \;\;(x + \epsilon)\;\; ] \\ & = P[\;\;X\;\;\le\;\;(x + \epsilon)\;\;]\;\; - \;\;P[\;\;X\;\;\le \;\;(x - \epsilon)\;\; ] \\ & = \sigma( \frac{x + \epsilon-\mu}{s} ) - \sigma(\frac{x - \epsilon-\mu}{s} ) \end{align}





We are assuming there are N components in logistic mixture. If $\mu_i, s_i$ are parameters of $i^{th}$ component, Probability of $X$ taking values in the interval $[ x - \epsilon \;\;,\;\; x + \epsilon ]$ under $i_{th}$ mixture component is given by;

$$ P_{x}^{i} \;=\; P[\;\;(x - \epsilon)\;\; \le \;\;X\;\; \le \;\;(x + \epsilon)\;\; ]_{i^{th}\; \text{mixture component}} \;=\; \sigma( \frac{x + \epsilon - \mu_i}{s_i}) \;\;-\;\; \sigma( \frac{x - \epsilon - \mu_i}{s_i})$$

Collectively, N mixture components would assign following probability to interval/bin $[ x - \epsilon \;\;,\;\; x + \epsilon ]$ ;

$$ P_x \;=\; P[\;\;(x - \epsilon)\;\; \le \;\;X\;\; \le \;\;(x + \epsilon)\;\; ]_{\text{all mixture components}} \;=\; \sum_{i=1}^{N} w_{i}P_{x}^{i} $$

where N is the total number of mixture components and $w_i$ is the weight of $i^{th}$ mixture component $(w_i\ge0~ \forall~i ~~ \text{and}~~\sum_{i=1}^{N} w_i = 1 )$



$$ loss \;=\; - \sum_{k=1}^{K} log( P_{x_k} ) $$

where K is the total number of data samples and $log$ is natural-logarithm





Numerically Stable Log-Probability

\begin{align*} P_x & = \sum_{i=1}^{N} w_{i}P_{x}^{i} \\ & = \sum_{i=1}^{N} \; e^{ log\{\;w_{i}\;.\;P_{x}^{i}\;\} } \\ & = \sum_{i=1}^{N} \; e^{ log\{\;w_{i}\;\}\;+\;log\{\;P_{x}^{i}\;\} } \\ & = \sum_{i=1}^{N} \; e^{\alpha_i}\\\\ log(P_x) & = log(\;\sum_{i=1}^{N} \; e^{\alpha_i}\;)\\ & = log(\;e^{\alpha_1}+e^{\alpha_2}+\ldots+e^{\alpha_N}\;)\\ & = \text{log sum exponent}\;( \alpha_1 , \alpha_2, \ldots, \alpha_N)\\ & = \text{LSE}\;( \alpha_1 , \alpha_2, \ldots, \alpha_N)\\\\\\\\ loss & = - \sum_{k=1}^{K} log( P_{x_{k}} ) \end{align*}

where K is the total number of data samples and $log$ is natural-logarithm


Approximation of Normal/Mid Case

To handle extreme case when 'prob_mid_case' << 1e-5

This scenario can happen when network provides mean $\mu$ that is way off data point $x$
e.g network outputs $\mu$=1000 ans s=1 but $max$(train_data) = 5 and $min$(train_data) = -5

Under these distribution parameters ($\mu$=1000, s=1), extremely low probability would be assigned to interval/bin $[ x - \epsilon \;\;,\;\; x + \epsilon ]$.
When we take log of this extremely low probability, NaNs or -Infinity will occur.

To solve this problem, we can approximate the integral (area under the PDF) by taking the centered PDF of
the logistic distribution and multiply it by one bin-width interval on the support of train_data.



\begin{align} f(x) & = \frac{ e^{-(x-\mu)/s} }{ s(1 + e^{-(x-\mu)/s} )^2 } \\ log\{f(x)\} & = -(x-\mu)/s - log(s) - log\{ \;(1 + e^{-(x-\mu)/s} )^2\; \} \\ log\{f(x)\} & = -(x-\mu)/s - log(s) - 2log\{ ( 1 + e^{-(x-\mu)/s} ) \} \\ log\{f(x)\} & = -(x-\mu)/s - log(s) - 2softplus\{ -(x-\mu)/s \} \end{align}


approximate area under one-bin wide PDF curve = $ f(x) \times \text{binWidth} $

$\large log\{f(x) \times \text{binWidth} \} = log\{f(x)\} + log\{ \text{binWidth} \} $





Note: We will not use {Right edge and Left edge Cases}


Right edge case { Current value of data point x is near the max of entire train_data}

Probability of value $x$ lying in interval $\;\;[ x - \epsilon \;\;,\;\; +\infty )$ where $\epsilon \gt 0$ is half of the bin width

$\large\;\; P_{\text{right_edge}} = P[\;\;(x - \epsilon)\;\; \le \;\;X\;\; \lt \;\; +\infty \;\; ]$

$\large = P[\;\;X\;\;\lt\;\;+\infty\;\;]\;\; - \;\;P[\;\;X\;\;\le \;\;(x - \epsilon)\;\; ]$

$\large = \sigma( \frac{+\infty -\mu}{s} ) - \sigma(\frac{x - \epsilon-\mu}{s} )$

$\large = 1 - \sigma(\frac{x - \epsilon-\mu}{s} )$



$\large\;\; log\{P_{\text{right_edge}}\} $

$\large\;\; = log\{1 - \sigma(\frac{x - \epsilon-\mu}{s}) \} $

$\large\;\; = log\{1 - \frac{e^{\frac{x - \epsilon-\mu}{s}}}{1+e^{\frac{x - \epsilon-\mu}{s}}} \} $

$\large\;\; = log\{\frac{1}{1+e^{\frac{x - \epsilon-\mu}{s}}} \} $

$\large\;\; = -log\{1+e^{\frac{x - \epsilon-\mu}{s}} \} $

$\large\;\; = -softplus\{\frac{x - \epsilon-\mu}{s} \} $




Left edge case { Current value of data point x is near the min of entire train_data}

Probability of value $x$ lying in interval $\;\;( -\infty\;\;,\;\; x + \epsilon ]$ where $\epsilon \gt 0$ is half of the bin width

$\large\;\; P_{\text{left_edge}} = P[\;\;-\infty\;\; \lt \;\;X\;\; \le \;\; (x + \epsilon) \;\; ]$

$\large = \;\;P[\;\;X\;\;\le \;\;(x + \epsilon)\;\; ] - P[\;\;X\;\;\lt\;\;-\infty\;\;]\;\; $

$\large = \sigma(\frac{x + \epsilon-\mu}{s}) - 0 $

$\large = \sigma(\frac{x + \epsilon-\mu}{s}) $



$\large\;\; log\{P_{\text{left_edge}}\} $

$\large\;\; = log\{\sigma(\frac{x + \epsilon-\mu}{s}) \} $

$\large\;\; = log\{\frac{e^{\frac{x + \epsilon-\mu}{s}}}{1+e^{\frac{x + \epsilon-\mu}{s}}} \} $

$\large\;\; = log\{e^{\frac{x + \epsilon-\mu}{s}}\} - log\{1+e^{\frac{x + \epsilon-\mu}{s}} \} $

$\large\;\; = \frac{x + \epsilon-\mu}{s}-log\{1+e^{\frac{x + \epsilon-\mu}{s}} \} $

$\large\;\; = \frac{x + \epsilon-\mu}{s} -softplus\{\frac{x + \epsilon-\mu}{s} \} $



In [ ]:
def log_sum_exp(x, dim):
    
    ''' Numerically stable log_sum_exp(x) 
    
    LSE = log_sum_exponent
    
    LSE(x1, x2, ... , xn) = log(  exp(x1) + exp(x2) + ... + exp(xn)  )
    LSE(x1, x2, ... , xn) = x* + log(  exp(x1-x*) + exp(x2-x*) + ... + exp(xn-x*)  )
    
    where x*  = max{x1, x2, ... , xn}
    
    '''
    
    x_star , idx = torch.max(x, dim=dim, keepdim=True)
    
    centered_x = x - x_star
    
    exp_centered_x = torch.exp(centered_x)
    
    log_sum_exp_centered_x = torch.log( torch.sum( exp_centered_x , dim=dim, keepdim=True) )
    
    lse = x_star + log_sum_exp_centered_x
    
  
    return lse



def log_softmax(x, dim):
    
    ''' Numerically stable log(softmax(x)) 
    
    x = [x1, x2, ... , xn ]
    
    softmax(x) = [ exp(x_1)/sum(exp(x)) , exp(x_2)/sum(exp(x)) , ... , exp(x_n)/sum(exp(x)) ]
    
    Note: Element-wise take natural log
    
    log( softmax(x) ) = log( [ exp(x_1)/sum(exp(x)) , exp(x_2)/sum(exp(x)) , ... , exp(x_n)/sum(exp(x)) ] )
                      =      [ x1 - LSE(x)   ,   x2 - LSE(x)  ,  ...  ,  xn - LSE(x)]
    
    '''
    
    x_star , idx = torch.max(x, dim=dim, keepdims=True)
    
    lse_x = log_sum_exp(x, dim=dim)
    
    log_softmax_x = x - lse_x

    return log_softmax_x
In [2]:
def logistic_mixture_loss(model_output, target_data,  target_data_range, n_bins):
        
    ''' 
     
    Inputs Arguments:
    -------------------------------------------------------------------------------------
    1) model_output:        parameters of PDFs i.e. mean, log_scales, logit_weights (output of model after forward pass of input value) :  shape ( batch_size, n_target_nodes , n_mixture , 3 )
    2) target_data:         target data to be used at the end of the Network for loss calculation  :  shape ( batch_size , 1 , n_target_nodes )
    3) target_data_range:   max(target_data) -  min(target_data)
    4) n_bins:              number of bins


    Returns:
    -------------------------------------------------------------------------------------
    1) log-loss

    '''
    
    
    batch_size , n_target_nodes, n_mixtures , _  = model_output.shape   
    
    #print("Batch Size : " , batch_size)
    #print("N_mixtures : " , n_mixtures)
    
    
    # Extract out each of the mixture parameters 
    # model_output[ BatchSize , N_TargetNodes  , N_Mixtures  , 3]
    
    
    m             = model_output[ :, :, :, 0]    # mixture means
    log_scales    = model_output[ :, :, :, 1]    # mixture log_scales
    w_logits      = model_output[ :, :, :, 2]    # mixture raw weights, or logit_weights
    
    
    
    log_s = torch.clamp(log_scales , min=-7)    # imposing constraint on log_scales values, corresponding inv_s = e^(-log(s)) 
                                                #                                                               = e^(-(-7))
                                                #                                                               = 1096.63
                                                # Hence, minimum value of s that can be obtained is (1/1096.63) = 0.0009118
    
    
    inv_s = torch.exp(-log_s)                   # e^(-log(s)) = 1/s
    
    
    
    log_w = log_softmax(w_logits, dim=2)        # w_logits needs to be converted to valid weights that sum to one using softmax. also take log or normalized weights
    
    
    
    x = target_data.reshape((batch_size , n_target_nodes , -1))            # shape ( batch_size , n_target_nodes , 1 )
    
    
    #print("model_output shape :" , model_output.shape)
    #print("Means shape        :" , m.shape)
    #print("Log_Scale shape    :" , log_scales.shape)
    #print("Logit_weight shape :" , w_logits.shape)
    #print("Target Data Shape  :" , x.shape)
    #print("Target Data        :" , x)    
    
    
    # There are total 'n_bins' number of bins. Bin index ranges from '0'  to  'n_bins-1'
    # 
    # range_of_train_data  =  max(train_data) - min(train_data)
    # When support of train_data is divided into equally spaced 'n_bins', then
    # width of one bin on the support of train_data is as follows ;
    
    bin_width      = target_data_range/(n_bins) 
    half_bin_width = bin_width/2.0
          
     
    
    # right and left points of bin around data point x 
    bin_arg_plus  =  ( x + half_bin_width - m )*inv_s   # shape ( batch_size , n_target_nodes, n_mixture )
    bin_arg_minus =  ( x - half_bin_width - m )*inv_s   # shape ( batch_size , n_target_nodes, n_mixture )
    
    
    
    
    # see documentation in markdown cell for derivation
    
    Px_i_mid_case  = torch.sigmoid(bin_arg_plus) - torch.sigmoid(bin_arg_minus) 
       
      
    
    # approximation of Px_i_mid_case when ( prob_mid_case << 1e-5 )   
    # see documentation in markdown cell for derivation and reason 
    log_pdf = -(x-m)*inv_s - log_s - 2*F.softplus( -(x-m)*inv_s  )
      
    # approximate area under one-bin wide PDF curve = pdf*bin_width 
    # log(pdf * bin_width ) = log(pdf) + log( bin_width )    
    log_Px_i_mid_case_approximate = log_pdf + np.log(bin_width) 
    

    
    # see documentation in markdown cells for follwing lines
    
    log_Px_i = torch.where( Px_i_mid_case > 1e-5,  torch.log(torch.clamp(Px_i_mid_case, min=1e-12)),                                             
                                                   log_Px_i_mid_case_approximate 
                          )
                          
           
    alpha_i = log_w + log_Px_i 
    
    
    log_Px_k = log_sum_exp(alpha_i, dim=2) 
    
    
    loss = -torch.sum(log_Px_k)
    
    
    return loss 
In [7]:
import os 
cwd = os.getcwd()
os.chdir(cwd)

!jupyter nbconvert LogisticMixtureLossFunction.ipynb
[NbConvertApp] Converting notebook LogisticMixtureLossFunction.ipynb to html
[NbConvertApp] Writing 296794 bytes to LogisticMixtureLossFunction.html
In [ ]: