15  RuleFit

RuleFit is a powerful algorithm for regression and classification, which uses gradient boosting and the LASSO to train a highly accurate and interpretable model.

Given a dataset X and an outcome y:

  1. Train a Gradient Boosting model on the raw inputs X to predict y
  2. Take all decision tree base learners from (1) and convert them to a list of rules R (by following all paths from root node to leaf node). The rules represent a transformation of the raw input features.
  3. Train a LASSO model on the ruleset R to predict y.

Thanks to the LASSO’s variable selection, step 3. will usually greatly reduce the large number of rules in R with no loss of accuracy. In fact, RuleFit may outperform gradient boosting.

RuleFit summary

15.1 Data

Let’s read in the Parkinsons dataset from the UCI repository:

parkinsons <- read.csv("https://archive.ics.uci.edu/ml/machine-learning-databases/parkinsons/parkinsons.data")
parkinsons$Status <- factor(parkinsons$status, levels = c(1, 0))
parkinsons$status <- NULL
parkinsons$name <- NULL
check_data(parkinsons)
  parkinsons: A data.table with 195 rows and 23 columns

  Data types
  * 22 numeric features
  * 0 integer features
  * 1 factor, which is not ordered
  * 0 character features
  * 0 date features

  Issues
  * 0 constant features
  * 0 duplicate cases
  * 0 missing values

  Recommendations
  * Everything looks good 

15.1.1 Resample

res <- resample(parkinsons, seed = 2019)
06-30-24 10:57:35 Input contains more than one columns; will stratify on last [resample]
.:Resampling Parameters
    n.resamples: 10 
      resampler: strat.sub 
   stratify.var: y 
        train.p: 0.75 
   strat.n.bins: 4 
06-30-24 10:57:35 Using max n bins possible = 2 [strat.sub]
06-30-24 10:57:35 Created 10 stratified subsamples [resample]

park.train <- parkinsons[res$Subsample_1, ]
park.test <- parkinsons[-res$Subsample_1, ]

15.2 RuleFit

park.rf <- s_RuleFit(park.train, park.test)
06-30-24 10:57:35 Hello, egenn [s_RuleFit]
06-30-24 10:57:36 Running 1 Gradient Boosting Model... [s_RuleFit]
06-30-24 10:57:36 Hello, egenn [s_GBM]

06-30-24 10:57:36 Imbalanced classes: using Inverse Frequency Weighting [prepare_data]

.:Classification Input Summary
Training features: 146 x 22 
 Training outcome: 146 x 1 
 Testing features: Not available
  Testing outcome: Not available
06-30-24 10:57:36 Distribution set to bernoulli [s_GBM]

06-30-24 10:57:36 Running Gradient Boosting Classification with a bernoulli loss function [s_GBM]

.:Parameters
             n.trees: 300 
   interaction.depth: 3 
           shrinkage: 0.1 
        bag.fraction: 1 
      n.minobsinnode: 5 
             weights: NULL 
06-30-24 10:57:36 Training GBM on full training set... [s_GBM]

.:GBM Classification Training Summary
                   Reference 
        Estimated  1    0   
                1  110   0
                0    0  36

                   Overall  
      Sensitivity  1.0000 
      Specificity  1.0000 
Balanced Accuracy  1.0000 
              PPV  1.0000 
              NPV  1.0000 
               F1  1.0000 
         Accuracy  1.0000 
              AUC  1.0000 
      Brier Score  1.9e-08

  Positive Class:  1 
06-30-24 10:57:36 Completed in 8.8e-04 minutes (Real: 0.05; User: 0.05; System: 4e-03) [s_GBM]
06-30-24 10:57:36 Collecting Gradient Boosting Rules (Trees)... [s_RuleFit]
1200 rules (length<=3) were extracted from the first 300 trees.
06-30-24 10:57:36 Extracted 1200 rules... [s_RuleFit]
06-30-24 10:57:36 ...and kept 423 unique rules [s_RuleFit]
06-30-24 10:57:36 Matching 423 rules to 146 cases... [matchCasesByRules]
06-30-24 10:57:36 Running LASSO on GBM rules... [s_RuleFit]
06-30-24 10:57:36 Hello, egenn

06-30-24 10:57:36 Imbalanced classes: using Inverse Frequency Weighting [prepare_data]

.:Classification Input Summary
Training features: 146 x 423 
 Training outcome: 146 x 1 
 Testing features: Not available
  Testing outcome: Not available

06-30-24 10:57:36 Running grid search... [gridSearchLearn]
.:Resampling Parameters
    n.resamples: 5 
      resampler: kfold 
   stratify.var: y 
   strat.n.bins: 4 
06-30-24 10:57:36 Using max n bins possible = 2 [kfold]
06-30-24 10:57:36 Created 5 independent folds [resample]
.:Search parameters
    grid.params:  
                 alpha: 1 
   fixed.params:  
                             .gs: TRUE 
                 which.cv.lambda: lambda.1se 
06-30-24 10:57:36 Tuning Elastic Net by exhaustive grid search. [gridSearchLearn]
06-30-24 10:57:36 5 inner resamples; 5 models total; running on 8 workers (aarch64-apple-darwin20) [gridSearchLearn]
06-30-24 10:57:37 Extracting best lambda from GLMNET models... [gridSearchLearn]
.:Best parameters to maximize Balanced Accuracy
   best.tune:  
              lambda: 0.0370390609993396 
               alpha: 1 
06-30-24 10:57:37 Completed in 0.01 minutes (Real: 0.62; User: 0.09; System: 0.08) [gridSearchLearn]

.:Parameters
    alpha: 1 
   lambda: 0.0370390609993396 

06-30-24 10:57:37 Training elastic net model...

.:GLMNET Classification Training Summary
                   Reference 
        Estimated  1    0   
                1  110   0
                0    0  36

                   Overall  
      Sensitivity  1.0000 
      Specificity  1.0000 
Balanced Accuracy  1.0000 
              PPV  1.0000 
              NPV  1.0000 
               F1  1.0000 
         Accuracy  1.0000 
              AUC  1.0000 
      Brier Score  0.0106 

  Positive Class:  1 
06-30-24 10:57:37 Completed in 0.01 minutes (Real: 0.78; User: 0.23; System: 0.10)

.:RuleFit Classification Training Summary
                   Reference 
        Estimated  1    0   
                1  110   0
                0    0  36

                   Overall  
      Sensitivity  1.0000 
      Specificity  1.0000 
Balanced Accuracy  1.0000 
              PPV  1.0000 
              NPV  1.0000 
               F1  1.0000 
         Accuracy  1.0000 
              AUC  1.0000 
      Brier Score  0.0106 

  Positive Class:  1 
06-30-24 10:57:37 Matching cases to rules... [predict.rulefit]
06-30-24 10:57:37 Matching 423 rules to 49 cases... [matchCasesByRules]

.:RuleFit Classification Testing Summary
                   Reference 
        Estimated  1   0  
                1  34  3
                0   3  9

                   Overall  
      Sensitivity  0.9189 
      Specificity  0.7500 
Balanced Accuracy  0.8345 
              PPV  0.9189 
              NPV  0.7500 
               F1  0.9189 
         Accuracy  0.8776 
              AUC  0.9212 
      Brier Score  0.0966 

  Positive Class:  1 
06-30-24 10:57:37 Completed in 0.03 minutes (Real: 1.54; User: 0.95; System: 0.12) [s_RuleFit]

15.2.1 RuleFeat Output

Let’s explore the algorithm output. The rules with their associated coefficients and empirical risk are stored in park.rf$mod$rules.selected.coef.er:

15.2.2 R-readable rules

We can also access the R-readable rules directly:

park.rf$mod$rules.selected
 [1] "MDVP.Fo.Hz.>133.131 & MDVP.Fhi.Hz.<=204.673 & PPE<=0.184526"      
 [2] "D2>2.0310275 & PPE>0.1345545"                                     
 [3] "D2>2.0459605 & PPE>0.150139"                                      
 [4] "MDVP.Fo.Hz.<=208.8315 & D2>2.224767"                              
 [5] "MDVP.Fhi.Hz.<=205.209 & spread2<=0.193501 & PPE<=0.1643355"       
 [6] "MDVP.Fhi.Hz.>205.209 & spread2<=0.193501"                         
 [7] "MDVP.Flo.Hz.>77.8015 & Shimmer.APQ5>0.006365 & spread1<=-5.618097"
 [8] "Shimmer.APQ3>0.00745 & MDVP.APQ<=0.019775"                        
 [9] "MDVP.Fo.Hz.<=208.8315 & NHR>0.004855 & RPDE>0.339766"             
[10] "DFA>0.6888735 & DFA<=0.7325515 & spread1<=-5.618097"              
[11] "MDVP.Fhi.Hz.<=229.1795 & NHR>0.004855 & RPDE>0.339766"            
[12] "MDVP.Fo.Hz.>117.548 & MDVP.Shimmer.dB.<=0.1875 & PPE>0.1345545"   
[13] "Shimmer.APQ3>0.00745 & PPE<=0.184526"                             
[14] "MDVP.Fo.Hz.>117.548 & MDVP.Shimmer.dB.<=0.1895 & PPE>0.1345545"   
[15] "MDVP.Fhi.Hz.<=229.1795 & MDVP.RAP>0.001885 & RPDE>0.339766"       
[16] "MDVP.Fo.Hz.<=133.131 & HNR<=26.804 & spread1<=-5.618097"          
[17] "MDVP.Fhi.Hz.<=204.673 & MDVP.Jitter...<=0.003315 & PPE<=0.184526" 
[18] "Shimmer.APQ3>0.00843 & Shimmer.APQ3>0.008815 & spread1>-6.4767615"
[19] "MDVP.Fhi.Hz.<=206.449 & MDVP.Jitter...<=0.003315"                 
[20] "MDVP.Fo.Hz.<=118.7465 & MDVP.Jitter...<=0.00449 & D2<=2.0310275"  
[21] "RPDE>0.367432 & PPE<=0.184526"                                    
[22] "MDVP.Jitter...<=0.00449 & D2<=2.0310275"                          
[23] "MDVP.Jitter...>0.00411 & Jitter.DDP<=0.00563"                     
[24] "MDVP.Fo.Hz.>118.3085 & MDVP.APQ<=0.013695 & spread1>-6.512704"    

15.2.3 Format rules

We can format the rules to a more human-readable format. Instead of using thresholds, as they are used in a decision tree, we can convert them to show the median (for continuous features) or mode (for categorical features) and range:

rules2medmod(park.rf$mod$rules.selected, park.train)
06-30-24 10:57:37 Matching 24 rules to 146 cases... [matchCasesByRules]
06-30-24 10:57:37 Converting rules... [rules2medmod]
06-30-24 10:57:37 Done [rules2medmod]

 [1] "MDVP.Fo.Hz. = 153.42 (136.93:187.73) & MDVP.Fhi.Hz. = 163.43 (154.61:202.45) & PPE = 0.14 (0.09:0.18)"      
 [2] "D2 = 2.52 (2.03:3.67) & PPE = 0.23 (0.14:0.53)"                                                             
 [3] "D2 = 2.54 (2.06:3.67) & PPE = 0.24 (0.15:0.53)"                                                             
 [4] "MDVP.Fo.Hz. = 151.09 (88.33:208.52) & D2 = 2.55 (2.23:3.67)"                                                
 [5] "MDVP.Fhi.Hz. = 160.27 (113.84:196.54) & spread2 = 0.15 (0.06:0.18) & PPE = 0.14 (0.09:0.16)"                
 [6] "MDVP.Fhi.Hz. = 244.42 (206.90:581.29) & spread2 = 0.13 (0.01:0.19)"                                         
 [7] "MDVP.Flo.Hz. = 133.75 (77.97:239.17) & Shimmer.APQ5 = 0.01 (0.01:0.04) & spread1 = -6.44 (-7.68:-5.62)"     
 [8] "Shimmer.APQ3 = 0.01 (0.01:0.02) & MDVP.APQ = 0.01 (0.01:0.02)"                                              
 [9] "MDVP.Fo.Hz. = 139.17 (88.33:202.80) & NHR = 0.02 (4.9e-03:0.31) & RPDE = 0.54 (0.35:0.69)"                  
[10] "DFA = 0.71 (0.69:0.73) & spread1 = -6.37 (-7.11:-5.63)"                                                     
[11] "MDVP.Fhi.Hz. = 158.06 (102.14:227.38) & NHR = 0.02 (4.9e-03:0.31) & RPDE = 0.55 (0.35:0.69)"                
[12] "MDVP.Fo.Hz. = 155.53 (118.75:217.12) & MDVP.Shimmer.dB. = 0.15 (0.09:0.19) & PPE = 0.18 (0.14:0.25)"        
[13] "Shimmer.APQ3 = 0.01 (0.01:0.03) & PPE = 0.14 (0.07:0.18)"                                                   
[14] "MDVP.Fo.Hz. = 158.22 (118.75:217.12) & MDVP.Shimmer.dB. = 0.15 (0.09:0.19) & PPE = 0.17 (0.14:0.25)"        
[15] "MDVP.Fhi.Hz. = 157.30 (102.14:227.38) & MDVP.RAP = 3.7e-03 (1.9e-03:0.02) & RPDE = 0.55 (0.34:0.69)"        
[16] "MDVP.Fo.Hz. = 117.99 (110.74:129.34) & HNR = 25.28 (17.37:26.55) & spread1 = -6.01 (-7.07:-5.62)"           
[17] "MDVP.Fhi.Hz. = 160.27 (113.84:202.45) & MDVP.Jitter... = 2.9e-03 (1.7e-03:3.3e-03) & PPE = 0.14 (0.09:0.18)"
[18] "Shimmer.APQ3 = 0.02 (0.01:0.06) & spread1 = -5.25 (-6.47:-2.43)"                                            
[19] "MDVP.Fhi.Hz. = 160.27 (113.84:202.45) & MDVP.Jitter... = 2.9e-03 (1.7e-03:3.3e-03)"                         
[20] "MDVP.Fo.Hz. = 115.88 (110.74:117.23) & MDVP.Jitter... = 3.5e-03 (3.3e-03:4.2e-03) & D2 = 1.93 (1.85:2.03)"  
[21] "RPDE = 0.45 (0.37:0.64) & PPE = 0.14 (0.06:0.18)"                                                           
[22] "MDVP.Jitter... = 3.3e-03 (1.8e-03:4.4e-03) & D2 = 1.88 (1.51:2.03)"                                         
[23] "MDVP.Jitter... = 4.4e-03 (4.2e-03:0.01) & Jitter.DDP = 4.7e-03 (3.7e-03:0.01)"                              
[24] "MDVP.Fo.Hz. = 151.80 (118.75:217.12) & MDVP.APQ = 0.01 (0.01:0.01) & spread1 = -5.95 (-6.49:-4.67)"