Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version 1.0.0 #171

Merged
merged 150 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
150 commits
Select commit Hold shift + click to select a range
9998d38
include MambAttn
AnFreTh Sep 4, 2024
de0886e
adapt mabattn config
AnFreTh Sep 4, 2024
d39c056
pruning for hpo
AnFreTh Sep 6, 2024
c1467bd
include config mapper for hpo
AnFreTh Sep 6, 2024
85adc31
include gp_minimize in sklearn base regressor
AnFreTh Sep 6, 2024
6a6a51b
fix bug in TabTransformer embedding_layer
AnFreTh Sep 11, 2024
29811d6
mlp basemodel convenience fix
AnFreTh Sep 11, 2024
1cb93c5
resnet convenience fix
AnFreTh Sep 11, 2024
4ef4b5d
config convenience fix
AnFreTh Sep 13, 2024
bf90683
add efficiency scripts
AnFreTh Sep 13, 2024
2c4bcb8
remove citation for anonymity
AnFreTh Sep 13, 2024
797b624
add hpo to classifier and lss
AnFreTh Sep 28, 2024
5b05664
minor pooling error in mambular
AnFreTh Sep 28, 2024
14663e0
add conv layer to rnn for positional invariance
AnFreTh Sep 28, 2024
4bd83e9
add convrnn to base class
AnFreTh Sep 28, 2024
b3327a1
adjust config of RNN
AnFreTh Sep 28, 2024
3d76805
include optimizer args in taskmodel
AnFreTh Sep 28, 2024
81ae17f
adapt sklearn classes to allow optimizer kwargs
AnFreTh Sep 28, 2024
bce26cb
adjust default optimizer
AnFreTh Sep 28, 2024
28d56b6
include skopt in requirements
AnFreTh Sep 28, 2024
ed5a0f3
Merge pull request #138 from basf/attn
AnFreTh Sep 28, 2024
4723942
include pscan and original mamba_ssm triton version
AnFreTh Oct 21, 2024
df60c1c
fix import
AnFreTh Oct 21, 2024
85a468b
fix import and config
AnFreTh Oct 21, 2024
3fdfa96
adjust attn_config
AnFreTh Oct 21, 2024
949bdb1
fix d_interemediate-d_state error
AnFreTh Oct 21, 2024
161ceeb
input of original ssm forward pass fix
AnFreTh Oct 21, 2024
141b76f
forward pass original mamba tryout
AnFreTh Oct 21, 2024
c0322bc
create ResidualBlock for Original Mamba-ssm
AnFreTh Oct 21, 2024
0dd1b37
fix residualblock **kwargs
AnFreTh Oct 21, 2024
4b9066e
fix RMSNorm init
AnFreTh Oct 21, 2024
0f57b26
try-out Mamba2
AnFreTh Oct 24, 2024
c10d40a
adjust config to new params
AnFreTh Oct 24, 2024
c90c9fd
adjust mamba-ssm import
AnFreTh Oct 24, 2024
b424e8b
adjust _lazy_import in ResidualBlock
AnFreTh Oct 24, 2024
9a04905
include configs in package
AnFreTh Oct 24, 2024
31ddf22
adjust mambatab config to use mamba-ssm and mamba2
AnFreTh Oct 24, 2024
8b08f8e
adjust docstrings to document new params
AnFreTh Oct 24, 2024
426e6c6
delete unnecessary imports in mambatab
AnFreTh Oct 24, 2024
617d6f3
include new models in readme
AnFreTh Oct 24, 2024
10881eb
include mamba-ssm installation in readme
AnFreTh Oct 24, 2024
95e87dd
Merge pull request #145 from basf/efficiency
AnFreTh Oct 24, 2024
ef04001
fix preprocessor kwargs typo
AnFreTh Nov 5, 2024
1c529f0
restructure utils and add neural decision tree
AnFreTh Nov 5, 2024
de77ed5
Adjust new imports
AnFreTh Nov 5, 2024
7322839
add neural Decision Forest base architecture
AnFreTh Nov 5, 2024
3f25fd6
add ndtf to new models in __init__
AnFreTh Nov 5, 2024
385f2dd
add new configs. include mLSTM/sLSTM in rnn config
AnFreTh Nov 5, 2024
54ca398
Merge pull request #149 from basf/NDTF_LSTM
AnFreTh Nov 5, 2024
bdac22b
add ntdf config in init
AnFreTh Nov 5, 2024
9eb5d42
add sparsemax
AnFreTh Nov 5, 2024
ab3abbf
data-aware initialization module
AnFreTh Nov 5, 2024
473db6b
utils func for checking if tensor or np.array
AnFreTh Nov 5, 2024
e59bfcb
add ODST and DenseBlock
AnFreTh Nov 5, 2024
8d362df
add node into basemodels - includes tabular MLP head
AnFreTh Nov 5, 2024
8b12c61
add default config for NODE model
AnFreTh Nov 5, 2024
90e1476
add Node to models and __init__
AnFreTh Nov 5, 2024
9bdbff3
refactor normalization layer -> get_normalization_layer included in _…
AnFreTh Nov 5, 2024
b0c0bf4
add nodeconfig in __init__
AnFreTh Nov 5, 2024
bea4bc3
fix typo in docstrings
AnFreTh Nov 5, 2024
61bb9b3
Merge pull request #151 from basf/NODE
AnFreTh Nov 5, 2024
4e3bcda
adapt config for normalization layer in rnn
AnFreTh Nov 5, 2024
bb11e81
Merge pull request #152 from basf/rnn_fix
AnFreTh Nov 5, 2024
fb74d8e
adjust readme and include new models
AnFreTh Nov 11, 2024
ef1166d
LinearBatchEnsemlbe layer as used in TabM paper
AnFreTh Nov 11, 2024
1bb2dc4
only use config in embedding layer as arg
AnFreTh Nov 11, 2024
5998fd3
allow for None as input
AnFreTh Nov 11, 2024
abf741d
rename MLp to MLPhead and only use config as input
AnFreTh Nov 11, 2024
d0440bb
use config as input in ConvRNN and introduce batchEnsemble RNN layer
AnFreTh Nov 11, 2024
6fb11fa
only use config in TransformerEncoder Layer
AnFreTh Nov 11, 2024
fd74037
include pooling and init pooling in basemodel class
AnFreTh Nov 11, 2024
3212cc5
new arch from utils -> only config as arg
AnFreTh Nov 11, 2024
230b276
adjust all models to new embeddinglayer and new layer utils
AnFreTh Nov 11, 2024
d80d16d
include TabM as introduce in paper
AnFreTh Nov 11, 2024
a570166
batch Ensemble RNN -> todo bidirectional
AnFreTh Nov 11, 2024
1ba1064
include tabm and batchtabrnn configs
AnFreTh Nov 11, 2024
51d71d1
delete bidirectional from config
AnFreTh Nov 11, 2024
10dca1f
add layer_norm_eps to config
AnFreTh Nov 11, 2024
b33b2aa
include batchtabrnn for reg/class/lss
AnFreTh Nov 11, 2024
38edb67
new model
AnFreTh Nov 11, 2024
bafcde3
remove default values for lr related params in fit
AnFreTh Nov 11, 2024
2414fa3
delete lr related default params in fit
AnFreTh Nov 11, 2024
8091909
lr realted param adjustments
AnFreTh Nov 11, 2024
9a236a5
Merge pull request #154 from basf/TabM
AnFreTh Nov 11, 2024
ea184ce
make usable even when params not in config
AnFreTh Nov 11, 2024
063b8dd
adapt embedding layer to plr encodings
AnFreTh Nov 11, 2024
786e4d2
PLR layer inclusion
AnFreTh Nov 11, 2024
b900b71
minor fix in embedding layer creation
AnFreTh Nov 11, 2024
0986c65
adjust defaults
AnFreTh Nov 11, 2024
058bad9
include new models in init
AnFreTh Nov 11, 2024
4d1f787
Merge pull request #155 from basf/TabM
AnFreTh Nov 11, 2024
7de4982
fix validation dataset bug
AnFreTh Nov 12, 2024
74ad3aa
Merge pull request #156 from basf/val_data-fix
AnFreTh Nov 12, 2024
097c6f4
original_mamba dt_rank fix
AnFreTh Nov 12, 2024
254827c
Merge pull request #157 from basf/original_mamba-fix
AnFreTh Nov 13, 2024
9650d24
including tab_mini
AnFreTh Nov 13, 2024
0cc9501
fix multiple predictions in ensemble loss
AnFreTh Nov 13, 2024
2a9fc96
adjust how values are retrieved from config
AnFreTh Nov 14, 2024
8f3994b
save config as hparams
AnFreTh Nov 14, 2024
c823188
fix batch-mini rnn arch
AnFreTh Nov 14, 2024
6ca631d
adjust models to new self.hparams arch
AnFreTh Nov 14, 2024
e8b123c
include ensembling loss
AnFreTh Nov 14, 2024
98b873f
adjust configs to new params
AnFreTh Nov 14, 2024
3cfa7ba
sklearn classes for new ensembling logic
AnFreTh Nov 14, 2024
bab7f11
fix proba prediction for ensembles
AnFreTh Nov 14, 2024
6a83537
add positional_invariance layer
AnFreTh Nov 14, 2024
2bb90f5
fix new mamba2 version
AnFreTh Nov 18, 2024
4d82e0e
include new ensemble method for rnn
AnFreTh Nov 22, 2024
a791bea
adapt readme to new models
AnFreTh Nov 22, 2024
fe632ab
adjust default parameters in configs for hpo
AnFreTh Nov 22, 2024
c662fca
adjust categorical embedding
AnFreTh Nov 28, 2024
f0f7c9b
make norm in resnet to layernorm for hpo
AnFreTh Nov 28, 2024
8a8e9be
add ensemble configs
AnFreTh Nov 28, 2024
9722342
adjust configs for better readibility
AnFreTh Nov 28, 2024
f90eb1f
adding attention batch-ensemble
AnFreTh Nov 28, 2024
6b2b2ac
Merge branch 'tab-mini' of https://github.com/basf/mamba-tabular into…
AnFreTh Nov 29, 2024
b1eccd6
adding tabular cnn
AnFreTh Dec 2, 2024
650c911
adding transformer BE layer
AnFreTh Dec 2, 2024
3b8e585
adjusting embedding layer
AnFreTh Dec 2, 2024
27ff556
adjusting embedding layer
AnFreTh Dec 2, 2024
578ff77
including ftet
AnFreTh Dec 2, 2024
fe161b1
adjusting default configs with new doc structure
AnFreTh Dec 2, 2024
4f91525
including new models and adjusting sklearn classes to hpo
AnFreTh Dec 2, 2024
fe3c1d7
add verbose to preprocessing
AnFreTh Dec 2, 2024
c0a62ca
.
AnFreTh Dec 2, 2024
3bbe703
adjust import
AnFreTh Dec 2, 2024
0709050
remove normalization layer
AnFreTh Dec 2, 2024
ca9dff0
adjust mambular config to new params
AnFreTh Dec 2, 2024
6d5f843
fix hpo bug
AnFreTh Dec 2, 2024
c6b266f
imrpove hpo config mapper defaults
AnFreTh Dec 2, 2024
a35f5c0
adapt models to new preprocessor
AnFreTh Dec 3, 2024
f3c218d
include new preprocessor funcs and adapt get_info logic
AnFreTh Dec 3, 2024
112cd96
adjust embedding layer to new preprocessing
AnFreTh Dec 3, 2024
142d4b2
delete model class
AnFreTh Dec 3, 2024
627d48a
adjust configs
AnFreTh Dec 3, 2024
a696987
include cat_preprocessing in preprocessor arg names
AnFreTh Dec 3, 2024
c354c00
include float preprocessing
AnFreTh Dec 3, 2024
328bf8a
Merge pull request #163 from basf/tab-mini
AnFreTh Dec 3, 2024
eadafe1
adapt readme
AnFreTh Dec 3, 2024
1ee79ea
Merge pull request #164 from basf/temp_develop
AnFreTh Dec 3, 2024
cf5f64b
adapt paper link in readme
AnFreTh Dec 3, 2024
e5c2892
Merge pull request #165 from basf/hot-fix_develop
AnFreTh Dec 3, 2024
99a309b
add prepro-args to sklearn hpo
AnFreTh Dec 4, 2024
978c49e
Merge pull request #166 from basf/gridsearch
AnFreTh Dec 4, 2024
4e4cde8
version bump
AnFreTh Dec 4, 2024
a4b624c
Merge pull request #167 from basf/version_bump
AnFreTh Dec 4, 2024
64417d1
add ndtf in readme
AnFreTh Dec 4, 2024
d127835
Merge pull request #168 from basf/ndtf-readme-fix
AnFreTh Dec 4, 2024
872404c
Resolved conflicts between master and develop
AnFreTh Dec 4, 2024
7e22659
Merge pull request #170 from basf/resolved-develop
AnFreTh Dec 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 156 additions & 74 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@
</div>

<div style="text-align: center;">
<h1>Mambular: Tabular Deep Learning (with Mamba)</h1>
<h1>Mambular: Tabular Deep Learning Made Simple</h1>
</div>

Mambular is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291).
Mambular is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, TabM and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291). Also check out our paper introducing [TabulaRNN](https://arxiv.org/pdf/2411.17207) and analyzing the efficiency of NLP inspired tabular models.

<h3> Table of Contents </h3>

- [🏃 Quickstart](#-quickstart)
- [📖 Introduction](#-introduction)
- [🤖 Models](#-models)
- [🏆 Results](#-results)
- [📚 Documentation](#-documentation)
- [🛠️ Installation](#️-installation)
- [🚀 Usage](#-usage)
- [💻 Implement Your Own Model](#-implement-your-own-model)
- [Custom Training](#custom-training)
- [🏷️ Citation](#️-citation)
- [License](#license)

Expand All @@ -55,75 +55,24 @@ Mambular is a Python package that brings the power of advanced deep learning arc

| Model | Description |
| ---------------- | --------------------------------------------------------------------------------------------------------------------------------------------------- |
| `Mambular` | A sequential model using Mamba blocks [Gu and Dao](https://arxiv.org/pdf/2312.00752) specifically designed for various tabular data tasks. |
| `Mambular` | A sequential model using Mamba blocks specifically designed for various tabular data tasks introduced [here](https://arxiv.org/abs/2408.06291). |
| `TabM` | Batch Ensembling for a MLP as introduced by [Gorishniy et al.](https://arxiv.org/abs/2410.24210) |
| `NODE` | Neural Oblivious Decision Ensembles as introduced by [Popov et al.](https://arxiv.org/abs/1909.06312) |
| `FTTransformer` | A model leveraging transformer encoders, as introduced by [Gorishniy et al.](https://arxiv.org/abs/2106.11959), for tabular data. |
| `MLP` | A classical Multi-Layer Perceptron (MLP) model for handling tabular data tasks. |
| `ResNet` | An adaptation of the ResNet architecture for tabular data applications. |
| `TabTransformer` | A transformer-based model for tabular data introduced by [Huang et al.](https://arxiv.org/abs/2012.06678), enhancing feature learning capabilities. |
| `MambaTab` | A tabular model using a Mamba-Block on a joint input representation described [here](https://arxiv.org/abs/2401.08867) . Not a sequential model. |
| `TabulaRNN` | A Recurrent Neural Network for Tabular data. Not yet included in the benchmarks |
| `TabulaRNN` | A Recurrent Neural Network for Tabular data, introduced [here](https://arxiv.org/pdf/2411.17207). |
| `MambAttention` | A combination between Mamba and Transformers, also introduced [here](https://arxiv.org/pdf/2411.17207). |
| `NDTF` | A neural decision forest using soft decision trees. See [Kontschieder et al.](https://openaccess.thecvf.com/content_iccv_2015/html/Kontschieder_Deep_Neural_Decision_ICCV_2015_paper.html) for inspiration. |



All models are available for `regression`, `classification` and distributional regression, denoted by `LSS`.
Hence, they are available as e.g. `MambularRegressor`, `MambularClassifier` or `MambularLSS`


# 🏆 Results
Detailed results for the available methods can be found [here](https://arxiv.org/abs/2408.06291).
Note, that these are achieved results with default hyperparameter and for our splits. Performing hyperparameter optimization could improve the performance of all models.

The average rank table over all models and all datasets is given here:

<div align="center">

<table>
<tr>
<th style="text-align:center;">Model</th>
<th style="text-align:center;">Avg. Rank</th>
</tr>
<tr>
<td style="text-align:center;"><strong>Mambular</strong></td>
<td style="text-align:center;"><strong>2.083</strong> <sub>±1.037</sub></td>
</tr>
<tr>
<td style="text-align:center;">FT-Transformer</td>
<td style="text-align:center;">2.417 <sub>±1.256</sub></td>
</tr>
<tr>
<td style="text-align:center;">XGBoost</td>
<td style="text-align:center;">3.167 <sub>±2.577</sub></td>
</tr>
<tr>
<td style="text-align:center;">MambaTab*</td>
<td style="text-align:center;">4.333 <sub>±1.374</sub></td>
</tr>
<tr>
<td style="text-align:center;">ResNet</td>
<td style="text-align:center;">4.750 <sub>±1.639</sub></td>
</tr>
<tr>
<td style="text-align:center;">TabTransformer</td>
<td style="text-align:center;">6.222 <sub>±1.618</sub></td>
</tr>
<tr>
<td style="text-align:center;">MLP</td>
<td style="text-align:center;">6.500 <sub>±1.500</sub></td>
</tr>
<tr>
<td style="text-align:center;">MambaTab</td>
<td style="text-align:center;">6.583 <sub>±1.801</sub></td>
</tr>
<tr>
<td style="text-align:center;">MambaTab<sup>T</sup></td>
<td style="text-align:center;">7.917 <sub>±1.187</sub></td>
</tr>
</table>

</div>




# 📚 Documentation

You can find the Mamba-Tabular API documentation [here](https://mambular.readthedocs.io/en/latest/).
Expand All @@ -135,6 +84,19 @@ Install Mambular using pip:
pip install mambular
```

If you want to use the original mamba and mamba2 implementations, additionally install mamba-ssm via:

```sh
pip install mamba-ssm
```

Be careful to use the correct torch and cuda versions:

```sh
pip install torch==2.0.0+cu118 torchvision==0.15.0+cu118 torchaudio==2.0.0+cu118 -f https://download.pytorch.org/whl/cu118/torch_stable.html
pip install mamba-ssm
```

# 🚀 Usage

<h2> Preprocessing </h2>
Expand All @@ -143,12 +105,18 @@ Mambular simplifies data preprocessing with a range of tools designed for easy t

<h3> Data Type Detection and Transformation </h3>

- **Ordinal & One-Hot Encoding**: Automatically transforms categorical data into numerical formats.
- **Binning**: Discretizes numerical features; can use decision trees for optimal binning.
- **Normalization & Standardization**: Scales numerical data appropriately.
- **Periodic Linear Encoding (PLE)**: Encodes periodicity in numerical data.
- **Quantile & Spline Transformations**: Applies advanced transformations to handle nonlinearity and distributional shifts.
- **Polynomial Features**: Generates polynomial and interaction terms to capture complex relationships.
- **Ordinal & One-Hot Encoding**: Automatically transforms categorical data into numerical formats using continuous ordinal encoding or one-hot encoding. Includes options for transforming outputs to `float` for compatibility with downstream models.
- **Binning**: Discretizes numerical features into bins, with support for both fixed binning strategies and optimal binning derived from decision tree models.
- **MinMax**: Scales numerical data to a specific range, such as [-1, 1], using Min-Max scaling or similar techniques.
- **Standardization**: Centers and scales numerical features to have a mean of zero and unit variance for better compatibility with certain models.
- **Quantile Transformations**: Normalizes numerical data to follow a uniform or normal distribution, handling distributional shifts effectively.
- **Spline Transformations**: Captures nonlinearity in numerical features using spline-based transformations, ideal for complex relationships.
- **Piecewise Linear Encodings (PLE)**: Captures complex numerical patterns by applying piecewise linear encoding, suitable for data with periodic or nonlinear structures.
- **Polynomial Features**: Automatically generates polynomial and interaction terms for numerical features, enhancing the ability to capture higher-order relationships.
- **Box-Cox & Yeo-Johnson Transformations**: Performs power transformations to stabilize variance and normalize distributions.
- **Custom Binning**: Enables user-defined bin edges for precise discretization of numerical data.




<h2> Fit a Model </h2>
Expand All @@ -159,9 +127,10 @@ from mambular.models import MambularClassifier
# Initialize and fit your model
model = MambularClassifier(
d_model=64,
n_layers=8,
n_layers=4,
numerical_preprocessing="ple",
n_bins=50
n_bins=50,
d_conv=8
)

# X can be a dataframe or something that can be easily transformed into a pd.DataFrame as a np.array
Expand All @@ -177,6 +146,59 @@ preds = model.predict(X)
preds = model.predict_proba(X)
```

<h3> Hyperparameter Optimization</h3>
Since all of the models are sklearn base estimators, you can use the built-in hyperparameter optimizatino from sklearn.

```python
from sklearn.model_selection import RandomizedSearchCV

param_dist = {
'd_model': randint(32, 128),
'n_layers': randint(2, 10),
'lr': uniform(1e-5, 1e-3)
}

random_search = RandomizedSearchCV(
estimator=model,
param_distributions=param_dist,
n_iter=50, # Number of parameter settings sampled
cv=5, # 5-fold cross-validation
scoring='accuracy', # Metric to optimize
random_state=42
)

fit_params = {"max_epochs":5, "rebuild":False}

# Fit the model
random_search.fit(X, y, **fit_params)

# Best parameters and score
print("Best Parameters:", random_search.best_params_)
print("Best Score:", random_search.best_score_)
```
Note, that using this, you can also optimize the preprocessing. Just use the prefix ``prepro__`` when specifying the preprocessor arguments you want to optimize:
```python
param_dist = {
'd_model': randint(32, 128),
'n_layers': randint(2, 10),
'lr': uniform(1e-5, 1e-3),
"prepro__numerical_preprocessing": ["ple", "standardization", "box-cox"]
}

```


Since we have early stopping integrated and return the best model with respect to the validation loss, setting max_epochs to a large number is sensible.


Or use the built-in bayesian hpo simply by running:

```python
best_params = model.optimize_hparams(X, y)
```

This automatically sets the search space based on the default config from ``mambular.configs``. See the documentation for all params with regard to ``optimize_hparams()``. However, the preprocessor arguments are fixed and cannot be optimized here.


<h2> ⚖️ Distributional Regression with MambularLSS </h2>

Expand Down Expand Up @@ -260,6 +282,7 @@ Here's how you can implement a custom model with Mambular:

```python
from mambular.base_models import BaseModel
from mambular.utils.get_feature_dimensions import get_feature_dimensions
import torch
import torch.nn

Expand All @@ -275,11 +298,7 @@ Here's how you can implement a custom model with Mambular:
super().__init__(**kwargs)
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])

input_dim = 0
for feature_name, input_shape in num_feature_info.items():
input_dim += input_shape
for feature_name, input_shape in cat_feature_info.items():
input_dim += 1
input_dim = get_feature_dimensions(num_feature_info, cat_feature_info)

self.linear = nn.Linear(input_dim, num_classes)

Expand Down Expand Up @@ -311,6 +330,59 @@ Here's how you can implement a custom model with Mambular:
regressor.fit(X_train, y_train, max_epochs=50)
```

# Custom Training
If you prefer to setup custom training, preprocessing and evaluation, you can simply use the `mambular.base_models`.
Just be careful that all basemodels expect lists of features as inputs. More precisely as list for numerical features and a list for categorical features. A custom training loop, with random data could look like this.

```python
import torch
import torch.nn as nn
import torch.optim as optim
from mambular.base_models import Mambular
from mambular.configs import DefaultMambularConfig

# Dummy data and configuration
cat_feature_info = {
"cat1": {
"preprocessing": "imputer -> continuous_ordinal",
"dimension": 1,
"categories": 4,
}
} # Example categorical feature information
num_feature_info = {
"num1": {"preprocessing": "imputer -> scaler", "dimension": 1, "categories": None}
} # Example numerical feature information
num_classes = 1
config = DefaultMambularConfig() # Use the desired configuration

# Initialize model, loss function, and optimizer
model = Mambular(cat_feature_info, num_feature_info, num_classes, config)
criterion = nn.MSELoss() # Use MSE for regression; change as appropriate for your task
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Example training loop
for epoch in range(10): # Number of epochs
model.train()
optimizer.zero_grad()

# Dummy Data
num_features = [torch.randn(32, 1) for _ in num_feature_info]
cat_features = [torch.randint(0, 5, (32,)) for _ in cat_feature_info]
labels = torch.randn(32, num_classes)

# Forward pass
outputs = model(num_features, cat_features)
loss = criterion(outputs, labels)

# Backward pass and optimization
loss.backward()
optimizer.step()

# Print loss for monitoring
print(f"Epoch [{epoch+1}/10], Loss: {loss.item():.4f}")

```

# 🏷️ Citation

If you find this project useful in your research, please consider cite:
Expand All @@ -323,6 +395,16 @@ If you find this project useful in your research, please consider cite:
}
```

If you use TabulaRNN please consider to cite:
```BibTeX
@article{thielmann2024efficiency,
title={On the Efficiency of NLP-Inspired Methods for Tabular Deep Learning},
author={Thielmann, Anton Frederik and Samiee, Soheila},
journal={arXiv preprint arXiv:2411.17207},
year={2024}
}
```

# License

The entire codebase is under MIT license.
Loading