Source: Machine Heart Editorial Team
Why do tree-based machine learning methods, such as XGBoost and Random Forest, outperform deep learning on tabular data?This article provides reasons behind this phenomenon, selecting 45 open datasets and defining a new benchmark to compare tree-based models and deep models, summarizing three key points to explainthis phenomenon.
Deep learning has made significant advances in fields such as images, language, and even audio. However, it performs rather poorly when handling tabular data. This is due to the characteristics of tabular data, such as uneven features, small sample sizes, and large outliers, which make it difficult to find corresponding invariants.
Tree-based models are non-differentiable and cannot be jointly trained with deep learning modules, making the creation of table-specific deep learning architectures a very active research area. Many studies claim to outperform or match tree-based models, but their findings have been met with skepticism.
In fact, there is a lack of established benchmarks for learning from tabular data, which gives researchers a lot of freedom when evaluating their methods. Additionally, compared to benchmarks in other machine learning subdomains, most publicly available tabular datasets are small, making evaluation more challenging.
To alleviate these concerns, researchers from institutions such as the French National Institute for Research in Computer Science and Automation and Sorbonne University proposed a benchmark for tabular data that can evaluate state-of-the-art deep learning models and show that tree-based models still hold the SOTA on medium-sized tabular datasets.
The article provides solid evidence for this conclusion, indicating that using tree-based methods on tabular data is more likely to yield good predictions than deep learning (even modern architectures), and the researchers explored the reasons behind this.
Paper link: https://hal.archives-ouvertes.fr/hal-03723551/document
It is worth mentioning that one of the authors of the paper is Gaël Varoquaux, who is one of the leaders of the Scikit-learn project. This project has become one of the most popular machine learning libraries on GitHub. The paper “Scikit-learn: Machine learning in Python,” co-authored by Gaël Varoquaux, has been cited 58,949 times.
The contributions of this article can be summarized as follows:
The study created a new benchmark for tabular data (selecting 45 open datasets) and shared these datasets through OpenML, making them easy to use.
The study compared deep learning models and tree-based models under various settings of tabular data, considering the cost of selecting hyperparameters. The study also shared the raw results of random search, enabling researchers to cheaply test new algorithms within a fixed hyperparameter optimization budget.
Tree-based models still outperform deep learning methods on tabular data
The new benchmark references 45 tabular datasets, with the following selection criteria:
-
Heterogeneous columns, where columns correspond to different types of features, excluding image or signal datasets.
-
Low dimensionality, with a d/n ratio of datasets lower than 1/10.
-
Invalid datasets, removing datasets with little available information.
-
I.I.D. (Independent and Identically Distributed) data, removing datasets similar to streams or time series.
-
Real-world data, removing artificial datasets while retaining some simulated datasets.
-
Datasets cannot be too small, removing datasets with too few features (< 4) and too few samples (< 3,000).
-
Removing overly simple datasets.
-
Removing datasets from games like poker and chess, as these datasets have deterministic targets.
In tree-based models, the researchers selected three SOTA models: Random Forest from Scikit Learn, Gradient Boosting Trees (GBTs), and XGBoost.
The study benchmarked the following deep models: MLP, Resnet, FT Transformer, and SAINT.
Figures 1 and 2 provide benchmark results for different types of datasets.
Empirical Investigation: Why Tree-Based Models Still Outperform Deep Learning on Tabular Data
Inductive Bias. Tree-based models outperformed neural networks across various hyperparameter selections. In fact, the best approaches for handling tabular data share two common properties: they are ensemble methods, either bagging (Random Forest) or boosting (XGBoost, GBT), and the weak learners used in these methods are decision trees.
Finding 1: Neural Networks (NN) Tend to Over-Smooth Solutions
As shown in Figure 3, for smaller scales, smoothing the objective function on the training set significantly reduces the accuracy of tree-based models but has little effect on NNs. These results indicate that the objective functions in the dataset are not smooth, making it difficult for NNs to adapt to these irregular functions compared to tree-based models. This aligns with the findings of Rahaman et al., who found that NNs tend to favor low-frequency functions. Decision tree-based models learn piecewise constant functions without such bias.
Finding 2: Non-Informative Features Impact NN Like MLP More
Tabular datasets often contain many non-informative features. For each dataset, the study selected a proportion of features to discard based on feature importance (usually sorted by Random Forest). As shown in Figure 4, removing more than half of the features has little impact on the classification accuracy of GBT.
Figure 5 shows that removing non-informative features (5a) reduces the performance gap between MLP (Resnet) and other models (FT Transformers and tree-based models), while adding non-informative features expands the gap, indicating that MLP is less robust to non-informative features. In Figure 5a, when researchers remove a larger proportion of features, useful informative features are also removed. Figure 5b indicates that the accuracy drop from removing these features can be compensated by discarding non-informative features, which is more beneficial for MLP compared to other models (while the study also removed redundant features, which did not affect model performance).
Finding 3: Data is Non-Invariant Through Rotation
Why are MLPs more easily affected by non-informative features compared to other models? One answer is that MLPs are rotation-invariant: when applying rotations to features in the training and test sets, learning MLP on the training set and evaluating on the test set remains invariant. In fact, any rotation-invariant learning process has a worst-case sample complexity that grows at least linearly with the number of irrelevant features. Intuitively, to remove useless features, rotation-invariant algorithms must first find the original direction of the features before selecting the least informative features.
Figure 6a shows the change in test accuracy when the dataset is randomly rotated, confirming that only Resnets are rotation-invariant. Notably, random rotation reverses the performance order: the result is that NNs are above tree-based models, and Resnets are above FT Transformers, indicating that rotation invariance is undesirable. In fact, tabular data often has individual meanings, such as age, weight, etc.
Figure 6b shows that removing the least important half of the features in each dataset (before rotation) reduces the performance of all models except Resnets, but the decline is relatively small compared to using all features without removing any.
Original link: https://twitter.com/GaelVaroquaux/status/1549422403889

It is not easy to organize,pleaselike and follow↓