From 63614255b0e959c6ab6da37fb5124a59be4a346d Mon Sep 17 00:00:00 2001
From: Katharine Xiao <2405771+katxiao@users.noreply.github.com>
Date: Wed, 28 Jul 2021 19:08:20 -0400
Subject: [PATCH] Add argument to drop column in ColumnFormula

---
 sdv/constraints/tabular.py             |  8 +++++--
 tests/unit/constraints/test_tabular.py | 30 ++++++++++++++++++++++++++
 2 files changed, 36 insertions(+), 2 deletions(-)

diff --git a/sdv/constraints/tabular.py b/sdv/constraints/tabular.py
index c646fcc6f..03a7070f2 100644
--- a/sdv/constraints/tabular.py
+++ b/sdv/constraints/tabular.py
@@ -483,11 +483,14 @@ class ColumnFormula(Constraint):
         handling_strategy (str):
             How this Constraint should be handled, which can be ``transform``
             or ``reject_sampling``. Defaults to ``transform``.
+        drop_column(str):
+            Whether or not to drop the constraint column.
     """
 
-    def __init__(self, column, formula, handling_strategy='transform'):
+    def __init__(self, column, formula, handling_strategy='transform', drop_column=True):
         self._column = column
         self._formula = import_object(formula)
+        self._drop_column = drop_column
         super().__init__(handling_strategy, fit_columns_model=False)
 
     def is_valid(self, table_data):
@@ -519,7 +522,8 @@ def transform(self, table_data):
                 Transformed data.
         """
         table_data = table_data.copy()
-        del table_data[self._column]
+        if self._drop_column:
+            del table_data[self._column]
 
         return table_data
 
diff --git a/tests/unit/constraints/test_tabular.py b/tests/unit/constraints/test_tabular.py
index 7744b5213..a8c5392a2 100644
--- a/tests/unit/constraints/test_tabular.py
+++ b/tests/unit/constraints/test_tabular.py
@@ -1802,6 +1802,36 @@ def test_transform(self):
         })
         pd.testing.assert_frame_equal(expected_out, out)
 
+    def test_transform_without_dropping_column(self):
+        """Test the ``ColumnFormula.transform`` method without dropping the column.
+
+        If `drop_column` is false, expect to not drop the constraint column.
+
+        Input:
+        - Table data (pandas.DataFrame)
+        Output:
+        - Table data with the indicated column (pandas.DataFrame)
+        """
+        # Setup
+        column = 'c'
+        instance = ColumnFormula(column=column, formula=new_column, drop_column=False)
+
+        # Run
+        table_data = pd.DataFrame({
+            'a': [1, 2, 3],
+            'b': [4, 5, 6],
+            'c': [5, 7, 9]
+        })
+        out = instance.transform(table_data)
+
+        # Assert
+        expected_out = pd.DataFrame({
+            'a': [1, 2, 3],
+            'b': [4, 5, 6],
+            'c': [5, 7, 9]
+        })
+        pd.testing.assert_frame_equal(expected_out, out)
+
     def test_reverse_transform(self):
         """Test the ``ColumnFormula.reverse_transform`` method.