How to apply corrective weight for the training of a logistic regression with imbalanced dataset using Apache Spark MLlib?

Some applications such as spam or online targeting have an imbalanced dataset. The number of observations associated to one label is very small (minority class) compared to the number of observations associated to the other labels.

Let's consider the case of intrusion detection system that identifies an cyber attack through a Virtual Private Network (VPN) in an entreprise. The objective is to classify every session based on one or several TCP connections as potentially harmful intrusion or normal transaction knowing that a very small fraction of attempt to connect and authenticate are actual deliberate attacks (less than 1 for each million of connections).

There are few points to consider:

There are several approaches to address the problem of imbalanced training and validation sets. These list of the most common known techniques onclude

It is assumed the reader is somewhat familiar the

Some applications such as spam or online targeting have an imbalanced dataset. The number of observations associated to one label is very small (minority class) compared to the number of observations associated to the other labels.

**Overview**Let's consider the case of intrusion detection system that identifies an cyber attack through a Virtual Private Network (VPN) in an entreprise. The objective is to classify every session based on one or several TCP connections as potentially harmful intrusion or normal transaction knowing that a very small fraction of attempt to connect and authenticate are actual deliberate attacks (less than 1 for each million of connections).

There are few points to consider:

- The number of reported attack is very likely very small (< 10). Therefore, the very small set of labeled security breach is plagued with high variance
- A very large and expensive labeling operation is required, if even available, to generate a fairly stable negative class (security breach)
- The predictor can be extremely accurate (as measured by F1-score, Area under the ROC curve or PR curve) by always classifying any new attempt to connect to the VPN as harmless. In the case of 1 reported attack per 1 million VPN session, the prediction would be accurate 99.999% of the time.

There are several approaches to address the problem of imbalanced training and validation sets. These list of the most common known techniques onclude

- Sub-sampling the majority class (i.e. harmless VPN sessions): It reduces the number of labels for the normal sessions while preserving the labeled reported attacks.
- Over-sampling the minority class: This technique generates synthetic sampling using bootstrapped samples based on k Nearest Neighbors algorithm
- Application of weights differential to the logistic loss used in the training of the model

It is assumed the reader is somewhat familiar the

**Apache Spark**, data frame and its machine learning module**MLlib**.**Weighting the logistic loss**

Let's consider the logistic loss commonly used in training a binary model

Let's consider the ratio,

*f*with feature*x*and label*y*: \[logLoss = - \sum_{i=0}^{n-1}[y_i .log(f(x_i)) + (1 - y_i) .log(1 - f(x_i))]\] The first component of the loss function is related to the minority observations (security breach through a VPN: label = 1) while the a second component represents the loss related to the majority observation (harmless VPN sessions: label = 0)Let's consider the ratio,

*r*of number of observations related to the majority label over the total number of observations: \[r = \frac{ i: (y_i = 1)}{i : y_i}\] The logistic loss can be then rebalanced as \[logloss = -\sum_{i=0}^{n-1} [r.y_i.log(f(x_i)) + (1-r).(1-y_i).log(f(x_i))]\] The next step is to implement the weighting of the binomial logistic regression classifier using Apache Spark.**Weighting the logistic loss**

**Apache Spark**is a open source framework for in-memory processing of large datasets (think as

*Hadoop on steroids*). Apache Spark framework contains a machine learning module known as

**MLlib**. The objective is to modify/override the

**train**method of the

**LogisticRegression**.

One simple option is to sub-class

*LogisticRegression*class in the

*mllib/ml*package. However the logistic loss is actually computed in the private class

*LogisticAggregator*b which cannot be overridden.

However if you browse through the

*ml.classification.LogisticRegression.train*Scala code, you notice that the class

*Instance*that encapsulates labeled data for the computation of the loss and the gradient of loss has three parameters

- label: Double
- feature: linalg.Vector
- weight: Double

The plan is to use this 3rd parameter

*weight*as the balancing weight ratio

*r*. This is simply accomplished by adding an extra column

**weightCol**to the input data frame

*dataset*and define its value as

- balancingRatio
*r*for label = 1 - 1 - balancingRatio (
*1 - r*) for label = 0

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | final val BalancingWeightColumnName: String = "weightCol" final val BalanceRatioBounds = (0.001, 0.999) final val ImportanceWeightNormalization = 2.0 class WeightedLogisticRegression(balanceRatio: Double = 0.5) extends LogisticRegression(UUID.randomUUID.toString) with Logging { this.setWeightCol(BalancingWeightColumnName) private[this] val balancingRatio = if (balanceRatio < BalanceRatioBounds._1) BalanceRatioBounds._1 else if (balanceRatio > BalanceRatioBounds._2) BalanceRatioBounds._2 else balanceRatio override protected def train(dataset: DataFrame): LogisticRegressionModel = { val newDataset = dataset.withColumn("weightCol", lit(0.0)) val weightedDataset: RDD[Row] = dataset.map(row => { val w = if (row(0) == 1.0) balancingRatio else 1.0 - balancingRatio Row(row.getAs[Double](0), row.getAs[Vector](1), ImportanceWeightNormalization * w) }) val weightedDataFrame = dataset.sqlContext .createDataFrame(weightedDataset, newDataset.schema) super.train(weightedDataFrame) } } |

__Notes__

- The balancing ratio has to be normalized by a factor
*ImportanceWeightNormalization = 2*: The factor is require to produce a weight of 1 for both the majority and minority classes from a fully balanced ratio of 0.5. - The actual balancing ratio
needs to be constrained within an acceptable range *BalanceRatioBounds*to prevent for the minority class to have an outsize influence of the weights of the logistic regression model. In the extreme case, there may not even be a single observation in the minority class (security breach through the VPN). These minimum and maximum values are highly dependent on the type of application.

Here is an example of application of the

*WeightedLogisticRegression*on a training data frame

*labeledData*. The number of data points (or observations) associated to

*label = 1*is extracted through a simple filter.

1 2 3 4 5 6 7 8 9 10 11 12 | val numPositives = trainingDF.filter("label == 1.0").count val datasetSize = labeledData.count val balanceRatio = (datasetSize- numPositives).toDouble/datasetSize val lr = new WeightedLogisticRegression(balanceRatio) .setMaxIter(20) .setElasticNetParam(0) .setFitIntercept(true) .setTol(1E-6) .setRegParam(1E-2) val model = lr.fit(trainingDF) |

One simple validation of this implementation of the weighted logistic regression on Apache Spark is to verify that the weights of the logistic regression model generated with the WeightedLogisticRegression with a balancing ratio of 0.5 are very similar to the weights generated by the logistic regression model of the Apache Spark MLlib (

*ml.classification.LogisticRegressionModel*)

Note: This implementation relies on Apache Spark 1.6.0. The latest implementation of the Logistic Regression in beta release of Apache Spark 2.0 does not allow to override the

*LogisticRegression*outside the scope of Apache Spark MLlib.

**References**