Why are tree-based machine learning methods, such as XGBoost and random forests, superior to deep learning on tabular data? This article provides reasons behind this phenomenon, selecting 45 open datasets and defining a new benchmark to compare tree-based models with deep models, summarizing three reasons to explain this phenomenon.
Deep learning has made significant progress in fields such as images, language, and even audio. However, it performs relatively poorly when dealing with tabular data. Due to the characteristics of tabular data, such as heterogeneous features, small sample sizes, and large outliers, it is challenging to find corresponding invariants.
Tree-based models are non-differentiable and cannot be jointly trained with deep learning modules, making it a very active research area to create deep learning architectures specific to tabular data. Many studies claim to outperform or match tree-based models, but their research has faced considerable skepticism.
In fact, there is a lack of established benchmarks for learning from tabular data, giving researchers a lot of freedom when evaluating their methods. Additionally, most publicly available tabular datasets are small compared to benchmarks in other subdomains of machine learning, making evaluation even more difficult.
To alleviate these concerns, researchers from the French National Institute for Research in Computer Science and Automation, Sorbonne University, and other institutions proposed a benchmark for tabular data that can evaluate the latest deep learning models and demonstrate that tree-based models remain SOTA on medium-sized tabular datasets.
For this conclusion, the article provides solid evidence that using tree-based methods on tabular data is easier to achieve good predictions compared to deep learning (even modern architectures), and the researchers explored the reasons behind this.
It is worth mentioning that one of the authors of the paper is Gaël Varoquaux, who is one of the leaders of the Scikit-learn project. Currently, this project has become one of the most popular machine learning libraries on GitHub. The article “Scikit-learn: Machine learning in Python,” co-authored by Gaël Varoquaux, has been cited 58,949 times.
The contributions of this article can be summarized as:
The study created a new benchmark for tabular data (selecting 45 open datasets) and shared these datasets via OpenML, making them easy to use.
The study compared deep learning models and tree-based models under various settings for tabular data, considering the cost of hyperparameter selection. The study also shared the raw results of random searches, allowing researchers to cheaply test new algorithms within a fixed hyperparameter optimization budget.
Tree-based models still outperform deep learning methods on tabular data
The new benchmark references 45 tabular datasets, with the following criteria for selection:
-
Heterogeneous columns, where columns correspond to different types of features, excluding image or signal datasets.
-
Low dimensionality, where the dataset’s d/n ratio is below 1/10.
-
Invalid datasets, removing those with little available information.
-
I.I.D. (Independent and Identically Distributed) data, removing datasets similar to streams or time series.
-
Real-world data, removing artificial datasets but retaining some simulated datasets.
-
Datasets should not be too small, removing those with too few features (< 4) and samples (< 3,000).
-
Removing overly simple datasets.
-
Removing datasets from games like poker and chess, as these datasets have deterministic targets.
In tree-based models, the researchers selected three SOTA models: RandomForest from Scikit Learn, GradientBoostingTrees (GBTs), and XGBoost.
The study benchmarked the following deep models: MLP, ResNet, FT Transformer, SAINT.
Figures 1 and 2 present benchmarking results for different types of datasets.
Empirical Investigation: Why Tree-Based Models Still Outperform Deep Learning on Tabular Data
Inductive bias. Tree-based models outperformed neural networks across various hyperparameter selections. In fact, the best methods for handling tabular data share two common properties: they are ensemble methods, either bagging (random forests) or boosting (XGBoost, GBT), and the weak learners used in these methods are decision trees.
Finding 1: Neural Networks (NN) Tend to Over-Smooth Solutions
As shown in Figure 3, smoothing the objective function on the training set significantly reduces the accuracy of tree-based models for smaller scales, but has almost no effect on NN. These results suggest that the objective function in the dataset is not smooth, making it difficult for NN to adapt to these irregular functions compared to tree-based models. This is consistent with the findings of Rahaman et al., who found that NN favors low-frequency functions. Tree-based models learn piece-wise constant functions without such bias.
Finding 2: Non-informative Features Affect NN Like MLP More
Tabular datasets contain many non-informative features. For each dataset, the study selects a certain proportion of features to discard based on feature importance (usually sorted by random forests). As shown in Figure 4, removing more than half of the features has little impact on the classification accuracy of GBT.
Figure 5 shows that removing non-informative features (5a) reduces the performance gap between MLP (ResNet) and other models (FT Transformers and tree-based models), while adding non-informative features widens the gap, indicating that MLP has poor robustness against non-informative features. In Figure 5a, as researchers remove a larger proportion of features, useful information features are also removed. Figure 5b shows that the accuracy drop caused by removing these features can be compensated by removing non-informative features, which is more beneficial for MLP compared to other models (while the study also removed redundant features, which did not affect model performance).
Finding 3: Data is Non-Invariant through Rotation
Why is MLP more susceptible to non-informative features compared to other models? One answer is that MLP is rotation invariant: when applying rotation to the features of the training and testing sets, learning MLP on the training set and evaluating it on the testing set is invariant. In fact, any rotation-invariant learning process has worst-case sample complexity that grows linearly with the number of irrelevant features. Intuitively, to eliminate useless features, rotation-invariant algorithms must first identify the original direction of the features and then select the least informative features.
Figure 6a shows the change in test accuracy when the dataset is randomly rotated, confirming that only ResNets are rotation invariant. Notably, random rotation reverses the performance order: the result is that NN is above tree-based models, and ResNets are above FT Transformers, suggesting that rotation invariance is undesirable. In fact, tabular data often has individual meanings, such as age, weight, etc.
Figure 6b shows that removing the least important half of the features from each dataset (before rotation) reduces the performance of all models except ResNets, but the drop is comparatively smaller than when using all features without removing any.
Original link: https://twitter.com/GaelVaroquaux/status/1549422403889
Edit /Fan Ruiqiang
Review / Fan Ruiqiang
Verification / Fan Ruiqiang
Read the original article