diff --git a/README.md b/README.md
index d2cbd8d712..415bdeb6a0 100644
--- a/README.md
+++ b/README.md
@@ -348,6 +348,8 @@ Join IM discussion groups:
 | OpenPAI | [![Build Status](https://msrasrg.visualstudio.com/NNIOpenSource/_apis/build/status/integration%20test%20-%20openpai%20-%20linux?branchName=master)](https://msrasrg.visualstudio.com/NNIOpenSource/_build/latest?definitionId=65&branchName=master) |
 | Frameworkcontroller | [![Build Status](https://msrasrg.visualstudio.com/NNIOpenSource/_apis/build/status/integration%20test%20-%20frameworkcontroller?branchName=master)](https://msrasrg.visualstudio.com/NNIOpenSource/_build/latest?definitionId=70&branchName=master) |
 | Kubeflow | [![Build Status](https://msrasrg.visualstudio.com/NNIOpenSource/_apis/build/status/integration%20test%20-%20kubeflow?branchName=master)](https://msrasrg.visualstudio.com/NNIOpenSource/_build/latest?definitionId=69&branchName=master) |
+| Hybrid | [![Build Status](https://msrasrg.visualstudio.com/NNIOpenSource/_apis/build/status/integration%20test%20-%20hybrid?branchName=master)](https://msrasrg.visualstudio.com/NNIOpenSource/_build/latest?definitionId=79&branchName=master) |
+| AzureML | [![Build Status](https://msrasrg.visualstudio.com/NNIOpenSource/_apis/build/status/integration%20test%20-%20aml?branchName=master)](https://msrasrg.visualstudio.com/NNIOpenSource/_build/latest?definitionId=78&branchName=master) |
 
 ## Related Projects
 
diff --git a/dependencies/recommended.txt b/dependencies/recommended.txt
index 79a148c3d4..8aaffd4423 100644
--- a/dependencies/recommended.txt
+++ b/dependencies/recommended.txt
@@ -10,3 +10,5 @@ pytorch-lightning >= 1.1.1
 onnx
 peewee
 graphviz
+gym
+tianshou >= 0.4.1
diff --git a/dependencies/recommended_legacy.txt b/dependencies/recommended_legacy.txt
index 19680a870a..5d9c36b6c7 100644
--- a/dependencies/recommended_legacy.txt
+++ b/dependencies/recommended_legacy.txt
@@ -11,3 +11,5 @@ keras == 2.1.6
 onnx
 peewee
 graphviz
+gym
+tianshou >= 0.4.1
diff --git a/docs/en_US/Compression/CompressionReference.rst b/docs/en_US/Compression/CompressionReference.rst
index b616b87a9e..50dcc12876 100644
--- a/docs/en_US/Compression/CompressionReference.rst
+++ b/docs/en_US/Compression/CompressionReference.rst
@@ -34,7 +34,7 @@ Weight Masker
 ..  autoclass:: nni.algorithms.compression.pytorch.pruning.weight_masker.WeightMasker
     :members:
 
-..  autoclass:: nni.algorithms.compression.pytorch.pruning.structured_pruning.StructuredWeightMasker
+..  autoclass:: nni.algorithms.compression.pytorch.pruning.structured_pruning_masker.StructuredWeightMasker
     :members:
 
 
@@ -43,40 +43,40 @@ Pruners
 ..  autoclass:: nni.algorithms.compression.pytorch.pruning.sensitivity_pruner.SensitivityPruner
     :members:
 
-..  autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.OneshotPruner
+..  autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot_pruner.OneshotPruner
     :members:
 
-..  autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.LevelPruner
+..  autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot_pruner.LevelPruner
     :members:
 
-..  autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.SlimPruner
+..  autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot_pruner.L1FilterPruner
     :members:
 
-..  autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.L1FilterPruner
+..  autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot_pruner.L2FilterPruner
     :members:
 
-..  autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.L2FilterPruner
+..  autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot_pruner.FPGMPruner
     :members:
 
-..  autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.FPGMPruner
+..  autoclass:: nni.algorithms.compression.pytorch.pruning.iterative_pruner.IterativePruner
     :members:
 
-..  autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.TaylorFOWeightFilterPruner
+..  autoclass:: nni.algorithms.compression.pytorch.pruning.iterative_pruner.SlimPruner
     :members:
 
-..  autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.ActivationAPoZRankFilterPruner
+..  autoclass:: nni.algorithms.compression.pytorch.pruning.iterative_pruner.TaylorFOWeightFilterPruner
     :members:
 
-..  autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.ActivationMeanRankFilterPruner
+..  autoclass:: nni.algorithms.compression.pytorch.pruning.iterative_pruner.ActivationAPoZRankFilterPruner
     :members:
 
-..  autoclass:: nni.algorithms.compression.pytorch.pruning.lottery_ticket.LotteryTicketPruner
+..  autoclass:: nni.algorithms.compression.pytorch.pruning.iterative_pruner.ActivationMeanRankFilterPruner
     :members:
 
-..  autoclass:: nni.algorithms.compression.pytorch.pruning.agp.AGPPruner
+..  autoclass:: nni.algorithms.compression.pytorch.pruning.iterative_pruner.AGPPruner
     :members:
 
-..  autoclass:: nni.algorithms.compression.pytorch.pruning.admm_pruner.ADMMPruner
+..  autoclass:: nni.algorithms.compression.pytorch.pruning.iterative_pruner.ADMMPruner
     :members:
 
 ..  autoclass:: nni.algorithms.compression.pytorch.pruning.auto_compress_pruner.AutoCompressPruner
@@ -88,6 +88,9 @@ Pruners
 ..  autoclass:: nni.algorithms.compression.pytorch.pruning.simulated_annealing_pruner.SimulatedAnnealingPruner
     :members:
 
+..  autoclass:: nni.algorithms.compression.pytorch.pruning.lottery_ticket.LotteryTicketPruner
+    :members:
+
 
 Quantizers
 ^^^^^^^^^^
diff --git a/docs/en_US/Compression/CustomizeCompressor.rst b/docs/en_US/Compression/CustomizeCompressor.rst
index f2f4f260c1..103bff818c 100644
--- a/docs/en_US/Compression/CustomizeCompressor.rst
+++ b/docs/en_US/Compression/CustomizeCompressor.rst
@@ -28,7 +28,7 @@ An implementation of ``weight masker`` may look like this:
            # mask = ...
            return {'weight_mask': mask}
 
-You can reference nni provided :githublink:`weight masker <nni/algorithms/compression/pytorch/pruning/structured_pruning.py>` implementations to implement your own weight masker.
+You can reference nni provided :githublink:`weight masker <nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py>` implementations to implement your own weight masker.
 
 A basic ``pruner`` looks likes this:
 
@@ -52,7 +52,7 @@ A basic ``pruner`` looks likes this:
                wrapper.if_calculated = True
                return masks
 
-Reference nni provided :githublink:`pruner <nni/algorithms/compression/pytorch/pruning/one_shot.py>` implementations to implement your own pruner class.
+Reference nni provided :githublink:`pruner <nni/algorithms/compression/pytorch/pruning/one_shot_pruner.py>` implementations to implement your own pruner class.
 
 ----
 
diff --git a/docs/en_US/Compression/Overview.rst b/docs/en_US/Compression/Overview.rst
index 262d9631f1..788aa0ac84 100644
--- a/docs/en_US/Compression/Overview.rst
+++ b/docs/en_US/Compression/Overview.rst
@@ -14,10 +14,19 @@ NNI provides a model compression toolkit to help user compress and speed up thei
 * Provide friendly and easy-to-use compression utilities for users to dive into the compression process and results.
 * Concise interface for users to customize their own compression algorithms.
 
+
+Compression Pipeline
+--------------------
+
+.. image:: ../../img/compression_flow.jpg
+   :target: ../../img/compression_flow.jpg
+   :alt: 
+
+The overall compression pipeline in NNI. For compressing a pretrained model, pruning and quantization can be used alone or in combination. 
+
 .. note::
   Since NNI compression algorithms are not meant to compress model while NNI speedup tool can truly compress model and reduce latency. To obtain a truly compact model, users should conduct `model speedup <./ModelSpeedup.rst>`__. The interface and APIs are unified for both PyTorch and TensorFlow, currently only PyTorch version has been supported, TensorFlow version will be supported in future.
 
-
 Supported Algorithms
 --------------------
 
@@ -26,7 +35,7 @@ The algorithms include pruning algorithms and quantization algorithms.
 Pruning Algorithms
 ^^^^^^^^^^^^^^^^^^
 
-Pruning algorithms compress the original network by removing redundant weights or channels of layers, which can reduce model complexity and address the over-fitting issue. 
+Pruning algorithms compress the original network by removing redundant weights or channels of layers, which can reduce model complexity and mitigate the over-fitting issue.
 
 .. list-table::
    :header-rows: 1
@@ -96,6 +105,7 @@ Model Speedup
 
 The final goal of model compression is to reduce inference latency and model size. However, existing model compression algorithms mainly use simulation to check the performance (e.g., accuracy) of compressed model, for example, using masks for pruning algorithms, and storing quantized values still in float32 for quantization algorithms. Given the output masks and quantization bits produced by those algorithms, NNI can really speed up the model. The detailed tutorial of Masked Model Speedup can be found `here <./ModelSpeedup.rst>`__, The detailed tutorial of Mixed Precision Quantization Model Speedup can be found `here <./QuantizationSpeedup.rst>`__.
 
+
 Compression Utilities
 ---------------------
 
@@ -110,7 +120,6 @@ NNI model compression leaves simple interface for users to customize a new compr
 Reference and Feedback
 ----------------------
 
-
 * To `report a bug <https://github.com/microsoft/nni/issues/new?template=bug-report.rst>`__ for this feature in GitHub;
 * To `file a feature or improvement request <https://github.com/microsoft/nni/issues/new?template=enhancement.rst>`__ for this feature in GitHub;
 * To know more about `Feature Engineering with NNI <../FeatureEngineering/Overview.rst>`__\ ;
diff --git a/docs/en_US/Compression/Pruner.rst b/docs/en_US/Compression/Pruner.rst
index 304a56e43e..eb9c32c875 100644
--- a/docs/en_US/Compression/Pruner.rst
+++ b/docs/en_US/Compression/Pruner.rst
@@ -1,15 +1,11 @@
 Supported Pruning Algorithms on NNI
 ===================================
 
-We provide several pruning algorithms that support fine-grained weight pruning and structural filter pruning. **Fine-grained Pruning** generally results in  unstructured models, which need specialized hardware or software to speed up the sparse network. **Filter Pruning** achieves acceleration by removing the entire filter. Some pruning algorithms use one-shot method that prune weights at once based on an importance metric. Other pruning algorithms control the **pruning schedule** that prune weights during optimization, including some automatic pruning algorithms.
+We provide several pruning algorithms that support fine-grained weight pruning and structural filter pruning. **Fine-grained Pruning** generally results in  unstructured models, which need specialized hardware or software to speed up the sparse network. **Filter Pruning** achieves acceleration by removing the entire filter. Some pruning algorithms use one-shot method that prune weights at once based on an importance metric (It is necessary to finetune the model to compensate for the loss of accuracy). Other pruning algorithms **iteratively** prune weights during optimization, which control the pruning schedule, including some automatic pruning algorithms.
 
 
-**Fine-grained Pruning**
-
-* `Level Pruner <#level-pruner>`__
-
-**Filter Pruning**
-
+**One-shot Pruning**
+* `Level Pruner <#level-pruner>`__ ((fine-grained pruning))
 * `Slim Pruner <#slim-pruner>`__
 * `FPGM Pruner <#fpgm-pruner>`__
 * `L1Filter Pruner <#l1filter-pruner>`__
@@ -18,7 +14,7 @@ We provide several pruning algorithms that support fine-grained weight pruning a
 * `Activation Mean Rank Filter Pruner <#activationmeanrankfilter-pruner>`__
 * `Taylor FO On Weight Pruner <#taylorfoweightfilter-pruner>`__
 
-**Pruning Schedule**
+**Iteratively Pruning**
 
 * `AGP Pruner <#agp-pruner>`__
 * `NetAdapt Pruner <#netadapt-pruner>`__
@@ -26,10 +22,9 @@ We provide several pruning algorithms that support fine-grained weight pruning a
 * `AutoCompress Pruner <#autocompress-pruner>`__
 * `AMC Pruner <#amc-pruner>`__
 * `Sensitivity Pruner <#sensitivity-pruner>`__
+* `ADMM Pruner <#admm-pruner>`__
 
 **Others**
-
-* `ADMM Pruner <#admm-pruner>`__
 * `Lottery Ticket Hypothesis <#lottery-ticket-hypothesis>`__
 
 Level Pruner
@@ -382,11 +377,7 @@ PyTorch code
 
    from nni.algorithms.compression.pytorch.pruning import AGPPruner
    config_list = [{
-       'initial_sparsity': 0,
-       'final_sparsity': 0.8,
-       'start_epoch': 0,
-       'end_epoch': 10,
-       'frequency': 1,
+       'sparsity': 0.8,
        'op_types': ['default']
    }]
 
diff --git a/docs/en_US/NAS/retiarii/Advanced.rst b/docs/en_US/NAS/retiarii/Advanced.rst
index 18146ab5a0..12568b8ab6 100644
--- a/docs/en_US/NAS/retiarii/Advanced.rst
+++ b/docs/en_US/NAS/retiarii/Advanced.rst
@@ -1,7 +1,17 @@
 Advanced Tutorial
 =================
 
-This document includes two parts. The first part explains the design decision of ``@basic_unit`` and ``serializer``. The second part is the tutorial of how to write a model space with mutators.
+Pure-python execution engine (experimental)
+-------------------------------------------
+
+If you are experiencing issues with TorchScript, or the generated model code by Retiarii, there is another execution engine called Pure-python execution engine which doesn't need the code-graph conversion. This should generally not affect models and strategies in most cases, but customized mutation might not be supported.
+
+This will come as the default execution engine in future version of Retiarii.
+
+Two steps are needed to enable this engine now.
+
+1. Add ``@nni.retiarii.model_wrapper`` decorator outside the whole PyTorch model.
+2. Add ``config.execution_engine = 'py'`` to ``RetiariiExeConfig``.
 
 ``@basic_unit`` and ``serializer``
 ----------------------------------
diff --git a/docs/en_US/NAS/retiarii/ApiReference.rst b/docs/en_US/NAS/retiarii/ApiReference.rst
index 43a86d8fa8..9d8cd03059 100644
--- a/docs/en_US/NAS/retiarii/ApiReference.rst
+++ b/docs/en_US/NAS/retiarii/ApiReference.rst
@@ -18,6 +18,12 @@ Inline Mutation APIs
 ..  autoclass:: nni.retiarii.nn.pytorch.ChosenInputs
     :members:
 
+..  autoclass:: nni.retiarii.nn.pytorch.Repeat
+    :members:
+
+..  autoclass:: nni.retiarii.nn.pytorch.Cell
+    :members:
+
 Graph Mutation APIs
 -------------------
 
diff --git a/docs/en_US/hpo_benchmark.rst b/docs/en_US/hpo_benchmark.rst
new file mode 100644
index 0000000000..a82baab79f
--- /dev/null
+++ b/docs/en_US/hpo_benchmark.rst
@@ -0,0 +1,237 @@
+
+Benchmark for Tuners
+====================
+
+We provide a benchmarking tool to compare the performances of tuners provided by NNI (and users' custom tuners) on different tasks. The implementation of this tool is based on the automlbenchmark repository (https://github.com/openml/automlbenchmark), which provides services of running different *frameworks* against different *benchmarks* consisting of multiple *tasks*. The tool is located in ``examples/trials/benchmarking/automlbenchmark``. This document provides a brief introduction to the tool and its usage. 
+
+Terminology
+^^^^^^^^^^^
+
+
+* **task**\ : a task can be thought of as (dataset, evaluator). It gives out a dataset containing (train, valid, test), and based on the received predictions, the evaluator evaluates a given metric (e.g., mse for regression, f1 for classification). 
+* **benchmark**\ : a benchmark is a set of tasks, along with other external constraints such as time and resource. 
+* **framework**\ : given a task, a framework conceives answers to the proposed regression or classification problem and produces predictions. Note that the automlbenchmark framework does not pose any restrictions on the hypothesis space of a framework. In our implementation in this folder, each framework is a tuple (tuner, architecture), where architecture provides the hypothesis space (and search space for tuner), and tuner determines the strategy of hyperparameter optimization. 
+* **tuner**\ : a tuner or advisor defined in the hpo folder, or a custom tuner provided by the user. 
+* **architecture**\ : an architecture is a specific method for solving the tasks, along with a set of hyperparameters to optimize (i.e., the search space). In our implementation, the architecture calls tuner multiple times to obtain possible hyperparameter configurations, and produces the final prediction for a task. See ``./nni/extensions/NNI/architectures`` for examples.
+
+Setup
+^^^^^
+
+Due to some incompatibilities between automlbenchmark and python 3.8, python 3.7 is recommended for running experiments contained in this folder. First, run the following shell script to clone the automlbenchmark repository. Note: it is recommended to perform the following steps in a separate virtual environment, as the setup code may install several packages. 
+
+.. code-block:: bash
+
+   ./setup.sh
+
+Run predefined benchmarks on existing tuners
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. code-block:: bash
+
+   ./runbenchmark_nni.sh [tuner-names]
+
+This script runs the benchmark 'nnivalid', which consists of a regression task, a binary classification task, and a multi-class classification task. After the script finishes, you can find a summary of the results in the folder results_[time]/reports/. To run on other predefined benchmarks, change the ``benchmark`` variable in ``runbenchmark_nni.sh``. Some benchmarks are defined in ``/examples/trials/benchmarking/automlbenchmark/nni/benchmarks``\ , and others are defined in ``/examples/trials/benchmarking/automlbenchmark/automlbenchmark/resources/benchmarks/``. One example of larger benchmarks is "nnismall", which consists of 8 regression tasks, 8 binary classification tasks, and 8 multi-class classification tasks.
+
+By default, the script runs the benchmark on all embedded tuners in NNI. If provided a list of tuners in [tuner-names], it only runs the tuners in the list. Currently, the following tuner names are supported: "TPE", "Random", "Anneal", "Evolution", "SMAC", "GPTuner", "MetisTuner", "Hyperband", "BOHB". It is also possible to evaluate custom tuners. See the next sections for details. 
+
+By default, the script runs the specified tuners against the specified benchmark one by one. To run all the experiments simultaneously in the background, set the "serialize" flag to false in ``runbenchmark_nni.sh``. 
+
+Note: the SMAC tuner and the BOHB advisor has to be manually installed before any experiments can be run on it. Please refer to `this page <https://nni.readthedocs.io/en/stable/Tuner/BuiltinTuner.html?highlight=nni>`_ for more details on installing SMAC and BOHB.
+
+Run customized benchmarks on existing tuners
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+To run customized benchmarks, add a benchmark_name.yaml file in the folder ``./nni/benchmarks``\ , and change the ``benchmark`` variable in ``runbenchmark_nni.sh``. See ``./automlbenchmark/resources/benchmarks/`` for some examples of defining a custom benchmark.
+
+Run benchmarks on custom tuners
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+To use custom tuners, first make sure that the tuner inherits from ``nni.tuner.Tuner`` and correctly implements the required APIs. For more information on implementing a custom tuner, please refer to `here <https://nni.readthedocs.io/en/stable/Tuner/CustomizeTuner.html>`_. Next, perform the following steps:
+
+
+#. Install the custom tuner with command ``nnictl algo register``. Check `this document <https://nni.readthedocs.io/en/stable/Tutorial/Nnictl.html>`_ for details. 
+#. In ``./nni/frameworks.yaml``\ , add a new framework extending the base framework NNI. Make sure that the parameter ``tuner_type`` corresponds to the "builtinName" of tuner installed in step 1.
+#. Run the following command
+
+.. code-block:: bash
+
+      ./runbenchmark_nni.sh new-tuner-builtinName
+
+A Benchmark Example 
+^^^^^^^^^^^^^^^^^^^
+
+As an example, we ran the "nnismall" benchmark on the following 8 tuners: "TPE", "Random", "Anneal", "Evolution", "SMAC", "GPTuner", "MetisTuner", "DngoTuner". (The DngoTuner is not available as a built-in tuner at the time this article is written.) As some of the tasks contains a considerable amount of training data, it took about 2 days to run the whole benchmark on one tuner using a single CPU core. For a more detailed description of the tasks, please check ``/examples/trials/benchmarking/automlbenchmark/nni/benchmarks/nnismall_description.txt``. 
+
+After the script finishes, the final scores of each tuner is summarized in the file ``results[time]/reports/performances.txt``. Since the file is large, we only show the following screenshot and summarize other important statistics instead. 
+
+.. image:: ../img/hpo_benchmark/performances.png
+   :target: ../img/hpo_benchmark/performances.png
+   :alt: 
+
+When the results are parsed, the tuners are ranked based on their final performance. ``results[time]/reports/rankings.txt`` presents a ranking of the tuners for each metric (logloss, rmse, auc), and the rankings of tuners for each metric (another view of the same data).
+
+Average rankings for metric rmse:
+
+.. list-table::
+   :header-rows: 1
+
+   * - Tuner Name
+     - Average Ranking
+   * - Anneal
+     - 3.75
+   * - Random
+     - 4.00
+   * - Evolution
+     - 4.44
+   * - DNGOTuner
+     - 4.44
+   * - SMAC
+     - 4.56
+   * - TPE
+     - 4.94
+   * - GPTuner
+     - 4.94
+   * - MetisTuner
+     - 4.94
+
+Average rankings for metric auc:
+
+.. list-table::
+   :header-rows: 1
+
+   * - Tuner Name
+     - Average Ranking
+   * - SMAC
+     - 3.67
+   * - GPTuner
+     - 4.00
+   * - Evolution
+     - 4.22
+   * - Anneal
+     - 4.39
+   * - MetisTuner
+     - 4.39
+   * - TPE
+     - 4.67
+   * - Random
+     - 5.33
+   * - DNGOTuner
+     - 5.33
+
+Average rankings for metric logloss:
+
+.. list-table::
+   :header-rows: 1
+
+   * - Tuner Name
+     - Average Ranking
+   * - Random
+     - 3.36
+   * - DNGOTuner
+     - 3.50
+   * - SMAC
+     - 3.93
+   * - GPTuner
+     - 4.64
+   * - TPE
+     - 4.71
+   * - Anneal
+     - 4.93
+   * - Evolution
+     - 5.00
+   * - MetisTuner
+     - 5.93
+
+Average rankings for tuners:
+
+.. list-table::
+   :header-rows: 1
+
+   * - Tuner Name
+     - rmse
+     - auc
+     - logloss
+   * - TPE
+     - 4.94
+     - 4.67
+     - 4.71
+   * - Random
+     - 4.00
+     - 5.33
+     - 3.36
+   * - Anneal
+     - 3.75
+     - 4.39
+     - 4.93
+   * - Evolution
+     - 4.44
+     - 4.22
+     - 5.00
+   * - GPTuner
+     - 4.94
+     - 4.00
+     - 4.64
+   * - MetisTuner
+     - 4.94
+     - 4.39
+     - 5.93
+   * - SMAC
+     - 4.56
+     - 3.67
+     - 3.93
+   * - DNGOTuner
+     - 4.44
+     - 5.33
+     - 3.50
+
+Besides these reports, our script also generates two graphs for each fold of each task. The first graph presents the best score seen by each tuner until trial x, and the second graph shows the scores of each tuner in trial x. These two graphs can give some information regarding how the tuners are "converging". We found that for "nnismall", tuners on the random forest model with search space defined in ``/examples/trials/benchmarking/automlbenchmark/nni/extensions/NNI/architectures/run_random_forest.py`` generally converge to the final solution after 40 to 60 trials. As there are too much graphs to incldue in a single report (96 graphs in total), we only present 10 graphs here.
+
+.. image:: ../img/hpo_benchmark/car_fold1_1.jpg
+   :target: ../img/hpo_benchmark/car_fold1_1.jpg
+   :alt: 
+
+
+.. image:: ../img/hpo_benchmark/car_fold1_2.jpg
+   :target: ../img/hpo_benchmark/car_fold1_2.jpg
+   :alt: 
+
+
+.. image:: ../img/hpo_benchmark/christine_fold0_1.jpg
+   :target: ../img/hpo_benchmark/christine_fold0_1.jpg
+   :alt: 
+
+
+.. image:: ../img/hpo_benchmark/christine_fold0_2.jpg
+   :target: ../img/hpo_benchmark/christine_fold0_2.jpg
+   :alt: 
+
+
+.. image:: ../img/hpo_benchmark/cnae-9_fold0_1.jpg
+   :target: ../img/hpo_benchmark/cnae-9_fold0_1.jpg
+   :alt: 
+
+
+.. image:: ../img/hpo_benchmark/cnae-9_fold0_2.jpg
+   :target: ../img/hpo_benchmark/cnae-9_fold0_2.jpg
+   :alt: 
+
+
+.. image:: ../img/hpo_benchmark/credit-g_fold1_1.jpg
+   :target: ../img/hpo_benchmark/credit-g_fold1_1.jpg
+   :alt: 
+
+
+.. image:: ../img/hpo_benchmark/credit-g_fold1_2.jpg
+   :target: ../img/hpo_benchmark/credit-g_fold1_2.jpg
+   :alt: 
+
+
+.. image:: ../img/hpo_benchmark/titanic_2_fold1_1.jpg
+   :target: ../img/hpo_benchmark/titanic_2_fold1_1.jpg
+   :alt: 
+
+
+.. image:: ../img/hpo_benchmark/titanic_2_fold1_2.jpg
+   :target: ../img/hpo_benchmark/titanic_2_fold1_2.jpg
+   :alt: 
+
diff --git a/docs/en_US/hyperparameter_tune.rst b/docs/en_US/hyperparameter_tune.rst
index 69078abf71..3f37b10c60 100644
--- a/docs/en_US/hyperparameter_tune.rst
+++ b/docs/en_US/hyperparameter_tune.rst
@@ -24,4 +24,5 @@ according to their needs.
     Examples <examples>
     WebUI <Tutorial/WebUI>
     How to Debug <Tutorial/HowToDebug>
-    Advanced <hpo_advanced>
\ No newline at end of file
+    Advanced <hpo_advanced>
+    Benchmark for Tuners <hpo_benchmark>
diff --git a/docs/img/compression_flow.jpg b/docs/img/compression_flow.jpg
new file mode 100644
index 0000000000..18c6a0d22e
Binary files /dev/null and b/docs/img/compression_flow.jpg differ
diff --git a/docs/img/hpo_benchmark/car_fold1_1.jpg b/docs/img/hpo_benchmark/car_fold1_1.jpg
new file mode 100644
index 0000000000..db30b6252a
Binary files /dev/null and b/docs/img/hpo_benchmark/car_fold1_1.jpg differ
diff --git a/docs/img/hpo_benchmark/car_fold1_2.jpg b/docs/img/hpo_benchmark/car_fold1_2.jpg
new file mode 100644
index 0000000000..16701e3667
Binary files /dev/null and b/docs/img/hpo_benchmark/car_fold1_2.jpg differ
diff --git a/docs/img/hpo_benchmark/christine_fold0_1.jpg b/docs/img/hpo_benchmark/christine_fold0_1.jpg
new file mode 100644
index 0000000000..eb2549c63f
Binary files /dev/null and b/docs/img/hpo_benchmark/christine_fold0_1.jpg differ
diff --git a/docs/img/hpo_benchmark/christine_fold0_2.jpg b/docs/img/hpo_benchmark/christine_fold0_2.jpg
new file mode 100644
index 0000000000..104019c9a3
Binary files /dev/null and b/docs/img/hpo_benchmark/christine_fold0_2.jpg differ
diff --git a/docs/img/hpo_benchmark/cnae-9_fold0_1.jpg b/docs/img/hpo_benchmark/cnae-9_fold0_1.jpg
new file mode 100644
index 0000000000..01e0b7137e
Binary files /dev/null and b/docs/img/hpo_benchmark/cnae-9_fold0_1.jpg differ
diff --git a/docs/img/hpo_benchmark/cnae-9_fold0_2.jpg b/docs/img/hpo_benchmark/cnae-9_fold0_2.jpg
new file mode 100644
index 0000000000..db8c2e1a16
Binary files /dev/null and b/docs/img/hpo_benchmark/cnae-9_fold0_2.jpg differ
diff --git a/docs/img/hpo_benchmark/credit-g_fold1_1.jpg b/docs/img/hpo_benchmark/credit-g_fold1_1.jpg
new file mode 100644
index 0000000000..a71ecdcfe1
Binary files /dev/null and b/docs/img/hpo_benchmark/credit-g_fold1_1.jpg differ
diff --git a/docs/img/hpo_benchmark/credit-g_fold1_2.jpg b/docs/img/hpo_benchmark/credit-g_fold1_2.jpg
new file mode 100644
index 0000000000..4bb1982290
Binary files /dev/null and b/docs/img/hpo_benchmark/credit-g_fold1_2.jpg differ
diff --git a/docs/img/hpo_benchmark/performances.png b/docs/img/hpo_benchmark/performances.png
new file mode 100644
index 0000000000..3a3d75653c
Binary files /dev/null and b/docs/img/hpo_benchmark/performances.png differ
diff --git a/docs/img/hpo_benchmark/titanic_2_fold1_1.jpg b/docs/img/hpo_benchmark/titanic_2_fold1_1.jpg
new file mode 100644
index 0000000000..97124b3874
Binary files /dev/null and b/docs/img/hpo_benchmark/titanic_2_fold1_1.jpg differ
diff --git a/docs/img/hpo_benchmark/titanic_2_fold1_2.jpg b/docs/img/hpo_benchmark/titanic_2_fold1_2.jpg
new file mode 100644
index 0000000000..65f29cb11b
Binary files /dev/null and b/docs/img/hpo_benchmark/titanic_2_fold1_2.jpg differ
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 04c0633cba..47c66e982d 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -22,5 +22,7 @@ prettytable
 psutil
 ruamel.yaml
 ipython
+gym
+tianshou
 https://download.pytorch.org/whl/cpu/torch-1.7.1%2Bcpu-cp37-cp37m-linux_x86_64.whl
 https://download.pytorch.org/whl/cpu/torchvision-0.8.2%2Bcpu-cp37-cp37m-linux_x86_64.whl
diff --git a/examples/model_compress/.gitignore b/examples/model_compress/.gitignore
new file mode 100644
index 0000000000..c2e41e6b0e
--- /dev/null
+++ b/examples/model_compress/.gitignore
@@ -0,0 +1,6 @@
+.pth
+.tar.gz
+data/
+MNIST/
+cifar-10-batches-py/
+experiment_data/
\ No newline at end of file
diff --git a/examples/model_compress/end2end_compression.py b/examples/model_compress/end2end_compression.py
new file mode 100644
index 0000000000..062d6351d6
--- /dev/null
+++ b/examples/model_compress/end2end_compression.py
@@ -0,0 +1,300 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+NNI example for combined pruning and quantization to compress a model.
+In this example, we show the compression process to first prune a model, then quantize the pruned model.
+
+"""
+import argparse
+import os
+import time
+import torch
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.optim.lr_scheduler import StepLR
+from torchvision import datasets, transforms
+
+from nni.compression.pytorch.utils.counter import count_flops_params
+from nni.compression.pytorch import ModelSpeedup
+
+from nni.algorithms.compression.pytorch.pruning import L1FilterPruner
+from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer
+
+from models.mnist.naive import NaiveModel
+from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT
+
+
+def get_model_time_cost(model, dummy_input):
+    model.eval()
+    n_times = 100
+    time_list = []
+    for _ in range(n_times):
+        torch.cuda.synchronize()
+        tic = time.time()
+        _ = model(dummy_input)
+        torch.cuda.synchronize()
+        time_list.append(time.time()-tic)
+    time_list = time_list[10:]
+    return sum(time_list) / len(time_list)
+
+
+def train(args, model, device, train_loader, criterion, optimizer, epoch):
+    model.train()
+    for batch_idx, (data, target) in enumerate(train_loader):
+        data, target = data.to(device), target.to(device)
+        optimizer.zero_grad()
+        output = model(data)
+        loss = criterion(output, target)
+        loss.backward()
+
+        optimizer.step()
+        if batch_idx % args.log_interval == 0:
+            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
+                epoch, batch_idx * len(data), len(train_loader.dataset),
+                100. * batch_idx / len(train_loader), loss.item()))
+            if args.dry_run:
+                break
+
+
+def test(args, model, device, criterion, test_loader):
+    model.eval()
+    test_loss = 0
+    correct = 0
+    with torch.no_grad():
+        for data, target in test_loader:
+            data, target = data.to(device), target.to(device)
+            output = model(data)
+            test_loss += criterion(output, target).item()
+            pred = output.argmax(dim=1, keepdim=True)
+            correct += pred.eq(target.view_as(pred)).sum().item()
+    test_loss /= len(test_loader.dataset)
+    acc = 100 * correct / len(test_loader.dataset)
+
+    print('Test Loss: {:.6f}  Accuracy: {}%\n'.format(
+        test_loss, acc))
+    return acc
+
+def test_trt(engine, test_loader):
+    test_loss = 0
+    correct = 0
+    time_elasped = 0
+    for data, target in test_loader:
+        output, time = engine.inference(data)
+        test_loss += F.nll_loss(output, target, reduction='sum').item()
+        pred = output.argmax(dim=1, keepdim=True)
+        correct += pred.eq(target.view_as(pred)).sum().item()
+        time_elasped += time
+    test_loss /= len(test_loader.dataset)
+
+    print('Loss: {}  Accuracy: {}%'.format(
+        test_loss, 100 * correct / len(test_loader.dataset)))
+    print("Inference elapsed_time (whole dataset): {}s".format(time_elasped))
+
+def main(args):
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    os.makedirs(args.experiment_data_dir, exist_ok=True)
+
+    transform = transforms.Compose([
+        transforms.ToTensor(),
+        transforms.Normalize((0.1307,), (0.3081,))
+        ])
+
+    train_loader = torch.utils.data.DataLoader(
+        datasets.MNIST('data', train=True, download=True, transform=transform),
+        batch_size=64,)
+    test_loader = torch.utils.data.DataLoader(
+        datasets.MNIST('data', train=False, transform=transform),
+        batch_size=1000)
+
+    # Step1. Model Pretraining
+    model = NaiveModel().to(device)
+    criterion = torch.nn.NLLLoss()
+    optimizer = optim.Adadelta(model.parameters(), lr=args.pretrain_lr)
+    scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
+    flops, params, _ = count_flops_params(model, (1, 1, 28, 28), verbose=False)
+
+    if args.pretrained_model_dir is None:
+        args.pretrained_model_dir = os.path.join(args.experiment_data_dir, f'pretrained.pth')
+
+        best_acc = 0
+        for epoch in range(args.pretrain_epochs):
+            train(args, model, device, train_loader, criterion, optimizer, epoch)
+            scheduler.step()
+            acc = test(args, model, device, criterion, test_loader)
+            if acc > best_acc:
+                best_acc = acc
+                state_dict = model.state_dict()
+
+        model.load_state_dict(state_dict)
+        torch.save(state_dict, args.pretrained_model_dir)
+        print(f'Model saved to {args.pretrained_model_dir}')
+    else:
+        state_dict = torch.load(args.pretrained_model_dir)
+        model.load_state_dict(state_dict)
+        best_acc = test(args, model, device, criterion, test_loader)
+
+    dummy_input = torch.randn([1000, 1, 28, 28]).to(device)
+    time_cost = get_model_time_cost(model, dummy_input)
+
+    # 125.49 M, 0.85M, 93.29, 1.1012
+    print(f'Pretrained model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {best_acc: .2f}, Time Cost: {time_cost}')
+
+    # Step2. Model Pruning
+    config_list = [{
+        'sparsity': args.sparsity,
+        'op_types': ['Conv2d']
+    }]
+
+    kw_args = {}
+    if args.dependency_aware:
+        dummy_input = torch.randn([1000, 1, 28, 28]).to(device)
+        print('Enable the dependency_aware mode')
+        # note that, not all pruners support the dependency_aware mode
+        kw_args['dependency_aware'] = True
+        kw_args['dummy_input'] = dummy_input
+
+    pruner = L1FilterPruner(model, config_list, **kw_args)
+    model = pruner.compress()
+    pruner.get_pruned_weights()
+
+    mask_path = os.path.join(args.experiment_data_dir, 'mask.pth')
+    model_path = os.path.join(args.experiment_data_dir, 'pruned.pth')
+    pruner.export_model(model_path=model_path, mask_path=mask_path)
+    pruner._unwrap_model()  # unwrap all modules to normal state
+
+    # Step3. Model Speedup
+    m_speedup = ModelSpeedup(model, dummy_input, mask_path, device)
+    m_speedup.speedup_model()
+    print('model after speedup', model)
+
+    flops, params, _ = count_flops_params(model, dummy_input, verbose=False)
+    acc = test(args, model, device, criterion, test_loader)
+    time_cost = get_model_time_cost(model, dummy_input)
+    print(f'Pruned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {acc: .2f}, Time Cost: {time_cost}')
+
+    # Step4. Model Finetuning
+    optimizer = optim.Adadelta(model.parameters(), lr=args.pretrain_lr)
+    scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
+
+    best_acc = 0
+    for epoch in range(args.finetune_epochs):
+        train(args, model, device, train_loader, criterion, optimizer, epoch)
+        scheduler.step()
+        acc = test(args, model, device, criterion, test_loader)
+        if acc > best_acc:
+            best_acc = acc
+            state_dict = model.state_dict()
+
+    model.load_state_dict(state_dict)
+    save_path = os.path.join(args.experiment_data_dir, f'finetuned.pth')
+    torch.save(state_dict, save_path)
+
+    flops, params, _ = count_flops_params(model, dummy_input, verbose=True)
+    time_cost = get_model_time_cost(model, dummy_input)
+
+    # FLOPs 28.48 M, #Params: 0.18M, Accuracy:  89.03, Time Cost: 1.03
+    print(f'Finetuned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {best_acc: .2f}, Time Cost: {time_cost}')
+    print(f'Model saved to {save_path}')
+
+    # Step5. Model Quantization via QAT
+    config_list = [{
+        'quant_types': ['weight', 'output'],
+        'quant_bits': {'weight': 8, 'output': 8},
+        'op_names': ['conv1']
+    }, {
+        'quant_types': ['output'],
+        'quant_bits': {'output':8},
+        'op_names': ['relu1']
+    }, {
+        'quant_types': ['weight', 'output'],
+        'quant_bits': {'weight': 8, 'output': 8},
+        'op_names': ['conv2']
+    }, {
+        'quant_types': ['output'],
+        'quant_bits': {'output': 8},
+        'op_names': ['relu2']
+    }]
+
+    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
+    quantizer = QAT_Quantizer(model, config_list, optimizer)
+    quantizer.compress()
+
+    # Step6. Quantization Aware Training
+    best_acc = 0
+    for epoch in range(1):
+        train(args, model, device, train_loader, criterion, optimizer, epoch)
+        scheduler.step()
+        acc = test(args, model, device, criterion, test_loader)
+        if acc > best_acc:
+            best_acc = acc
+            state_dict = model.state_dict()
+
+    calibration_path = os.path.join(args.experiment_data_dir, 'calibration.pth')
+    calibration_config = quantizer.export_model(model_path, calibration_path)
+    print("calibration_config: ", calibration_config)
+
+    # Step7. Model Speedup
+    batch_size = 32
+    input_shape = (batch_size, 1, 28, 28)
+    engine = ModelSpeedupTensorRT(model, input_shape, config=calibration_config, batchsize=32)
+    engine.compress()
+
+    test_trt(engine, test_loader)
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='PyTorch Example for model comporession')
+
+    # dataset and model
+    # parser.add_argument('--dataset', type=str, default='mnist',
+    #                     help='dataset to use, mnist, cifar10 or imagenet')
+    # parser.add_argument('--data-dir', type=str, default='./data/',
+    #                     help='dataset directory')
+    parser.add_argument('--pretrained-model-dir', type=str, default=None,
+                        help='path to pretrained model')
+    parser.add_argument('--pretrain-epochs', type=int, default=10,
+                        help='number of epochs to pretrain the model')
+    parser.add_argument('--pretrain-lr', type=float, default=1.0,
+                        help='learning rate to pretrain the model')
+
+    parser.add_argument('--experiment-data-dir', type=str, default='./experiment_data',
+                        help='For saving output checkpoints')
+    parser.add_argument('--log-interval', type=int, default=100, metavar='N',
+                        help='how many batches to wait before logging training status')
+    parser.add_argument('--dry-run', action='store_true', default=False,
+                        help='quickly check a single pass')
+    # parser.add_argument('--multi-gpu', action='store_true', default=False,
+    #                     help='run on mulitple gpus')
+    # parser.add_argument('--test-only', action='store_true', default=False,
+    #                     help='run test only')
+
+    # pruner
+    # parser.add_argument('--pruner', type=str, default='l1filter',
+    #                     choices=['level', 'l1filter', 'l2filter', 'slim', 'agp',
+    #                              'fpgm', 'mean_activation', 'apoz', 'admm'],
+    #                     help='pruner to use')
+    parser.add_argument('--sparsity', type=float, default=0.5,
+                        help='target overall target sparsity')
+    parser.add_argument('--dependency-aware', action='store_true', default=False,
+                        help='toggle dependency aware mode')
+
+    # finetuning
+    parser.add_argument('--finetune-epochs', type=int, default=5,
+                        help='epochs to fine tune')
+    # parser.add_argument('--kd', action='store_true', default=False,
+    #                     help='quickly check a single pass')
+    # parser.add_argument('--kd_T', type=float, default=4,
+    #                     help='temperature for KD distillation')
+    # parser.add_argument('--finetune-lr', type=float, default=0.5,
+    #                     help='learning rate to finetune the model')
+
+    # speedup
+    # parser.add_argument('--speed-up', action='store_true', default=False,
+    #                     help='whether to speed-up the pruned model')
+
+    # parser.add_argument('--nni', action='store_true', default=False,
+    #                     help="whether to tune the pruners using NNi tuners")
+
+    args = parser.parse_args()
+    main(args)
diff --git a/examples/model_compress/pruning/models/cifar10/resnet.py b/examples/model_compress/models/cifar10/resnet.py
similarity index 100%
rename from examples/model_compress/pruning/models/cifar10/resnet.py
rename to examples/model_compress/models/cifar10/resnet.py
diff --git a/examples/model_compress/pruning/models/cifar10/vgg.py b/examples/model_compress/models/cifar10/vgg.py
similarity index 100%
rename from examples/model_compress/pruning/models/cifar10/vgg.py
rename to examples/model_compress/models/cifar10/vgg.py
diff --git a/examples/model_compress/pruning/models/mnist/lenet.py b/examples/model_compress/models/mnist/lenet.py
similarity index 100%
rename from examples/model_compress/pruning/models/mnist/lenet.py
rename to examples/model_compress/models/mnist/lenet.py
diff --git a/examples/model_compress/models/mnist/naive.py b/examples/model_compress/models/mnist/naive.py
new file mode 100644
index 0000000000..4609862527
--- /dev/null
+++ b/examples/model_compress/models/mnist/naive.py
@@ -0,0 +1,27 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import reduce
+
+class NaiveModel(torch.nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
+        self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
+        self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
+        self.fc2 = torch.nn.Linear(500, 10)
+        self.relu1 = torch.nn.ReLU6()
+        self.relu2 = torch.nn.ReLU6()
+        self.relu3 = torch.nn.ReLU6()
+        self.max_pool1 = torch.nn.MaxPool2d(2, 2)
+        self.max_pool2 = torch.nn.MaxPool2d(2, 2)
+
+    def forward(self, x):
+        x = self.relu1(self.conv1(x))
+        x = self.max_pool1(x)
+        x = self.relu2(self.conv2(x))
+        x = self.max_pool2(x)
+        x = x.view(-1, x.size()[1:].numel())
+        x = self.relu3(self.fc1(x))
+        x = self.fc2(x)
+        return F.log_softmax(x, dim=1)
\ No newline at end of file
diff --git a/examples/model_compress/pruning/models/mobilenet.py b/examples/model_compress/models/mobilenet.py
similarity index 100%
rename from examples/model_compress/pruning/models/mobilenet.py
rename to examples/model_compress/models/mobilenet.py
diff --git a/examples/model_compress/pruning/models/mobilenet_v2.py b/examples/model_compress/models/mobilenet_v2.py
similarity index 100%
rename from examples/model_compress/pruning/models/mobilenet_v2.py
rename to examples/model_compress/models/mobilenet_v2.py
diff --git a/examples/model_compress/pruning/amc/amc_search.py b/examples/model_compress/pruning/amc/amc_search.py
index 6e10f554b9..5c861a8887 100644
--- a/examples/model_compress/pruning/amc/amc_search.py
+++ b/examples/model_compress/pruning/amc/amc_search.py
@@ -12,7 +12,7 @@
 from data import get_split_dataset
 from utils import AverageMeter, accuracy
 
-sys.path.append('../models')
+sys.path.append('../../models')
 
 def parse_args():
     parser = argparse.ArgumentParser(description='AMC search script')
diff --git a/examples/model_compress/pruning/amc/amc_train.py b/examples/model_compress/pruning/amc/amc_train.py
index 732d3bbae9..eb02c7020a 100644
--- a/examples/model_compress/pruning/amc/amc_train.py
+++ b/examples/model_compress/pruning/amc/amc_train.py
@@ -22,7 +22,7 @@
 from data import get_dataset
 from utils import AverageMeter, accuracy, progress_bar
 
-sys.path.append('../models')
+sys.path.append('../../models')
 from mobilenet import MobileNet
 from mobilenet_v2 import MobileNetV2
 
diff --git a/examples/model_compress/pruning/auto_pruners_torch.py b/examples/model_compress/pruning/auto_pruners_torch.py
index d9e0f53824..f32faccfa8 100644
--- a/examples/model_compress/pruning/auto_pruners_torch.py
+++ b/examples/model_compress/pruning/auto_pruners_torch.py
@@ -13,14 +13,16 @@
 from torch.optim.lr_scheduler import StepLR, MultiStepLR
 from torchvision import datasets, transforms
 
-from models.mnist.lenet import LeNet
-from models.cifar10.vgg import VGG
-from models.cifar10.resnet import ResNet18, ResNet50
 from nni.algorithms.compression.pytorch.pruning import L1FilterPruner, L2FilterPruner, FPGMPruner
 from nni.algorithms.compression.pytorch.pruning import SimulatedAnnealingPruner, ADMMPruner, NetAdaptPruner, AutoCompressPruner
 from nni.compression.pytorch import ModelSpeedup
 from nni.compression.pytorch.utils.counter import count_flops_params
 
+import sys
+sys.path.append('../models')
+from mnist.lenet import LeNet
+from cifar10.vgg import VGG
+from cifar10.resnet import ResNet18, ResNet50
 
 def get_data(dataset, data_dir, batch_size, test_batch_size):
     '''
@@ -67,7 +69,7 @@ def get_data(dataset, data_dir, batch_size, test_batch_size):
     return train_loader, val_loader, criterion
 
 
-def train(args, model, device, train_loader, criterion, optimizer, epoch, callback=None):
+def train(args, model, device, train_loader, criterion, optimizer, epoch):
     model.train()
     for batch_idx, (data, target) in enumerate(train_loader):
         data, target = data.to(device), target.to(device)
@@ -75,9 +77,6 @@ def train(args, model, device, train_loader, criterion, optimizer, epoch, callba
         output = model(data)
         loss = criterion(output, target)
         loss.backward()
-        # callback should be inserted between loss.backward() and optimizer.step()
-        if callback:
-            callback()
         optimizer.step()
         if batch_idx % args.log_interval == 0:
             print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
@@ -198,8 +197,8 @@ def short_term_fine_tuner(model, epochs=1):
         for epoch in range(epochs):
             train(args, model, device, train_loader, criterion, optimizer, epoch)
 
-    def trainer(model, optimizer, criterion, epoch, callback):
-        return train(args, model, device, train_loader, criterion, optimizer, epoch=epoch, callback=callback)
+    def trainer(model, optimizer, criterion, epoch):
+        return train(args, model, device, train_loader, criterion, optimizer, epoch=epoch)
 
     def evaluator(model):
         return test(model, device, criterion, val_loader)
@@ -264,7 +263,7 @@ def evaluator(model):
                 }]
         else:
             raise ValueError('Example only implemented for LeNet.')
-        pruner = ADMMPruner(model, config_list, trainer=trainer, num_iterations=2, training_epochs=2)
+        pruner = ADMMPruner(model, config_list, trainer=trainer, num_iterations=2, epochs_per_iteration=2)
     elif args.pruner == 'SimulatedAnnealingPruner':
         pruner = SimulatedAnnealingPruner(
             model, config_list, evaluator=evaluator, base_algo=args.base_algo,
@@ -273,7 +272,7 @@ def evaluator(model):
         pruner = AutoCompressPruner(
             model, config_list, trainer=trainer, evaluator=evaluator, dummy_input=dummy_input,
             num_iterations=3, optimize_mode='maximize', base_algo=args.base_algo,
-            cool_down_rate=args.cool_down_rate, admm_num_iterations=30, admm_training_epochs=5,
+            cool_down_rate=args.cool_down_rate, admm_num_iterations=30, admm_epochs_per_iteration=5,
             experiment_data_dir=args.experiment_data_dir)
     else:
         raise ValueError(
diff --git a/examples/model_compress/pruning/basic_pruners_torch.py b/examples/model_compress/pruning/basic_pruners_torch.py
index c3225353f4..51c2068aa3 100644
--- a/examples/model_compress/pruning/basic_pruners_torch.py
+++ b/examples/model_compress/pruning/basic_pruners_torch.py
@@ -12,25 +12,24 @@
 
 import argparse
 import os
-import time
+import sys
 import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.optim as optim
 from torch.optim.lr_scheduler import StepLR, MultiStepLR
 from torchvision import datasets, transforms
 
-from models.mnist.lenet import LeNet
-from models.cifar10.vgg import VGG
+sys.path.append('../models')
+from mnist.lenet import LeNet
+from cifar10.vgg import VGG
 
 from nni.compression.pytorch.utils.counter import count_flops_params
 
 import nni
-from nni.compression.pytorch import apply_compression_results, ModelSpeedup
+from nni.compression.pytorch import ModelSpeedup
 from nni.algorithms.compression.pytorch.pruning import (
     LevelPruner,
     SlimPruner,
     FPGMPruner,
+    TaylorFOWeightFilterPruner,
     L1FilterPruner,
     L2FilterPruner,
     AGPPruner,
@@ -38,7 +37,6 @@
     ActivationAPoZRankFilterPruner
 )
 
-
 _logger = logging.getLogger('mnist_example')
 _logger.setLevel(logging.INFO)
 
@@ -50,7 +48,8 @@
     'agp': AGPPruner,
     'fpgm': FPGMPruner,
     'mean_activation': ActivationMeanRankFilterPruner,
-    'apoz': ActivationAPoZRankFilterPruner
+    'apoz': ActivationAPoZRankFilterPruner,
+    'taylorfo': TaylorFOWeightFilterPruner
 }
 
 def get_dummy_input(args, device):
@@ -60,53 +59,6 @@ def get_dummy_input(args, device):
         dummy_input = torch.randn([args.test_batch_size, 3, 32, 32]).to(device)
     return dummy_input
 
-def get_pruner(model, pruner_name, device, optimizer=None, dependency_aware=False):
-
-    pruner_cls = str2pruner[pruner_name]
-
-    if pruner_name == 'level':
-        config_list = [{
-            'sparsity': args.sparsity,
-            'op_types': ['default']
-        }]
-    elif pruner_name in ['l1filter', 'mean_activation', 'apoz']:
-        # Reproduced result in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS',
-        # Conv_1, Conv_8, Conv_9, Conv_10, Conv_11, Conv_12 are pruned with 50% sparsity, as 'VGG-16-pruned-A'
-        config_list = [{
-            'sparsity': args.sparsity,
-            'op_types': ['Conv2d'],
-            'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
-        }]
-    elif pruner_name == 'slim':
-        config_list = [{
-            'sparsity': args.sparsity,
-            'op_types': ['BatchNorm2d'],
-        }]
-    elif pruner_name == 'agp':
-        config_list = [{
-            'initial_sparsity': 0.,
-            'final_sparsity': 0.8,
-            'start_epoch': 0,
-            'end_epoch': 10,
-            'frequency': 1,
-            'op_types': ['Conv2d']
-        }]
-    else:
-        config_list = [{
-            'sparsity': args.sparsity,
-            'op_types': ['Conv2d']
-        }]
-
-    kw_args = {}
-    if dependency_aware:
-        dummy_input = get_dummy_input(args, device)
-        print('Enable the dependency_aware mode')
-        # note that, not all pruners support the dependency_aware mode
-        kw_args['dependency_aware'] = True
-        kw_args['dummy_input'] = dummy_input
-
-    pruner = pruner_cls(model, config_list, optimizer, **kw_args)
-    return pruner
 
 def get_data(dataset, data_dir, batch_size, test_batch_size):
     kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {
@@ -174,7 +126,7 @@ def get_model_optimizer_scheduler(args, device, train_loader, test_loader, crite
         print('start pre-training...')
         best_acc = 0
         for epoch in range(args.pretrain_epochs):
-            train(args, model, device, train_loader, criterion, optimizer, epoch, sparse_bn=True if args.pruner == 'slim' else False)
+            train(args, model, device, train_loader, criterion, optimizer, epoch)
             scheduler.step()
             acc = test(args, model, device, criterion, test_loader)
             if acc > best_acc:
@@ -198,12 +150,7 @@ def get_model_optimizer_scheduler(args, device, train_loader, test_loader, crite
     print('Pretrained model acc:', best_acc)
     return model, optimizer, scheduler
 
-def updateBN(model):
-    for m in model.modules():
-        if isinstance(m, nn.BatchNorm2d):
-            m.weight.grad.data.add_(0.0001 * torch.sign(m.weight.data))
-
-def train(args, model, device, train_loader, criterion, optimizer, epoch, sparse_bn=False):
+def train(args, model, device, train_loader, criterion, optimizer, epoch):
     model.train()
     for batch_idx, (data, target) in enumerate(train_loader):
         data, target = data.to(device), target.to(device)
@@ -211,11 +158,6 @@ def train(args, model, device, train_loader, criterion, optimizer, epoch, sparse
         output = model(data)
         loss = criterion(output, target)
         loss.backward()
-
-        if sparse_bn:
-            # L1 regularization on BN layer
-            updateBN(model)
-
         optimizer.step()
         if batch_idx % args.log_interval == 0:
             print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
@@ -256,64 +198,99 @@ def main(args):
     flops, params, results = count_flops_params(model, dummy_input)
     print(f"FLOPs: {flops}, params: {params}")
 
-    print('start pruning...')
+    print(f'start {args.pruner} pruning...')
+
+    def trainer(model, optimizer, criterion, epoch):
+        return train(args, model, device, train_loader, criterion, optimizer, epoch=epoch)
+
+    pruner_cls = str2pruner[args.pruner]
+
+    kw_args = {}
+    config_list = [{
+        'sparsity': args.sparsity,
+        'op_types': ['Conv2d']
+    }]
+
+    if args.pruner == 'level':
+        config_list = [{
+            'sparsity': args.sparsity,
+            'op_types': ['default']
+        }]
+
+    else:
+        if args.dependency_aware:
+            dummy_input = get_dummy_input(args, device)
+            print('Enable the dependency_aware mode')
+            # note that, not all pruners support the dependency_aware mode
+            kw_args['dependency_aware'] = True
+            kw_args['dummy_input'] = dummy_input
+        if args.pruner not in ('l1filter', 'l2filter', 'fpgm'):
+            # set only work for training aware pruners
+            kw_args['trainer'] = trainer
+            kw_args['optimizer'] = optimizer
+            kw_args['criterion'] = criterion
+
+        if args.pruner in ('slim', 'mean_activation', 'apoz', 'taylorfo'):
+            kw_args['sparsity_training_epochs'] = 5
+
+        if args.pruner == 'agp':
+            kw_args['pruning_algorithm'] = 'l1'
+            kw_args['num_iterations'] = 5
+            kw_args['epochs_per_iteration'] = 1
+
+        # Reproduced result in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS',
+        # Conv_1, Conv_8, Conv_9, Conv_10, Conv_11, Conv_12 are pruned with 50% sparsity, as 'VGG-16-pruned-A'
+        if args.pruner == 'slim':
+            config_list = [{
+                'sparsity': args.sparsity,
+                'op_types': ['BatchNorm2d'],
+            }]
+        else:
+            config_list = [{
+                'sparsity': args.sparsity,
+                'op_types': ['Conv2d'],
+                'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
+            }]
+
+    pruner = pruner_cls(model, config_list, **kw_args)
+
+    # Pruner.compress() returns the masked model
+    model = pruner.compress()
+    pruner.get_pruned_weights()
+
+    # export the pruned model masks for model speedup
     model_path = os.path.join(args.experiment_data_dir, 'pruned_{}_{}_{}.pth'.format(
         args.model, args.dataset, args.pruner))
     mask_path = os.path.join(args.experiment_data_dir, 'mask_{}_{}_{}.pth'.format(
         args.model, args.dataset, args.pruner))
-
-    pruner = get_pruner(model, args.pruner, device, optimizer, args.dependency_aware)
-    model = pruner.compress()
-
-    if args.multi_gpu and torch.cuda.device_count() > 1:
-        model = nn.DataParallel(model)
+    pruner.export_model(model_path=model_path, mask_path=mask_path)
 
     if args.test_only:
         test(args, model, device, criterion, test_loader)
 
+    # Unwrap all modules to normal state
+    pruner._unwrap_model() 
+    m_speedup = ModelSpeedup(model, dummy_input, mask_path, device)
+    m_speedup.speedup_model()
+
+    print('start finetuning...')
     best_top1 = 0
+    save_path = os.path.join(args.experiment_data_dir, f'finetuned.pth')
     for epoch in range(args.fine_tune_epochs):
-        pruner.update_epoch(epoch)
         print('# Epoch {} #'.format(epoch))
         train(args, model, device, train_loader, criterion, optimizer, epoch)
         scheduler.step()
         top1 = test(args, model, device, criterion, test_loader)
         if top1 > best_top1:
             best_top1 = top1
-            # Export the best model, 'model_path' stores state_dict of the pruned model,
-            # mask_path stores mask_dict of the pruned model
-            pruner.export_model(model_path=model_path, mask_path=mask_path)
+            torch.save(model.state_dict(), save_path)
+
+    flops, params, results = count_flops_params(model, dummy_input)
+    print(f'Finetuned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {best_top1: .2f}')
 
     if args.nni:
         nni.report_final_result(best_top1)
 
-    if args.speed_up:
-        # reload the best checkpoint for speed-up
-        args.pretrained_model_dir = model_path
-        model, _, _ = get_model_optimizer_scheduler(args, device, train_loader, test_loader, criterion)
-        model.eval()
-
-        apply_compression_results(model, mask_path, device)
-
-        # test model speed
-        start = time.time()
-        for _ in range(32):
-            use_mask_out = model(dummy_input)
-        print('elapsed time when use mask: ', time.time() - start)
-
-        m_speedup = ModelSpeedup(model, dummy_input, mask_path, device)
-        m_speedup.speedup_model()
-
-        flops, params, results = count_flops_params(model, dummy_input)
-        print(f"FLOPs: {flops}, params: {params}")
-
-        start = time.time()
-        for _ in range(32):
-            use_speedup_out = model(dummy_input)
-        print('elapsed time when use speedup: ', time.time() - start)
-
-        top1 = test(args, model, device, criterion, test_loader)
-
 if __name__ == '__main__':
 
     parser = argparse.ArgumentParser(description='PyTorch Example for model comporession')
@@ -352,17 +329,13 @@ def main(args):
                         help='toggle dependency aware mode')
     parser.add_argument('--pruner', type=str, default='l1filter',
                         choices=['level', 'l1filter', 'l2filter', 'slim', 'agp',
-                                 'fpgm', 'mean_activation', 'apoz'],
+                                 'fpgm', 'mean_activation', 'apoz', 'taylorfo'],
                         help='pruner to use')
 
     # fine-tuning
     parser.add_argument('--fine-tune-epochs', type=int, default=160,
                         help='epochs to fine tune')
 
-    # speed-up
-    parser.add_argument('--speed-up', action='store_true', default=False,
-                        help='whether to speed-up the pruned model')
-
     parser.add_argument('--nni', action='store_true', default=False,
                         help="whether to tune the pruners using NNi tuners")
 
diff --git a/examples/model_compress/pruning/finetune_kd_torch.py b/examples/model_compress/pruning/finetune_kd_torch.py
index 10fccd3484..68c96b4ba3 100644
--- a/examples/model_compress/pruning/finetune_kd_torch.py
+++ b/examples/model_compress/pruning/finetune_kd_torch.py
@@ -20,8 +20,11 @@
 from torch.optim.lr_scheduler import MultiStepLR, StepLR
 from torchvision import datasets, transforms
 from basic_pruners_torch import get_data
-from models.cifar10.vgg import VGG
-from models.mnist.lenet import LeNet
+
+import sys
+sys.path.append('../models')
+from cifar10.vgg import VGG
+from mnist.lenet import LeNet
 
 class DistillKL(nn.Module):
     """Distilling the Knowledge in a Neural Network"""
diff --git a/examples/model_compress/pruning/lottery_torch_mnist_fc.py b/examples/model_compress/pruning/lottery_torch_mnist_fc.py
index 7a46c79834..215bc5f5f7 100644
--- a/examples/model_compress/pruning/lottery_torch_mnist_fc.py
+++ b/examples/model_compress/pruning/lottery_torch_mnist_fc.py
@@ -20,7 +20,7 @@ class fc1(nn.Module):
     def __init__(self, num_classes=10):
         super(fc1, self).__init__()
         self.classifier = nn.Sequential(
-            nn.Linear(28*28, 300),
+            nn.Linear(28 * 28, 300),
             nn.ReLU(inplace=True),
             nn.Linear(300, 100),
             nn.ReLU(inplace=True),
diff --git a/examples/model_compress/pruning/model_speedup.py b/examples/model_compress/pruning/model_speedup.py
index 48aff8702c..bec053542a 100644
--- a/examples/model_compress/pruning/model_speedup.py
+++ b/examples/model_compress/pruning/model_speedup.py
@@ -5,8 +5,12 @@
 import torch.nn as nn
 import torch.nn.functional as F
 from torchvision import datasets, transforms
-from models.cifar10.vgg import VGG
-from models.mnist.lenet import LeNet
+
+import sys
+sys.path.append('../models')
+from cifar10.vgg import VGG
+from mnist.lenet import LeNet
+
 from nni.compression.pytorch import apply_compression_results, ModelSpeedup
 
 torch.manual_seed(0)
diff --git a/examples/model_compress/pruning/naive_prune_torch.py b/examples/model_compress/pruning/naive_prune_torch.py
index 5509db9aa5..88ff3df6d9 100644
--- a/examples/model_compress/pruning/naive_prune_torch.py
+++ b/examples/model_compress/pruning/naive_prune_torch.py
@@ -10,15 +10,16 @@
 
 import argparse
 import torch
-import torch.nn as nn
 import torch.nn.functional as F
 import torch.optim as optim
 from torchvision import datasets, transforms
 from torch.optim.lr_scheduler import StepLR
-from models.mnist.lenet import LeNet
+
 from nni.algorithms.compression.pytorch.pruning import LevelPruner
 
-import nni
+import sys
+sys.path.append('../models')
+from mnist.lenet import LeNet
 
 _logger = logging.getLogger('mnist_example')
 _logger.setLevel(logging.INFO)
@@ -108,7 +109,7 @@ def main(args):
         'op_types': ['default'],
     }]
 
-    pruner = LevelPruner(model, prune_config, optimizer_finetune)
+    pruner = LevelPruner(model, prune_config)
     model = pruner.compress()
 
     # fine-tuning
@@ -149,5 +150,4 @@ def main(args):
                         help='target overall target sparsity')
     args = parser.parse_args()
 
-
-    main(args)
\ No newline at end of file
+    main(args)
diff --git a/examples/model_compress/quantization/BNN_quantizer_cifar10.py b/examples/model_compress/quantization/BNN_quantizer_cifar10.py
index 1615a289a4..f6d4c27316 100644
--- a/examples/model_compress/quantization/BNN_quantizer_cifar10.py
+++ b/examples/model_compress/quantization/BNN_quantizer_cifar10.py
@@ -31,7 +31,6 @@ def __init__(self, num_classes=1000):
             nn.BatchNorm2d(256, eps=1e-4, momentum=0.1),
             nn.Hardtanh(inplace=True),
 
-
             nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=False),
             nn.BatchNorm2d(512, eps=1e-4, momentum=0.1),
             nn.Hardtanh(inplace=True),
diff --git a/examples/model_compress/quantization/DoReFaQuantizer_torch_mnist.py b/examples/model_compress/quantization/DoReFaQuantizer_torch_mnist.py
index 18cd059556..10de852570 100644
--- a/examples/model_compress/quantization/DoReFaQuantizer_torch_mnist.py
+++ b/examples/model_compress/quantization/DoReFaQuantizer_torch_mnist.py
@@ -3,27 +3,9 @@
 from torchvision import datasets, transforms
 from nni.algorithms.compression.pytorch.quantization import DoReFaQuantizer
 
-
-class Mnist(torch.nn.Module):
-    def __init__(self):
-        super().__init__()
-        self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
-        self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
-        self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
-        self.fc2 = torch.nn.Linear(500, 10)
-        self.relu1 = torch.nn.ReLU6()
-        self.relu2 = torch.nn.ReLU6()
-        self.relu3 = torch.nn.ReLU6()
-
-    def forward(self, x):
-        x = self.relu1(self.conv1(x))
-        x = F.max_pool2d(x, 2, 2)
-        x = self.relu2(self.conv2(x))
-        x = F.max_pool2d(x, 2, 2)
-        x = x.view(-1, 4 * 4 * 50)
-        x = self.relu3(self.fc1(x))
-        x = self.fc2(x)
-        return F.log_softmax(x, dim=1)
+import sys
+sys.path.append('../models')
+from mnist.naive import NaiveModel
 
 
 def train(model, quantizer, device, train_loader, optimizer):
@@ -66,7 +48,7 @@ def main():
         datasets.MNIST('data', train=False, transform=trans),
         batch_size=1000, shuffle=True)
 
-    model = Mnist()
+    model = NaiveModel()
     model = model.to(device)
     configure_list = [{
         'quant_types': ['weight'],
diff --git a/examples/model_compress/quantization/QAT_torch_quantizer.py b/examples/model_compress/quantization/QAT_torch_quantizer.py
index ef14ff5ce0..4ccbe34eb0 100644
--- a/examples/model_compress/quantization/QAT_torch_quantizer.py
+++ b/examples/model_compress/quantization/QAT_torch_quantizer.py
@@ -3,28 +3,9 @@
 from torchvision import datasets, transforms
 from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer
 
-
-class Mnist(torch.nn.Module):
-    def __init__(self):
-        super().__init__()
-        self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
-        self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
-        self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
-        self.fc2 = torch.nn.Linear(500, 10)
-        self.relu1 = torch.nn.ReLU6()
-        self.relu2 = torch.nn.ReLU6()
-        self.relu3 = torch.nn.ReLU6()
-
-    def forward(self, x):
-        x = self.relu1(self.conv1(x))
-        x = F.max_pool2d(x, 2, 2)
-        x = self.relu2(self.conv2(x))
-        x = F.max_pool2d(x, 2, 2)
-        x = x.view(-1, 4 * 4 * 50)
-        x = self.relu3(self.fc1(x))
-        x = self.fc2(x)
-        return F.log_softmax(x, dim=1)
-
+import sys
+sys.path.append('../models')
+from mnist.naive import NaiveModel
 
 def train(model, quantizer, device, train_loader, optimizer):
     model.train()
@@ -66,7 +47,7 @@ def main():
         datasets.MNIST('data', train=False, transform=trans),
         batch_size=1000, shuffle=True)
 
-    model = Mnist()
+    model = NaiveModel()
     '''you can change this to DoReFaQuantizer to implement it
     DoReFaQuantizer(configure_list).compress(model)
     '''
diff --git a/examples/model_compress/quantization/mixed_precision_speedup_mnist.py b/examples/model_compress/quantization/mixed_precision_speedup_mnist.py
index bdcdcb7f5f..687fec6a1f 100644
--- a/examples/model_compress/quantization/mixed_precision_speedup_mnist.py
+++ b/examples/model_compress/quantization/mixed_precision_speedup_mnist.py
@@ -5,28 +5,9 @@
 from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer
 from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT
 
-class Mnist(torch.nn.Module):
-    def __init__(self):
-        super().__init__()
-        self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
-        self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
-        self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
-        self.fc2 = torch.nn.Linear(500, 10)
-        self.relu1 = torch.nn.ReLU6()
-        self.relu2 = torch.nn.ReLU6()
-        self.relu3 = torch.nn.ReLU6()
-        self.max_pool1 = torch.nn.MaxPool2d(2, 2)
-        self.max_pool2 = torch.nn.MaxPool2d(2, 2)
-
-    def forward(self, x):
-        x = self.relu1(self.conv1(x))
-        x = self.max_pool1(x)
-        x = self.relu2(self.conv2(x))
-        x = self.max_pool2(x)
-        x = x.view(-1, 4 * 4 * 50)
-        x = self.relu3(self.fc1(x))
-        x = self.fc2(x)
-        return F.log_softmax(x, dim=1)
+import sys
+sys.path.append('../models')
+from mnist.naive import NaiveModel
 
 
 def train(model, device, train_loader, optimizer):
@@ -74,7 +55,7 @@ def test_trt(engine, test_loader):
     print("Inference elapsed_time (whole dataset): {}s".format(time_elasped))
 
 def post_training_quantization_example(train_loader, test_loader, device):
-    model = Mnist()
+    model = NaiveModel()
 
     config = {
         'conv1':{'weight_bit':8, 'activation_bit':8},
@@ -99,7 +80,7 @@ def post_training_quantization_example(train_loader, test_loader, device):
     test_trt(engine, test_loader)
 
 def quantization_aware_training_example(train_loader, test_loader, device):
-    model = Mnist()
+    model = NaiveModel()
 
     configure_list = [{
             'quant_types': ['weight', 'output'],
diff --git a/examples/trials/benchmarking/automlbenchmark/.gitignore b/examples/trials/benchmarking/automlbenchmark/.gitignore
new file mode 100644
index 0000000000..5184f9196f
--- /dev/null
+++ b/examples/trials/benchmarking/automlbenchmark/.gitignore
@@ -0,0 +1,13 @@
+# data files 
+nni/data/
+
+# benchmark repository 
+automlbenchmark/
+
+# all experiment results
+results*
+
+# intermediate outputs of tuners
+smac3-output*
+param_config_space.pcs
+scenario.txt
\ No newline at end of file
diff --git a/examples/trials/benchmarking/automlbenchmark/nni/benchmarks/nnismall.yaml b/examples/trials/benchmarking/automlbenchmark/nni/benchmarks/nnismall.yaml
new file mode 100644
index 0000000000..5b68dc898d
--- /dev/null
+++ b/examples/trials/benchmarking/automlbenchmark/nni/benchmarks/nnismall.yaml
@@ -0,0 +1,77 @@
+---
+- name: __defaults__
+  folds: 2
+  cores: 2
+  max_runtime_seconds: 300
+  
+- name: cholesterol
+  openml_task_id: 2295
+
+- name: liver-disorders
+  openml_task_id: 52948
+
+- name: kin8nm
+  openml_task_id: 2280
+
+- name: cpu_small
+  openml_task_id: 4883
+  
+- name: titanic_2
+  openml_task_id: 211993
+ 
+- name: boston
+  openml_task_id: 4857
+ 
+- name: stock
+  openml_task_id: 2311
+ 
+- name: space_ga
+  openml_task_id: 4835
+ 
+- name: Australian
+  openml_task_id: 146818
+
+- name: blood-transfusion
+  openml_task_id: 10101
+
+- name: car
+  openml_task_id: 146821
+
+- name: christine
+  openml_task_id: 168908
+
+- name: cnae-9
+  openml_task_id: 9981
+
+- name: credit-g
+  openml_task_id: 31
+
+- name: dilbert
+  openml_task_id: 168909
+
+- name: fabert
+  openml_task_id: 168910
+
+- name: jasmine
+  openml_task_id: 168911
+
+- name: kc1
+  openml_task_id: 3917
+
+- name: kr-vs-kp
+  openml_task_id: 3
+
+- name: mfeat-factors
+  openml_task_id: 12
+
+- name: phoneme
+  openml_task_id: 9952
+
+- name: segment
+  openml_task_id: 146822
+
+- name: sylvine
+  openml_task_id: 168912
+
+- name: vehicle
+  openml_task_id: 53
diff --git a/examples/trials/benchmarking/automlbenchmark/nni/benchmarks/nnismall_description.txt b/examples/trials/benchmarking/automlbenchmark/nni/benchmarks/nnismall_description.txt
new file mode 100644
index 0000000000..d8d0aa0ec4
--- /dev/null
+++ b/examples/trials/benchmarking/automlbenchmark/nni/benchmarks/nnismall_description.txt
@@ -0,0 +1,152 @@
+nnismall:
+This benchmark contains 24 tasks: 8 tasks each for binary classfication, multi-class classification, and regression. 
+
+Binary Classification: 
+- name: Australian
+  openml_task_id: 146818
+  Introduction: Australian Credit Approval dataset, originating from the StatLog project. It concerns credit card applications. 
+  Features: 6 numerical and 8 categorical features, all normalized to [-1,1].
+  Number of instances: 690
+
+- name: blood-transfusion
+  openml_task_id: 10101
+  Introduction: Data taken from the Blood Transfusion Service Center in Hsin-Chu City in Taiwan. The target attribute is a binary variable representing whether he/she donated blood in March 2007 (2 stands for donating blood; 1 stands for not donating blood).
+  Features: 4 numerical features.
+  Number of instances: 748
+
+- name: christine
+  openml_task_id: 168908
+  Introduction: An Openml challenge dataset on classification. The identity of the datasets and the type of data is concealed. 
+  Features: 1599 numerical features and 38 categorical features
+  Number of instances: 5418
+
+- name: credit-g
+  openml_task_id: 31
+  Introduction: This dataset classifies people described by a set of attributes as good or bad credit risks.
+  Features: 7 numerical features and 13 categorical features
+  Number of instances: 1000
+
+- name: kc1
+  openml_task_id: 3917
+  Introduction: One of the NASA Metrics Data Program defect data sets. Data from software for storage management for receiving and processing ground data.
+  Features: 21 numerical features
+  Number of instances: 2109
+
+- name: kr-vs-kp
+  openml_task_id: 3
+  Introduction: Given a board configuration, predict whether white can win or not. 
+  Features: 37 categorical features
+  Number of instances: 3196
+
+- name: phoneme
+  openml_task_id: 9952
+  Introduction: The aim of this dataset is to distinguish between nasal (class 0) and oral sounds (class 1). 
+  Features: 5 numerical features
+  Number of instances: 5404
+
+- name: sylvine
+  openml_task_id: 168912
+  Introduction: An Openml challenge dataset on classification. The identity of the datasets and the type of data is concealed.
+  Features: 20 numerical features
+  Number of instances: 5124
+
+
+
+Multi-class Classification
+- name: car
+  openml_task_id: 146821
+  Introduction: The model evaluates cars using six intermediate concepts. 
+  Features: 6 categorical features
+  Number of instances: 1728
+
+- name: cnae-9
+  openml_task_id: 9981
+  Introduction: This is a data set containing 1080 documents of free text business descriptions of Brazilian companies categorized into a subset of 9 categories.
+  Features: 856 numerical features (word frequency)
+  Number of instances: 1080
+
+- name: dilbert
+  openml_task_id: 168909
+  Introduction: An Openml challenge dataset on classification. The identity of the datasets and the type of data is concealed. 
+  Features: 2000 numerical features
+  Number of instances: 10000
+
+- name: fabert
+  openml_task_id: 168910
+  Introduction: An Openml challenge dataset on classification. The identity of the datasets and the type of data is concealed. 
+  Features: 800 numerical features
+  Number of instances: 8237
+
+- name: jasmine
+  openml_task_id: 168911
+  Introduction: An Openml challenge dataset on classification. The identity of the datasets and the type of data is concealed.
+  Features: 8 numerical features and 137 categorical features 
+  Number of instances: 2984
+
+- name: mfeat-factors
+  openml_task_id: 12
+  Introduction: Hand-written numeral classification. 
+  Features: 216 numerical features(corresponding to binarized image) 
+  Number of instances: 2000
+
+- name: segment
+  openml_task_id: 146822
+  Introduction: segmentation of outdoor images into 7 classes
+  Features: 19 numerical features
+  Number of instances: 2310  (3x3 patches from 7 images)
+
+- name: vehicle
+  openml_task_id: 53
+  Introduction: Classify a given silhouette as one of four types of vehicle, using a set of features extracted from the silhouette. The vehicle may be viewed from one of many different angles.
+  Features: 18 numerical features
+  Number of instances: 846
+
+
+Regression
+- name: cholesterol
+  openml_task_id: 2295
+  Introduction: Predict the cholesterol level of patients. 
+  Features: 6 numerical features and 7 categorical features 
+  Number of instances: 303
+
+- name: liver-disorders
+  openml_task_id: 52948
+  Introduction: Predict alcohol assumption based on blood test results. 
+  Features: 5 numerical features
+  Number of instances: 345
+
+- name: kin8nm
+  openml_task_id: 2280
+  Introduction: This dataset is concerned with the forward kinematics of an 8 link robot arm.
+  Features: 8 numerical features
+  Number of instances: 8192
+
+- name: cpu_small
+  openml_task_id: 4883
+  Introduction: Predict the portion of time that cpus run in user mode.
+  Features: 12 numerical features
+  Number of instances: 8192
+
+- name: titanic_2
+  openml_task_id: 211993
+  Introduction: Predict probability of survival
+  Features: 7 numerical features
+  Number of instances: 891
+
+- name: boston
+  openml_task_id: 4857
+  Introduction: Boston house price. 
+  Features: 11 numerical features and 2 categorical features 
+  Number of instances: 506
+
+- name: stock
+  openml_task_id: 2311
+  Introduction: This is a dataset obtained from the StatLib repository. The data provided are daily stock prices from January 1988 through October 1991, for ten aerospace companies.
+  Features: 11 numerical features
+  Number of instances: 950
+
+- name: space_ga
+  openml_task_id: 4835
+  Introduction: Predict the log of the proportion of votes cast for both candidates in the 1980 presidential election.
+  Features: 6 numerical attributes
+  Number of instances: 3107
diff --git a/examples/trials/benchmarking/automlbenchmark/nni/benchmarks/nnivalid.yaml b/examples/trials/benchmarking/automlbenchmark/nni/benchmarks/nnivalid.yaml
new file mode 100644
index 0000000000..ffc4d3d8ee
--- /dev/null
+++ b/examples/trials/benchmarking/automlbenchmark/nni/benchmarks/nnivalid.yaml
@@ -0,0 +1,43 @@
+---
+#for doc purpose using <placeholder:default_value> syntax when it applies.
+
+#FORMAT: global defaults are defined in config.yaml
+- name: __dummy-task
+  enabled: false  # actual default is `true` of course...
+  openml_task_id: 0
+  metric: # the first metric in the task list will be optimized against and used for the main result, the other ones are optional and purely informative. Only the metrics annotated with (*) can be used as a performance metric.
+    - # classification
+    - acc # (*) accuracy
+    - auc # (*) array under curve
+    - logloss # (*) log loss
+    - f1  # F1 score
+    - # regression
+    - mae  # (*) mean absolute error
+    - mse # (*) mean squared error
+    - rmse  # root mean squared error
+    - rmsle  # root mean squared log error
+    - r2  # R^2 score
+  folds: 1
+  max_runtime_seconds: 600
+  cores: 1
+  max_mem_size_mb: -1
+  ec2_instance_type: m5.large
+
+
+# local defaults (applying only to tasks defined in this file) can be defined in a task named "__defaults__"
+- name: __defaults__
+  folds: 2
+  cores: 2
+  max_runtime_seconds: 180
+
+- name: kc2
+  openml_task_id: 3913
+  description: "binary test dataset"
+
+- name: iris
+  openml_task_id: 59
+  description: "multiclass test dataset"
+
+- name: cholesterol
+  openml_task_id: 2295
+  description: "regression test dataset"
diff --git a/examples/trials/benchmarking/automlbenchmark/nni/config.yaml b/examples/trials/benchmarking/automlbenchmark/nni/config.yaml
new file mode 100644
index 0000000000..b4eb5360d5
--- /dev/null
+++ b/examples/trials/benchmarking/automlbenchmark/nni/config.yaml
@@ -0,0 +1,16 @@
+---
+input_dir: '{user}/data'
+
+frameworks:
+  definition_file:
+    - '{root}/resources/frameworks.yaml'
+    - '{user}/frameworks.yaml'
+
+benchmarks:
+  definition_dir:
+    - '{user}/benchmarks'
+    - '{root}/resources/benchmarks'
+    
+  constraints_file:
+    - '{user}/constraints.yaml'
+    - '{root}/resources/constraints.yaml'
diff --git a/examples/trials/benchmarking/automlbenchmark/nni/extensions/NNI/.marker_setup_safe_to_delete b/examples/trials/benchmarking/automlbenchmark/nni/extensions/NNI/.marker_setup_safe_to_delete
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/examples/trials/benchmarking/automlbenchmark/nni/extensions/NNI/__init__.py b/examples/trials/benchmarking/automlbenchmark/nni/extensions/NNI/__init__.py
new file mode 100644
index 0000000000..5f5a2aa16a
--- /dev/null
+++ b/examples/trials/benchmarking/automlbenchmark/nni/extensions/NNI/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+def run(*args, **kwargs):
+    from .exec import run
+    return run(*args, **kwargs)
diff --git a/examples/trials/benchmarking/automlbenchmark/nni/extensions/NNI/architectures/run_random_forest.py b/examples/trials/benchmarking/automlbenchmark/nni/extensions/NNI/architectures/run_random_forest.py
new file mode 100644
index 0000000000..35044665be
--- /dev/null
+++ b/examples/trials/benchmarking/automlbenchmark/nni/extensions/NNI/architectures/run_random_forest.py
@@ -0,0 +1,163 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import logging
+import sklearn
+import time
+import numpy as np
+
+from sklearn.impute import SimpleImputer
+from sklearn.compose import ColumnTransformer
+from sklearn.preprocessing import OrdinalEncoder
+from sklearn.pipeline import Pipeline
+from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
+from sklearn.model_selection import cross_val_score
+
+from amlb.benchmark import TaskConfig
+from amlb.data import Dataset
+from amlb.datautils import impute
+from amlb.utils import Timer
+from amlb.results import save_predictions_to_file
+
+
+SEARCH_SPACE = {
+    "n_estimators": {"_type":"randint", "_value": [8, 512]},
+    "max_depth": {"_type":"choice", "_value": [4, 8, 16, 32, 64, 128, 256, 0]},   # 0 for None
+    "min_samples_leaf": {"_type":"randint", "_value": [1, 8]},
+    "min_samples_split": {"_type":"randint", "_value": [2, 16]},
+    "max_leaf_nodes": {"_type":"randint", "_value": [0, 4096]}                    # 0 for None
+}
+
+SEARCH_SPACE_CHOICE = {
+    "n_estimators": {"_type":"choice", "_value": [8, 16, 32, 64, 128, 256, 512]},
+    "max_depth": {"_type":"choice", "_value": [4, 8, 16, 32, 64, 128, 0]},   # 0 for None
+    "min_samples_leaf": {"_type":"choice", "_value": [1, 2, 4, 8]},
+    "min_samples_split": {"_type":"choice", "_value": [2, 4, 8, 16]},
+    "max_leaf_nodes": {"_type":"choice", "_value": [8, 32, 128, 512, 0]}     # 0 for None
+}
+
+SEARCH_SPACE_SIMPLE = {
+    "n_estimators": {"_type":"choice", "_value": [10]},
+    "max_depth": {"_type":"choice", "_value": [5]},
+    "min_samples_leaf": {"_type":"choice", "_value": [8]},
+    "min_samples_split": {"_type":"choice", "_value": [16]},
+    "max_leaf_nodes": {"_type":"choice", "_value": [64]}
+}
+
+
+def preprocess_random_forest(dataset, log):
+    '''
+    For random forest:
+    - Do nothing for numerical features except null imputation. 
+    - For categorical features, use ordinal encoding to map them into integers. 
+    '''
+    cat_columns, num_columns = [], []
+    shift_amount = 0
+    for i, f in enumerate(dataset.features):
+        if f.is_target:
+            shift_amount += 1
+            continue
+        elif f.is_categorical():
+            cat_columns.append(i - shift_amount)
+        else:
+            num_columns.append(i - shift_amount)
+
+    cat_pipeline = Pipeline([('imputer', SimpleImputer(strategy='most_frequent')),
+                             ('ordinal_encoder', OrdinalEncoder()),
+                             ])
+    
+    num_pipeline = Pipeline([('imputer', SimpleImputer(strategy='mean')),
+                             ])
+    
+    data_pipeline = ColumnTransformer([
+        ('categorical', cat_pipeline, cat_columns),
+        ('numerical', num_pipeline, num_columns),
+    ])
+
+    data_pipeline.fit(np.concatenate([dataset.train.X, dataset.test.X], axis=0))
+    
+    X_train = data_pipeline.transform(dataset.train.X)
+    X_test = data_pipeline.transform(dataset.test.X)  
+    
+    return X_train, X_test
+
+    
+def run_random_forest(dataset, config, tuner, log):
+    """
+    Using the given tuner, tune a random forest within the given time constraint.
+    This function uses cross validation score as the feedback score to the tuner. 
+    The search space on which tuners search on is defined above empirically as a global variable.
+    """
+    
+    limit_type, trial_limit = config.framework_params['limit_type'], None
+    if limit_type == 'ntrials':
+        trial_limit = int(config.framework_params['trial_limit'])
+    
+    X_train, X_test = preprocess_random_forest(dataset, log)
+    y_train, y_test = dataset.train.y, dataset.test.y
+
+    is_classification = config.type == 'classification'
+    estimator = RandomForestClassifier if is_classification else RandomForestRegressor
+
+    best_score, best_params, best_model = None, None, None
+    score_higher_better = True
+
+    tuner.update_search_space(SEARCH_SPACE)    
+    
+    start_time = time.time()
+    trial_count = 0
+    intermediate_scores = []
+    intermediate_best_scores = []           # should be monotonically increasing 
+    
+    while True:
+        try:
+            trial_count += 1
+            param_idx, cur_params = tuner.generate_parameters()
+            train_params = cur_params.copy()
+            if 'TRIAL_BUDGET' in cur_params:
+                train_params.pop('TRIAL_BUDGET')
+            if cur_params['max_leaf_nodes'] == 0: 
+                train_params.pop('max_leaf_nodes')
+            if cur_params['max_depth'] == 0:
+                train_params.pop('max_depth')
+            log.info("Trial {}: \n{}\n".format(param_idx, cur_params))
+                
+            cur_model = estimator(random_state=config.seed, **train_params)
+            
+            # Here score is the output of score() from the estimator
+            cur_score = cross_val_score(cur_model, X_train, y_train)
+            cur_score = sum(cur_score) / float(len(cur_score))
+            if np.isnan(cur_score):
+                cur_score = 0
+            
+            log.info("Score: {}\n".format(cur_score))
+            if best_score is None or (score_higher_better and cur_score > best_score) or (not score_higher_better and cur_score < best_score):
+                best_score, best_params, best_model = cur_score, cur_params, cur_model    
+            
+            intermediate_scores.append(cur_score)
+            intermediate_best_scores.append(best_score)
+            tuner.receive_trial_result(param_idx, cur_params, cur_score)
+
+            if limit_type == 'time':
+                current_time = time.time()
+                elapsed_time = current_time - start_time
+                if elapsed_time >= config.max_runtime_seconds:
+                    break
+            elif limit_type == 'ntrials':
+                if trial_count >= trial_limit:
+                    break
+        except:
+            break
+
+    # This line is required to fully terminate some advisors
+    tuner.handle_terminate()
+        
+    log.info("Tuning done, the best parameters are:\n{}\n".format(best_params))
+
+    # retrain on the whole dataset 
+    with Timer() as training:
+        best_model.fit(X_train, y_train)     
+    predictions = best_model.predict(X_test)
+    probabilities = best_model.predict_proba(X_test) if is_classification else None
+
+    return probabilities, predictions, training, y_test, intermediate_scores, intermediate_best_scores
diff --git a/examples/trials/benchmarking/automlbenchmark/nni/extensions/NNI/exec.py b/examples/trials/benchmarking/automlbenchmark/nni/extensions/NNI/exec.py
new file mode 100644
index 0000000000..a79e0d3b5c
--- /dev/null
+++ b/examples/trials/benchmarking/automlbenchmark/nni/extensions/NNI/exec.py
@@ -0,0 +1,71 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import logging
+
+from .tuners import NNITuner
+from .run_experiment import *
+
+from amlb.benchmark import TaskConfig
+from amlb.data import Dataset
+from amlb.results import save_predictions_to_file
+from amlb.utils import Timer
+
+
+log = logging.getLogger(__name__)
+
+
+def validate_config(config: TaskConfig):
+    if 'tuner_type' not in config.framework_params:
+        raise RuntimeError('framework.yaml does not have a "tuner_type" field.')
+    if 'limit_type' not in config.framework_params:
+        raise RuntimeError('framework.yaml does not have a "limit_type" field.')
+    if config.framework_params['limit_type'] not in ['time', 'ntrials']:
+        raise RuntimeError('"limit_type" field must be "time" or "ntrials".')
+    if config.framework_params['limit_type'] == 'ntrials':
+        if 'trial_limit' not in config.framework_params:
+            raise RuntimeError('framework.yaml does not have a "limit" field.')
+        else:
+            try:
+                _ = int(config.framework_params['trial_limit'])
+            except:
+                raise RuntimeError('"trial_limit" field must be an integer.')  
+
+
+def save_scores_to_file(intermediate_scores, intermediate_best_scores, out_file):
+    """
+    Save statistics of every trial to a log file for generating reports. 
+    """
+    with open(out_file, 'w') as f:
+        f.write('ntrials,trial_score,best_score\n')
+        for i, (trial_score, best_score) in enumerate(zip(intermediate_scores, intermediate_best_scores)):
+            f.write('{},{},{}\n'.format(i+1, trial_score, best_score))
+            
+    
+def run(dataset: Dataset, config: TaskConfig):
+    validate_config(config)
+    tuner = NNITuner(config)
+    if config.framework_params['limit_type']  == 'time':
+        log.info("Tuning {} with NNI {} with a maximum time of {}s\n"
+                 .format(config.framework_params['arch_type'], tuner.description, config.max_runtime_seconds))
+    elif config.framework_params['limit_type'] == 'ntrials':
+        log.info("Tuning {} with NNI {} with a maximum number of trials of {}\n"
+                 .format(config.framework_params['arch_type'], tuner.description, config.framework_params['trial_limit']))
+        log.info("Note: any time constraints are ignored.")
+
+    probabilities, predictions, train_timer, y_test, intermediate_scores, intermediate_best_scores = run_experiment(dataset, config, tuner, log)
+
+    save_predictions_to_file(dataset=dataset,
+                             output_file=config.output_predictions_file,
+                             probabilities=probabilities,
+                             predictions=predictions,
+                             truth=y_test)
+
+    scores_file = '/'.join(config.output_predictions_file.split('/')[:-3]) + '/scorelogs/' + config.output_predictions_file.split('/')[-1]
+    assert(len(intermediate_scores) == len(intermediate_best_scores))
+    save_scores_to_file(intermediate_scores, intermediate_best_scores, scores_file)
+
+    return dict(
+        models_count=1,
+        training_duration=train_timer.duration
+    )
diff --git a/examples/trials/benchmarking/automlbenchmark/nni/extensions/NNI/run_experiment.py b/examples/trials/benchmarking/automlbenchmark/nni/extensions/NNI/run_experiment.py
new file mode 100644
index 0000000000..ae85c909e3
--- /dev/null
+++ b/examples/trials/benchmarking/automlbenchmark/nni/extensions/NNI/run_experiment.py
@@ -0,0 +1,15 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+from .architectures.run_random_forest import *
+
+
+def run_experiment(dataset, config, tuner, log):
+    if 'arch_type' not in config.framework_params:
+        raise RuntimeError('framework.yaml does not have a "arch_type" field.')
+    
+    if config.framework_params['arch_type'] == 'random_forest':
+        return run_random_forest(dataset, config, tuner, log)
+
+    else:
+        raise RuntimeError('The requested arch type in framework.yaml is unavailable.') 
diff --git a/examples/trials/benchmarking/automlbenchmark/nni/extensions/NNI/tuners.py b/examples/trials/benchmarking/automlbenchmark/nni/extensions/NNI/tuners.py
new file mode 100644
index 0000000000..4016aae6e5
--- /dev/null
+++ b/examples/trials/benchmarking/automlbenchmark/nni/extensions/NNI/tuners.py
@@ -0,0 +1,156 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import os
+import yaml
+import importlib
+
+import nni
+from nni.runtime.config import get_config_file 
+from nni.utils import MetricType 
+from nni.tuner import Tuner
+from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
+
+from amlb.benchmark import TaskConfig
+
+
+def get_tuner_class_dict():
+    config_file = str(get_config_file('registered_algorithms.yml'))
+    if os.path.exists(config_file):
+        with open(config_file, 'r') as f:
+            config = yaml.load(f, Loader=yaml.SafeLoader)
+    else:
+        config = {}
+    ret = {}
+    for t in ['tuners', 'advisors']:
+        for entry in config[t]:
+            ret[entry['builtinName']] = entry['className']
+    return ret
+
+
+def get_tuner(config: TaskConfig):
+    name2tuner = get_tuner_class_dict()
+    if config.framework_params['tuner_type'] not in name2tuner:
+        raise RuntimeError('The requested tuner type is unavailable.')
+    else:
+        module_name = name2tuner[config.framework_params['tuner_type']]
+        tuner_name = module_name.split('.')[-1]
+        module_name = '.'.join(module_name.split('.')[:-1])
+        tuner_type = getattr(importlib.import_module(module_name), tuner_name)
+
+        # special handlings for tuner initialization
+        tuner = None
+        if config.framework_params['tuner_type'] == 'TPE':
+            tuner = tuner_type('tpe')
+
+        elif config.framework_params['tuner_type'] == 'Random':
+            tuner = tuner_type('random_search')
+
+        elif config.framework_params['tuner_type'] == 'Anneal':
+            tuner = tuner_type('anneal')
+
+        elif config.framework_params['tuner_type'] == 'Hyperband':
+            if 'max_resource' in config.framework_params:
+                tuner = tuner_type(R=config.framework_params['max_resource'])
+            else:
+                tuner = tuner_type()
+
+        elif config.framework_params['tuner_type'] == 'BOHB':
+            if 'max_resource' in config.framework_params:
+                tuner = tuner_type(max_budget=config.framework_params['max_resource'])
+            else:
+                tuner = tuner_type(max_budget=60)
+
+        else:
+            tuner = tuner_type()
+
+        assert(tuner is not None)
+
+        return tuner, config.framework_params['tuner_type']
+
+    
+class NNITuner:
+    '''
+    A specialized wrapper for the automlbenchmark framework.
+    Abstracts the different behaviors of tuners and advisors into a tuner API. 
+    '''
+    def __init__(self, config: TaskConfig):
+        self.config = config
+        self.core, self.description = get_tuner(config)
+
+        # 'tuner' or 'advisor'
+        self.core_type = None      
+        if isinstance(self.core, Tuner):
+            self.core_type = 'tuner'
+        elif isinstance(self.core, MsgDispatcherBase):
+            self.core_type = 'advisor'
+        else:
+            raise RuntimeError('Unsupported tuner or advisor type') 
+
+        # note: tuners and advisors use this variable differently
+        self.cur_param_id = 0
+
+        
+    def __del__(self):
+        self.handle_terminate()
+
+        
+    def update_search_space(self, search_space):
+        if self.core_type == 'tuner':
+            self.core.update_search_space(search_space)
+            
+        elif self.core_type == 'advisor':
+            self.core.handle_update_search_space(search_space)
+            # special initializations for BOHB Advisor
+            from nni.algorithms.hpo.hyperband_advisor import Hyperband
+            if isinstance(self.core, Hyperband):
+                pass
+            else:
+                from nni.algorithms.hpo.bohb_advisor.bohb_advisor import BOHB
+                from nni.algorithms.hpo.bohb_advisor.config_generator import CG_BOHB   
+                if isinstance(self.core, BOHB):
+                    self.core.cg = CG_BOHB(configspace=self.core.search_space,
+                                           min_points_in_model=self.core.min_points_in_model,
+                                           top_n_percent=self.core.top_n_percent,
+                                           num_samples=self.core.num_samples,
+                                           random_fraction=self.core.random_fraction,
+                                           bandwidth_factor=self.core.bandwidth_factor,
+                                           min_bandwidth=self.core.min_bandwidth)
+                    self.core.generate_new_bracket()
+                
+        
+    def generate_parameters(self):
+        self.cur_param_id += 1
+        if self.core_type == 'tuner':
+            self.cur_param = self.core.generate_parameters(self.cur_param_id-1)
+            return self.cur_param_id-1, self.cur_param
+            
+        elif self.core_type == 'advisor':
+            self.cur_param = self.core._get_one_trial_job()
+            hyperparams = self.cur_param['parameters'].copy()
+            #hyperparams.pop('TRIAL_BUDGET')
+            return self.cur_param['parameter_id'], hyperparams
+
+        
+    def receive_trial_result(self, parameter_id, parameters, value):
+        if self.core_type == 'tuner':
+            return self.core.receive_trial_result(parameter_id, parameters, value)
+
+        elif self.core_type == 'advisor':
+            metric_report = {}
+            metric_report['parameter_id'] = parameter_id
+            metric_report['trial_job_id'] = self.cur_param_id
+            metric_report['type'] = MetricType.FINAL
+            metric_report['value'] = str(value)
+            metric_report['sequence'] = self.cur_param_id
+            return self.core.handle_report_metric_data(metric_report)   
+
+        
+    def handle_terminate(self):
+        if self.core_type == 'tuner':
+            pass
+        
+        elif self.core_type == 'advisor':   
+            self.core.stopping = True 
+
+    
diff --git a/examples/trials/benchmarking/automlbenchmark/nni/frameworks.yaml b/examples/trials/benchmarking/automlbenchmark/nni/frameworks.yaml
new file mode 100644
index 0000000000..5fa057ac0b
--- /dev/null
+++ b/examples/trials/benchmarking/automlbenchmark/nni/frameworks.yaml
@@ -0,0 +1,85 @@
+---
+
+NNI:
+  module: extensions.NNI
+  version: 'stable'
+  project: https://github.com/microsoft/nni
+
+# type in ['TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'GPTuner', 'MetisTuner', 'Hyperband', 'BOHB']
+# arch_type in ['random_forest']
+# limit_type in ['time', 'ntrials']
+# limit must be an integer
+
+TPE:
+  extends: NNI
+  params:
+    tuner_type: 'TPE'
+    arch_type: 'random_forest'
+    limit_type: 'ntrials'
+    trial_limit: 10
+
+Random:
+  extends: NNI 
+  params:
+    tuner_type: 'Random'
+    arch_type: 'random_forest'
+    limit_type: 'ntrials'
+    trial_limit: 10
+
+Anneal:
+  extends: NNI 
+  params:
+    tuner_type: 'Anneal'
+    arch_type: 'random_forest'
+    limit_type: 'ntrials'
+    trial_limit: 10
+
+Evolution:
+  extends: NNI 
+  params:
+    tuner_type: 'Evolution'
+    arch_type: 'random_forest'
+    limit_type: 'ntrials'
+    trial_limit: 10
+
+SMAC:
+  extends: NNI 
+  params:
+    tuner_type: 'SMAC'
+    arch_type: 'random_forest'
+    limit_type: 'ntrials'
+    trial_limit: 10
+
+GPTuner:
+  extends: NNI 
+  params:
+    tuner_type: 'GPTuner'
+    arch_type: 'random_forest'
+    limit_type: 'ntrials'
+    trial_limit: 10
+
+MetisTuner:
+  extends: NNI 
+  params:
+    tuner_type: 'MetisTuner'
+    arch_type: 'random_forest'
+    limit_type: 'ntrials'
+    trial_limit: 10
+
+Hyperband:
+  extends: NNI 
+  params:
+    tuner_type: 'Hyperband'
+    arch_type: 'random_forest'
+    max_resource: 60
+    limit_type: 'ntrials'
+    trial_limit: 10
+
+BOHB:
+  extends: NNI 
+  params:
+    tuner_type: 'BOHB'
+    arch_type: 'random_forest'
+    max_resource: 60
+    limit_type: 'ntrials'
+    trial_limit: 10
diff --git a/examples/trials/benchmarking/automlbenchmark/parse_result_csv.py b/examples/trials/benchmarking/automlbenchmark/parse_result_csv.py
new file mode 100644
index 0000000000..b10827da36
--- /dev/null
+++ b/examples/trials/benchmarking/automlbenchmark/parse_result_csv.py
@@ -0,0 +1,166 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import pandas as pd
+import sys
+import matplotlib.pyplot as plt
+from matplotlib.lines import Line2D
+
+
+def generate_perf_report(result_file_name):
+    """
+    Generate a performance report. 
+    The input result_file_name should be the path of the "results.csv" generated by automlbenchmark.
+    This function outputs 1) a formatted report named "performances.txt" in the "reports/" directory 
+    located in the same parent directory as "results.csv" and 2) a report named "rankings.txt" in the
+    same directory ranking the tuners contained in "results.csv". 
+    """
+    result = pd.read_csv(result_file_name)
+    task_ids = result['id'].unique()
+    tuners = result['framework'].unique()
+    metric_types = ['rmse', 'auc', 'logloss']
+    metric2taskres = {}
+    for m in metric_types:
+        metric2taskres[m] = []
+    keep_parameters = ['framework', 'constraint', 'result', 'metric', 'params', 'utc', 'duration'] + list(result.columns[16:])    
+
+    # performance report: one table per task
+    with open(result_file_name.replace('results.csv', 'reports/performances.txt'), 'w') as out_f:
+        for task_id in task_ids:
+            task_results = result[result['id'] == task_id]
+            task_name = task_results.task.unique()[0]
+            out_f.write("====================================================\n")
+            out_f.write("Task ID: {}\n".format(task_id))
+            out_f.write("Task Name: {}\n".format(task_name))
+            folds = task_results['fold'].unique()
+            for fold in folds:
+                out_f.write("Fold {}:\n".format(fold))
+                res = task_results[task_results['fold'] == fold][keep_parameters]
+                out_f.write(res.to_string())
+                out_f.write('\n')
+                # save results for the next step
+                res_list = []
+                for _, row in res.iterrows():
+                    res_list.append([row['framework'], row['result']])
+                metric2taskres[res['metric'].unique()[0]].append(res_list)
+            out_f.write('\n')
+
+    # rankings report: per task and per tuner    
+    with open(result_file_name.replace('results.csv', 'reports/rankings.txt'), 'w') as out_f:
+        # generate reports per task
+        ranking_aggs = {}
+        for metric_type in metric_types:
+            sorted_lists = []
+            if metric_type in ['auc']:
+                for l in metric2taskres[metric_type]:
+                    l_sorted = sorted(l, key=(lambda x: x[-1]), reverse=True)
+                    l_sorted = [[x[0], x[1], i+1] for (i, x) in enumerate(l_sorted)]
+                    sorted_lists.append(l_sorted)
+            elif metric_type in ['rmse', 'logloss']:
+                for l in metric2taskres[metric_type]:
+                    l_sorted = sorted(l, key=(lambda x: x[-1]))
+                    l_sorted = [[x[0], x[1], i+1] for (i, x) in enumerate(l_sorted)]
+                    sorted_lists.append(l_sorted)
+            metric2taskres[metric_type] = sorted_lists
+
+            out_f.write("====================================================\n")
+            out_f.write("Average rankings for metric {}:\n".format(metric_type))
+            ranking_agg = [[t, 0] for t in tuners]
+            for i, tuner in enumerate(tuners):
+                for trial_res in metric2taskres[metric_type]:
+                    for t, s, r in trial_res:
+                        if t == tuner:
+                            ranking_agg[i][-1] += r
+
+            ranking_agg = [[x[0], x[1]/float(len(metric2taskres[metric_type]))] for x in ranking_agg]
+            ranking_agg = sorted(ranking_agg, key=(lambda x: x[-1]))
+            for t, r in ranking_agg:
+                out_f.write('{:<12} {:.2f}\n'.format(t, r))
+            ranking_aggs[metric_type] = ranking_agg
+            out_f.write('\n') 
+
+        # generate reports per tuner
+        out_f.write("====================================================\n")
+        out_f.write("Average rankings for tuners:\n")
+        header_string = '{:<12}'
+        for _ in metric_types:
+            header_string += ' {:<12}'
+        header_string += '\n'
+        out_f.write(header_string.format("Tuner", *metric_types))    
+        for tuner in tuners:
+            tuner_ranks = []
+            for m in metric_types:
+                for t, r in ranking_aggs[m]:
+                    if t == tuner:
+                        tuner_ranks.append('{:.2f}'.format(r))
+                        break
+            out_f.write(header_string.format(tuner, *tuner_ranks))
+        out_f.write('\n')
+
+
+def generate_graphs(result_file_name):
+    """
+    Generate graphs describing performance statistics.
+    The input result_file_name should be the path of the "results.csv" generated by automlbenchmark.
+    For each task, this function outputs two graphs in the "reports/" directory located in the same 
+    parent directory as "results.csv".
+    The graph named task_foldx_1.jpg summarizes the best score each tuner gets after n trials. 
+    The graph named task_foldx_2.jpg summarizes the score each tuner gets in each trial. 
+    """
+    markers = list(Line2D.markers.keys())
+    result = pd.read_csv(result_file_name)
+    scorelog_dir = result_file_name.replace('results.csv', 'scorelogs/')
+    output_dir = result_file_name.replace('results.csv', 'reports/') 
+    task_ids = result['id'].unique()
+    for task_id in task_ids:
+        task_results = result[result['id'] == task_id]
+        task_name = task_results.task.unique()[0]
+        folds = task_results['fold'].unique()        
+
+        for fold in folds:            
+            # load scorelog files
+            trial_scores, best_scores = [], []
+            tuners = list(task_results[task_results.fold == fold]['framework'].unique())
+            for tuner in tuners:
+                scorelog_name = '{}_{}_{}.csv'.format(tuner.lower(), task_name, fold)
+                intermediate_scores = pd.read_csv(scorelog_dir + scorelog_name)
+                bs = list(intermediate_scores['best_score'])
+                ts = [(i+1, x) for i, x in enumerate(list(intermediate_scores['trial_score'])) if x != 0]
+                best_scores.append([tuner, bs])
+                trial_scores.append([tuner, ts])
+
+            # generate the best score graph
+            plt.figure(figsize=(16, 8))
+            for i, (tuner, score) in enumerate(best_scores):
+                plt.plot(score, label=tuner, marker=markers[i])
+            plt.title('{} Fold {}'.format(task_name, fold))
+            plt.xlabel("Number of Trials")
+            plt.ylabel("Best Score")        
+            plt.legend()
+            plt.savefig(output_dir + '{}_fold{}_1.jpg'.format(task_name, fold))
+            plt.close()
+
+            # generate the trial score graph
+            plt.figure(figsize=(16, 8))
+            for i, (tuner, score) in enumerate(trial_scores):
+                x = [l[0] for l in score]
+                y = [l[1] for l in score] 
+                plt.plot(x, y, label=tuner)      #, marker=markers[i])
+            plt.title('{} Fold {}'.format(task_name, fold))
+            plt.xlabel("Trial Number")
+            plt.ylabel("Trial Score")        
+            plt.legend()
+            plt.savefig(output_dir + '{}_fold{}_2.jpg'.format(task_name, fold))
+            plt.close()
+
+            
+def main():
+    if len(sys.argv) != 2:
+        print("Usage: python parse_result_csv.py <result.csv file>")
+        exit(0)
+    generate_perf_report(sys.argv[1])
+    generate_graphs(sys.argv[1])
+    
+
+if __name__ == '__main__':
+    main()
diff --git a/examples/trials/benchmarking/automlbenchmark/requirements.txt b/examples/trials/benchmarking/automlbenchmark/requirements.txt
new file mode 100644
index 0000000000..9a7b00059d
--- /dev/null
+++ b/examples/trials/benchmarking/automlbenchmark/requirements.txt
@@ -0,0 +1,3 @@
+pandas>=1.2.0
+pyyaml>=5.4.1
+matplotlib>=3.4.1
diff --git a/examples/trials/benchmarking/automlbenchmark/runbenchmark_nni.sh b/examples/trials/benchmarking/automlbenchmark/runbenchmark_nni.sh
new file mode 100755
index 0000000000..d2d740925b
--- /dev/null
+++ b/examples/trials/benchmarking/automlbenchmark/runbenchmark_nni.sh
@@ -0,0 +1,56 @@
+#!/bin/bash
+
+time=$(date "+%Y%m%d%H%M%S")
+installation='automlbenchmark'
+outdir="results_$time"
+benchmark='nnivalid'      # 'nnismall'  
+serialize=$true           # if false, run all experiments together in background
+
+mkdir $outdir $outdir/scorelogs $outdir/reports 
+
+if [ "$#" -eq 0 ]; then
+    tuner_array=('TPE' 'Random' 'Anneal' 'Evolution' 'GPTuner' 'MetisTuner' 'Hyperband')
+else
+    tuner_array=( "$@" )
+fi
+
+if [ $serialize ]; then
+    # run tuners serially 
+    for tuner in ${tuner_array[*]}; do
+	echo "python $installation/runbenchmark.py $tuner $benchmark -o $outdir -u nni"
+	python $installation/runbenchmark.py $tuner $benchmark -o $outdir -u nni
+    done
+
+    # parse final results
+    echo "python parse_result_csv.py $outdir/results.csv"
+    python parse_result_csv.py "$outdir/results.csv"
+
+else
+    # run all the tuners in background
+    for tuner in ${tuner_array[*]}; do
+	mkdir "$outdir/$tuner" "$outdir/$tuner/scorelogs"
+	echo "python $installation/runbenchmark.py $tuner $benchmark -o $outdir/$tuner -u nni &"
+	python $installation/runbenchmark.py $tuner $benchmark -o $outdir/$tuner -u nni &
+    done
+    
+    wait
+
+    # aggregate results
+    touch "$outdir/results.csv"
+    let i=0
+    for tuner in ${tuner_array[*]}; do
+	cp "$outdir/$tuner/scorelogs"/* $outdir/scorelogs
+	if [ $i -eq 0 ]; then
+	    cp "$outdir/$tuner/results.csv" "$outdir/results.csv"
+	else
+	    let nlines=`cat "$outdir/$tuner/results.csv" | wc -l`
+	    ((nlines=nlines-1))
+	    tail -n $nlines "$outdir/$tuner/results.csv" >> "$outdir/results.csv" 
+	fi
+	((i=i+1))
+    done
+
+    # parse final results
+    echo "python parse_result_csv.py $outdir/results.csv"
+    python parse_result_csv.py "$outdir/results.csv"
+fi
diff --git a/examples/trials/benchmarking/automlbenchmark/setup.sh b/examples/trials/benchmarking/automlbenchmark/setup.sh
new file mode 100755
index 0000000000..d87aea1b66
--- /dev/null
+++ b/examples/trials/benchmarking/automlbenchmark/setup.sh
@@ -0,0 +1,13 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+#!/bin/bash
+
+# download automlbenchmark repository
+if [ ! -d './automlbenchmark' ] ; then
+    git clone https://github.com/openml/automlbenchmark.git --branch stable --depth 1
+fi
+
+# install dependencies 
+pip3 install -r automlbenchmark/requirements.txt
+pip3 install -r requirements.txt --ignore-installed
diff --git a/examples/trials/benchmarking/config_hyperband.yml b/examples/trials/benchmarking/hyperband/config_hyperband.yml
similarity index 100%
rename from examples/trials/benchmarking/config_hyperband.yml
rename to examples/trials/benchmarking/hyperband/config_hyperband.yml
diff --git a/examples/trials/benchmarking/main.py b/examples/trials/benchmarking/hyperband/main.py
similarity index 100%
rename from examples/trials/benchmarking/main.py
rename to examples/trials/benchmarking/hyperband/main.py
diff --git a/examples/trials/benchmarking/search_space.json b/examples/trials/benchmarking/hyperband/search_space.json
similarity index 100%
rename from examples/trials/benchmarking/search_space.json
rename to examples/trials/benchmarking/hyperband/search_space.json
diff --git a/nni/algorithms/compression/pytorch/pruning/__init__.py b/nni/algorithms/compression/pytorch/pruning/__init__.py
index f534b25da0..f49cf0cb65 100644
--- a/nni/algorithms/compression/pytorch/pruning/__init__.py
+++ b/nni/algorithms/compression/pytorch/pruning/__init__.py
@@ -1,14 +1,13 @@
 # Copyright (c) Microsoft Corporation.
 # Licensed under the MIT license.
 
-from .finegrained_pruning import *
-from .structured_pruning import *
-from .one_shot import *
-from .agp import *
+from .finegrained_pruning_masker import *
+from .structured_pruning_masker import *
+from .one_shot_pruner import *
+from .iterative_pruner import *
 from .lottery_ticket import LotteryTicketPruner
 from .simulated_annealing_pruner import SimulatedAnnealingPruner
 from .net_adapt_pruner import NetAdaptPruner
-from .admm_pruner import ADMMPruner
 from .auto_compress_pruner import AutoCompressPruner
 from .sensitivity_pruner import SensitivityPruner
 from .amc import AMCPruner
diff --git a/nni/algorithms/compression/pytorch/pruning/admm_pruner.py b/nni/algorithms/compression/pytorch/pruning/admm_pruner.py
deleted file mode 100644
index 30e73a23f8..0000000000
--- a/nni/algorithms/compression/pytorch/pruning/admm_pruner.py
+++ /dev/null
@@ -1,177 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT license.
-
-import logging
-import torch
-from schema import And, Optional
-import copy
-
-from nni.compression.pytorch.utils.config_validation import CompressorSchema
-from .constants import MASKER_DICT
-from .one_shot import OneshotPruner
-
-
-_logger = logging.getLogger(__name__)
-
-
-class ADMMPruner(OneshotPruner):
-    """
-    A Pytorch implementation of ADMM Pruner algorithm.
-
-    Parameters
-    ----------
-    model : torch.nn.Module
-        Model to be pruned.
-    config_list : list
-        List on pruning configs.
-    trainer : function
-        Function used for the first subproblem.
-        Users should write this function as a normal function to train the Pytorch model
-        and include `model, optimizer, criterion, epoch, callback` as function arguments.
-        Here `callback` acts as an L2 regulizer as presented in the formula (7) of the original paper.
-        The logic of `callback` is implemented inside the Pruner,
-        users are just required to insert `callback()` between `loss.backward()` and `optimizer.step()`.
-        Example::
-
-            def trainer(model, criterion, optimizer, epoch, callback):
-                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-                train_loader = ...
-                model.train()
-                for batch_idx, (data, target) in enumerate(train_loader):
-                    data, target = data.to(device), target.to(device)
-                    optimizer.zero_grad()
-                    output = model(data)
-                    loss = criterion(output, target)
-                    loss.backward()
-                    # callback should be inserted between loss.backward() and optimizer.step()
-                    if callback:
-                        callback()
-                    optimizer.step()
-    num_iterations : int
-        Total number of iterations.
-    training_epochs : int
-        Training epochs of the first subproblem.
-    row : float
-        Penalty parameters for ADMM training.
-    base_algo : str
-        Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among the ops,
-        the assigned `base_algo` is used to decide which filters/channels/weights to prune.
-
-    """
-
-    def __init__(self, model, config_list, trainer, num_iterations=30, training_epochs=5, row=1e-4, base_algo='l1'):
-        self._base_algo = base_algo
-
-        super().__init__(model, config_list)
-
-        self._trainer = trainer
-        self._num_iterations = num_iterations
-        self._training_epochs = training_epochs
-        self._row = row
-
-        self.set_wrappers_attribute("if_calculated", False)
-        self.masker = MASKER_DICT[self._base_algo](self.bound_model, self)
-
-    def validate_config(self, model, config_list):
-        """
-        Parameters
-        ----------
-        model : torch.nn.Module
-            Model to be pruned
-        config_list : list
-            List on pruning configs
-        """
-
-        if self._base_algo == 'level':
-            schema = CompressorSchema([{
-                'sparsity': And(float, lambda n: 0 < n < 1),
-                Optional('op_types'): [str],
-                Optional('op_names'): [str],
-            }], model, _logger)
-        elif self._base_algo in ['l1', 'l2', 'fpgm']:
-            schema = CompressorSchema([{
-                'sparsity': And(float, lambda n: 0 < n < 1),
-                'op_types': ['Conv2d'],
-                Optional('op_names'): [str]
-            }], model, _logger)
-
-        schema.validate(config_list)
-
-    def _projection(self, weight, sparsity, wrapper):
-        '''
-        Return the Euclidean projection of the weight matrix according to the pruning mode.
-
-        Parameters
-        ----------
-        weight : tensor
-            original matrix
-        sparsity : float
-            the ratio of parameters which need to be set to zero
-        wrapper: PrunerModuleWrapper
-            layer wrapper of this layer
-
-        Returns
-        -------
-        tensor
-            the projected matrix
-        '''
-        wrapper_copy = copy.deepcopy(wrapper)
-        wrapper_copy.module.weight.data = weight
-        return weight.data.mul(self.masker.calc_mask(sparsity, wrapper_copy)['weight_mask'])
-
-    def compress(self):
-        """
-        Compress the model with ADMM.
-
-        Returns
-        -------
-        torch.nn.Module
-            model with specified modules compressed.
-        """
-        _logger.info('Starting ADMM Compression...')
-
-        # initiaze Z, U
-        # Z_i^0 = W_i^0
-        # U_i^0 = 0
-        Z = []
-        U = []
-        for wrapper in self.get_modules_wrapper():
-            z = wrapper.module.weight.data
-            Z.append(z)
-            U.append(torch.zeros_like(z))
-
-        optimizer = torch.optim.Adam(
-            self.bound_model.parameters(), lr=1e-3, weight_decay=5e-5)
-
-        # Loss = cross_entropy +  l2 regulization + \Sum_{i=1}^N \row_i ||W_i - Z_i^k + U_i^k||^2
-        criterion = torch.nn.CrossEntropyLoss()
-
-        # callback function to do additonal optimization, refer to the deriatives of Formula (7)
-        def callback():
-            for i, wrapper in enumerate(self.get_modules_wrapper()):
-                wrapper.module.weight.data -= self._row * \
-                    (wrapper.module.weight.data - Z[i] + U[i])
-
-        # optimization iteration
-        for k in range(self._num_iterations):
-            _logger.info('ADMM iteration : %d', k)
-
-            # step 1: optimize W with AdamOptimizer
-            for epoch in range(self._training_epochs):
-                self._trainer(self.bound_model, optimizer=optimizer,
-                              criterion=criterion, epoch=epoch, callback=callback)
-
-            # step 2: update Z, U
-            # Z_i^{k+1} = projection(W_i^{k+1} + U_i^k)
-            # U_i^{k+1} = U^k + W_i^{k+1} - Z_i^{k+1}
-            for i, wrapper in enumerate(self.get_modules_wrapper()):
-                z = wrapper.module.weight.data + U[i]
-                Z[i] = self._projection(z, wrapper.config['sparsity'], wrapper)
-                U[i] = U[i] + wrapper.module.weight.data - Z[i]
-
-        # apply prune
-        self.update_mask()
-
-        _logger.info('Compression finished.')
-
-        return self.bound_model
diff --git a/nni/algorithms/compression/pytorch/pruning/agp.py b/nni/algorithms/compression/pytorch/pruning/agp.py
deleted file mode 100644
index ef9ca71635..0000000000
--- a/nni/algorithms/compression/pytorch/pruning/agp.py
+++ /dev/null
@@ -1,151 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT license.
-
-"""
-An automated gradual pruning algorithm that prunes the smallest magnitude
-weights to achieve a preset level of network sparsity.
-Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
-efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
-Learning of Phones and other Consumer Devices.
-"""
-
-import logging
-import torch
-from schema import And, Optional
-from .constants import MASKER_DICT
-from nni.compression.pytorch.utils.config_validation import CompressorSchema
-from nni.compression.pytorch.compressor import Pruner
-
-__all__ = ['AGPPruner']
-
-logger = logging.getLogger('torch pruner')
-
-class AGPPruner(Pruner):
-    """
-    Parameters
-    ----------
-    model : torch.nn.Module
-        Model to be pruned.
-    config_list : listlist
-        Supported keys:
-            - initial_sparsity: This is to specify the sparsity when compressor starts to compress.
-            - final_sparsity: This is to specify the sparsity when compressor finishes to compress.
-            - start_epoch: This is to specify the epoch number when compressor starts to compress, default start from epoch 0.
-            - end_epoch: This is to specify the epoch number when compressor finishes to compress.
-            - frequency: This is to specify every *frequency* number epochs compressor compress once, default frequency=1.
-    optimizer: torch.optim.Optimizer
-        Optimizer used to train model.
-    pruning_algorithm: str
-        Algorithms being used to prune model,
-        choose from `['level', 'slim', 'l1', 'l2', 'fpgm', 'taylorfo', 'apoz', 'mean_activation']`, by default `level`
-    """
-
-    def __init__(self, model, config_list, optimizer, pruning_algorithm='level'):
-        super().__init__(model, config_list, optimizer)
-        assert isinstance(optimizer, torch.optim.Optimizer), "AGP pruner is an iterative pruner, please pass optimizer of the model to it"
-        self.masker = MASKER_DICT[pruning_algorithm](model, self)
-
-        self.now_epoch = 0
-        self.set_wrappers_attribute("if_calculated", False)
-
-    def validate_config(self, model, config_list):
-        """
-        Parameters
-        ----------
-        model : torch.nn.Module
-            Model to be pruned
-        config_list : list
-            List on pruning configs
-        """
-        schema = CompressorSchema([{
-            'initial_sparsity': And(float, lambda n: 0 <= n <= 1),
-            'final_sparsity': And(float, lambda n: 0 <= n <= 1),
-            'start_epoch': And(int, lambda n: n >= 0),
-            'end_epoch': And(int, lambda n: n >= 0),
-            'frequency': And(int, lambda n: n > 0),
-            Optional('op_types'): [str],
-            Optional('op_names'): [str]
-        }], model, logger)
-
-        schema.validate(config_list)
-
-    def calc_mask(self, wrapper, wrapper_idx=None):
-        """
-        Calculate the mask of given layer.
-        Scale factors with the smallest absolute value in the BN layer are masked.
-        Parameters
-        ----------
-        wrapper : Module
-            the layer to instrument the compression operation
-        wrapper_idx: int
-            index of this wrapper in pruner's all wrappers
-        Returns
-        -------
-        dict | None
-            Dictionary for storing masks, keys of the dict:
-            'weight_mask':  weight mask tensor
-            'bias_mask': bias mask tensor (optional)
-        """
-
-        config = wrapper.config
-
-        start_epoch = config.get('start_epoch', 0)
-        freq = config.get('frequency', 1)
-
-        if wrapper.if_calculated:
-            return None
-        if not (self.now_epoch >= start_epoch and (self.now_epoch - start_epoch) % freq == 0):
-            return None
-
-        target_sparsity = self.compute_target_sparsity(config)
-        new_mask = self.masker.calc_mask(sparsity=target_sparsity, wrapper=wrapper, wrapper_idx=wrapper_idx)
-        if new_mask is not None:
-            wrapper.if_calculated = True
-
-        return new_mask
-
-    def compute_target_sparsity(self, config):
-        """
-        Calculate the sparsity for pruning
-        Parameters
-        ----------
-        config : dict
-            Layer's pruning config
-        Returns
-        -------
-        float
-            Target sparsity to be pruned
-        """
-
-        end_epoch = config.get('end_epoch', 1)
-        start_epoch = config.get('start_epoch', 0)
-        freq = config.get('frequency', 1)
-        final_sparsity = config.get('final_sparsity', 0)
-        initial_sparsity = config.get('initial_sparsity', 0)
-        if end_epoch <= start_epoch or initial_sparsity >= final_sparsity:
-            logger.warning('your end epoch <= start epoch or initial_sparsity >= final_sparsity')
-            return final_sparsity
-
-        if end_epoch <= self.now_epoch:
-            return final_sparsity
-
-        span = ((end_epoch - start_epoch - 1) // freq) * freq
-        assert span > 0
-        target_sparsity = (final_sparsity +
-                           (initial_sparsity - final_sparsity) *
-                           (1.0 - ((self.now_epoch - start_epoch) / span)) ** 3)
-        return target_sparsity
-
-    def update_epoch(self, epoch):
-        """
-        Update epoch
-        Parameters
-        ----------
-        epoch : int
-            current training epoch
-        """
-
-        if epoch > 0:
-            self.now_epoch = epoch
-            for wrapper in self.get_modules_wrapper():
-                wrapper.if_calculated = False
diff --git a/nni/algorithms/compression/pytorch/pruning/auto_compress_pruner.py b/nni/algorithms/compression/pytorch/pruning/auto_compress_pruner.py
index fdc27ac2f4..82a8f1cb98 100644
--- a/nni/algorithms/compression/pytorch/pruning/auto_compress_pruner.py
+++ b/nni/algorithms/compression/pytorch/pruning/auto_compress_pruner.py
@@ -13,8 +13,7 @@
 from nni.compression.pytorch.compressor import Pruner
 from nni.compression.pytorch.utils.config_validation import CompressorSchema
 from .simulated_annealing_pruner import SimulatedAnnealingPruner
-from .admm_pruner import ADMMPruner
-
+from .iterative_pruner import ADMMPruner
 
 _logger = logging.getLogger(__name__)
 
@@ -34,26 +33,7 @@ class AutoCompressPruner(Pruner):
     trainer : function
         Function used for the first subproblem of ADMM Pruner.
         Users should write this function as a normal function to train the Pytorch model
-        and include `model, optimizer, criterion, epoch, callback` as function arguments.
-        Here `callback` acts as an L2 regulizer as presented in the formula (7) of the original paper.
-        The logic of `callback` is implemented inside the Pruner,
-        users are just required to insert `callback()` between `loss.backward()` and `optimizer.step()`.
-        Example::
-
-            def trainer(model, criterion, optimizer, epoch, callback):
-                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-                train_loader = ...
-                model.train()
-                for batch_idx, (data, target) in enumerate(train_loader):
-                    data, target = data.to(device), target.to(device)
-                    optimizer.zero_grad()
-                    output = model(data)
-                    loss = criterion(output, target)
-                    loss.backward()
-                    # callback should be inserted between loss.backward() and optimizer.step()
-                    if callback:
-                        callback()
-                    optimizer.step()
+        and include `model, optimizer, criterion, epoch` as function arguments.
     evaluator : function
         function to evaluate the pruned model.
         This function should include `model` as the only parameter, and returns a scalar value.
@@ -80,8 +60,8 @@ def evaluator(model):
     optimize_mode : str
         optimize mode, `maximize` or `minimize`, by default `maximize`.
     base_algo : str
-        Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among the ops,
-        the assigned `base_algo` is used to decide which filters/channels/weights to prune.
+        Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among
+        the ops, the assigned `base_algo` is used to decide which filters/channels/weights to prune.
     start_temperature : float
         Start temperature of the simulated annealing process.
     stop_temperature : float
@@ -92,7 +72,7 @@ def evaluator(model):
         Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
     admm_num_iterations : int
         Number of iterations of ADMM Pruner.
-    admm_training_epochs : int
+    admm_epochs_per_iteration : int
         Training epochs of the first optimization subproblem of ADMMPruner.
     row : float
         Penalty parameters for ADMM training.
@@ -100,18 +80,19 @@ def evaluator(model):
         PATH to store temporary experiment data.
     """
 
-    def __init__(self, model, config_list, trainer, evaluator, dummy_input,
+    def __init__(self, model, config_list, trainer, criterion, evaluator, dummy_input,
                  num_iterations=3, optimize_mode='maximize', base_algo='l1',
                  # SimulatedAnnealing related
                  start_temperature=100, stop_temperature=20, cool_down_rate=0.9, perturbation_magnitude=0.35,
                  # ADMM related
-                 admm_num_iterations=30, admm_training_epochs=5, row=1e-4,
+                 admm_num_iterations=30, admm_epochs_per_iteration=5, row=1e-4,
                  experiment_data_dir='./'):
         # original model
         self._model_to_prune = model
         self._base_algo = base_algo
 
         self._trainer = trainer
+        self._criterion = criterion
         self._evaluator = evaluator
         self._dummy_input = dummy_input
         self._num_iterations = num_iterations
@@ -125,7 +106,7 @@ def __init__(self, model, config_list, trainer, evaluator, dummy_input,
 
         # hyper parameters for ADMM algorithm
         self._admm_num_iterations = admm_num_iterations
-        self._admm_training_epochs = admm_training_epochs
+        self._admm_epochs_per_iteration = admm_epochs_per_iteration
         self._row = row
 
         # overall pruning rate
@@ -174,12 +155,12 @@ def compress(self):
         """
         _logger.info('Starting AutoCompress pruning...')
 
-        sparsity_each_round = 1 - pow(1-self._sparsity, 1/self._num_iterations)
+        sparsity_each_round = 1 - pow(1 - self._sparsity, 1 / self._num_iterations)
 
         for i in range(self._num_iterations):
             _logger.info('Pruning iteration: %d', i)
             _logger.info('Target sparsity this round: %s',
-                         1-pow(1-sparsity_each_round, i+1))
+                         1 - pow(1 - sparsity_each_round, i + 1))
 
             # SimulatedAnnealingPruner
             _logger.info(
@@ -204,9 +185,10 @@ def compress(self):
             ADMMpruner = ADMMPruner(
                 model=copy.deepcopy(self._model_to_prune),
                 config_list=config_list,
+                criterion=self._criterion,
                 trainer=self._trainer,
                 num_iterations=self._admm_num_iterations,
-                training_epochs=self._admm_training_epochs,
+                epochs_per_iteration=self._admm_epochs_per_iteration,
                 row=self._row,
                 base_algo=self._base_algo)
             ADMMpruner.compress()
@@ -214,12 +196,13 @@ def compress(self):
             ADMMpruner.export_model(os.path.join(self._experiment_data_dir, 'model_admm_masked.pth'), os.path.join(
                 self._experiment_data_dir, 'mask.pth'))
 
-            # use speed up to prune the model before next iteration, because SimulatedAnnealingPruner & ADMMPruner don't take masked models
+            # use speed up to prune the model before next iteration,
+            # because SimulatedAnnealingPruner & ADMMPruner don't take masked models
             self._model_to_prune.load_state_dict(torch.load(os.path.join(
                 self._experiment_data_dir, 'model_admm_masked.pth')))
 
             masks_file = os.path.join(self._experiment_data_dir, 'mask.pth')
-            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+            device = next(self._model_to_prune.parameters()).device
 
             _logger.info('Speeding up models...')
             m_speedup = ModelSpeedup(self._model_to_prune, self._dummy_input, masks_file, device)
diff --git a/nni/algorithms/compression/pytorch/pruning/constants_pruner.py b/nni/algorithms/compression/pytorch/pruning/constants_pruner.py
index b0ad5cce37..55ba9506f3 100644
--- a/nni/algorithms/compression/pytorch/pruning/constants_pruner.py
+++ b/nni/algorithms/compression/pytorch/pruning/constants_pruner.py
@@ -2,7 +2,7 @@
 # Licensed under the MIT license.
 
 
-from .one_shot import LevelPruner, L1FilterPruner, L2FilterPruner, FPGMPruner
+from .one_shot_pruner import LevelPruner, L1FilterPruner, L2FilterPruner, FPGMPruner
 
 PRUNER_DICT = {
     'level': LevelPruner,
diff --git a/nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py b/nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py
new file mode 100644
index 0000000000..c0ca053a7d
--- /dev/null
+++ b/nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py
@@ -0,0 +1,162 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import logging
+from schema import And, Optional, SchemaError
+from nni.common.graph_utils import TorchModuleGraph
+from nni.compression.pytorch.utils.shape_dependency import ChannelDependency, GroupDependency
+from nni.compression.pytorch.utils.config_validation import CompressorSchema
+from nni.compression.pytorch.compressor import Pruner
+from .constants import MASKER_DICT
+
+__all__ = ['DependencyAwarePruner']
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+class DependencyAwarePruner(Pruner):
+    """
+    DependencyAwarePruner has two ways to calculate the masks
+    for conv layers. In the normal way, the DependencyAwarePruner
+    will calculate the mask of each layer separately. For example, each
+    conv layer determine which filters should be pruned according to its L1
+    norm. In constrast, in the dependency-aware way, the layers that in a
+    dependency group will be pruned jointly and these layers will be forced
+    to prune the same channels.
+    """
+
+    def __init__(self, model, config_list, optimizer=None, pruning_algorithm='level', dependency_aware=False,
+                 dummy_input=None, **algo_kwargs):
+        super().__init__(model, config_list=config_list, optimizer=optimizer)
+
+        self.dependency_aware = dependency_aware
+        self.dummy_input = dummy_input
+
+        if self.dependency_aware:
+            if not self._supported_dependency_aware():
+                raise ValueError('This pruner does not support dependency aware!')
+
+            errmsg = "When dependency_aware is set, the dummy_input should not be None"
+            assert self.dummy_input is not None, errmsg
+            # Get the TorchModuleGraph of the target model
+            # to trace the model, we need to unwrap the wrappers
+            self._unwrap_model()
+            self.graph = TorchModuleGraph(model, dummy_input)
+            self._wrap_model()
+            self.channel_depen = ChannelDependency(
+                traced_model=self.graph.trace)
+            self.group_depen = GroupDependency(traced_model=self.graph.trace)
+            self.channel_depen = self.channel_depen.dependency_sets
+            self.channel_depen = {
+                name: sets for sets in self.channel_depen for name in sets}
+            self.group_depen = self.group_depen.dependency_sets
+
+        self.masker = MASKER_DICT[pruning_algorithm](
+            model, self, **algo_kwargs)
+        # set the dependency-aware switch for the masker
+        self.masker.dependency_aware = dependency_aware
+        self.set_wrappers_attribute("if_calculated", False)
+
+    def calc_mask(self, wrapper, wrapper_idx=None):
+        if not wrapper.if_calculated:
+            sparsity = wrapper.config['sparsity']
+            masks = self.masker.calc_mask(
+                sparsity=sparsity, wrapper=wrapper, wrapper_idx=wrapper_idx)
+
+            # masker.calc_mask returns None means calc_mask is not calculated sucessfully, can try later
+            if masks is not None:
+                wrapper.if_calculated = True
+            return masks
+        else:
+            return None
+
+    def update_mask(self):
+        if not self.dependency_aware:
+            # if we use the normal way to update the mask,
+            # then call the update_mask of the father class
+            super(DependencyAwarePruner, self).update_mask()
+        else:
+            # if we update the mask in a dependency-aware way
+            # then we call _dependency_update_mask
+            self._dependency_update_mask()
+
+    def validate_config(self, model, config_list):
+        schema = CompressorSchema([{
+            Optional('sparsity'): And(float, lambda n: 0 < n < 1),
+            Optional('op_types'): ['Conv2d'],
+            Optional('op_names'): [str],
+            Optional('exclude'): bool
+        }], model, logger)
+
+        schema.validate(config_list)
+        for config in config_list:
+            if 'exclude' not in config and 'sparsity' not in config:
+                raise SchemaError('Either sparisty or exclude must be specified!')
+
+    def _supported_dependency_aware(self):
+        raise NotImplementedError
+
+    def _dependency_calc_mask(self, wrappers, channel_dsets, wrappers_idx=None):
+        """
+        calculate the masks for the conv layers in the same
+        channel dependecy set. All the layers passed in have
+        the same number of channels.
+
+        Parameters
+        ----------
+        wrappers: list
+            The list of the wrappers that in the same channel dependency
+            set.
+        wrappers_idx: list
+            The list of the indexes of wrapppers.
+        Returns
+        -------
+        masks: dict
+            A dict object that contains the masks of the layers in this
+            dependency group, the key is the name of the convolutional layers.
+        """
+        # The number of the groups for each conv layers
+        # Note that, this number may be different from its
+        # original number of groups of filters.
+        groups = [self.group_depen[_w.name] for _w in wrappers]
+        sparsities = [_w.config['sparsity'] for _w in wrappers]
+        masks = self.masker.calc_mask(
+            sparsities, wrappers, wrappers_idx, channel_dsets=channel_dsets, groups=groups)
+        if masks is not None:
+            # if masks is None, then the mask calculation fails.
+            # for example, in activation related maskers, we should
+            # pass enough batches of data to the model, so that the
+            # masks can be calculated successfully.
+            for _w in wrappers:
+                _w.if_calculated = True
+        return masks
+
+    def _dependency_update_mask(self):
+        """
+        In the original update_mask, the wraper of each layer will update its
+        own mask according to the sparsity specified in the config_list. However, in
+        the _dependency_update_mask, we may prune several layers at the same
+        time according the sparsities and the channel/group dependencies.
+        """
+        name2wrapper = {x.name: x for x in self.get_modules_wrapper()}
+        wrapper2index = {x: i for i, x in enumerate(self.get_modules_wrapper())}
+        for wrapper in self.get_modules_wrapper():
+            if wrapper.if_calculated:
+                continue
+            # find all the conv layers that have channel dependecy with this layer
+            # and prune all these layers at the same time.
+            _names = [x for x in self.channel_depen[wrapper.name]]
+            logger.info('Pruning the dependent layers: %s', ','.join(_names))
+            _wrappers = [name2wrapper[name]
+                         for name in _names if name in name2wrapper]
+            _wrapper_idxes = [wrapper2index[_w] for _w in _wrappers]
+
+            masks = self._dependency_calc_mask(
+                _wrappers, _names, wrappers_idx=_wrapper_idxes)
+            if masks is not None:
+                for layer in masks:
+                    for mask_type in masks[layer]:
+                        assert hasattr(name2wrapper[layer], mask_type), "there is no attribute '%s' in wrapper on %s" \
+                            % (mask_type, layer)
+                        setattr(name2wrapper[layer], mask_type, masks[layer][mask_type])
diff --git a/nni/algorithms/compression/pytorch/pruning/finegrained_pruning.py b/nni/algorithms/compression/pytorch/pruning/finegrained_pruning_masker.py
similarity index 100%
rename from nni/algorithms/compression/pytorch/pruning/finegrained_pruning.py
rename to nni/algorithms/compression/pytorch/pruning/finegrained_pruning_masker.py
diff --git a/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py b/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py
new file mode 100644
index 0000000000..9651e9e35a
--- /dev/null
+++ b/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py
@@ -0,0 +1,576 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import logging
+import copy
+import torch
+from schema import And, Optional
+from nni.compression.pytorch.utils.config_validation import CompressorSchema
+from .constants import MASKER_DICT
+from .dependency_aware_pruner import DependencyAwarePruner
+
+__all__ = ['AGPPruner', 'ADMMPruner', 'SlimPruner', 'TaylorFOWeightFilterPruner', 'ActivationAPoZRankFilterPruner',
+           'ActivationMeanRankFilterPruner']
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+class IterativePruner(DependencyAwarePruner):
+    """
+    Prune model during the training process.
+    """
+
+    def __init__(self, model, config_list, optimizer=None, pruning_algorithm='slim', trainer=None, criterion=None,
+                 num_iterations=20, epochs_per_iteration=5, dependency_aware=False, dummy_input=None, **algo_kwargs):
+        """
+        Parameters
+        ----------
+        model: torch.nn.Module
+            Model to be pruned
+        config_list: list
+            List on pruning configs
+        optimizer: torch.optim.Optimizer
+            Optimizer used to train model
+        pruning_algorithm: str
+            algorithms being used to prune model
+        trainer: function
+            Function used to train the model.
+            Users should write this function as a normal function to train the Pytorch model
+            and include `model, optimizer, criterion, epoch` as function arguments.
+        criterion: function
+            Function used to calculate the loss between the target and the output.
+        num_iterations: int
+            Total number of iterations in pruning process. We will calculate mask at the end of an iteration.
+        epochs_per_iteration: Union[int, list]
+            The number of training epochs for each iteration. `int` represents the same value for each iteration.
+            `list` represents the specific value for each iteration.
+        dependency_aware: bool
+            If prune the model in a dependency-aware way.
+        dummy_input: torch.Tensor
+            The dummy input to analyze the topology constraints. Note that,
+            the dummy_input should on the same device with the model.
+        algo_kwargs: dict
+            Additional parameters passed to pruning algorithm masker class
+        """
+        super().__init__(model, config_list, optimizer, pruning_algorithm, dependency_aware, dummy_input, **algo_kwargs)
+
+        if isinstance(epochs_per_iteration, list):
+            assert len(epochs_per_iteration) == num_iterations, 'num_iterations should equal to the length of epochs_per_iteration'
+            self.epochs_per_iteration = epochs_per_iteration
+        else:
+            self.epochs_per_iteration = [epochs_per_iteration] * num_iterations
+
+        self._trainer = trainer
+        self._criterion = criterion
+
+    def _fresh_calculated(self):
+        for wrapper in self.get_modules_wrapper():
+            wrapper.if_calculated = False
+
+    def compress(self):
+        training = self.bound_model.training
+        self.bound_model.train()
+        for _, epochs_num in enumerate(self.epochs_per_iteration):
+            self._fresh_calculated()
+            for epoch in range(epochs_num):
+                self._trainer(self.bound_model, optimizer=self.optimizer, criterion=self._criterion, epoch=epoch)
+            self.update_mask()
+        self.bound_model.train(training)
+
+        return self.bound_model
+
+
+class AGPPruner(IterativePruner):
+    """
+    Parameters
+    ----------
+    model : torch.nn.Module
+        Model to be pruned.
+    config_list : listlist
+        Supported keys:
+            - sparsity : This is to specify the sparsity operations to be compressed to.
+            - op_types : See supported type in your specific pruning algorithm.
+    optimizer: torch.optim.Optimizer
+        Optimizer used to train model.
+    trainer: function
+        Function to train the model
+    criterion: function
+        Function used to calculate the loss between the target and the output.
+    num_iterations: int
+        Total number of iterations in pruning process. We will calculate mask at the end of an iteration.
+    epochs_per_iteration: int
+        The number of training epochs for each iteration.
+    pruning_algorithm: str
+        Algorithms being used to prune model,
+        choose from `['level', 'slim', 'l1', 'l2', 'fpgm', 'taylorfo', 'apoz', 'mean_activation']`, by default `level`
+    """
+
+    def __init__(self, model, config_list, optimizer, trainer, criterion, num_iterations=10, epochs_per_iteration=1, pruning_algorithm='level'):
+        super().__init__(model, config_list, optimizer=optimizer, trainer=trainer, criterion=criterion,
+                         num_iterations=num_iterations, epochs_per_iteration=epochs_per_iteration)
+        assert isinstance(optimizer, torch.optim.Optimizer), "AGP pruner is an iterative pruner, please pass optimizer of the model to it"
+        self.masker = MASKER_DICT[pruning_algorithm](model, self)
+        self.now_epoch = 0
+        self.freq = epochs_per_iteration
+        self.end_epoch = epochs_per_iteration * num_iterations
+        self.set_wrappers_attribute("if_calculated", False)
+
+    def validate_config(self, model, config_list):
+        """
+        Parameters
+        ----------
+        model : torch.nn.Module
+            Model to be pruned
+        config_list : list
+            List on pruning configs
+        """
+        schema = CompressorSchema([{
+            'sparsity': And(float, lambda n: 0 <= n <= 1),
+            Optional('op_types'): [str],
+            Optional('op_names'): [str]
+        }], model, logger)
+
+        schema.validate(config_list)
+
+    def _supported_dependency_aware(self):
+        return False
+
+    def calc_mask(self, wrapper, wrapper_idx=None):
+        """
+        Calculate the mask of given layer.
+        Scale factors with the smallest absolute value in the BN layer are masked.
+        Parameters
+        ----------
+        wrapper : Module
+            the layer to instrument the compression operation
+        wrapper_idx: int
+            index of this wrapper in pruner's all wrappers
+        Returns
+        -------
+        dict | None
+            Dictionary for storing masks, keys of the dict:
+            'weight_mask':  weight mask tensor
+            'bias_mask': bias mask tensor (optional)
+        """
+
+        config = wrapper.config
+
+        if wrapper.if_calculated:
+            return None
+
+        if not self.now_epoch % self.freq == 0:
+            return None
+
+        target_sparsity = self.compute_target_sparsity(config)
+        new_mask = self.masker.calc_mask(sparsity=target_sparsity, wrapper=wrapper, wrapper_idx=wrapper_idx)
+
+        if new_mask is not None:
+            wrapper.if_calculated = True
+
+        return new_mask
+
+    def compute_target_sparsity(self, config):
+        """
+        Calculate the sparsity for pruning
+        Parameters
+        ----------
+        config : dict
+            Layer's pruning config
+        Returns
+        -------
+        float
+            Target sparsity to be pruned
+        """
+
+        initial_sparsity = 0
+        self.target_sparsity = final_sparsity = config.get('sparsity', 0)
+
+        if initial_sparsity >= final_sparsity:
+            logger.warning('your initial_sparsity >= final_sparsity')
+            return final_sparsity
+
+        if self.end_epoch == 1 or self.end_epoch <= self.now_epoch:
+            return final_sparsity
+
+        span = ((self.end_epoch - 1) // self.freq) * self.freq
+        assert span > 0
+        self.target_sparsity = (final_sparsity + (initial_sparsity - final_sparsity) * (1.0 - (self.now_epoch / span)) ** 3)
+        return self.target_sparsity
+
+    def update_epoch(self, epoch):
+        """
+        Update epoch
+        Parameters
+        ----------
+        epoch : int
+            current training epoch
+        """
+
+        if epoch > 0:
+            self.now_epoch = epoch
+            for wrapper in self.get_modules_wrapper():
+                wrapper.if_calculated = False
+
+    # TODO: need refactor
+    def compress(self):
+        training = self.bound_model.training
+        self.bound_model.train()
+
+        for epoch in range(self.end_epoch):
+            self.update_epoch(epoch)
+            self._trainer(self.bound_model, optimizer=self.optimizer, criterion=self._criterion, epoch=epoch)
+            self.update_mask()
+            logger.info(f'sparsity is {self.target_sparsity:.2f} at epoch {epoch}')
+            self.get_pruned_weights()
+
+        self.bound_model.train(training)
+
+        return self.bound_model
+
+
+class ADMMPruner(IterativePruner):
+    """
+    A Pytorch implementation of ADMM Pruner algorithm.
+
+    Parameters
+    ----------
+    model : torch.nn.Module
+        Model to be pruned.
+    config_list : list
+        List on pruning configs.
+    trainer : function
+        Function used for the first subproblem.
+        Users should write this function as a normal function to train the Pytorch model
+        and include `model, optimizer, criterion, epoch` as function arguments.
+    criterion: function
+        Function used to calculate the loss between the target and the output. By default, we use CrossEntropyLoss in ADMMPruner.
+    num_iterations: int
+        Total number of iterations in pruning process. We will calculate mask after we finish all iterations in ADMMPruner.
+    epochs_per_iteration: int
+        Training epochs of the first subproblem.
+    row : float
+        Penalty parameters for ADMM training.
+    base_algo : str
+        Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among
+        the ops, the assigned `base_algo` is used to decide which filters/channels/weights to prune.
+
+    """
+
+    def __init__(self, model, config_list, trainer, criterion=torch.nn.CrossEntropyLoss(),
+                 num_iterations=30, epochs_per_iteration=5, row=1e-4, base_algo='l1'):
+        self._base_algo = base_algo
+
+        super().__init__(model, config_list)
+
+        self._trainer = trainer
+        self.optimizer = torch.optim.Adam(
+            self.bound_model.parameters(), lr=1e-3, weight_decay=5e-5)
+        self._criterion = criterion
+        self._num_iterations = num_iterations
+        self._training_epochs = epochs_per_iteration
+        self._row = row
+
+        self.set_wrappers_attribute("if_calculated", False)
+        self.masker = MASKER_DICT[self._base_algo](self.bound_model, self)
+
+        self.patch_optimizer_before(self._callback)
+
+    def validate_config(self, model, config_list):
+        """
+        Parameters
+        ----------
+        model : torch.nn.Module
+            Model to be pruned
+        config_list : list
+            List on pruning configs
+        """
+
+        if self._base_algo == 'level':
+            schema = CompressorSchema([{
+                'sparsity': And(float, lambda n: 0 < n < 1),
+                Optional('op_types'): [str],
+                Optional('op_names'): [str],
+            }], model, logger)
+        elif self._base_algo in ['l1', 'l2', 'fpgm']:
+            schema = CompressorSchema([{
+                'sparsity': And(float, lambda n: 0 < n < 1),
+                'op_types': ['Conv2d'],
+                Optional('op_names'): [str]
+            }], model, logger)
+
+        schema.validate(config_list)
+
+    def _supported_dependency_aware(self):
+        return False
+
+    def _projection(self, weight, sparsity, wrapper):
+        '''
+        Return the Euclidean projection of the weight matrix according to the pruning mode.
+
+        Parameters
+        ----------
+        weight : tensor
+            original matrix
+        sparsity : float
+            the ratio of parameters which need to be set to zero
+        wrapper: PrunerModuleWrapper
+            layer wrapper of this layer
+
+        Returns
+        -------
+        tensor
+            the projected matrix
+        '''
+        wrapper_copy = copy.deepcopy(wrapper)
+        wrapper_copy.module.weight.data = weight
+        return weight.data.mul(self.masker.calc_mask(sparsity, wrapper_copy)['weight_mask'])
+
+    def _callback(self):
+        # callback function to do additonal optimization, refer to the deriatives of Formula (7)
+        for i, wrapper in enumerate(self.get_modules_wrapper()):
+            wrapper.module.weight.data -= self._row * \
+                (wrapper.module.weight.data - self.Z[i] + self.U[i])
+
+    def compress(self):
+        """
+        Compress the model with ADMM.
+
+        Returns
+        -------
+        torch.nn.Module
+            model with specified modules compressed.
+        """
+        logger.info('Starting ADMM Compression...')
+
+        # initiaze Z, U
+        # Z_i^0 = W_i^0
+        # U_i^0 = 0
+        self.Z = []
+        self.U = []
+        for wrapper in self.get_modules_wrapper():
+            z = wrapper.module.weight.data
+            self.Z.append(z)
+            self.U.append(torch.zeros_like(z))
+
+        # Loss = cross_entropy +  l2 regulization + \Sum_{i=1}^N \row_i ||W_i - Z_i^k + U_i^k||^2
+        # optimization iteration
+        for k in range(self._num_iterations):
+            logger.info('ADMM iteration : %d', k)
+
+            # step 1: optimize W with AdamOptimizer
+            for epoch in range(self._training_epochs):
+                self._trainer(self.bound_model, optimizer=self.optimizer, criterion=self._criterion, epoch=epoch)
+
+            # step 2: update Z, U
+            # Z_i^{k+1} = projection(W_i^{k+1} + U_i^k)
+            # U_i^{k+1} = U^k + W_i^{k+1} - Z_i^{k+1}
+            for i, wrapper in enumerate(self.get_modules_wrapper()):
+                z = wrapper.module.weight.data + self.U[i]
+                self.Z[i] = self._projection(z, wrapper.config['sparsity'], wrapper)
+                self.U[i] = self.U[i] + wrapper.module.weight.data - self.Z[i]
+
+        # apply prune
+        self.update_mask()
+
+        logger.info('Compression finished.')
+
+        return self.bound_model
+
+
+class SlimPruner(IterativePruner):
+    """
+    Parameters
+    ----------
+    model : torch.nn.Module
+        Model to be pruned
+    config_list : list
+        Supported keys:
+            - sparsity : This is to specify the sparsity operations to be compressed to.
+            - op_types : Only BatchNorm2d is supported in Slim Pruner.
+    optimizer : torch.optim.Optimizer
+            Optimizer used to train model
+    trainer : function
+        Function used to sparsify BatchNorm2d scaling factors.
+        Users should write this function as a normal function to train the Pytorch model
+        and include `model, optimizer, criterion, epoch` as function arguments.
+    criterion : function
+        Function used to calculate the loss between the target and the output.
+    sparsity_training_epochs: int
+        The number of channel sparsity regularization training epochs before pruning.
+    scale : float
+        Penalty parameters for sparsification.
+    dependency_aware: bool
+        If prune the model in a dependency-aware way. If it is `True`, this pruner will
+        prune the model according to the l2-norm of weights and the channel-dependency or
+        group-dependency of the model. In this way, the pruner will force the conv layers
+        that have dependencies to prune the same channels, so the speedup module can better
+        harvest the speed benefit from the pruned model. Note that, if this flag is set True
+        , the dummy_input cannot be None, because the pruner needs a dummy input to trace the
+        dependency between the conv layers.
+    dummy_input : torch.Tensor
+        The dummy input to analyze the topology constraints. Note that, the dummy_input
+        should on the same device with the model.
+    """
+
+    def __init__(self, model, config_list, optimizer, trainer, criterion, sparsity_training_epochs=10, scale=0.0001,
+                 dependency_aware=False, dummy_input=None):
+        super().__init__(model, config_list, optimizer=optimizer, pruning_algorithm='slim', trainer=trainer, criterion=criterion,
+                         num_iterations=1, epochs_per_iteration=sparsity_training_epochs, dependency_aware=dependency_aware,
+                         dummy_input=dummy_input)
+        self.scale = scale
+        self.patch_optimizer_before(self._callback)
+
+    def validate_config(self, model, config_list):
+        schema = CompressorSchema([{
+            'sparsity': And(float, lambda n: 0 < n < 1),
+            'op_types': ['BatchNorm2d'],
+            Optional('op_names'): [str]
+        }], model, logger)
+
+        schema.validate(config_list)
+
+        if len(config_list) > 1:
+            logger.warning('Slim pruner only supports 1 configuration')
+
+    def _supported_dependency_aware(self):
+        return True
+
+    def _callback(self):
+        for _, wrapper in enumerate(self.get_modules_wrapper()):
+            wrapper.module.weight.grad.data.add_(self.scale * torch.sign(wrapper.module.weight.data))
+
+
+class TaylorFOWeightFilterPruner(IterativePruner):
+    """
+    Parameters
+    ----------
+    model : torch.nn.Module
+        Model to be pruned
+    config_list : list
+        Supported keys:
+            - sparsity : How much percentage of convolutional filters are to be pruned.
+            - op_types : Currently only Conv2d is supported in TaylorFOWeightFilterPruner.
+    optimizer: torch.optim.Optimizer
+            Optimizer used to train model
+    trainer : function
+        Function used to sparsify BatchNorm2d scaling factors.
+        Users should write this function as a normal function to train the Pytorch model
+        and include `model, optimizer, criterion, epoch` as function arguments.
+    criterion : function
+        Function used to calculate the loss between the target and the output.
+    sparsity_training_epochs: int
+        The number of epochs to collect the contributions.
+    dependency_aware: bool
+        If prune the model in a dependency-aware way. If it is `True`, this pruner will
+        prune the model according to the l2-norm of weights and the channel-dependency or
+        group-dependency of the model. In this way, the pruner will force the conv layers
+        that have dependencies to prune the same channels, so the speedup module can better
+        harvest the speed benefit from the pruned model. Note that, if this flag is set True
+        , the dummy_input cannot be None, because the pruner needs a dummy input to trace the
+        dependency between the conv layers.
+    dummy_input : torch.Tensor
+        The dummy input to analyze the topology constraints. Note that, the dummy_input
+        should on the same device with the model.
+
+    """
+
+    def __init__(self, model, config_list, optimizer, trainer, criterion, sparsity_training_epochs=1, dependency_aware=False,
+                 dummy_input=None):
+        super().__init__(model, config_list, optimizer=optimizer, pruning_algorithm='taylorfo', trainer=trainer,
+                         criterion=criterion, num_iterations=1, epochs_per_iteration=sparsity_training_epochs,
+                         dependency_aware=dependency_aware, dummy_input=dummy_input)
+
+    def _supported_dependency_aware(self):
+        return True
+
+
+class ActivationAPoZRankFilterPruner(IterativePruner):
+    """
+    Parameters
+    ----------
+    model : torch.nn.Module
+        Model to be pruned
+    config_list : list
+        Supported keys:
+            - sparsity : How much percentage of convolutional filters are to be pruned.
+            - op_types : Only Conv2d is supported in ActivationAPoZRankFilterPruner.
+    optimizer: torch.optim.Optimizer
+            Optimizer used to train model
+    trainer: function
+        Function used to train the model.
+        Users should write this function as a normal function to train the Pytorch model
+        and include `model, optimizer, criterion, epoch` as function arguments.
+    criterion : function
+        Function used to calculate the loss between the target and the output.
+    activation: str
+        The activation type.
+    sparsity_training_epochs: int
+        The number of epochs to statistic the activation.
+    dependency_aware: bool
+        If prune the model in a dependency-aware way. If it is `True`, this pruner will
+        prune the model according to the l2-norm of weights and the channel-dependency or
+        group-dependency of the model. In this way, the pruner will force the conv layers
+        that have dependencies to prune the same channels, so the speedup module can better
+        harvest the speed benefit from the pruned model. Note that, if this flag is set True
+        , the dummy_input cannot be None, because the pruner needs a dummy input to trace the
+        dependency between the conv layers.
+    dummy_input : torch.Tensor
+        The dummy input to analyze the topology constraints. Note that, the dummy_input
+        should on the same device with the model.
+
+    """
+
+    def __init__(self, model, config_list, optimizer, trainer, criterion, activation='relu',
+                 sparsity_training_epochs=1, dependency_aware=False, dummy_input=None):
+        super().__init__(model, config_list, pruning_algorithm='apoz', optimizer=optimizer, trainer=trainer,
+                         criterion=criterion, dependency_aware=dependency_aware, dummy_input=dummy_input,
+                         activation=activation, num_iterations=1, epochs_per_iteration=sparsity_training_epochs)
+
+    def _supported_dependency_aware(self):
+        return True
+
+
+class ActivationMeanRankFilterPruner(IterativePruner):
+    """
+    Parameters
+    ----------
+    model : torch.nn.Module
+        Model to be pruned
+    config_list : list
+        Supported keys:
+            - sparsity : How much percentage of convolutional filters are to be pruned.
+            - op_types : Only Conv2d is supported in ActivationMeanRankFilterPruner.
+    optimizer: torch.optim.Optimizer
+            Optimizer used to train model.
+    trainer: function
+            Function used to train the model.
+            Users should write this function as a normal function to train the Pytorch model
+            and include `model, optimizer, criterion, epoch` as function arguments.
+    criterion : function
+        Function used to calculate the loss between the target and the output.
+    activation: str
+        The activation type.
+    sparsity_training_epochs: int
+        The number of batches to statistic the activation.
+    dependency_aware: bool
+        If prune the model in a dependency-aware way. If it is `True`, this pruner will
+        prune the model according to the l2-norm of weights and the channel-dependency or
+        group-dependency of the model. In this way, the pruner will force the conv layers
+        that have dependencies to prune the same channels, so the speedup module can better
+        harvest the speed benefit from the pruned model. Note that, if this flag is set True
+        , the dummy_input cannot be None, because the pruner needs a dummy input to trace the
+        dependency between the conv layers.
+    dummy_input : torch.Tensor
+        The dummy input to analyze the topology constraints. Note that, the dummy_input
+        should on the same device with the model.
+    """
+
+    def __init__(self, model, config_list, optimizer, trainer, criterion, activation='relu',
+                 sparsity_training_epochs=1, dependency_aware=False, dummy_input=None):
+        super().__init__(model, config_list, pruning_algorithm='mean_activation', optimizer=optimizer, trainer=trainer,
+                         criterion=criterion, dependency_aware=dependency_aware, dummy_input=dummy_input,
+                         activation=activation, num_iterations=1, epochs_per_iteration=sparsity_training_epochs)
+
+    def _supported_dependency_aware(self):
+        return True
diff --git a/nni/algorithms/compression/pytorch/pruning/lottery_ticket.py b/nni/algorithms/compression/pytorch/pruning/lottery_ticket.py
index b0d041dd02..caa1c831e6 100644
--- a/nni/algorithms/compression/pytorch/pruning/lottery_ticket.py
+++ b/nni/algorithms/compression/pytorch/pruning/lottery_ticket.py
@@ -7,7 +7,7 @@
 from schema import And, Optional
 from nni.compression.pytorch.utils.config_validation import CompressorSchema
 from nni.compression.pytorch.compressor import Pruner
-from .finegrained_pruning import LevelPrunerMasker
+from .finegrained_pruning_masker import LevelPrunerMasker
 
 logger = logging.getLogger('torch pruner')
 
diff --git a/nni/algorithms/compression/pytorch/pruning/one_shot.py b/nni/algorithms/compression/pytorch/pruning/one_shot.py
deleted file mode 100644
index 75e2a7c307..0000000000
--- a/nni/algorithms/compression/pytorch/pruning/one_shot.py
+++ /dev/null
@@ -1,460 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT license.
-
-import logging
-from schema import And, Optional, SchemaError
-from nni.common.graph_utils import TorchModuleGraph
-from nni.compression.pytorch.utils.shape_dependency import ChannelDependency, GroupDependency
-from .constants import MASKER_DICT
-from nni.compression.pytorch.utils.config_validation import CompressorSchema
-from nni.compression.pytorch.compressor import Pruner
-
-
-__all__ = ['LevelPruner', 'SlimPruner', 'L1FilterPruner', 'L2FilterPruner', 'FPGMPruner',
-           'TaylorFOWeightFilterPruner', 'ActivationAPoZRankFilterPruner', 'ActivationMeanRankFilterPruner']
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
-
-
-class OneshotPruner(Pruner):
-    """
-    Prune model to an exact pruning level for one time.
-    """
-
-    def __init__(self, model, config_list, pruning_algorithm='level', optimizer=None, **algo_kwargs):
-        """
-        Parameters
-        ----------
-        model : torch.nn.Module
-            Model to be pruned
-        config_list : list
-            List on pruning configs
-        pruning_algorithm: str
-            algorithms being used to prune model
-        optimizer: torch.optim.Optimizer
-            Optimizer used to train model
-        algo_kwargs: dict
-            Additional parameters passed to pruning algorithm masker class
-        """
-
-        super().__init__(model, config_list, optimizer)
-        self.set_wrappers_attribute("if_calculated", False)
-        self.masker = MASKER_DICT[pruning_algorithm](
-            model, self, **algo_kwargs)
-
-    def validate_config(self, model, config_list):
-        """
-        Parameters
-        ----------
-        model : torch.nn.Module
-            Model to be pruned
-        config_list : list
-            List on pruning configs
-        """
-        schema = CompressorSchema([{
-            'sparsity': And(float, lambda n: 0 < n < 1),
-            Optional('op_types'): [str],
-            Optional('op_names'): [str]
-        }], model, logger)
-
-        schema.validate(config_list)
-
-    def calc_mask(self, wrapper, wrapper_idx=None):
-        """
-        Calculate the mask of given layer
-        Parameters
-        ----------
-        wrapper : Module
-            the module to instrument the compression operation
-        wrapper_idx: int
-            index of this wrapper in pruner's all wrappers
-        Returns
-        -------
-        dict
-            dictionary for storing masks, keys of the dict:
-            'weight_mask':  weight mask tensor
-            'bias_mask': bias mask tensor (optional)
-        """
-        if wrapper.if_calculated:
-            return None
-
-        sparsity = wrapper.config['sparsity']
-        if not wrapper.if_calculated:
-            masks = self.masker.calc_mask(
-                sparsity=sparsity, wrapper=wrapper, wrapper_idx=wrapper_idx)
-
-            # masker.calc_mask returns None means calc_mask is not calculated sucessfully, can try later
-            if masks is not None:
-                wrapper.if_calculated = True
-            return masks
-        else:
-            return None
-
-
-class LevelPruner(OneshotPruner):
-    """
-    Parameters
-    ----------
-    model : torch.nn.Module
-        Model to be pruned
-    config_list : list
-        Supported keys:
-            - sparsity : This is to specify the sparsity operations to be compressed to.
-            - op_types : Operation types to prune.
-    optimizer: torch.optim.Optimizer
-            Optimizer used to train model
-    """
-
-    def __init__(self, model, config_list, optimizer=None):
-        super().__init__(model, config_list, pruning_algorithm='level', optimizer=optimizer)
-
-
-class SlimPruner(OneshotPruner):
-    """
-    Parameters
-    ----------
-    model : torch.nn.Module
-        Model to be pruned
-    config_list : list
-        Supported keys:
-            - sparsity : This is to specify the sparsity operations to be compressed to.
-            - op_types : Only BatchNorm2d is supported in Slim Pruner.
-    optimizer: torch.optim.Optimizer
-            Optimizer used to train model
-    """
-
-    def __init__(self, model, config_list, optimizer=None):
-        super().__init__(model, config_list, pruning_algorithm='slim', optimizer=optimizer)
-
-    def validate_config(self, model, config_list):
-        schema = CompressorSchema([{
-            'sparsity': And(float, lambda n: 0 < n < 1),
-            'op_types': ['BatchNorm2d'],
-            Optional('op_names'): [str]
-        }], model, logger)
-
-        schema.validate(config_list)
-
-        if len(config_list) > 1:
-            logger.warning('Slim pruner only supports 1 configuration')
-
-
-class _StructuredFilterPruner(OneshotPruner):
-    """
-    _StructuredFilterPruner has two ways to calculate the masks
-    for conv layers. In the normal way, the _StructuredFilterPruner
-    will calculate the mask of each layer separately. For example, each
-    conv layer determine which filters should be pruned according to its L1
-    norm. In constrast, in the dependency-aware way, the layers that in a
-    dependency group will be pruned jointly and these layers will be forced
-    to prune the same channels.
-    """
-
-    def __init__(self, model, config_list, pruning_algorithm, optimizer=None, dependency_aware=False, dummy_input=None, **algo_kwargs):
-        super().__init__(model, config_list, pruning_algorithm=pruning_algorithm,
-                         optimizer=optimizer, **algo_kwargs)
-        self.dependency_aware = dependency_aware
-        # set the dependency-aware switch for the masker
-        self.masker.dependency_aware = dependency_aware
-        self.dummy_input = dummy_input
-        if self.dependency_aware:
-            errmsg = "When dependency_aware is set, the dummy_input should not be None"
-            assert self.dummy_input is not None, errmsg
-            # Get the TorchModuleGraph of the target model
-            # to trace the model, we need to unwrap the wrappers
-            self._unwrap_model()
-            self.graph = TorchModuleGraph(model, dummy_input)
-            self._wrap_model()
-            self.channel_depen = ChannelDependency(
-                traced_model=self.graph.trace)
-            self.group_depen = GroupDependency(traced_model=self.graph.trace)
-            self.channel_depen = self.channel_depen.dependency_sets
-            self.channel_depen = {
-                name: sets for sets in self.channel_depen for name in sets}
-            self.group_depen = self.group_depen.dependency_sets
-
-    def update_mask(self):
-        if not self.dependency_aware:
-            # if we use the normal way to update the mask,
-            # then call the update_mask of the father class
-            super(_StructuredFilterPruner, self).update_mask()
-        else:
-            # if we update the mask in a dependency-aware way
-            # then we call _dependency_update_mask
-            self._dependency_update_mask()
-
-    def validate_config(self, model, config_list):
-        schema = CompressorSchema([{
-            Optional('sparsity'): And(float, lambda n: 0 < n < 1),
-            Optional('op_types'): ['Conv2d'],
-            Optional('op_names'): [str],
-            Optional('exclude'): bool
-        }], model, logger)
-
-        schema.validate(config_list)
-        for config in config_list:
-            if 'exclude' not in config and 'sparsity' not in config:
-                raise SchemaError('Either sparisty or exclude must be specified!')
-
-    def _dependency_calc_mask(self, wrappers, channel_dsets, wrappers_idx=None):
-        """
-        calculate the masks for the conv layers in the same
-        channel dependecy set. All the layers passed in have
-        the same number of channels.
-
-        Parameters
-        ----------
-        wrappers: list
-            The list of the wrappers that in the same channel dependency
-            set.
-        wrappers_idx: list
-            The list of the indexes of wrapppers.
-        Returns
-        -------
-        masks: dict
-            A dict object that contains the masks of the layers in this
-            dependency group, the key is the name of the convolutional layers.
-        """
-        # The number of the groups for each conv layers
-        # Note that, this number may be different from its
-        # original number of groups of filters.
-        groups = [self.group_depen[_w.name] for _w in wrappers]
-        sparsities = [_w.config['sparsity'] for _w in wrappers]
-        masks = self.masker.calc_mask(
-            sparsities, wrappers, wrappers_idx, channel_dsets=channel_dsets, groups=groups)
-        if masks is not None:
-            # if masks is None, then the mask calculation fails.
-            # for example, in activation related maskers, we should
-            # pass enough batches of data to the model, so that the
-            # masks can be calculated successfully.
-            for _w in wrappers:
-                _w.if_calculated = True
-        return masks
-
-    def _dependency_update_mask(self):
-        """
-        In the original update_mask, the wraper of each layer will update its
-        own mask according to the sparsity specified in the config_list. However, in
-        the _dependency_update_mask, we may prune several layers at the same
-        time according the sparsities and the channel/group dependencies.
-        """
-        name2wrapper = {x.name: x for x in self.get_modules_wrapper()}
-        wrapper2index = {x: i for i, x in enumerate(self.get_modules_wrapper())}
-        for wrapper in self.get_modules_wrapper():
-            if wrapper.if_calculated:
-                continue
-            # find all the conv layers that have channel dependecy with this layer
-            # and prune all these layers at the same time.
-            _names = [x for x in self.channel_depen[wrapper.name]]
-            logger.info('Pruning the dependent layers: %s', ','.join(_names))
-            _wrappers = [name2wrapper[name]
-                         for name in _names if name in name2wrapper]
-            _wrapper_idxes = [wrapper2index[_w] for _w in _wrappers]
-
-            masks = self._dependency_calc_mask(
-                _wrappers, _names, wrappers_idx=_wrapper_idxes)
-            if masks is not None:
-                for layer in masks:
-                    for mask_type in masks[layer]:
-                        assert hasattr(
-                            name2wrapper[layer], mask_type), "there is no attribute '%s' in wrapper on %s" % (mask_type, layer)
-                        setattr(name2wrapper[layer], mask_type, masks[layer][mask_type])
-
-
-class L1FilterPruner(_StructuredFilterPruner):
-    """
-    Parameters
-    ----------
-    model : torch.nn.Module
-        Model to be pruned
-    config_list : list
-        Supported keys:
-            - sparsity : This is to specify the sparsity operations to be compressed to.
-            - op_types : Only Conv2d is supported in L1FilterPruner.
-    optimizer: torch.optim.Optimizer
-            Optimizer used to train model
-    dependency_aware: bool
-        If prune the model in a dependency-aware way. If it is `True`, this pruner will
-        prune the model according to the l2-norm of weights and the channel-dependency or
-        group-dependency of the model. In this way, the pruner will force the conv layers
-        that have dependencies to prune the same channels, so the speedup module can better
-        harvest the speed benefit from the pruned model. Note that, if this flag is set True
-        , the dummy_input cannot be None, because the pruner needs a dummy input to trace the
-        dependency between the conv layers.
-    dummy_input : torch.Tensor
-        The dummy input to analyze the topology constraints. Note that, the dummy_input
-        should on the same device with the model.
-    """
-
-    def __init__(self, model, config_list, optimizer=None, dependency_aware=False, dummy_input=None):
-        super().__init__(model, config_list, pruning_algorithm='l1', optimizer=optimizer,
-                         dependency_aware=dependency_aware, dummy_input=dummy_input)
-
-
-class L2FilterPruner(_StructuredFilterPruner):
-    """
-    Parameters
-    ----------
-    model : torch.nn.Module
-        Model to be pruned
-    config_list : list
-        Supported keys:
-            - sparsity : This is to specify the sparsity operations to be compressed to.
-            - op_types : Only Conv2d is supported in L2FilterPruner.
-    optimizer: torch.optim.Optimizer
-            Optimizer used to train model
-    dependency_aware: bool
-        If prune the model in a dependency-aware way. If it is `True`, this pruner will
-        prune the model according to the l2-norm of weights and the channel-dependency or
-        group-dependency of the model. In this way, the pruner will force the conv layers
-        that have dependencies to prune the same channels, so the speedup module can better
-        harvest the speed benefit from the pruned model. Note that, if this flag is set True
-        , the dummy_input cannot be None, because the pruner needs a dummy input to trace the
-        dependency between the conv layers.
-    dummy_input : torch.Tensor
-        The dummy input to analyze the topology constraints. Note that, the dummy_input
-        should on the same device with the model.
-    """
-
-    def __init__(self, model, config_list, optimizer=None, dependency_aware=False, dummy_input=None):
-        super().__init__(model, config_list, pruning_algorithm='l2', optimizer=optimizer,
-                         dependency_aware=dependency_aware, dummy_input=dummy_input)
-
-
-class FPGMPruner(_StructuredFilterPruner):
-    """
-    Parameters
-    ----------
-    model : torch.nn.Module
-        Model to be pruned
-    config_list : list
-        Supported keys:
-            - sparsity : This is to specify the sparsity operations to be compressed to.
-            - op_types : Only Conv2d is supported in FPGM Pruner.
-    optimizer: torch.optim.Optimizer
-            Optimizer used to train model
-    dependency_aware: bool
-        If prune the model in a dependency-aware way. If it is `True`, this pruner will
-        prune the model according to the l2-norm of weights and the channel-dependency or
-        group-dependency of the model. In this way, the pruner will force the conv layers
-        that have dependencies to prune the same channels, so the speedup module can better
-        harvest the speed benefit from the pruned model. Note that, if this flag is set True
-        , the dummy_input cannot be None, because the pruner needs a dummy input to trace the
-        dependency between the conv layers.
-    dummy_input : torch.Tensor
-        The dummy input to analyze the topology constraints. Note that, the dummy_input
-        should on the same device with the model.
-    """
-
-    def __init__(self, model, config_list, optimizer=None, dependency_aware=False, dummy_input=None):
-        super().__init__(model, config_list, pruning_algorithm='fpgm',
-                         dependency_aware=dependency_aware, dummy_input=dummy_input, optimizer=optimizer)
-
-
-class TaylorFOWeightFilterPruner(_StructuredFilterPruner):
-    """
-    Parameters
-    ----------
-    model : torch.nn.Module
-        Model to be pruned
-    config_list : list
-        Supported keys:
-            - sparsity : How much percentage of convolutional filters are to be pruned.
-            - op_types : Currently only Conv2d is supported in TaylorFOWeightFilterPruner.
-    optimizer: torch.optim.Optimizer
-            Optimizer used to train model
-    statistics_batch_num: int
-        The number of batches to statistic the activation.
-    dependency_aware: bool
-        If prune the model in a dependency-aware way. If it is `True`, this pruner will
-        prune the model according to the l2-norm of weights and the channel-dependency or
-        group-dependency of the model. In this way, the pruner will force the conv layers
-        that have dependencies to prune the same channels, so the speedup module can better
-        harvest the speed benefit from the pruned model. Note that, if this flag is set True
-        , the dummy_input cannot be None, because the pruner needs a dummy input to trace the
-        dependency between the conv layers.
-    dummy_input : torch.Tensor
-        The dummy input to analyze the topology constraints. Note that, the dummy_input
-        should on the same device with the model.
-
-    """
-
-    def __init__(self, model, config_list, optimizer=None, statistics_batch_num=1,
-                 dependency_aware=False, dummy_input=None):
-        super().__init__(model, config_list, pruning_algorithm='taylorfo',
-                         dependency_aware=dependency_aware, dummy_input=dummy_input,
-                         optimizer=optimizer, statistics_batch_num=statistics_batch_num)
-
-
-class ActivationAPoZRankFilterPruner(_StructuredFilterPruner):
-    """
-    Parameters
-    ----------
-    model : torch.nn.Module
-        Model to be pruned
-    config_list : list
-        Supported keys:
-            - sparsity : How much percentage of convolutional filters are to be pruned.
-            - op_types : Only Conv2d is supported in ActivationAPoZRankFilterPruner.
-    optimizer: torch.optim.Optimizer
-            Optimizer used to train model
-    activation: str
-        The activation type.
-    statistics_batch_num: int
-        The number of batches to statistic the activation.
-    dependency_aware: bool
-        If prune the model in a dependency-aware way. If it is `True`, this pruner will
-        prune the model according to the l2-norm of weights and the channel-dependency or
-        group-dependency of the model. In this way, the pruner will force the conv layers
-        that have dependencies to prune the same channels, so the speedup module can better
-        harvest the speed benefit from the pruned model. Note that, if this flag is set True
-        , the dummy_input cannot be None, because the pruner needs a dummy input to trace the
-        dependency between the conv layers.
-    dummy_input : torch.Tensor
-        The dummy input to analyze the topology constraints. Note that, the dummy_input
-        should on the same device with the model.
-
-    """
-
-    def __init__(self, model, config_list, optimizer=None, activation='relu',
-                 statistics_batch_num=1, dependency_aware=False, dummy_input=None):
-        super().__init__(model, config_list, pruning_algorithm='apoz', optimizer=optimizer,
-                         dependency_aware=dependency_aware, dummy_input=dummy_input,
-                         activation=activation, statistics_batch_num=statistics_batch_num)
-
-
-class ActivationMeanRankFilterPruner(_StructuredFilterPruner):
-    """
-    Parameters
-    ----------
-    model : torch.nn.Module
-        Model to be pruned
-    config_list : list
-        Supported keys:
-            - sparsity : How much percentage of convolutional filters are to be pruned.
-            - op_types : Only Conv2d is supported in ActivationMeanRankFilterPruner.
-    optimizer: torch.optim.Optimizer
-            Optimizer used to train model.
-    activation: str
-        The activation type.
-    statistics_batch_num: int
-        The number of batches to statistic the activation.
-    dependency_aware: bool
-        If prune the model in a dependency-aware way. If it is `True`, this pruner will
-        prune the model according to the l2-norm of weights and the channel-dependency or
-        group-dependency of the model. In this way, the pruner will force the conv layers
-        that have dependencies to prune the same channels, so the speedup module can better
-        harvest the speed benefit from the pruned model. Note that, if this flag is set True
-        , the dummy_input cannot be None, because the pruner needs a dummy input to trace the
-        dependency between the conv layers.
-    dummy_input : torch.Tensor
-        The dummy input to analyze the topology constraints. Note that, the dummy_input
-        should on the same device with the model.
-    """
-
-    def __init__(self, model, config_list, optimizer=None, activation='relu',
-                 statistics_batch_num=1, dependency_aware=False, dummy_input=None):
-        super().__init__(model, config_list, pruning_algorithm='mean_activation', optimizer=optimizer,
-                         dependency_aware=dependency_aware, dummy_input=dummy_input,
-                         activation=activation, statistics_batch_num=statistics_batch_num)
diff --git a/nni/algorithms/compression/pytorch/pruning/one_shot_pruner.py b/nni/algorithms/compression/pytorch/pruning/one_shot_pruner.py
new file mode 100644
index 0000000000..c17a5ddafa
--- /dev/null
+++ b/nni/algorithms/compression/pytorch/pruning/one_shot_pruner.py
@@ -0,0 +1,169 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import logging
+from schema import And, Optional
+
+from nni.compression.pytorch.utils.config_validation import CompressorSchema
+from .dependency_aware_pruner import DependencyAwarePruner
+
+__all__ = ['LevelPruner', 'L1FilterPruner', 'L2FilterPruner', 'FPGMPruner']
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+class OneshotPruner(DependencyAwarePruner):
+    """
+    Prune model to an exact pruning level for one time.
+    """
+
+    def __init__(self, model, config_list, pruning_algorithm='level', dependency_aware=False, dummy_input=None,
+                 **algo_kwargs):
+        """
+        Parameters
+        ----------
+        model : torch.nn.Module
+            Model to be pruned
+        config_list : list
+            List on pruning configs
+        pruning_algorithm: str
+            algorithms being used to prune model
+        dependency_aware: bool
+            If prune the model in a dependency-aware way.
+        dummy_input : torch.Tensor
+            The dummy input to analyze the topology constraints. Note that,
+            the dummy_input should on the same device with the model.
+        algo_kwargs: dict
+            Additional parameters passed to pruning algorithm masker class
+        """
+        super().__init__(model, config_list, None, pruning_algorithm, dependency_aware, dummy_input, **algo_kwargs)
+
+    def validate_config(self, model, config_list):
+        """
+        Parameters
+        ----------
+        model : torch.nn.Module
+            Model to be pruned
+        config_list : list
+            List on pruning configs
+        """
+        schema = CompressorSchema([{
+            'sparsity': And(float, lambda n: 0 < n < 1),
+            Optional('op_types'): [str],
+            Optional('op_names'): [str]
+        }], model, logger)
+
+        schema.validate(config_list)
+
+
+class LevelPruner(OneshotPruner):
+    """
+    Parameters
+    ----------
+    model : torch.nn.Module
+        Model to be pruned
+    config_list : list
+        Supported keys:
+            - sparsity : This is to specify the sparsity operations to be compressed to.
+            - op_types : Operation types to prune.
+    """
+
+    def __init__(self, model, config_list):
+        super().__init__(model, config_list, pruning_algorithm='level')
+
+    def _supported_dependency_aware(self):
+        return False
+
+
+class L1FilterPruner(OneshotPruner):
+    """
+    Parameters
+    ----------
+    model : torch.nn.Module
+        Model to be pruned
+    config_list : list
+        Supported keys:
+            - sparsity : This is to specify the sparsity operations to be compressed to.
+            - op_types : Only Conv2d is supported in L1FilterPruner.
+    dependency_aware: bool
+        If prune the model in a dependency-aware way. If it is `True`, this pruner will
+        prune the model according to the l2-norm of weights and the channel-dependency or
+        group-dependency of the model. In this way, the pruner will force the conv layers
+        that have dependencies to prune the same channels, so the speedup module can better
+        harvest the speed benefit from the pruned model. Note that, if this flag is set True
+        , the dummy_input cannot be None, because the pruner needs a dummy input to trace the
+        dependency between the conv layers.
+    dummy_input : torch.Tensor
+        The dummy input to analyze the topology constraints. Note that, the dummy_input
+        should on the same device with the model.
+    """
+
+    def __init__(self, model, config_list, dependency_aware=False, dummy_input=None):
+        super().__init__(model, config_list, pruning_algorithm='l1', dependency_aware=dependency_aware,
+                         dummy_input=dummy_input)
+
+    def _supported_dependency_aware(self):
+        return True
+
+
+class L2FilterPruner(OneshotPruner):
+    """
+    Parameters
+    ----------
+    model : torch.nn.Module
+        Model to be pruned
+    config_list : list
+        Supported keys:
+            - sparsity : This is to specify the sparsity operations to be compressed to.
+            - op_types : Only Conv2d is supported in L2FilterPruner.
+    dependency_aware: bool
+        If prune the model in a dependency-aware way. If it is `True`, this pruner will
+        prune the model according to the l2-norm of weights and the channel-dependency or
+        group-dependency of the model. In this way, the pruner will force the conv layers
+        that have dependencies to prune the same channels, so the speedup module can better
+        harvest the speed benefit from the pruned model. Note that, if this flag is set True
+        , the dummy_input cannot be None, because the pruner needs a dummy input to trace the
+        dependency between the conv layers.
+    dummy_input : torch.Tensor
+        The dummy input to analyze the topology constraints. Note that, the dummy_input
+        should on the same device with the model.
+    """
+
+    def __init__(self, model, config_list, dependency_aware=False, dummy_input=None):
+        super().__init__(model, config_list, pruning_algorithm='l2', dependency_aware=dependency_aware,
+                         dummy_input=dummy_input)
+
+    def _supported_dependency_aware(self):
+        return True
+
+
+class FPGMPruner(OneshotPruner):
+    """
+    Parameters
+    ----------
+    model : torch.nn.Module
+        Model to be pruned
+    config_list : list
+        Supported keys:
+            - sparsity : This is to specify the sparsity operations to be compressed to.
+            - op_types : Only Conv2d is supported in FPGM Pruner.
+    dependency_aware: bool
+        If prune the model in a dependency-aware way. If it is `True`, this pruner will
+        prune the model according to the l2-norm of weights and the channel-dependency or
+        group-dependency of the model. In this way, the pruner will force the conv layers
+        that have dependencies to prune the same channels, so the speedup module can better
+        harvest the speed benefit from the pruned model. Note that, if this flag is set True
+        , the dummy_input cannot be None, because the pruner needs a dummy input to trace the
+        dependency between the conv layers.
+    dummy_input : torch.Tensor
+        The dummy input to analyze the topology constraints. Note that, the dummy_input
+        should on the same device with the model.
+    """
+
+    def __init__(self, model, config_list, dependency_aware=False, dummy_input=None):
+        super().__init__(model, config_list, pruning_algorithm='fpgm', dependency_aware=dependency_aware,
+                         dummy_input=dummy_input)
+
+    def _supported_dependency_aware(self):
+        return True
diff --git a/nni/algorithms/compression/pytorch/pruning/structured_pruning.py b/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py
similarity index 98%
rename from nni/algorithms/compression/pytorch/pruning/structured_pruning.py
rename to nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py
index 277bed4757..671811e138 100644
--- a/nni/algorithms/compression/pytorch/pruning/structured_pruning.py
+++ b/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py
@@ -474,8 +474,8 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker):
     def __init__(self, model, pruner, statistics_batch_num=1):
         super().__init__(model, pruner)
         self.pruner.statistics_batch_num = statistics_batch_num
-        self.pruner.set_wrappers_attribute("contribution", None)
         self.pruner.iterations = 0
+        self.pruner.set_wrappers_attribute("contribution", None)
         self.pruner.patch_optimizer(self.calc_contributions)
 
     def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None):
@@ -499,6 +499,7 @@ def calc_contributions(self):
         """
         if self.pruner.iterations >= self.pruner.statistics_batch_num:
             return
+
         for wrapper in self.pruner.get_modules_wrapper():
             filters = wrapper.module.weight.size(0)
             contribution = (
@@ -677,16 +678,24 @@ class SlimPrunerMasker(WeightMasker):
 
     def __init__(self, model, pruner, **kwargs):
         super().__init__(model, pruner)
+        self.global_threshold = None
+
+    def _get_global_threshold(self):
         weight_list = []
-        for (layer, _) in pruner.get_modules_to_compress():
+        for (layer, _) in self.pruner.get_modules_to_compress():
             weight_list.append(layer.module.weight.data.abs().clone())
         all_bn_weights = torch.cat(weight_list)
-        k = int(all_bn_weights.shape[0] * pruner.config_list[0]['sparsity'])
+        k = int(all_bn_weights.shape[0] * self.pruner.config_list[0]['sparsity'])
         self.global_threshold = torch.topk(
             all_bn_weights.view(-1), k, largest=False)[0].max()
+        print(f'set global threshold to {self.global_threshold}')
 
     def calc_mask(self, sparsity, wrapper, wrapper_idx=None):
         assert wrapper.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
+
+        if self.global_threshold is None:
+            self._get_global_threshold()
+
         weight = wrapper.module.weight.data.clone()
         if wrapper.weight_mask is not None:
             # apply base mask for iterative pruning
@@ -706,7 +715,6 @@ def calc_mask(self, sparsity, wrapper, wrapper_idx=None):
             ), 'bias_mask': mask_bias.detach()}
         return mask
 
-
 def least_square_sklearn(X, Y):
     from sklearn.linear_model import LinearRegression
     reg = LinearRegression(fit_intercept=False)
diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py
index 62703d449b..dbd5e5b3c3 100644
--- a/nni/algorithms/compression/pytorch/quantization/quantizers.py
+++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py
@@ -148,6 +148,7 @@ def __init__(self, model, config_list, optimizer=None):
         super().__init__(model, config_list, optimizer)
         self.quant_grad = QATGrad.apply
         modules_to_compress = self.get_modules_to_compress()
+        device = next(model.parameters()).device
         self.bound_model.register_buffer("steps", torch.Tensor([1]))
         for layer, config in modules_to_compress:
             layer.module.register_buffer("zero_point", torch.Tensor([0.0]))
@@ -161,7 +162,7 @@ def __init__(self, model, config_list, optimizer=None):
                 layer.module.register_buffer('activation_bit', torch.zeros(1))
                 layer.module.register_buffer('tracked_min_activation', torch.zeros(1))
                 layer.module.register_buffer('tracked_max_activation', torch.zeros(1))
-                
+        self.bound_model.to(device)
 
     def _del_simulated_attr(self, module):
         """
@@ -359,7 +360,7 @@ def step_with_optimizer(self):
         """
         override `compressor` `step` method, quantization only happens after certain number of steps
         """
-        self.bound_model.steps +=1
+        self.bound_model.steps += 1
 
 
 class DoReFaQuantizer(Quantizer):
@@ -370,10 +371,12 @@ class DoReFaQuantizer(Quantizer):
 
     def __init__(self, model, config_list, optimizer=None):
         super().__init__(model, config_list, optimizer)
+        device = next(model.parameters()).device
         modules_to_compress = self.get_modules_to_compress()
         for layer, config in modules_to_compress:
             if "weight" in config.get("quant_types", []):
                 layer.module.register_buffer('weight_bit', torch.zeros(1))
+        self.bound_model.to(device)
 
     def _del_simulated_attr(self, module):
         """
@@ -474,11 +477,13 @@ class BNNQuantizer(Quantizer):
 
     def __init__(self, model, config_list, optimizer=None):
         super().__init__(model, config_list, optimizer)
+        device = next(model.parameters()).device
         self.quant_grad = ClipGrad.apply
         modules_to_compress = self.get_modules_to_compress()
         for layer, config in modules_to_compress:
             if "weight" in config.get("quant_types", []):
                 layer.module.register_buffer('weight_bit', torch.zeros(1))
+        self.bound_model.to(device)
 
     def _del_simulated_attr(self, module):
         """
@@ -589,6 +594,7 @@ def __init__(self, model, config_list, optimizer=None):
                     types of nn.module you want to apply quantization, eg. 'Conv2d'
         """
         super().__init__(model, config_list, optimizer)
+        device = next(model.parameters()).device
         self.quant_grad = QuantForward()
         modules_to_compress = self.get_modules_to_compress()
         self.bound_model.register_buffer("steps", torch.Tensor([1]))
@@ -631,6 +637,8 @@ def __init__(self, model, config_list, optimizer=None):
 
                 self.optimizer.add_param_group({"params": layer.module.input_scale})
 
+        self.bound_model.to(device)
+
     @staticmethod
     def grad_scale(x, scale):
         """
diff --git a/nni/algorithms/compression/tensorflow/pruning/__init__.py b/nni/algorithms/compression/tensorflow/pruning/__init__.py
index f8ac8ea9b9..c535fd7512 100644
--- a/nni/algorithms/compression/tensorflow/pruning/__init__.py
+++ b/nni/algorithms/compression/tensorflow/pruning/__init__.py
@@ -1 +1 @@
-from .one_shot import *
+from .one_shot_pruner import *
diff --git a/nni/algorithms/compression/tensorflow/pruning/one_shot.py b/nni/algorithms/compression/tensorflow/pruning/one_shot_pruner.py
similarity index 100%
rename from nni/algorithms/compression/tensorflow/pruning/one_shot.py
rename to nni/algorithms/compression/tensorflow/pruning/one_shot_pruner.py
diff --git a/nni/compression/pytorch/compressor.py b/nni/compression/pytorch/compressor.py
index 08543caf1a..01b8bb24e4 100644
--- a/nni/compression/pytorch/compressor.py
+++ b/nni/compression/pytorch/compressor.py
@@ -8,7 +8,6 @@
 
 _logger = logging.getLogger(__name__)
 
-
 class LayerInfo:
     def __init__(self, name, module):
         self.module = module
@@ -235,7 +234,6 @@ def _wrap_modules(self, layer, config):
         """
         raise NotImplementedError()
 
-
     def add_activation_collector(self, collector):
         self._fwd_hook_id += 1
         self._fwd_hook_handles[self._fwd_hook_id] = []
@@ -264,6 +262,18 @@ def new_step(_, *args, **kwargs):
         if self.optimizer is not None:
             self.optimizer.step = types.MethodType(patch_step(self.optimizer.step), self.optimizer)
 
+    def patch_optimizer_before(self, *tasks):
+        def patch_step(old_step):
+            def new_step(_, *args, **kwargs):
+                for task in tasks:
+                    task()
+                # call origin optimizer step method
+                output = old_step(*args, **kwargs)
+                return output
+            return new_step
+        if self.optimizer is not None:
+            self.optimizer.step = types.MethodType(patch_step(self.optimizer.step), self.optimizer)
+
 class PrunerModuleWrapper(torch.nn.Module):
     def __init__(self, module, module_name, module_type, config, pruner):
         """
@@ -319,8 +329,6 @@ class Pruner(Compressor):
 
     def __init__(self, model, config_list, optimizer=None):
         super().__init__(model, config_list, optimizer)
-        if optimizer is not None:
-            self.patch_optimizer(self.update_mask)
 
     def compress(self):
         self.update_mask()
@@ -386,7 +394,7 @@ def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=N
         """
         assert model_path is not None, 'model_path must be specified'
         mask_dict = {}
-        self._unwrap_model() # used for generating correct state_dict name without wrapper state
+        self._unwrap_model()  # used for generating correct state_dict name without wrapper state
 
         for wrapper in self.get_modules_wrapper():
             weight_mask = wrapper.weight_mask
@@ -433,6 +441,27 @@ def load_model_state_dict(self, model_state):
         else:
             self.bound_model.load_state_dict(model_state)
 
+    def get_pruned_weights(self, dim=0):
+        """
+        Log the simulated prune sparsity.
+
+        Parameters
+        ----------
+        dim : int
+            the pruned dim.
+        """
+        for _, wrapper in enumerate(self.get_modules_wrapper()):
+            weight_mask = wrapper.weight_mask
+            mask_size = weight_mask.size()
+            if len(mask_size) == 1:
+                index = torch.nonzero(weight_mask.abs() != 0).tolist()
+            else:
+                sum_idx = list(range(len(mask_size)))
+                sum_idx.remove(dim)
+                index = torch.nonzero(weight_mask.abs().sum(sum_idx) != 0).tolist()
+            _logger.info(f'simulated prune {wrapper.name} remain/total: {len(index)}/{weight_mask.size(dim)}')
+
+
 class QuantizerModuleWrapper(torch.nn.Module):
     def __init__(self, module, module_name, module_type, config, quantizer):
         """
@@ -549,7 +578,6 @@ def quantize_input(self, *inputs, wrapper, **kwargs):
         """
         raise NotImplementedError('Quantizer must overload quantize_input()')
 
-
     def _wrap_modules(self, layer, config):
         """
         Create a wrapper forward function to replace the original one.
@@ -571,8 +599,8 @@ def _wrap_modules(self, layer, config):
 
         return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self)
 
-    def export_model_save(self, model, model_path, calibration_config=None, calibration_path=None, onnx_path=None, \
-        input_shape=None, device=None):
+    def export_model_save(self, model, model_path, calibration_config=None, calibration_path=None, onnx_path=None,
+                          input_shape=None, device=None):
         """
         This method helps save pytorch model, calibration config, onnx model in quantizer.
 
@@ -671,6 +699,7 @@ def _quantize(cls, x, scale, zero_point):
             quantized x without clamped
         """
         return ((x / scale) + zero_point).round()
+
     @classmethod
     def get_bits_length(cls, config, quant_type):
         """
@@ -703,8 +732,8 @@ def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qma
         grad_output : Tensor
             gradient of the output of quantization operation
         scale : Tensor
-            the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
-            you can define different behavior for different types.
+            the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`,
+            `QuantType.QUANT_OUTPUT`, you can define different behavior for different types.
         zero_point : Tensor
             zero_point for quantizing tensor
         qmin : Tensor
diff --git a/nni/compression/pytorch/utils/mask_conflict.py b/nni/compression/pytorch/utils/mask_conflict.py
index 8e37893ba4..e89372d60e 100644
--- a/nni/compression/pytorch/utils/mask_conflict.py
+++ b/nni/compression/pytorch/utils/mask_conflict.py
@@ -31,7 +31,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
         # if the input is the path of the mask_file
         assert os.path.exists(masks)
         masks = torch.load(masks)
-    assert len(masks) > 0,  'Mask tensor cannot be empty'
+    assert len(masks) > 0, 'Mask tensor cannot be empty'
     # if the user uses the model and dummy_input to trace the model, we
     # should get the traced model handly, so that, we only trace the
     # model once, GroupMaskConflict and ChannelMaskConflict will reuse
@@ -181,10 +181,8 @@ def fix_mask(self):
             w_mask = self.masks[layername]['weight']
             shape = w_mask.size()
             count = np.prod(shape[1:])
-            all_ones = (w_mask.flatten(1).sum(-1) ==
-                        count).nonzero().squeeze(1).tolist()
-            all_zeros = (w_mask.flatten(1).sum(-1) ==
-                         0).nonzero().squeeze(1).tolist()
+            all_ones = (w_mask.flatten(1).sum(-1) == count).nonzero().squeeze(1).tolist()
+            all_zeros = (w_mask.flatten(1).sum(-1) == 0).nonzero().squeeze(1).tolist()
             if len(all_ones) + len(all_zeros) < w_mask.size(0):
                 # In fine-grained pruning, skip this layer
                 _logger.info('Layers %s using fine-grained pruning', layername)
@@ -198,7 +196,7 @@ def fix_mask(self):
             group_masked = []
             for i in range(group):
                 _start = step * i
-                _end = step * (i+1)
+                _end = step * (i + 1)
                 _tmp_list = list(
                     filter(lambda x: _start <= x and x < _end, all_zeros))
                 group_masked.append(_tmp_list)
@@ -286,7 +284,7 @@ def fix_mask(self):
                             0, 2, 3) if self.conv_prune_dim == 0 else (1, 2, 3)
                         channel_mask = (mask.abs().sum(tmp_sum_idx) != 0).int()
                         channel_masks.append(channel_mask)
-                        if (channel_mask.sum() * (mask.numel() / mask.shape[1-self.conv_prune_dim])).item() != (mask > 0).sum().item():
+                        if (channel_mask.sum() * (mask.numel() / mask.shape[1 - self.conv_prune_dim])).item() != (mask > 0).sum().item():
                             fine_grained = True
                     else:
                         raise RuntimeError(
diff --git a/nni/retiarii/__init__.py b/nni/retiarii/__init__.py
index 762af7c834..f441367460 100644
--- a/nni/retiarii/__init__.py
+++ b/nni/retiarii/__init__.py
@@ -5,4 +5,4 @@
 from .graph import *
 from .execution import *
 from .mutator import *
-from .serializer import basic_unit, json_dump, json_dumps, json_load, json_loads, serialize, serialize_cls
+from .serializer import basic_unit, json_dump, json_dumps, json_load, json_loads, serialize, serialize_cls, model_wrapper
diff --git a/nni/retiarii/converter/graph_gen.py b/nni/retiarii/converter/graph_gen.py
index 373cd9b69a..f8b06b887a 100644
--- a/nni/retiarii/converter/graph_gen.py
+++ b/nni/retiarii/converter/graph_gen.py
@@ -642,6 +642,16 @@ def convert_module(self, script_module, module, module_name, ir_model):
 
         ir_graph._register()
 
+        # add mutation signal for special modules
+        if original_type_name == OpTypeName.Repeat:
+            attrs = {
+                'mutation': 'repeat',
+                'label': module.label,
+                'min_depth': module.min_depth,
+                'max_depth': module.max_depth
+            }
+            return ir_graph, attrs
+
         return ir_graph, {}
 
 
diff --git a/nni/retiarii/converter/op_types.py b/nni/retiarii/converter/op_types.py
index 0d59d9ea08..1a4ba5a42d 100644
--- a/nni/retiarii/converter/op_types.py
+++ b/nni/retiarii/converter/op_types.py
@@ -17,3 +17,5 @@ class OpTypeName(str, Enum):
     ValueChoice = 'ValueChoice'
     Placeholder = 'Placeholder'
     MergedSlice = 'MergedSlice'
+    Repeat = 'Repeat'
+    Cell = 'Cell'
diff --git a/nni/retiarii/execution/api.py b/nni/retiarii/execution/api.py
index c53ee56fd0..8027e7e363 100644
--- a/nni/retiarii/execution/api.py
+++ b/nni/retiarii/execution/api.py
@@ -15,19 +15,18 @@
            'list_models', 'submit_models', 'wait_models', 'query_available_resources',
            'set_execution_engine', 'is_stopped_exec', 'budget_exhausted']
 
-def set_execution_engine(engine) -> None:
+
+def set_execution_engine(engine: AbstractExecutionEngine) -> None:
     global _execution_engine
     if _execution_engine is None:
         _execution_engine = engine
     else:
-        raise RuntimeError('execution engine is already set')
+        raise RuntimeError('Execution engine is already set.')
 
 
 def get_execution_engine() -> AbstractExecutionEngine:
-    """
-    Currently we assume the default execution engine is BaseExecutionEngine.
-    """
     global _execution_engine
+    assert _execution_engine is not None, 'You need to set execution engine, before using it.'
     return _execution_engine
 
 
diff --git a/nni/retiarii/execution/base.py b/nni/retiarii/execution/base.py
index 65ab99fc2b..36d09b505f 100644
--- a/nni/retiarii/execution/base.py
+++ b/nni/retiarii/execution/base.py
@@ -5,7 +5,7 @@
 import os
 import random
 import string
-from typing import Dict, Iterable, List
+from typing import Any, Dict, Iterable, List
 
 from .interface import AbstractExecutionEngine, AbstractGraphListener
 from .. import codegen, utils
@@ -59,7 +59,7 @@ def __init__(self) -> None:
 
     def submit_models(self, *models: Model) -> None:
         for model in models:
-            data = BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator)
+            data = self.pack_model_data(model)
             self._running_models[send_trial(data.dump())] = model
             self._history.append(model)
 
@@ -108,6 +108,10 @@ def budget_exhausted(self) -> bool:
         advisor = get_advisor()
         return advisor.stopping
 
+    @classmethod
+    def pack_model_data(cls, model: Model) -> Any:
+        return BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator)
+
     @classmethod
     def trial_execute_graph(cls) -> None:
         """
diff --git a/nni/retiarii/execution/python.py b/nni/retiarii/execution/python.py
new file mode 100644
index 0000000000..93a4d10333
--- /dev/null
+++ b/nni/retiarii/execution/python.py
@@ -0,0 +1,53 @@
+from typing import Dict, Any, List
+
+from ..graph import Evaluator, Model
+from ..integration_api import receive_trial_parameters
+from ..utils import ContextStack, import_, get_importable_name
+from .base import BaseExecutionEngine
+
+
+class PythonGraphData:
+    def __init__(self, class_name: str, init_parameters: Dict[str, Any],
+                 mutation: Dict[str, Any], evaluator: Evaluator) -> None:
+        self.class_name = class_name
+        self.init_parameters = init_parameters
+        self.mutation = mutation
+        self.evaluator = evaluator
+
+    def dump(self) -> dict:
+        return {
+            'class_name': self.class_name,
+            'init_parameters': self.init_parameters,
+            'mutation': self.mutation,
+            'evaluator': self.evaluator
+        }
+
+    @staticmethod
+    def load(data) -> 'PythonGraphData':
+        return PythonGraphData(data['class_name'], data['init_parameters'], data['mutation'], data['evaluator'])
+
+
+class PurePythonExecutionEngine(BaseExecutionEngine):
+    @classmethod
+    def pack_model_data(cls, model: Model) -> Any:
+        mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model.history}
+        graph_data = PythonGraphData(get_importable_name(model.python_class, relocate_module=True),
+                                     model.python_init_params, mutation, model.evaluator)
+        return graph_data
+
+    @classmethod
+    def trial_execute_graph(cls) -> None:
+        graph_data = PythonGraphData.load(receive_trial_parameters())
+
+        class _model(import_(graph_data.class_name)):
+            def __init__(self):
+                super().__init__(**graph_data.init_parameters)
+
+        with ContextStack('fixed', graph_data.mutation):
+            graph_data.evaluator._execute(_model)
+
+
+def _unpack_if_only_one(ele: List[Any]):
+    if len(ele) == 1:
+        return ele[0]
+    return ele
diff --git a/nni/retiarii/experiment/pytorch.py b/nni/retiarii/experiment/pytorch.py
index b7473507a4..f15931dcaf 100644
--- a/nni/retiarii/experiment/pytorch.py
+++ b/nni/retiarii/experiment/pytorch.py
@@ -28,11 +28,11 @@
 
 from ..codegen import model_to_pytorch_script
 from ..converter import convert_to_graph
-from ..execution import list_models
+from ..execution import list_models, set_execution_engine
 from ..graph import Model, Evaluator
 from ..integration import RetiariiAdvisor
 from ..mutator import Mutator
-from ..nn.pytorch.mutator import process_inline_mutation
+from ..nn.pytorch.mutator import process_inline_mutation, extract_mutation_from_pt_module
 from ..strategy import BaseStrategy
 from ..oneshot.interface import BaseOneShotTrainer
 
@@ -43,7 +43,7 @@
 class RetiariiExeConfig(ConfigBase):
     experiment_name: Optional[str] = None
     search_space: Any = ''  # TODO: remove
-    trial_command: str = 'python3 -m nni.retiarii.trial_entry'
+    trial_command: str = '_reserved'
     trial_code_directory: PathLike = '.'
     trial_concurrency: int
     trial_gpu_number: int = 0
@@ -55,21 +55,26 @@ class RetiariiExeConfig(ConfigBase):
     experiment_working_directory: PathLike = '~/nni-experiments'
     # remove configuration of tuner/assessor/advisor
     training_service: TrainingServiceConfig
+    execution_engine: str = 'base'
 
     def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
         super().__init__(**kwargs)
         if training_service_platform is not None:
             assert 'training_service' not in kwargs
             self.training_service = util.training_service_config_factory(platform = training_service_platform)
+        self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry base'
 
     def __setattr__(self, key, value):
         fixed_attrs = {'search_space': '',
-                       'trial_command': 'python3 -m nni.retiarii.trial_entry'}
+                       'trial_command': '_reserved'}
         if key in fixed_attrs and fixed_attrs[key] != value:
             raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!')
         # 'trial_code_directory' is handled differently because the path will be converted to absolute path by us
         if key == 'trial_code_directory' and not (value == Path('.') or os.path.isabs(value)):
             raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!')
+        if key == 'execution_engine':
+            assert value in ['base', 'py', 'cgo'], f'The specified execution engine "{value}" is not supported.'
+            self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry ' + value
         self.__dict__[key] = value
 
     def validate(self, initialized_tuner: bool = False) -> None:
@@ -100,23 +105,27 @@ def _validation_rules(self):
     'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
 }
 
-def preprocess_model(base_model, trainer, applied_mutators):
+def preprocess_model(base_model, trainer, applied_mutators, full_ir=True):
+    # TODO: this logic might need to be refactored into execution engine
+    if full_ir:
         try:
             script_module = torch.jit.script(base_model)
         except Exception as e:
             _logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
             raise e
         base_model_ir = convert_to_graph(script_module, base_model)
-        base_model_ir.evaluator = trainer
-
         # handle inline mutations
         mutators = process_inline_mutation(base_model_ir)
-        if mutators is not None and applied_mutators:
-            raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, '
-                               'do not use mutators when you use LayerChoice/InputChoice')
-        if mutators is not None:
-            applied_mutators = mutators
-        return base_model_ir, applied_mutators
+    else:
+        base_model_ir, mutators = extract_mutation_from_pt_module(base_model)
+    base_model_ir.evaluator = trainer
+
+    if mutators is not None and applied_mutators:
+        raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, '
+                            'do not use mutators when you use LayerChoice/InputChoice')
+    if mutators is not None:
+        applied_mutators = mutators
+    return base_model_ir, applied_mutators
 
 def debug_mutated_model(base_model, trainer, applied_mutators):
     """
@@ -160,7 +169,8 @@ def __init__(self, base_model: nn.Module, trainer: Union[Evaluator, BaseOneShotT
         self._pipe: Optional[Pipe] = None
 
     def _start_strategy(self):
-        base_model_ir, self.applied_mutators = preprocess_model(self.base_model, self.trainer, self.applied_mutators)
+        base_model_ir, self.applied_mutators = preprocess_model(
+            self.base_model, self.trainer, self.applied_mutators, full_ir=self.config.execution_engine != 'py')
 
         _logger.info('Start strategy...')
         self.strategy.run(base_model_ir, self.applied_mutators)
@@ -182,6 +192,18 @@ def start(self, port: int = 8080, debug: bool = False) -> None:
         """
         atexit.register(self.stop)
 
+        # we will probably need a execution engine factory to make this clean and elegant
+        if self.config.execution_engine == 'base':
+            from ..execution.base import BaseExecutionEngine
+            engine = BaseExecutionEngine()
+        elif self.config.execution_engine == 'cgo':
+            from ..execution.cgo_engine import CGOExecutionEngine
+            engine = CGOExecutionEngine()
+        elif self.config.execution_engine == 'py':
+            from ..execution.python import PurePythonExecutionEngine
+            engine = PurePythonExecutionEngine()
+        set_execution_engine(engine)
+
         self.id = management.generate_experiment_id()
 
         if self.config.experiment_working_directory is not None:
diff --git a/nni/retiarii/graph.py b/nni/retiarii/graph.py
index 48f2971a74..2255e288f1 100644
--- a/nni/retiarii/graph.py
+++ b/nni/retiarii/graph.py
@@ -9,12 +9,12 @@
 import copy
 import json
 from enum import Enum
-from typing import (Any, Dict, Iterable, List, Optional, Tuple, Union, overload)
+from typing import (Any, Dict, Iterable, List, Optional, Tuple, Type, Union, overload)
 
 from .operation import Cell, Operation, _IOPseudoOperation
 from .utils import get_importable_name, import_, uid
 
-__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'IllegalGraphError', 'MetricData']
+__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'Mutation', 'IllegalGraphError', 'MetricData']
 
 
 MetricData = Any
@@ -80,6 +80,10 @@ class Model:
 
     Attributes
     ----------
+    python_class
+        Python class that base model is converted from.
+    python_init_params
+        Initialization parameters of python class.
     status
         See `ModelStatus`.
     root_graph
@@ -102,6 +106,8 @@ class Model:
     def __init__(self, _internal=False):
         assert _internal, '`Model()` is private, use `model.fork()` instead'
         self.model_id: int = uid('model')
+        self.python_class: Optional[Type] = None
+        self.python_init_params: Optional[Dict[str, Any]] = None
 
         self.status: ModelStatus = ModelStatus.Mutating
 
@@ -116,7 +122,8 @@ def __init__(self, _internal=False):
 
     def __repr__(self):
         return f'Model(model_id={self.model_id}, status={self.status}, graphs={list(self.graphs.keys())}, ' + \
-            f'evaluator={self.evaluator}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics})'
+            f'evaluator={self.evaluator}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics}, ' + \
+            f'python_class={self.python_class})'
 
     @property
     def root_graph(self) -> 'Graph':
@@ -133,9 +140,12 @@ def fork(self) -> 'Model':
         """
         new_model = Model(_internal=True)
         new_model._root_graph_name = self._root_graph_name
+        new_model.python_class = self.python_class
+        new_model.python_init_params = self.python_init_params
         new_model.graphs = {name: graph._fork_to(new_model) for name, graph in self.graphs.items()}
         new_model.evaluator = copy.deepcopy(self.evaluator)  # TODO this may be a problem when evaluator is large
-        new_model.history = self.history + [self]
+        new_model.history = [*self.history]
+        # Note: the history is not updated. It will be updated when the model is changed, that is in mutator.
         return new_model
 
     @staticmethod
@@ -167,8 +177,8 @@ def get_nodes(self) -> Iterable['Node']:
 
     def get_nodes_by_label(self, label: str) -> List['Node']:
         """
-        Traverse all the nodes to find the matched node(s) with the given name.
-        There could be multiple nodes with the same name. Name space name can uniquely
+        Traverse all the nodes to find the matched node(s) with the given label.
+        There could be multiple nodes with the same label. Name space name can uniquely
         identify a graph or node.
 
         NOTE: the implementation does not support the class abstration
@@ -493,6 +503,8 @@ class Node:
         If two models have nodes with same ID, they are semantically the same node.
     name
         Mnemonic name. It should have an one-to-one mapping with ID.
+    label
+        Optional. If two nodes have the same label, they are considered same by the mutator.
     operation
         ...
     cell
@@ -515,7 +527,7 @@ def __init__(self, graph, node_id, name, operation, _internal=False):
         # TODO: the operation is likely to be considered editable by end-user and it will be hard to debug
         # maybe we should copy it here or make Operation class immutable, in next release
         self.operation: Operation = operation
-        self.label: str = None
+        self.label: Optional[str] = None
 
     def __repr__(self):
         return f'Node(id={self.id}, name={self.name}, label={self.label}, operation={self.operation})'
@@ -673,6 +685,37 @@ def _dump(self) -> Any:
         }
 
 
+class Mutation:
+    """
+    An execution of mutation, which consists of four parts: a mutator, a list of decisions (choices),
+    the model that it comes from, and the model that it becomes.
+
+    In general cases, the mutation logs are not reliable and should not be replayed as the mutators can
+    be arbitrarily complex. However, for inline mutations, the labels correspond to mutator labels here,
+    this can be useful for metadata visualization and python execution mode.
+
+    Attributes
+    ----------
+    mutator
+        Mutator.
+    samples
+        Decisions/choices.
+    from_
+        Model that is comes from.
+    to
+        Model that it becomes.
+    """
+
+    def __init__(self, mutator: 'Mutator', samples: List[Any], from_: Model, to: Model):  # noqa: F821
+        self.mutator: 'Mutator' = mutator  # noqa: F821
+        self.samples: List[Any] = samples
+        self.from_: Model = from_
+        self.to: Model = to
+
+    def __repr__(self):
+        return f'Edge(mutator={self.mutator}, samples={self.samples}, from={self.from_}, to={self.to})'
+
+
 class IllegalGraphError(ValueError):
     def __init__(self, graph, *args):
         self._debug_dump_graph(graph)
diff --git a/nni/retiarii/integration.py b/nni/retiarii/integration.py
index 1027062d8e..189db5ff5c 100644
--- a/nni/retiarii/integration.py
+++ b/nni/retiarii/integration.py
@@ -2,7 +2,6 @@
 # Licensed under the MIT license.
 
 import logging
-import os
 from typing import Any, Callable
 
 from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
@@ -10,9 +9,6 @@
 from nni.utils import MetricType
 
 from .graph import MetricData
-from .execution.base import BaseExecutionEngine
-from .execution.cgo_engine import CGOExecutionEngine
-from .execution.api import set_execution_engine
 from .integration_api import register_advisor
 from .serializer import json_dumps, json_loads
 
@@ -62,15 +58,6 @@ def __init__(self):
 
         self.parameters_count = 0
 
-        engine = self._create_execution_engine()
-        set_execution_engine(engine)
-
-    def _create_execution_engine(self):
-        if os.environ.get('CGO') == 'true':
-            return CGOExecutionEngine()
-        else:
-            return BaseExecutionEngine()
-
     def handle_initialize(self, data):
         """callback for initializing the advisor
         Parameters
diff --git a/nni/retiarii/mutator.py b/nni/retiarii/mutator.py
index fac3350f7c..e7d5708169 100644
--- a/nni/retiarii/mutator.py
+++ b/nni/retiarii/mutator.py
@@ -3,7 +3,7 @@
 
 from typing import (Any, Iterable, List, Optional)
 
-from .graph import Model
+from .graph import Model, Mutation, ModelStatus
 
 
 __all__ = ['Sampler', 'Mutator']
@@ -40,10 +40,13 @@ class Mutator:
     and then use `Mutator.apply()` to mutate model.
     For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates.
     # Method names are open for discussion.
+
+    If mutator has a label, in most cases, it means that this mutator is applied to nodes with this label.
     """
 
-    def __init__(self, sampler: Optional[Sampler] = None):
+    def __init__(self, sampler: Optional[Sampler] = None, label: Optional[str] = None):
         self.sampler: Optional[Sampler] = sampler
+        self.label: Optional[str] = label
         self._cur_model: Optional[Model] = None
         self._cur_choice_idx: Optional[int] = None
 
@@ -64,9 +67,12 @@ def apply(self, model: Model) -> Model:
         copy = model.fork()
         self._cur_model = copy
         self._cur_choice_idx = 0
+        self._cur_samples = []
         self.sampler.mutation_start(self, copy)
         self.mutate(copy)
         self.sampler.mutation_end(self, copy)
+        copy.history.append(Mutation(self, self._cur_samples, model, copy))
+        copy.status = ModelStatus.Frozen
         self._cur_model = None
         self._cur_choice_idx = None
         return copy
@@ -97,6 +103,7 @@ def choice(self, candidates: Iterable[Choice]) -> Choice:
         """
         assert self.sampler is not None and self._cur_model is not None and self._cur_choice_idx is not None
         ret = self.sampler.choice(list(candidates), self, self._cur_model, self._cur_choice_idx)
+        self._cur_samples.append(ret)
         self._cur_choice_idx += 1
         return ret
 
diff --git a/nni/retiarii/nn/pytorch/__init__.py b/nni/retiarii/nn/pytorch/__init__.py
index dffb882777..5c392164b1 100644
--- a/nni/retiarii/nn/pytorch/__init__.py
+++ b/nni/retiarii/nn/pytorch/__init__.py
@@ -1,2 +1,3 @@
 from .api import *
+from .component import *
 from .nn import *
diff --git a/nni/retiarii/nn/pytorch/api.py b/nni/retiarii/nn/pytorch/api.py
index b394d3ca55..69d12fb908 100644
--- a/nni/retiarii/nn/pytorch/api.py
+++ b/nni/retiarii/nn/pytorch/api.py
@@ -4,13 +4,13 @@
 import copy
 import warnings
 from collections import OrderedDict
-from typing import Any, List, Union, Dict
+from typing import Any, List, Union, Dict, Optional
 
 import torch
 import torch.nn as nn
 
 from ...serializer import Translatable, basic_unit
-from ...utils import uid
+from .utils import generate_new_label, get_fixed_value
 
 
 __all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'Placeholder', 'ChosenInputs']
@@ -55,7 +55,17 @@ class LayerChoice(nn.Module):
     ``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet.
     """
 
-    def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: str = None, **kwargs):
+    def __new__(cls, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: Optional[str] = None, **kwargs):
+        try:
+            chosen = get_fixed_value(label)
+            if isinstance(candidates, list):
+                return candidates[int(chosen)]
+            else:
+                return candidates[chosen]
+        except AssertionError:
+            return super().__new__(cls)
+
+    def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: Optional[str] = None, **kwargs):
         super(LayerChoice, self).__init__()
         if 'key' in kwargs:
             warnings.warn(f'"key" is deprecated. Assuming label.')
@@ -65,7 +75,7 @@ def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], lab
         if 'reduction' in kwargs:
             warnings.warn(f'"reduction" is deprecated. Ignoring...')
         self.candidates = candidates
-        self._label = label if label is not None else f'layerchoice_{uid()}'
+        self._label = generate_new_label(label)
 
         self.names = []
         if isinstance(candidates, OrderedDict):
@@ -163,7 +173,13 @@ class InputChoice(nn.Module):
         Identifier of the input choice.
     """
 
-    def __init__(self, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: str = None, **kwargs):
+    def __new__(cls, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: Optional[str] = None, **kwargs):
+        try:
+            return ChosenInputs(get_fixed_value(label), reduction=reduction)
+        except AssertionError:
+            return super().__new__(cls)
+
+    def __init__(self, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: Optional[str] = None, **kwargs):
         super(InputChoice, self).__init__()
         if 'key' in kwargs:
             warnings.warn(f'"key" is deprecated. Assuming label.')
@@ -176,7 +192,7 @@ def __init__(self, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum',
         self.n_chosen = n_chosen
         self.reduction = reduction
         assert self.reduction in ['mean', 'concat', 'sum', 'none']
-        self._label = label if label is not None else f'inputchoice_{uid()}'
+        self._label = generate_new_label(label)
 
     @property
     def key(self):
@@ -265,10 +281,16 @@ def forward(self, x):
         Identifier of the value choice.
     """
 
-    def __init__(self, candidates: List[Any], label: str = None):
+    def __new__(cls, candidates: List[Any], label: Optional[str] = None):
+        try:
+            return get_fixed_value(label)
+        except AssertionError:
+            return super().__new__(cls)
+
+    def __init__(self, candidates: List[Any], label: Optional[str] = None):
         super().__init__()
         self.candidates = candidates
-        self._label = label if label is not None else f'valuechoice_{uid()}'
+        self._label = generate_new_label(label)
         self._accessor = []
 
     @property
@@ -297,6 +319,14 @@ def access(self, value):
             raise KeyError(''.join([f'[{a}]' for a in self._accessor]) + f' does not work on {value}')
         return v
 
+    def __copy__(self):
+        return self
+
+    def __deepcopy__(self, memo):
+        new_item = ValueChoice(self.candidates, self.label)
+        new_item._accessor = [*self._accessor]
+        return new_item
+
     def __getitem__(self, item):
         """
         Get a sub-element of value choice.
@@ -331,9 +361,9 @@ class ChosenInputs(nn.Module):
     The already-chosen version of InputChoice.
     """
 
-    def __init__(self, chosen: List[int], reduction: str):
+    def __init__(self, chosen: Union[List[int], int], reduction: str):
         super().__init__()
-        self.chosen = chosen
+        self.chosen = chosen if isinstance(chosen, list) else [chosen]
         self.reduction = reduction
 
     def forward(self, candidate_inputs):
diff --git a/nni/retiarii/nn/pytorch/component.py b/nni/retiarii/nn/pytorch/component.py
new file mode 100644
index 0000000000..4ae5dc03bb
--- /dev/null
+++ b/nni/retiarii/nn/pytorch/component.py
@@ -0,0 +1,147 @@
+import copy
+from typing import Callable, List, Union, Tuple, Optional
+
+import torch
+import torch.nn as nn
+
+from .api import LayerChoice, InputChoice
+from .nn import ModuleList
+
+from .utils import generate_new_label, get_fixed_value
+
+
+__all__ = ['Repeat', 'Cell']
+
+
+class Repeat(nn.Module):
+    """
+    Repeat a block by a variable number of times.
+
+    Parameters
+    ----------
+    blocks : function, list of function, module or list of module
+        The block to be repeated. If not a list, it will be replicated into a list.
+        If a list, it should be of length ``max_depth``, the modules will be instantiated in order and a prefix will be taken.
+        If a function, it will be called to instantiate a module. Otherwise the module will be deep-copied.
+    depth : int or tuple of int
+        If one number, the block will be repeated by a fixed number of times. If a tuple, it should be (min, max),
+        meaning that the block will be repeated at least `min` times and at most `max` times.
+    """
+
+    def __new__(cls, blocks: Union[Callable[[], nn.Module], List[Callable[[], nn.Module]], nn.Module, List[nn.Module]],
+                depth: Union[int, Tuple[int, int]], label: Optional[str] = None):
+        try:
+            repeat = get_fixed_value(label)
+            return nn.Sequential(*cls._replicate_and_instantiate(blocks, repeat))
+        except AssertionError:
+            return super().__new__(cls)
+
+    def __init__(self,
+                 blocks: Union[Callable[[], nn.Module], List[Callable[[], nn.Module]], nn.Module, List[nn.Module]],
+                 depth: Union[int, Tuple[int, int]], label: Optional[str] = None):
+        super().__init__()
+        self._label = generate_new_label(label)
+        self.min_depth = depth if isinstance(depth, int) else depth[0]
+        self.max_depth = depth if isinstance(depth, int) else depth[1]
+        assert self.max_depth >= self.min_depth > 0
+        self.blocks = nn.ModuleList(self._replicate_and_instantiate(blocks, self.max_depth))
+
+    @property
+    def label(self):
+        return self._label
+
+    def forward(self, x):
+        for block in self.blocks:
+            x = block(x)
+        return x
+
+    @staticmethod
+    def _replicate_and_instantiate(blocks, repeat):
+        if not isinstance(blocks, list):
+            if isinstance(blocks, nn.Module):
+                blocks = [blocks] + [copy.deepcopy(blocks) for _ in range(repeat - 1)]
+            else:
+                blocks = [blocks for _ in range(repeat)]
+        assert len(blocks) > 0
+        assert repeat <= len(blocks), f'Not enough blocks to be used. {repeat} expected, only found {len(blocks)}.'
+        blocks = blocks[:repeat]
+        if not isinstance(blocks[0], nn.Module):
+            blocks = [b() for b in blocks]
+        return blocks
+
+
+class Cell(nn.Module):
+    """
+    Cell structure [1]_ [2]_ that is popularly used in NAS literature.
+
+    A cell consists of multiple "nodes". Each node is a sum of multiple operators. Each operator is chosen from
+    ``op_candidates``, and takes one input from previous nodes and predecessors. Predecessor means the input of cell.
+    The output of cell is the concatenation of some of the nodes in the cell (currently all the nodes).
+
+    Parameters
+    ----------
+    op_candidates : function or list of module
+        A list of modules to choose from, or a function that returns a list of modules.
+    num_nodes : int
+        Number of nodes in the cell.
+    num_ops_per_node: int
+        Number of operators in each node. The output of each node is the sum of all operators in the node. Default: 1.
+    num_predecessors : int
+        Number of inputs of the cell. The input to forward should be a list of tensors. Default: 1.
+    merge_op : str
+        Currently only ``all`` is supported, which has slight difference with that described in reference. Default: all.
+    label : str
+        Identifier of the cell. Cell sharing the same label will semantically share the same choice.
+
+    References
+    ----------
+    .. [1] Barret Zoph, Quoc V. Le, "Neural Architecture Search with Reinforcement Learning". https://arxiv.org/abs/1611.01578
+    .. [2] Barret Zoph, Vijay Vasudevan, Jonathon Shlens, Quoc V. Le,
+        "Learning Transferable Architectures for Scalable Image Recognition". https://arxiv.org/abs/1707.07012
+    """
+
+    # TODO:
+    # Support loose end concat (shape inference on the following cells)
+    # How to dynamically create convolution with stride as the first node
+
+    def __init__(self,
+                 op_candidates: Union[Callable, List[nn.Module]],
+                 num_nodes: int,
+                 num_ops_per_node: int = 1,
+                 num_predecessors: int = 1,
+                 merge_op: str = 'all',
+                 label: str = None):
+        super().__init__()
+        self._label = generate_new_label(label)
+        self.ops = ModuleList()
+        self.inputs = ModuleList()
+        self.num_nodes = num_nodes
+        self.num_ops_per_node = num_ops_per_node
+        self.num_predecessors = num_predecessors
+        for i in range(num_nodes):
+            self.ops.append(ModuleList())
+            self.inputs.append(ModuleList())
+            for k in range(num_ops_per_node):
+                if isinstance(op_candidates, list):
+                    assert len(op_candidates) > 0 and isinstance(op_candidates[0], nn.Module)
+                    ops = copy.deepcopy(op_candidates)
+                else:
+                    ops = op_candidates()
+                self.ops[-1].append(LayerChoice(ops, label=f'{self.label}__op_{i}_{k}'))
+                self.inputs[-1].append(InputChoice(i + num_predecessors, 1, label=f'{self.label}/input_{i}_{k}'))
+        assert merge_op in ['all']  # TODO: loose_end
+        self.merge_op = merge_op
+
+    @property
+    def label(self):
+        return self._label
+
+    def forward(self, x: List[torch.Tensor]):
+        states = x
+        for ops, inps in zip(self.ops, self.inputs):
+            current_state = []
+            for op, inp in zip(ops, inps):
+                current_state.append(op(inp(states)))
+            current_state = torch.sum(torch.stack(current_state), 0)
+            states.append(current_state)
+        return torch.cat(states[self.num_predecessors:], 1)
diff --git a/nni/retiarii/nn/pytorch/mutator.py b/nni/retiarii/nn/pytorch/mutator.py
index 3f4b256da8..6ef2ef19af 100644
--- a/nni/retiarii/nn/pytorch/mutator.py
+++ b/nni/retiarii/nn/pytorch/mutator.py
@@ -1,11 +1,16 @@
 # Copyright (c) Microsoft Corporation.
 # Licensed under the MIT license.
 
+import inspect
 from typing import Any, List, Optional, Tuple
 
+import torch.nn as nn
+
 from ...mutator import Mutator
-from ...graph import Cell, Model, Node
-from .api import ValueChoice
+from ...graph import Cell, Graph, Model, ModelStatus, Node
+from .api import LayerChoice, InputChoice, ValueChoice, Placeholder
+from .component import Repeat
+from ...utils import uid
 
 
 class LayerChoiceMutator(Mutator):
@@ -40,7 +45,7 @@ def __init__(self, nodes: List[Node]):
 
     def mutate(self, model):
         n_candidates = self.nodes[0].operation.parameters['n_candidates']
-        n_chosen =  self.nodes[0].operation.parameters['n_chosen']
+        n_chosen = self.nodes[0].operation.parameters['n_chosen']
         candidates = list(range(n_candidates))
         chosen = [self.choice(candidates) for _ in range(n_chosen)]
         for node in self.nodes:
@@ -76,6 +81,42 @@ def mutate(self, model):
             target.update_operation(target.operation.type, {**target.operation.parameters, argname: chosen_value})
 
 
+class RepeatMutator(Mutator):
+    def __init__(self, nodes: List[Node]):
+        # nodes is a subgraph consisting of repeated blocks.
+        super().__init__()
+        self.nodes = nodes
+
+    def _retrieve_chain_from_graph(self, graph: Graph) -> List[Node]:
+        u = graph.input_node
+        chain = []
+        while u != graph.output_node:
+            if u != graph.input_node:
+                chain.append(u)
+            assert len(u.successors) == 1, f'This graph is an illegal chain. {u} has output {u.successor}.'
+            u = u.successors[0]
+        return chain
+
+    def mutate(self, model):
+        min_depth = self.nodes[0].operation.parameters['min_depth']
+        max_depth = self.nodes[0].operation.parameters['max_depth']
+        if min_depth < max_depth:
+            chosen_depth = self.choice(list(range(min_depth, max_depth + 1)))
+        for node in self.nodes:
+            # the logic here is similar to layer choice. We find cell attached to each node.
+            target: Graph = model.graphs[node.operation.cell_name]
+            chain = self._retrieve_chain_from_graph(target)
+            for edge in chain[chosen_depth - 1].outgoing_edges:
+                edge.remove()
+            target.add_edge((chain[chosen_depth - 1], None), (target.output_node, None))
+            for rm_node in chain[chosen_depth:]:
+                for edge in rm_node.outgoing_edges:
+                    edge.remove()
+                rm_node.remove()
+            # to delete the unused parameters.
+            model.get_node_by_name(node.name).update_operation(Cell(node.operation.cell_name))
+
+
 def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
     applied_mutators = []
 
@@ -116,12 +157,110 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
         mutator = LayerChoiceMutator(node_list)
         applied_mutators.append(mutator)
 
+    repeat_nodes = _group_by_label(filter(lambda d: d.operation.parameters.get('mutation') == 'repeat',
+                                          model.get_nodes_by_type('_cell')))
+    for node_list in repeat_nodes:
+        assert _is_all_equal(map(lambda node: node.operation.parameters['max_depth'], node_list)) and \
+            _is_all_equal(map(lambda node: node.operation.parameters['min_depth'], node_list)), \
+            'Repeat with the same label must have the same number of candidates.'
+        mutator = RepeatMutator(node_list)
+        applied_mutators.append(mutator)
 
     if applied_mutators:
         return applied_mutators
     return None
 
 
+# The following are written for pure-python mode
+
+
+class ManyChooseManyMutator(Mutator):
+    """
+    Choose based on labels. Will not affect the model itself.
+    """
+
+    def __init__(self, label: Optional[str]):
+        super().__init__(label=label)
+
+    @staticmethod
+    def candidates(node):
+        if 'n_candidates' in node.operation.parameters:
+            return list(range(node.operation.parameters['n_candidates']))
+        else:
+            return node.operation.parameters['candidates']
+
+    @staticmethod
+    def number_of_chosen(node):
+        if 'n_chosen' in node.operation.parameters:
+            return node.operation.parameters['n_chosen']
+        return 1
+
+    def mutate(self, model: Model):
+        # this mutate does not have any effect, but it is recorded in the mutation history
+        for node in model.get_nodes_by_label(self.label):
+            for _ in range(self.number_of_chosen(node)):
+                self.choice(self.candidates(node))
+            break
+
+
+def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Optional[List[Mutator]]]:
+    model = Model(_internal=True)
+    graph = Graph(model, uid(), '_model', _internal=True)._register()
+    model.python_class = pytorch_model.__class__
+    if len(inspect.signature(model.python_class.__init__).parameters) > 1:
+        if not hasattr(pytorch_model, '_init_parameters'):
+            raise ValueError('Please annotate the model with @serialize decorator in python execution mode '
+                             'if your model has init parameters.')
+        model.python_init_params = pytorch_model._init_parameters
+    else:
+        model.python_init_params = {}
+
+    for name, module in pytorch_model.named_modules():
+        # tricky case: value choice that serves as parameters are stored in _init_parameters
+        if hasattr(module, '_init_parameters'):
+            for key, value in module._init_parameters.items():
+                if isinstance(value, ValueChoice):
+                    node = graph.add_node(name + '.init.' + key, 'ValueChoice', {'candidates': value.candidates})
+                    node.label = value.label
+
+        if isinstance(module, (LayerChoice, InputChoice, ValueChoice)):
+            # TODO: check the label of module and warn if it's auto-generated
+            pass
+        if isinstance(module, LayerChoice):
+            node = graph.add_node(name, 'LayerChoice', {'candidates': module.names})
+            node.label = module.label
+        if isinstance(module, InputChoice):
+            node = graph.add_node(name, 'InputChoice',
+                                  {'n_candidates': module.n_candidates, 'n_chosen': module.n_chosen})
+            node.label = module.label
+        if isinstance(module, ValueChoice):
+            node = graph.add_node(name, 'ValueChoice', {'candidates': module.candidates})
+            node.label = module.label
+        if isinstance(module, Repeat) and module.min_depth <= module.max_depth:
+            node = graph.add_node(name, 'Repeat', {
+                'candidates': list(range(module.min_depth, module.max_depth + 1))
+            })
+            node.label = module.label
+        if isinstance(module, Placeholder):
+            raise NotImplementedError('Placeholder is not supported in python execution mode.')
+
+    model.status = ModelStatus.Frozen
+    if not graph.hidden_nodes:
+        return model, None
+
+    mutators = []
+    for nodes in _group_by_label_and_type(graph.hidden_nodes):
+        assert _is_all_equal(map(lambda n: n.operation.type, nodes)), \
+            f'Node with label "{nodes[0].label}" does not all have the same type.'
+        assert _is_all_equal(map(lambda n: n.operation.parameters, nodes)), \
+            f'Node with label "{nodes[0].label}" does not agree on parameters.'
+        mutators.append(ManyChooseManyMutator(nodes[0].label))
+    return model, mutators
+
+
+# utility functions
+
+
 def _is_all_equal(lst):
     last = None
     for x in lst:
@@ -131,6 +270,16 @@ def _is_all_equal(lst):
     return True
 
 
+def _group_by_label_and_type(nodes: List[Node]) -> List[List[Node]]:
+    result = {}
+    for node in nodes:
+        key = (node.label, node.operation.type)
+        if key not in result:
+            result[key] = []
+        result[key].append(node)
+    return list(result.values())
+
+
 def _group_by_label(nodes: List[Node]) -> List[List[Node]]:
     result = {}
     for node in nodes:
diff --git a/nni/retiarii/nn/pytorch/utils.py b/nni/retiarii/nn/pytorch/utils.py
new file mode 100644
index 0000000000..352348b997
--- /dev/null
+++ b/nni/retiarii/nn/pytorch/utils.py
@@ -0,0 +1,17 @@
+from typing import Optional
+
+from ...utils import uid, get_current_context
+
+
+def generate_new_label(label: Optional[str]):
+    if label is None:
+        return '_mutation_' + str(uid('mutation'))
+    return label
+
+
+def get_fixed_value(label: str):
+    ret = get_current_context('fixed')
+    try:
+        return ret[generate_new_label(label)]
+    except KeyError:
+        raise KeyError(f'Fixed context with {label} not found. Existing values are: {ret}')
diff --git a/nni/retiarii/serializer.py b/nni/retiarii/serializer.py
index 9aad75c9c6..e0c2a26115 100644
--- a/nni/retiarii/serializer.py
+++ b/nni/retiarii/serializer.py
@@ -9,7 +9,7 @@
 
 import json_tricks
 
-from .utils import get_importable_name, get_module_name, import_
+from .utils import get_importable_name, get_module_name, import_, reset_uid
 
 
 def get_init_parameters_or_fail(obj, silently=False):
@@ -83,9 +83,11 @@ def _translate(self) -> Any:
         pass
 
 
-def _create_wrapper_cls(cls, store_init_parameters=True):
+def _create_wrapper_cls(cls, store_init_parameters=True, reset_mutation_uid=False):
     class wrapper(cls):
         def __init__(self, *args, **kwargs):
+            if reset_mutation_uid:
+                reset_uid('mutation')
             if store_init_parameters:
                 argname_list = list(inspect.signature(cls.__init__).parameters.keys())[1:]
                 full_args = {}
@@ -149,3 +151,15 @@ def basic_unit(cls):
     import torch.nn as nn
     assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.'
     return serialize_cls(cls)
+
+
+def model_wrapper(cls):
+    """
+    Wrap the model if you are using pure-python execution engine.
+
+    The wrapper serves two purposes:
+
+        1. Capture the init parameters of python class so that it can be re-instantiated in another process.
+        2. Reset uid in `mutation` namespace so that each model counts from zero. Can be useful in unittest and other multi-model scenarios.
+    """
+    return _create_wrapper_cls(cls, reset_mutation_uid=True)
diff --git a/nni/retiarii/strategy/__init__.py b/nni/retiarii/strategy/__init__.py
index e3cd6c5591..04511eaa69 100644
--- a/nni/retiarii/strategy/__init__.py
+++ b/nni/retiarii/strategy/__init__.py
@@ -6,3 +6,4 @@
 from .evolution import RegularizedEvolution
 from .tpe_strategy import TPEStrategy
 from .local_debug_strategy import _LocalDebugStrategy
+from .rl import PolicyBasedRL
diff --git a/nni/retiarii/strategy/_rl_impl.py b/nni/retiarii/strategy/_rl_impl.py
new file mode 100644
index 0000000000..43b07a3136
--- /dev/null
+++ b/nni/retiarii/strategy/_rl_impl.py
@@ -0,0 +1,121 @@
+# This file might cause import error for those who didn't install RL-related dependencies
+
+import logging
+
+import gym
+import numpy as np
+import torch
+import torch.nn as nn
+
+from gym import spaces
+from tianshou.data import to_torch
+
+from .utils import get_targeted_model
+from ..graph import ModelStatus
+from ..execution import submit_models, wait_models
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ModelEvaluationEnv(gym.Env):
+    def __init__(self, base_model, mutators, search_space):
+        self.base_model = base_model
+        self.mutators = mutators
+        self.search_space = search_space
+        self.ss_keys = list(self.search_space.keys())
+        self.action_dim = max(map(lambda v: len(v), self.search_space.values()))
+        self.num_steps = len(self.search_space)
+
+    @property
+    def observation_space(self):
+        return spaces.Dict({
+            'action_history': spaces.MultiDiscrete([self.action_dim] * self.num_steps),
+            'cur_step': spaces.Discrete(self.num_steps + 1),
+            'action_dim': spaces.Discrete(self.action_dim + 1)
+        })
+
+    @property
+    def action_space(self):
+        return spaces.Discrete(self.action_dim)
+
+    def reset(self):
+        self.action_history = np.zeros(self.num_steps, dtype=np.int32)
+        self.cur_step = 0
+        self.sample = {}
+        return {
+            'action_history': self.action_history,
+            'cur_step': self.cur_step,
+            'action_dim': len(self.search_space[self.ss_keys[self.cur_step]])
+        }
+
+    def step(self, action):
+        cur_key = self.ss_keys[self.cur_step]
+        assert action < len(self.search_space[cur_key]), \
+            f'Current action {action} out of range {self.search_space[cur_key]}.'
+        self.action_history[self.cur_step] = action
+        self.sample[cur_key] = self.search_space[cur_key][action]
+        self.cur_step += 1
+        obs = {
+            'action_history': self.action_history,
+            'cur_step': self.cur_step,
+            'action_dim': len(self.search_space[self.ss_keys[self.cur_step]]) \
+                if self.cur_step < self.num_steps else self.action_dim
+        }
+        if self.cur_step == self.num_steps:
+            model = get_targeted_model(self.base_model, self.mutators, self.sample)
+            _logger.info(f'New model created: {self.sample}')
+            submit_models(model)
+            wait_models(model)
+            if model.status == ModelStatus.Failed:
+                return self.reset(), 0., False, {}
+            rew = model.metric
+            _logger.info(f'Model metric received as reward: {rew}')
+            return obs, rew, True, {}
+        else:
+
+            return obs, 0., False, {}
+
+
+class Preprocessor(nn.Module):
+    def __init__(self, obs_space, hidden_dim=64, num_layers=1):
+        super().__init__()
+        self.action_dim = obs_space['action_history'].nvec[0]
+        self.hidden_dim = hidden_dim
+        # first token is [SOS]
+        self.embedding = nn.Embedding(self.action_dim + 1, hidden_dim)
+        self.rnn = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
+
+    def forward(self, obs):
+        seq = nn.functional.pad(obs['action_history'] + 1, (1, 1))  # pad the start token and end token
+        # end token is used to avoid out-of-range of v_s_. Will not actually affect BP.
+        seq = self.embedding(seq.long())
+        feature, _ = self.rnn(seq)
+        return feature[torch.arange(len(feature), device=feature.device), obs['cur_step'].long() + 1]
+
+
+class Actor(nn.Module):
+    def __init__(self, action_space, preprocess):
+        super().__init__()
+        self.preprocess = preprocess
+        self.action_dim = action_space.n
+        self.linear = nn.Linear(self.preprocess.hidden_dim, self.action_dim)
+
+    def forward(self, obs, **kwargs):
+        obs = to_torch(obs, device=self.linear.weight.device)
+        out = self.linear(self.preprocess(obs))
+        # to take care of choices with different number of options
+        mask = torch.arange(self.action_dim).expand(len(out), self.action_dim) >= obs['action_dim'].unsqueeze(1)
+        out[mask.to(out.device)] = float('-inf')
+        return nn.functional.softmax(out), kwargs.get('state', None)
+
+
+class Critic(nn.Module):
+    def __init__(self, preprocess):
+        super().__init__()
+        self.preprocess = preprocess
+        self.linear = nn.Linear(self.preprocess.hidden_dim, 1)
+
+    def forward(self, obs, **kwargs):
+        obs = to_torch(obs, device=self.linear.weight.device)
+        return self.linear(self.preprocess(obs)).squeeze(-1)
diff --git a/nni/retiarii/strategy/rl.py b/nni/retiarii/strategy/rl.py
new file mode 100644
index 0000000000..50052471eb
--- /dev/null
+++ b/nni/retiarii/strategy/rl.py
@@ -0,0 +1,92 @@
+import logging
+from typing import Optional, Callable
+
+from .base import BaseStrategy
+from .utils import dry_run_for_search_space
+from ..execution import query_available_resources
+
+try:
+    has_tianshou = True
+    import torch
+    from tianshou.data import AsyncCollector, Collector, VectorReplayBuffer
+    from tianshou.env import SubprocVectorEnv
+    from tianshou.policy import BasePolicy, PPOPolicy  # pylint: disable=unused-import
+    from ._rl_impl import ModelEvaluationEnv, Preprocessor, Actor, Critic
+except ImportError:
+    has_tianshou = False
+
+
+_logger = logging.getLogger(__name__)
+
+
+class PolicyBasedRL(BaseStrategy):
+    """
+    Algorithm for policy-based reinforcement learning.
+    This is a wrapper of algorithms provided in tianshou (PPO by default),
+    and can be easily customized with other algorithms that inherit ``BasePolicy`` (e.g., REINFORCE [1]_).
+
+    Note that RL algorithms are known to have issues on Windows and MacOS. They will be supported in future.
+
+    Parameters
+    ----------
+    max_collect : int
+        How many times collector runs to collect trials for RL. Default 100.
+    trial_per_collect : int
+        How many trials (trajectories) each time collector collects.
+        After each collect, trainer will sample batch from replay buffer and do the update. Default: 20.
+    policy_fn : function
+        Takes ``ModelEvaluationEnv`` as input and return a policy. See ``_default_policy_fn`` for an example.
+    asynchronous : bool
+        If true, in each step, collector won't wait for all the envs to complete.
+        This should generally not affect the result, but might affect the efficiency. Note that a slightly more trials
+        than expected might be collected if this is enabled.
+        If asynchronous is false, collector will wait for all parallel environments to complete in each step.
+        See ``tianshou.data.AsyncCollector`` for more details.
+
+    References
+    ----------
+
+    .. [1] Barret Zoph and Quoc V. Le, "Neural Architecture Search with Reinforcement Learning".
+        https://arxiv.org/abs/1611.01578
+    """
+
+    def __init__(self, max_collect: int = 100, trial_per_collect = 20,
+                 policy_fn: Optional[Callable[['ModelEvaluationEnv'], 'BasePolicy']] = None, asynchronous: bool = True):
+        if not has_tianshou:
+            raise ImportError('`tianshou` is required to run RL-based strategy. '
+                              'Please use "pip install tianshou" to install it beforehand.')
+
+        self.policy_fn = policy_fn or self._default_policy_fn
+        self.max_collect = max_collect
+        self.trial_per_collect = trial_per_collect
+        self.asynchronous = asynchronous
+
+    @staticmethod
+    def _default_policy_fn(env):
+        net = Preprocessor(env.observation_space)
+        actor = Actor(env.action_space, net)
+        critic = Critic(net)
+        optim = torch.optim.Adam(set(actor.parameters()).union(critic.parameters()), lr=1e-4)
+        return PPOPolicy(actor, critic, optim, torch.distributions.Categorical,
+                         discount_factor=1., action_space=env.action_space)
+
+    def run(self, base_model, applied_mutators):
+        search_space = dry_run_for_search_space(base_model, applied_mutators)
+        concurrency = query_available_resources()
+
+        env_fn = lambda: ModelEvaluationEnv(base_model, applied_mutators, search_space)
+        policy = self.policy_fn(env_fn())
+
+        if self.asynchronous:
+            # wait for half of the env complete in each step
+            env = SubprocVectorEnv([env_fn for _ in range(concurrency)], wait_num=int(concurrency * 0.5))
+            collector = AsyncCollector(policy, env, VectorReplayBuffer(20000, len(env)))
+        else:
+            env = SubprocVectorEnv([env_fn for _ in range(concurrency)])
+            collector = Collector(policy, env, VectorReplayBuffer(20000, len(env)))
+
+        for cur_collect in range(1, self.max_collect + 1):
+            _logger.info('Collect [%d] Running...', cur_collect)
+            result = collector.collect(n_episode=self.trial_per_collect)
+            _logger.info('Collect [%d] Result: %s', cur_collect, str(result))
+            policy.update(0, collector.buffer, batch_size=64, repeat=5)
diff --git a/nni/retiarii/trial_entry.py b/nni/retiarii/trial_entry.py
index 16855ca820..7d805dd47f 100644
--- a/nni/retiarii/trial_entry.py
+++ b/nni/retiarii/trial_entry.py
@@ -6,13 +6,20 @@
 
 Assuming execution engine is BaseExecutionEngine.
 """
-import os
+import argparse
 
-from .execution.base import BaseExecutionEngine
-from .execution.cgo_engine import CGOExecutionEngine
 
 if __name__ == '__main__':
-    if os.environ.get('CGO') == 'true':
-        CGOExecutionEngine.trial_execute_graph()
-    else:
-        BaseExecutionEngine.trial_execute_graph()
+    parser = argparse.ArgumentParser()
+    parser.add_argument('exec', choices=['base', 'py', 'cgo'])
+    args = parser.parse_args()
+    if args.exec == 'base':
+        from .execution.base import BaseExecutionEngine
+        engine = BaseExecutionEngine
+    elif args.exec == 'cgo':
+        from .execution.cgo_engine import CGOExecutionEngine
+        engine = CGOExecutionEngine
+    elif args.exec == 'py':
+        from .execution.python import PurePythonExecutionEngine
+        engine = PurePythonExecutionEngine
+    engine.trial_execute_graph()
diff --git a/nni/retiarii/utils.py b/nni/retiarii/utils.py
index d17ccabcfc..c8b02dfba4 100644
--- a/nni/retiarii/utils.py
+++ b/nni/retiarii/utils.py
@@ -4,7 +4,7 @@
 import inspect
 import warnings
 from collections import defaultdict
-from typing import Any
+from typing import Any, List, Dict
 from pathlib import Path
 
 
@@ -31,6 +31,10 @@ def uid(namespace: str = 'default') -> int:
     return _last_uid[namespace]
 
 
+def reset_uid(namespace: str = 'default') -> None:
+    _last_uid[namespace] = 0
+
+
 def get_module_name(cls_or_func):
     module_name = cls_or_func.__module__
     if module_name == '__main__':
@@ -61,3 +65,42 @@ def get_module_name(cls_or_func):
 def get_importable_name(cls, relocate_module=False):
     module_name = get_module_name(cls) if relocate_module else cls.__module__
     return module_name + '.' + cls.__name__
+
+
+class ContextStack:
+    """
+    This is to maintain a globally-accessible context envinronment that is visible to everywhere.
+
+    Use ``with ContextStack(namespace, value):`` to initiate, and use ``get_current_context(namespace)`` to
+    get the corresponding value in the namespace.
+    """
+
+    _stack: Dict[str, List[Any]] = defaultdict(list)
+
+    def __init__(self, key: str, value: Any):
+        self.key = key
+        self.value = value
+
+    def __enter__(self):
+        self.push(self.key, self.value)
+        return self
+
+    def __exit__(self, *args, **kwargs):
+        self.pop(self.key)
+
+    @classmethod
+    def push(cls, key: str, value: Any):
+        cls._stack[key].append(value)
+
+    @classmethod
+    def pop(cls, key: str) -> None:
+        cls._stack[key].pop()
+
+    @classmethod
+    def top(cls, key: str) -> Any:
+        assert cls._stack[key], 'Context is empty.'
+        return cls._stack[key][-1]
+
+
+def get_current_context(key: str) -> Any:
+    return ContextStack.top(key)
diff --git a/nni/tools/nnictl/url_utils.py b/nni/tools/nnictl/url_utils.py
index 71af16de68..379ac8e729 100644
--- a/nni/tools/nnictl/url_utils.py
+++ b/nni/tools/nnictl/url_utils.py
@@ -67,7 +67,7 @@ def trial_job_id_url(port, job_id):
 
 def export_data_url(port):
     '''get export_data url'''
-    return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, EXPORT_DATA_API)
+    return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, EXPORT_DATA_API)
 
 
 def tensorboard_url(port):
diff --git a/pipelines/full-test-linux.yml b/pipelines/full-test-linux.yml
index aaf09d175c..97f8ab415a 100644
--- a/pipelines/full-test-linux.yml
+++ b/pipelines/full-test-linux.yml
@@ -35,6 +35,7 @@ jobs:
       python3 -m pip install keras==2.1.6
       python3 -m pip install tensorflow==2.3.1 tensorflow-estimator==2.3.0
       python3 -m pip install thop
+      python3 -m pip install tianshou>=0.4.1 gym
       sudo apt-get install swig -y
     displayName: Install extra dependencies
 
diff --git a/pipelines/full-test-windows.yml b/pipelines/full-test-windows.yml
index 282b9fd1e7..30dd072ddb 100644
--- a/pipelines/full-test-windows.yml
+++ b/pipelines/full-test-windows.yml
@@ -30,6 +30,7 @@ jobs:
       python -m pip install torch==1.6.0 torchvision==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
       python -m pip install 'pytorch-lightning>=1.1.1'
       python -m pip install tensorflow==2.3.1 tensorflow-estimator==2.3.0
+      python -m pip install tianshou>=0.4.1 gym
     displayName: Install extra dependencies
 
   # Need del later
diff --git a/pipelines/integration-test-hybrid.yml b/pipelines/integration-test-hybrid.yml
new file mode 100644
index 0000000000..81ae2afcfb
--- /dev/null
+++ b/pipelines/integration-test-hybrid.yml
@@ -0,0 +1,102 @@
+trigger: none
+pr: none
+schedules:
+- cron: 0 16 * * *
+  branches:
+    include: [ master ]
+
+variables:
+  worker: remote_nni-ci-gpu-03
+
+
+jobs:
+- job: hybrid
+  pool: NNI CI REMOTE CLI
+  timeoutInMinutes: 120
+
+  steps:
+  - script: |
+      export NNI_RELEASE=999.$(date -u +%Y%m%d%H%M%S)
+      echo "##vso[task.setvariable variable=PATH]${PATH}:${HOME}/.local/bin"
+      echo "##vso[task.setvariable variable=NNI_RELEASE]${NNI_RELEASE}"
+
+      echo "Working directory: ${PWD}"
+      echo "NNI version: ${NNI_RELEASE}"
+
+      python3 -m pip install --upgrade pip setuptools
+    displayName: Prepare
+
+  - script: |
+      set -e
+      python3 test/vso_tools/install_nni.py $(NNI_RELEASE) SMAC,BOHB
+
+      cd examples/tuners/customized_tuner
+      python3 setup.py develop --user
+      nnictl algo register --meta meta_file.yml
+    displayName: Install NNI
+
+  - task: CopyFilesOverSSH@0
+    inputs:
+      sshEndpoint: $(worker)
+      sourceFolder: dist
+      targetFolder: /tmp/nnitest/$(Build.BuildId)/dist
+      overwrite: true
+    displayName: Copy wheel to remote machine
+    timeoutInMinutes: 10
+
+  - task: CopyFilesOverSSH@0
+    inputs:
+      sshEndpoint: $(worker)
+      contents: Dockerfile
+      targetFolder: /tmp/nnitest/$(Build.BuildId)
+      overwrite: true
+    displayName: Copy dockerfile to remote machine
+    timeoutInMinutes: 10
+
+  - task: CopyFilesOverSSH@0
+    inputs:
+      sshEndpoint: $(worker)
+      sourceFolder: test
+      targetFolder: /tmp/nnitest/$(Build.BuildId)/test
+      overwrite: true
+    displayName: Copy test scripts to remote machine
+    timeoutInMinutes: 10
+
+  # Need del later
+  - task: CopyFilesOverSSH@0
+    inputs:
+      sshEndpoint: $(worker)
+      contents: interim_vision_patch.py
+      targetFolder: /tmp/nnitest/$(Build.BuildId)
+      overwrite: true
+    displayName: Copy vision patch to remote machine
+    timeoutInMinutes: 10
+
+  - task: SSH@0
+    inputs:
+      sshEndpoint: $(worker)
+      runOptions: commands
+      commands: python3 /tmp/nnitest/$(Build.BuildId)/test/vso_tools/start_docker.py $(NNI_RELEASE) $(Build.BuildId) $(password_in_docker)
+    displayName: Install NNI and run docker on Linux worker
+
+  - script: |
+      cd test
+      python3 nni_test/nnitest/generate_ts_config.py \
+          --ts hybrid \
+          --remote_reuse true \
+          --remote_user nni \
+          --remote_host $(worker_ip) \
+          --remote_pwd $(password_in_docker) \
+          --remote_port $(docker_port) \
+          --nni_manager_ip $(manager_ip) \
+          --config_version v2
+      python3 nni_test/nnitest/run_tests.py --config config/integration_tests_config_v2.yml --ts hybrid
+    displayName: Integration test
+
+  - task: SSH@0
+    inputs:
+      sshEndpoint: $(worker)
+      runOptions: commands
+      commands: python3 /tmp/nnitest/$(Build.BuildId)/test/vso_tools/stop_docker.py $(Build.BuildId)
+    condition: always()
+    displayName: Stop docker
diff --git a/test/config/assessors/curvefitting-v2.yml b/test/config/assessors/curvefitting-v2.yml
new file mode 100644
index 0000000000..ce17bbb655
--- /dev/null
+++ b/test/config/assessors/curvefitting-v2.yml
@@ -0,0 +1,20 @@
+experimentName: default_test
+searchSpaceFile: ../naive_trial/search_space.json
+trialCommand: python3 trial.py
+trialCodeDirectory: ../naive_trial
+trialGpuNumber: 0
+trialConcurrency: 8
+maxExperimentDuration: 15m
+maxTrialNumber: 8
+tuner:
+  name: TPE
+  classArgs:
+    optimize_mode: maximize
+trainingService:
+  platform: local
+assessor:
+  name: Curvefitting
+  classArgs:
+    epoch_num: 20
+    start_step: 6
+    threshold: 0.95
diff --git a/test/config/assessors/medianstop-v2.yml b/test/config/assessors/medianstop-v2.yml
new file mode 100644
index 0000000000..b59a594304
--- /dev/null
+++ b/test/config/assessors/medianstop-v2.yml
@@ -0,0 +1,18 @@
+experimentName: default_test
+searchSpaceFile: ../naive_trial/search_space.json
+trialCommand: python3 trial.py
+trialCodeDirectory: ../naive_trial
+trialGpuNumber: 0
+trialConcurrency: 8
+maxExperimentDuration: 15m
+maxTrialNumber: 8
+tuner:
+  name: TPE
+  classArgs:
+    optimize_mode: maximize
+trainingService:
+  platform: local
+assessor:
+  name: Medianstop
+  classArgs:
+    optimize_mode: maximize
\ No newline at end of file
diff --git a/test/config/customized_tuners/demotuner-sklearn-classification-v2.yml b/test/config/customized_tuners/demotuner-sklearn-classification-v2.yml
new file mode 100644
index 0000000000..6c8a5bfb35
--- /dev/null
+++ b/test/config/customized_tuners/demotuner-sklearn-classification-v2.yml
@@ -0,0 +1,16 @@
+experimentName: default_test
+searchSpaceFile: ../../../examples/trials/sklearn/classification/search_space.json
+trialCommand: python3 main.py
+trialCodeDirectory: ../../../examples/trials/sklearn/classification
+trialGpuNumber: 0
+trialConcurrency: 4
+maxExperimentDuration: 15m
+maxTrialNumber: 2
+tuner:
+  name: demotuner
+trainingService:
+  platform: local
+assessor:
+  name: Medianstop
+  classArgs:
+    optimize_mode: maximize
diff --git a/test/config/examples/cifar10-pytorch-v2.yml b/test/config/examples/cifar10-pytorch-v2.yml
new file mode 100644
index 0000000000..1273f94567
--- /dev/null
+++ b/test/config/examples/cifar10-pytorch-v2.yml
@@ -0,0 +1,16 @@
+experimentName: default_test
+searchSpaceFile: cifar10_search_space.json
+trialCommand: python3 main.py --epochs 1 --batches 1
+trialCodeDirectory: ../../../examples/trials/cifar10_pytorch
+trialGpuNumber: 0
+trialConcurrency: 1
+maxExperimentDuration: 15m
+maxTrialNumber: 1
+tuner:
+  name: Random
+trainingService:
+  platform: local
+assessor:
+  name: Medianstop
+  classArgs:
+    optimize_mode: maximize
diff --git a/test/config/examples/classic-nas-pytorch-v2.yml b/test/config/examples/classic-nas-pytorch-v2.yml
new file mode 100644
index 0000000000..585cccbda3
--- /dev/null
+++ b/test/config/examples/classic-nas-pytorch-v2.yml
@@ -0,0 +1,18 @@
+experimentName: default_test
+searchSpaceFile: ni-nas-search-space.json
+trialCommand: python3 main.py --epochs 1 --batches 1
+trialCodeDirectory: ../../../examples/nas/legacy/classic_nas
+trialGpuNumber: 0
+trialConcurrency: 1
+maxExperimentDuration: 15m
+maxTrialNumber: 1
+tuner:
+  name: PPOTuner
+  classArgs:
+    optimize_mode: maximize
+trainingService:
+  platform: local
+assessor:
+  name: Medianstop
+  classArgs:
+    optimize_mode: maximize
diff --git a/test/config/examples/mnist-annotation-v2.yml b/test/config/examples/mnist-annotation-v2.yml
new file mode 100644
index 0000000000..0601ffedba
--- /dev/null
+++ b/test/config/examples/mnist-annotation-v2.yml
@@ -0,0 +1,16 @@
+experimentName: default_test
+searchSpaceFile: ../../../examples/trials/mnist-keras/search_space.json
+trialCommand: python3 mnist.py --batch_num 10
+trialCodeDirectory: ../../../examples/trials/mnist-annotation
+trialGpuNumber: 0
+trialConcurrency: 2
+maxExperimentDuration: 15m
+maxTrialNumber: 2
+tuner:
+  name: Random
+trainingService:
+  platform: local
+assessor:
+  name: Medianstop
+  classArgs:
+    optimize_mode: maximize
diff --git a/test/config/examples/mnist-keras-v2.yml b/test/config/examples/mnist-keras-v2.yml
new file mode 100644
index 0000000000..7d1313d239
--- /dev/null
+++ b/test/config/examples/mnist-keras-v2.yml
@@ -0,0 +1,16 @@
+experimentName: default_test
+searchSpaceFile: ../../../examples/trials/mnist-keras/search_space.json
+trialCommand: python3 mnist-keras.py --num_train 200 --epochs 1
+trialCodeDirectory: ../../../examples/trials/mnist-keras
+trialGpuNumber: 0
+trialConcurrency: 2
+maxExperimentDuration: 15m
+maxTrialNumber: 2
+tuner:
+  name: Random
+trainingService:
+  platform: local
+assessor:
+  name: Medianstop
+  classArgs:
+    optimize_mode: maximize
diff --git a/test/config/examples/mnist-nested-search-space-v2.yml b/test/config/examples/mnist-nested-search-space-v2.yml
new file mode 100644
index 0000000000..e5099daf90
--- /dev/null
+++ b/test/config/examples/mnist-nested-search-space-v2.yml
@@ -0,0 +1,16 @@
+experimentName: default_test
+searchSpaceFile: ../../../examples/trials/mnist-nested-search-space/search_space.json
+trialCommand: python3 mnist.py --batch_num 10
+trialCodeDirectory: ../../../examples/trials/mnist-nested-search-space
+trialGpuNumber: 0
+trialConcurrency: 2
+maxExperimentDuration: 15m
+maxTrialNumber: 2
+tuner:
+  name: TPE
+trainingService:
+  platform: local
+assessor:
+  name: Medianstop
+  classArgs:
+    optimize_mode: maximize
diff --git a/test/config/examples/mnist-pytorch-v2.yml b/test/config/examples/mnist-pytorch-v2.yml
new file mode 100644
index 0000000000..e481507d74
--- /dev/null
+++ b/test/config/examples/mnist-pytorch-v2.yml
@@ -0,0 +1,17 @@
+experimentName: default_test
+searchSpaceFile: ./mnist_pytorch_search_space.json
+trialCommand: python3 mnist.py  --epochs 1 --batch_num 10
+trialCodeDirectory: ../../../examples/trials/mnist-pytorch
+trialGpuNumber: 0
+trialConcurrency: 1
+maxExperimentDuration: 15m
+maxTrialNumber: 1
+tuner:
+  name: Random
+trainingService:
+  platform: local
+assessor:
+  name: Medianstop
+  classArgs:
+    optimize_mode: maximize
+
diff --git a/test/config/examples/mnist-tfv1-v2.yml b/test/config/examples/mnist-tfv1-v2.yml
new file mode 100644
index 0000000000..49f3a79e9d
--- /dev/null
+++ b/test/config/examples/mnist-tfv1-v2.yml
@@ -0,0 +1,16 @@
+experimentName: default_test
+searchSpaceFile: ./mnist_search_space.json
+trialCommand: python3 mnist.py --batch_num 10
+trialCodeDirectory: ../../../examples/trials/mnist-tfv1
+trialGpuNumber: 0
+trialConcurrency: 1
+maxExperimentDuration: 15m
+maxTrialNumber: 1
+tuner:
+  name: Random
+trainingService:
+  platform: local
+assessor:
+  name: Medianstop
+  classArgs:
+    optimize_mode: maximize
diff --git a/test/config/examples/mnist-tfv2-v2.yml b/test/config/examples/mnist-tfv2-v2.yml
new file mode 100644
index 0000000000..089ce672e8
--- /dev/null
+++ b/test/config/examples/mnist-tfv2-v2.yml
@@ -0,0 +1,17 @@
+experimentName: default_test
+searchSpaceFile: ./mnist_search_space.json
+trialCommand: python3 mnist.py
+trialCodeDirectory: ../../../examples/trials/mnist-tfv2
+trialGpuNumber: 0
+trialConcurrency: 2
+maxExperimentDuration: 15m
+maxTrialNumber: 4
+tuner:
+  name: Random
+trainingService:
+  platform: local
+assessor:
+  name: Medianstop
+  classArgs:
+    optimize_mode: maximize
+
diff --git a/test/config/examples/sklearn-classification-v2.yml b/test/config/examples/sklearn-classification-v2.yml
new file mode 100644
index 0000000000..1c7f566cd6
--- /dev/null
+++ b/test/config/examples/sklearn-classification-v2.yml
@@ -0,0 +1,16 @@
+experimentName: default_test
+searchSpaceFile: ../../../examples/trials/sklearn/classification/search_space.json
+trialCommand: python3 main.py
+trialCodeDirectory: ../../../examples/trials/sklearn/classification
+trialGpuNumber: 0
+trialConcurrency: 2
+maxExperimentDuration: 15m
+maxTrialNumber: 4
+tuner:
+  name: Random
+trainingService:
+  platform: local
+assessor:
+  name: Medianstop
+  classArgs:
+    optimize_mode: maximize
diff --git a/test/config/examples/sklearn-regression-v2.yml b/test/config/examples/sklearn-regression-v2.yml
new file mode 100644
index 0000000000..2d128cf061
--- /dev/null
+++ b/test/config/examples/sklearn-regression-v2.yml
@@ -0,0 +1,17 @@
+experimentName: default_test
+searchSpaceFile: ../../../examples/trials/sklearn/regression/search_space.json
+trialCommand: python3 main.py
+trialCodeDirectory: ../../../examples/trials/sklearn/regression
+trialGpuNumber: 0
+trialConcurrency: 2
+maxExperimentDuration: 15m
+maxTrialNumber: 4
+tuner:
+  name: Random
+trainingService:
+  platform: local
+assessor:
+  name: Medianstop
+  classArgs:
+    optimize_mode: maximize
+
diff --git a/test/config/integration_tests_config_v2.yml b/test/config/integration_tests_config_v2.yml
new file mode 100644
index 0000000000..42a428f09e
--- /dev/null
+++ b/test/config/integration_tests_config_v2.yml
@@ -0,0 +1,135 @@
+
+defaultTestCaseConfig:
+  launchCommand: nnictl create --config $configFile --debug
+  stopCommand: nnictl stop
+  experimentStatusCheck: True
+  platform: linux darwin win32
+  trainingService: all
+
+testCases:
+#######################################################################
+# nni examples test
+#######################################################################
+- name: sklearn-classification
+  # test case config yml file relative to nni source code directory
+  configFile: test/config/examples/sklearn-classification-v2.yml
+
+- name: sklearn-regression
+  configFile: test/config/examples/sklearn-regression-v2.yml
+
+- name: mnist-tensorflow
+  configFile: test/config/examples/mnist-tfv2-v2.yml
+  trainingService: local remote hybrid
+
+- name: mnist-pytorch-local
+  configFile: test/config/examples/mnist-pytorch-v2.yml
+  # download data first, to prevent concurrent issue.
+  launchCommand: python3 ../examples/trials/mnist-pytorch/mnist.py --epochs 1 --batch_num 0 --data_dir ../examples/trials/mnist-pytorch/data && nnictl create --config $configFile --debug
+  trainingService: local
+
+- name: mnist-pytorch-non-local
+  configFile: test/config/examples/mnist-pytorch-v2.yml
+  trainingService: remote pai kubeflow frameworkcontroller dlts hybrid
+
+- name: cifar10-pytorch
+  configFile: test/config/examples/cifar10-pytorch-v2.yml
+
+- name: cifar10-pytorch-adl
+  configFile: test/config/examples/cifar10-pytorch-adl.yml
+  trainingService: adl
+
+- name: classic-nas-gen-ss
+  configFile: test/config/examples/classic-nas-pytorch-v2.yml
+  launchCommand: nnictl ss_gen --trial_command="python3 mnist.py --epochs 1" --trial_dir=../examples/nas/legacy/classic_nas --file=config/examples/nni-nas-search-space.json
+  stopCommand:
+  experimentStatusCheck: False
+  trainingService: local
+
+- name: classic-nas-pytorch
+  configFile: test/config/examples/classic-nas-pytorch-v2.yml
+  # remove search space file
+  stopCommand: nnictl stop
+  onExitCommand: python3 -c "import os; os.remove('config/examples/nni-nas-search-space.json')"
+  trainingService: local
+
+#########################################################################
+# nni features test
+#########################################################################
+- name: metrics-float
+  configFile: test/config/metrics_test/config-v2.yml
+  validator:
+    class: MetricsValidator
+    kwargs:
+      expected_result_file: expected_metrics.json
+
+- name: export-float
+  configFile: test/config/metrics_test/config-v2.yml
+  validator:
+    class: ExportValidator 
+
+- name: metrics-dict
+  configFile: test/config/metrics_test/config_dict_metrics-v2.yml
+  validator:
+    class: MetricsValidator
+    kwargs:
+      expected_result_file: expected_metrics_dict.json
+
+- name: export-dict
+  configFile: test/config/metrics_test/config_dict_metrics-v2.yml
+  validator:
+    class: ExportValidator 
+
+- name: experiment-import
+  configFile: test/config/nnictl_experiment/sklearn-classification-v2.yml
+  validator:
+    class: ImportValidator
+    kwargs:
+      import_data_file_path: config/nnictl_experiment/test_import.json
+
+- name: foreground
+  configFile: test/config/examples/sklearn-regression-v2.yml
+  launchCommand: python3 nni_test/nnitest/foreground.py --config $configFile --timeout 45
+  stopCommand:
+  experimentStatusCheck: False
+  platform: linux darwin
+
+# Experiment resume test part 1
+- name: nnictl-resume-1
+  configFile: test/config/examples/sklearn-regression-v2.yml
+  setExperimentIdtoVar: $resumeExpId
+  # for subfolder in codedir test
+  launchCommand: python3 -c "import os; os.makedirs('../examples/trials/sklearn/regression/subfolder', exist_ok=True); open('../examples/trials/sklearn/regression/subfolder/subfile', 'a').close()" && nnictl create --config $configFile --debug
+
+# Experiment resume test part 2
+- name: nnictl-resume-2
+  configFile: test/config/examples/sklearn-regression-v2.yml
+  launchCommand: nnictl resume $resumeExpId
+
+# Experiment view test
+- name: nnictl-view
+  configFile: test/config/examples/sklearn-regression-v2.yml
+  launchCommand: nnictl view $resumeExpId
+  experimentStatusCheck: False
+
+
+#########################################################################
+# nni assessor test
+#########################################################################
+- name: assessor-curvefitting
+  configFile: test/config/assessors/curvefitting-v2.yml
+
+- name: assessor-medianstop
+  configFile: test/config/assessors/medianstop-v2.yml
+
+#########################################################################
+# nni tuners test
+#########################################################################
+- name: tuner-regularized_evolution
+  configFile: test/config/tuners/regularized_evolution_tuner-v2.yml
+
+#########################################################################
+# nni customized-tuners test
+#########################################################################
+- name: customized-tuners-demotuner
+  configFile: test/config/customized_tuners/demotuner-sklearn-classification-v2.yml
+
diff --git a/test/config/metrics_test/config-v2.yml b/test/config/metrics_test/config-v2.yml
new file mode 100644
index 0000000000..d553c8c878
--- /dev/null
+++ b/test/config/metrics_test/config-v2.yml
@@ -0,0 +1,16 @@
+experimentName: default_test
+searchSpaceFile: ./search_space.json
+trialCommand: python3 trial.py
+trialCodeDirectory: .
+trialGpuNumber: 0
+trialConcurrency: 1
+maxExperimentDuration: 15m
+maxTrialNumber: 1
+tuner:
+  name: Random
+trainingService:
+  platform: local
+assessor:
+  name: Medianstop
+  classArgs:
+    optimize_mode: maximize
diff --git a/test/config/metrics_test/config_dict_metrics-v2.yml b/test/config/metrics_test/config_dict_metrics-v2.yml
new file mode 100644
index 0000000000..f387487048
--- /dev/null
+++ b/test/config/metrics_test/config_dict_metrics-v2.yml
@@ -0,0 +1,16 @@
+experimentName: default_test
+searchSpaceFile: ./search_space.json
+trialCommand: python3 trial.py --dict_metrics
+trialCodeDirectory: .
+trialGpuNumber: 0
+trialConcurrency: 1
+maxExperimentDuration: 15m
+maxTrialNumber: 1
+tuner:
+  name: Random
+trainingService:
+  platform: local
+assessor:
+  name: Medianstop
+  classArgs:
+    optimize_mode: maximize
diff --git a/test/config/metrics_test/config_failure-v2.yml b/test/config/metrics_test/config_failure-v2.yml
new file mode 100644
index 0000000000..28cc8a6398
--- /dev/null
+++ b/test/config/metrics_test/config_failure-v2.yml
@@ -0,0 +1,17 @@
+experimentName: default_test
+searchSpaceFile: ./search_space.json
+trialCommand: python3 not_exist.py
+trialCodeDirectory: .
+trialGpuNumber: 0
+trialConcurrency: 1
+maxExperimentDuration: 15m
+maxTrialNumber: 1
+tuner:
+  name: Random
+trainingService:
+  platform: local
+assessor:
+  name: Medianstop
+  classArgs:
+    optimize_mode: maximize
+
diff --git a/test/config/nnictl_experiment/sklearn-classification-v2.yml b/test/config/nnictl_experiment/sklearn-classification-v2.yml
new file mode 100644
index 0000000000..4d8b575776
--- /dev/null
+++ b/test/config/nnictl_experiment/sklearn-classification-v2.yml
@@ -0,0 +1,16 @@
+experimentName: default_test
+searchSpaceFile: ../../../examples/trials/sklearn/classification/search_space.json
+trialCommand: python3 main.py
+trialCodeDirectory: ../../../examples/trials/sklearn/classification
+trialGpuNumber: 0
+trialConcurrency: 1
+maxExperimentDuration: 15m
+maxTrialNumber: 1
+tuner:
+  name: TPE
+trainingService:
+  platform: local
+assessor:
+  name: Medianstop
+  classArgs:
+    optimize_mode: maximize
\ No newline at end of file
diff --git a/test/config/training_service.yml b/test/config/training_service.yml
index cc85f57d5c..94d2288367 100644
--- a/test/config/training_service.yml
+++ b/test/config/training_service.yml
@@ -87,6 +87,26 @@ remote:
     port:
     username:
   trainingServicePlatform: remote
+hybrid:
+  maxExecDuration: 15m
+  nniManagerIp:
+  maxTrialNum: 2
+  trialConcurrency: 2
+  trial:
+    gpuNum: 0
+  trainingServicePlatform: hybrid
+  hybridConfig:
+    # TODO: Add more platforms
+    trainingServicePlatforms:
+      - remote
+      - local
+  machineList:
+  - ip:
+    passwd:
+    port:
+    username:
+  remoteConfig:
+    reuse: true
 adl:
   maxExecDuration: 15m
   nniManagerIp:
diff --git a/test/config/training_service_v2.yml b/test/config/training_service_v2.yml
new file mode 100644
index 0000000000..853e0df058
--- /dev/null
+++ b/test/config/training_service_v2.yml
@@ -0,0 +1,9 @@
+hybrid:
+  trainingService:
+  - platform: remote
+    machineList:
+      - host: 
+        user: 
+        password: 
+        port: 
+  - platform: local
\ No newline at end of file
diff --git a/test/config/tuners/regularized_evolution_tuner-v2.yml b/test/config/tuners/regularized_evolution_tuner-v2.yml
new file mode 100644
index 0000000000..a2a923b220
--- /dev/null
+++ b/test/config/tuners/regularized_evolution_tuner-v2.yml
@@ -0,0 +1,14 @@
+experimentName: default_test
+searchSpaceFile: seach_space_classic_nas.json
+trialCommand: python3 mnist.py --epochs 1
+trialCodeDirectory: ../../../examples/nas/legacy/classic_nas
+trialGpuNumber: 0
+trialConcurrency: 1
+maxExperimentDuration: 15m
+maxTrialNumber: 1
+tuner:
+  name: RegularizedEvolutionTuner
+  classArgs:
+    optimize_mode: maximize
+trainingService:
+  platform: local
diff --git a/test/nni_test/nnitest/generate_ts_config.py b/test/nni_test/nnitest/generate_ts_config.py
index 7dcf5465c8..d406131661 100644
--- a/test/nni_test/nnitest/generate_ts_config.py
+++ b/test/nni_test/nnitest/generate_ts_config.py
@@ -8,10 +8,11 @@
 from utils import get_yml_content, dump_yml_content
 
 TRAINING_SERVICE_FILE = os.path.join('config', 'training_service.yml')
+TRAINING_SERVICE_FILE_V2 = os.path.join('config', 'training_service_v2.yml')
 
 def update_training_service_config(args):
     config = get_yml_content(TRAINING_SERVICE_FILE)
-    if args.nni_manager_ip is not None:
+    if args.nni_manager_ip is not None and args.config_version == 'v1':
         config[args.ts]['nniManagerIp'] = args.nni_manager_ip
     if args.ts == 'pai':
         if args.pai_user is not None:
@@ -99,13 +100,22 @@ def update_training_service_config(args):
             config[args.ts]['amlConfig']['workspaceName'] = args.workspace_name
         if args.compute_target is not None:
             config[args.ts]['amlConfig']['computeTarget'] = args.compute_target
-
     dump_yml_content(TRAINING_SERVICE_FILE, config)
 
+    if args.ts == 'hybrid':
+        config = get_yml_content(TRAINING_SERVICE_FILE_V2)
+        config[args.ts]['trainingService'][0]['machineList'][0]['user'] = args.remote_user
+        config[args.ts]['trainingService'][0]['machineList'][0]['host'] = args.remote_host
+        config[args.ts]['trainingService'][0]['machineList'][0]['password'] = args.remote_pwd
+        config[args.ts]['trainingService'][0]['machineList'][0]['port'] = args.remote_port
+        config[args.ts]['nni_manager_ip'] = args.nni_manager_ip
+        dump_yml_content(TRAINING_SERVICE_FILE_V2, config)
+
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
-    parser.add_argument("--ts", type=str, choices=['pai', 'kubeflow', 'remote', 'local', 'frameworkcontroller', 'adl', 'aml'], default='pai')
+    parser.add_argument("--ts", type=str, choices=['pai', 'kubeflow', 'remote', 'local', 'frameworkcontroller', 'adl', 'aml', 'hybrid'], default='pai')
+    parser.add_argument("--config_version", type=str, choices=['v1', 'v2'], default='v1')
     parser.add_argument("--nni_docker_image", type=str)
     parser.add_argument("--nni_manager_ip", type=str)
     # args for PAI
diff --git a/test/nni_test/nnitest/run_tests.py b/test/nni_test/nnitest/run_tests.py
index 0b888cf3b7..f517f0a816 100644
--- a/test/nni_test/nnitest/run_tests.py
+++ b/test/nni_test/nnitest/run_tests.py
@@ -52,8 +52,11 @@ def update_training_service_config(config, training_service, config_file_path):
             containerCodeDir = config['trial']['codeDir'].replace('../../../', '/')
         it_ts_config[training_service]['trial']['codeDir'] = containerCodeDir
         it_ts_config[training_service]['trial']['command'] = 'cd {0} && {1}'.format(containerCodeDir, config['trial']['command'])
-
-    deep_update(config, it_ts_config['all'])
+    
+    if training_service == 'hybrid':
+        it_ts_config = get_yml_content(os.path.join('config', 'training_service_v2.yml'))
+    else:
+        deep_update(config, it_ts_config['all'])
     deep_update(config, it_ts_config[training_service])
 
 
@@ -123,7 +126,10 @@ def invoke_validator(test_case_config, nni_source_dir, training_service):
 
 def get_max_values(config_file):
     experiment_config = get_yml_content(config_file)
-    return parse_max_duration_time(experiment_config['maxExecDuration']), experiment_config['maxTrialNum']
+    if experiment_config.get('maxExecDuration'):
+        return parse_max_duration_time(experiment_config['maxExecDuration']), experiment_config['maxTrialNum']
+    else:
+        return parse_max_duration_time(experiment_config['maxExperimentDuration']), experiment_config['maxTrialNumber']
 
 
 def get_command(test_case_config, commandKey):
@@ -259,7 +265,7 @@ def run(args):
                 name, args.ts, test_case_config['trainingService']))
             continue
         # remote mode need more time to cleanup 
-        if args.ts == 'remote':
+        if args.ts == 'remote' or args.ts == 'hybrid':
             wait_for_port_available(8080, 240)
         else:
             wait_for_port_available(8080, 60)
@@ -281,7 +287,7 @@ def run(args):
     parser.add_argument("--cases", type=str, default=None)
     parser.add_argument("--exclude", type=str, default=None)
     parser.add_argument("--ts", type=str, choices=['local', 'remote', 'pai',
-                                                   'kubeflow', 'frameworkcontroller', 'adl', 'aml'], default='local')
+                                                   'kubeflow', 'frameworkcontroller', 'adl', 'aml', 'hybrid'], default='local')
     args = parser.parse_args()
 
     run(args)
diff --git a/test/ut/retiarii/debug_mnist_pytorch.py b/test/ut/retiarii/debug_mnist_pytorch.py
index 4ac3ddff8d..a15977e5f2 100644
--- a/test/ut/retiarii/debug_mnist_pytorch.py
+++ b/test/ut/retiarii/debug_mnist_pytorch.py
@@ -3,22 +3,24 @@
 import torch.nn.functional as F
 import torch.optim as optim
 
+import torch
+
 
 class _model(nn.Module):
     def __init__(self):
         super().__init__()
         self.stem = stem()
-        
-        self.fc1 = nn.Linear(1024, 256)
-        self.fc2 = nn.Linear(256, 10)
-        
+        self.flatten = torch.nn.Flatten()
+        self.fc1 = torch.nn.Linear(out_features=256, in_features=1024)
+        self.fc2 = torch.nn.Linear(out_features=10, in_features=256)
+        self.softmax = torch.nn.Softmax()
 
     def forward(self, image):
         stem = self.stem(image)
-        flatten = stem.view(stem.size(0), -1)
+        flatten = self.flatten(stem)
         fc1 = self.fc1(flatten)
         fc2 = self.fc2(fc1)
-        softmax = F.softmax(fc2, -1)
+        softmax = self.softmax(fc2)
         return softmax
 
 
@@ -26,10 +28,10 @@ def forward(self, image):
 class stem(nn.Module):
     def __init__(self):
         super().__init__()
-        self.conv1 = nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5)
-        self.pool1 = nn.MaxPool2d(kernel_size=2)
-        self.conv2 = nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
-        self.pool2 = nn.MaxPool2d(kernel_size=2)
+        self.conv1 = torch.nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5)
+        self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
+        self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
+        self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
 
     def forward(self, *_inputs):
         conv1 = self.conv1(_inputs[0])
diff --git a/test/ut/retiarii/mnist_pytorch.json b/test/ut/retiarii/mnist_pytorch.json
index 5788136d8a..b5ddc87887 100644
--- a/test/ut/retiarii/mnist_pytorch.json
+++ b/test/ut/retiarii/mnist_pytorch.json
@@ -5,10 +5,10 @@
 
         "nodes": {
             "stem": {"operation": {"type": "_cell", "cell_name": "stem"}},
-            "flatten": {"operation": {"type": "Flatten"}},
-            "fc1": {"operation": {"type": "Dense", "parameters": {"out_features": 256, "in_features": 1024}}},
-            "fc2": {"operation": {"type": "Dense", "parameters": {"out_features": 10, "in_features": 256}}},
-            "softmax": {"operation": {"type": "Softmax"}}
+            "flatten": {"operation": {"type": "__torch__.torch.nn.Flatten"}},
+            "fc1": {"operation": {"type": "__torch__.torch.nn.Linear", "parameters": {"out_features": 256, "in_features": 1024}}},
+            "fc2": {"operation": {"type": "__torch__.torch.nn.Linear", "parameters": {"out_features": 10, "in_features": 256}}},
+            "softmax": {"operation": {"type": "__torch__.torch.nn.Softmax"}}
         },
 
         "edges": [
@@ -23,10 +23,10 @@
 
     "stem": {
         "nodes": {
-            "conv1": {"operation": {"type": "__torch__.Conv2d", "parameters": {"out_channels": 32, "in_channels": 1, "kernel_size": 5}}},
-            "pool1": {"operation": {"type": "__torch__.MaxPool2d", "parameters": {"kernel_size": 2}}},
-            "conv2": {"operation": {"type": "__torch__.Conv2d", "parameters": {"out_channels": 64, "in_channels": 32, "kernel_size": 5}}},
-            "pool2": {"operation": {"type": "__torch__.MaxPool2d", "parameters": {"kernel_size": 2}}}
+            "conv1": {"operation": {"type": "__torch__.torch.nn.Conv2d", "parameters": {"out_channels": 32, "in_channels": 1, "kernel_size": 5}}},
+            "pool1": {"operation": {"type": "__torch__.torch.nn.MaxPool2d", "parameters": {"kernel_size": 2}}},
+            "conv2": {"operation": {"type": "__torch__.torch.nn.Conv2d", "parameters": {"out_channels": 64, "in_channels": 32, "kernel_size": 5}}},
+            "pool2": {"operation": {"type": "__torch__.torch.nn.MaxPool2d", "parameters": {"kernel_size": 2}}}
         },
 
         "edges": [
@@ -36,26 +36,5 @@
             {"head": ["conv2", null], "tail": ["pool2", null]},
             {"head": ["pool2", null], "tail": ["_outputs", 0]}
         ]
-    },
-
-    "_evaluator": {
-        "module": "nni.retiarii.trainer.PyTorchImageClassificationTrainer",
-        "kwargs": {
-            "dataset_cls": "MNIST",
-            "dataset_kwargs": {
-                "root": "data/mnist",
-                "download": true
-            },
-            "dataloader_kwargs": {
-                "batch_size": 32
-            },
-            "optimizer_cls" : "SGD",
-            "optimizer_kwargs": {
-                "lr": 1e-3
-            },
-            "trainer_kwargs": {
-                "max_epochs": 1
-            }
-        }
     }
 }
diff --git a/test/ut/retiarii/test_engine.py b/test/ut/retiarii/test_engine.py
index 48dc53e9bb..0a7881308b 100644
--- a/test/ut/retiarii/test_engine.py
+++ b/test/ut/retiarii/test_engine.py
@@ -1,59 +1,68 @@
 import json
 import os
-import sys
-import threading
 import unittest
 from pathlib import Path
 
-import nni
+import nni.retiarii
 from nni.retiarii import Model, submit_models
 from nni.retiarii.codegen import model_to_pytorch_script
-from nni.retiarii.integration import RetiariiAdvisor, register_advisor
-from nni.retiarii.evaluator.pytorch import PyTorchImageClassificationTrainer
-from nni.retiarii.utils import import_
+from nni.retiarii.execution import set_execution_engine
+from nni.retiarii.execution.base import BaseExecutionEngine
+from nni.retiarii.execution.python import PurePythonExecutionEngine
+from nni.retiarii.integration import RetiariiAdvisor
 
 
-@unittest.skip('Skipped in this version')
-class CodeGenTest(unittest.TestCase):
-    def test_mnist_example_pytorch(self):
-        with open('mnist_pytorch.json') as f:
+class EngineTest(unittest.TestCase):
+    def test_codegen(self):
+        with open(self.enclosing_dir / 'mnist_pytorch.json') as f:
             model = Model._load(json.load(f))
             script = model_to_pytorch_script(model)
-        with open('debug_mnist_pytorch.py') as f:
+        with open(self.enclosing_dir / 'debug_mnist_pytorch.py') as f:
             reference_script = f.read()
         self.assertEqual(script.strip(), reference_script.strip())
 
+    def test_base_execution_engine(self):
+        advisor = RetiariiAdvisor()
+        set_execution_engine(BaseExecutionEngine())
+        with open(self.enclosing_dir / 'mnist_pytorch.json') as f:
+            model = Model._load(json.load(f))
+        submit_models(model, model)
 
-@unittest.skip('Skipped in this version')
-class TrainerTest(unittest.TestCase):
-    def test_trainer(self):
-        sys.path.insert(0, Path(__file__).parent.as_posix())
-        Model = import_('debug_mnist_pytorch._model')
-        trainer = PyTorchImageClassificationTrainer(
-            Model(),
-            dataset_kwargs={'root': (Path(__file__).parent / 'data' / 'mnist').as_posix(), 'download': True},
-            dataloader_kwargs={'batch_size': 32},
-            optimizer_kwargs={'lr': 1e-3},
-            trainer_kwargs={'max_epochs': 1}
-        )
-        trainer.fit()
-
-
-@unittest.skip('Skipped in this version')
-class EngineTest(unittest.TestCase):
+        advisor.stopping = True
+        advisor.default_worker.join()
+        advisor.assessor_worker.join()
 
-    def test_submit_models(self):
-        os.makedirs('generated', exist_ok=True)
-        from nni.runtime import protocol
-        protocol._out_file = open(Path(__file__).parent / 'generated/debug_protocol_out_file.py', 'wb')
+    def test_py_execution_engine(self):
+        
         advisor = RetiariiAdvisor()
-        with open('mnist_pytorch.json') as f:
-            model = Model._load(json.load(f))
+        set_execution_engine(PurePythonExecutionEngine())
+        model = Model._load({
+            '_model': {
+                'inputs': None,
+                'outputs': None,
+                'nodes': {
+                    'layerchoice_1': {
+                        'operation': {'type': 'LayerChoice', 'parameters': {'candidates': ['0', '1']}}
+                    }
+                },
+                'edges': []
+            }
+        })
+        model.python_class = object
         submit_models(model, model)
 
         advisor.stopping = True
         advisor.default_worker.join()
         advisor.assessor_worker.join()
 
-    def test_execution_engine(self):
-        pass
+    def setUp(self) -> None:
+        self.enclosing_dir = Path(__file__).parent
+        os.makedirs(self.enclosing_dir / 'generated', exist_ok=True)
+        from nni.runtime import protocol
+        protocol._out_file = open(self.enclosing_dir / 'generated/debug_protocol_out_file.py', 'wb')
+
+    def tearDown(self) -> None:
+        from nni.runtime import protocol
+        protocol._out_file.close()
+        nni.retiarii.execution.api._execution_engine = None
+        nni.retiarii.integration_api._advisor = None
diff --git a/test/ut/retiarii/test_highlevel_apis.py b/test/ut/retiarii/test_highlevel_apis.py
index 20fbbcddb6..a3ff2b2d5d 100644
--- a/test/ut/retiarii/test_highlevel_apis.py
+++ b/test/ut/retiarii/test_highlevel_apis.py
@@ -8,7 +8,10 @@
 from nni.retiarii import Sampler, basic_unit
 from nni.retiarii.converter import convert_to_graph
 from nni.retiarii.codegen import model_to_pytorch_script
-from nni.retiarii.nn.pytorch.mutator import process_inline_mutation
+from nni.retiarii.execution.python import _unpack_if_only_one
+from nni.retiarii.nn.pytorch.mutator import process_inline_mutation, extract_mutation_from_pt_module
+from nni.retiarii.serializer import model_wrapper
+from nni.retiarii.utils import ContextStack
 
 
 class EnumerateSampler(Sampler):
@@ -44,7 +47,7 @@ def forward(self, x: torch.Tensor, index: int):
             return self.conv2(x)
 
 
-class TestHighLevelAPI(unittest.TestCase):
+class GraphIR(unittest.TestCase):
 
     def _convert_to_ir(self, model):
         script_module = torch.jit.script(model)
@@ -56,7 +59,19 @@ def _get_converted_pytorch_model(self, model_ir):
         exec(model_code + '\n\nconverted_model = _model()', exec_vars)
         return exec_vars['converted_model']
 
+    def _get_model_with_mutators(self, pytorch_model):
+        model = self._convert_to_ir(pytorch_model)
+        mutators = process_inline_mutation(model)
+        return model, mutators
+
+    def get_serializer(self):
+        def dummy(cls):
+            return cls
+
+        return dummy
+
     def test_layer_choice(self):
+        @self.get_serializer()
         class Net(nn.Module):
             def __init__(self):
                 super().__init__()
@@ -68,8 +83,7 @@ def __init__(self):
             def forward(self, x):
                 return self.module(x)
 
-        model = self._convert_to_ir(Net())
-        mutators = process_inline_mutation(model)
+        model, mutators = self._get_model_with_mutators(Net())
         self.assertEqual(len(mutators), 1)
         mutator = mutators[0].bind_sampler(EnumerateSampler())
         model1 = mutator.apply(model)
@@ -80,6 +94,7 @@ def forward(self, x):
                          torch.Size([1, 5, 3, 3]))
 
     def test_input_choice(self):
+        @self.get_serializer()
         class Net(nn.Module):
             def __init__(self):
                 super().__init__()
@@ -92,8 +107,7 @@ def forward(self, x):
                 x2 = self.conv2(x)
                 return self.input([x1, x2])
 
-        model = self._convert_to_ir(Net())
-        mutators = process_inline_mutation(model)
+        model, mutators = self._get_model_with_mutators(Net())
         self.assertEqual(len(mutators), 1)
         mutator = mutators[0].bind_sampler(EnumerateSampler())
         model1 = mutator.apply(model)
@@ -104,6 +118,7 @@ def forward(self, x):
                          torch.Size([1, 5, 3, 3]))
 
     def test_chosen_inputs(self):
+        @self.get_serializer()
         class Net(nn.Module):
             def __init__(self, reduction):
                 super().__init__()
@@ -117,8 +132,7 @@ def forward(self, x):
                 return self.input([x1, x2])
 
         for reduction in ['none', 'sum', 'mean', 'concat']:
-            model = self._convert_to_ir(Net(reduction))
-            mutators = process_inline_mutation(model)
+            model, mutators = self._get_model_with_mutators(Net(reduction))
             self.assertEqual(len(mutators), 1)
             mutator = mutators[0].bind_sampler(EnumerateSampler())
             model = mutator.apply(model)
@@ -133,6 +147,7 @@ def forward(self, x):
                 self.assertEqual(result.size(), torch.Size([1, 3, 3, 3]))
 
     def test_value_choice(self):
+        @self.get_serializer()
         class Net(nn.Module):
             def __init__(self):
                 super().__init__()
@@ -142,8 +157,7 @@ def __init__(self):
             def forward(self, x):
                 return self.conv(x, self.index())
 
-        model = self._convert_to_ir(Net())
-        mutators = process_inline_mutation(model)
+        model, mutators = self._get_model_with_mutators(Net())
         self.assertEqual(len(mutators), 1)
         mutator = mutators[0].bind_sampler(EnumerateSampler())
         model1 = mutator.apply(model)
@@ -154,6 +168,7 @@ def forward(self, x):
                          torch.Size([1, 5, 3, 3]))
 
     def test_value_choice_as_parameter(self):
+        @self.get_serializer()
         class Net(nn.Module):
             def __init__(self):
                 super().__init__()
@@ -162,8 +177,7 @@ def __init__(self):
             def forward(self, x):
                 return self.conv(x)
 
-        model = self._convert_to_ir(Net())
-        mutators = process_inline_mutation(model)
+        model, mutators = self._get_model_with_mutators(Net())
         self.assertEqual(len(mutators), 1)
         mutator = mutators[0].bind_sampler(EnumerateSampler())
         model1 = mutator.apply(model)
@@ -174,6 +188,7 @@ def forward(self, x):
                          torch.Size([1, 5, 1, 1]))
 
     def test_value_choice_as_parameter(self):
+        @self.get_serializer()
         class Net(nn.Module):
             def __init__(self):
                 super().__init__()
@@ -182,8 +197,7 @@ def __init__(self):
             def forward(self, x):
                 return self.conv(x)
 
-        model = self._convert_to_ir(Net())
-        mutators = process_inline_mutation(model)
+        model, mutators = self._get_model_with_mutators(Net())
         self.assertEqual(len(mutators), 1)
         mutator = mutators[0].bind_sampler(EnumerateSampler())
         model1 = mutator.apply(model)
@@ -194,6 +208,7 @@ def forward(self, x):
                          torch.Size([1, 5, 1, 1]))
 
     def test_value_choice_as_parameter(self):
+        @self.get_serializer()
         class Net(nn.Module):
             def __init__(self):
                 super().__init__()
@@ -202,8 +217,7 @@ def __init__(self):
             def forward(self, x):
                 return self.conv(x)
 
-        model = self._convert_to_ir(Net())
-        mutators = process_inline_mutation(model)
+        model, mutators = self._get_model_with_mutators(Net())
         self.assertEqual(len(mutators), 2)
         mutators[0].bind_sampler(EnumerateSampler())
         mutators[1].bind_sampler(EnumerateSampler())
@@ -214,6 +228,7 @@ def forward(self, x):
                          torch.Size([1, 8, 1, 1]))
 
     def test_value_choice_as_parameter_shared(self):
+        @self.get_serializer()
         class Net(nn.Module):
             def __init__(self):
                 super().__init__()
@@ -223,8 +238,7 @@ def __init__(self):
             def forward(self, x):
                 return self.conv1(x) + self.conv2(x)
 
-        model = self._convert_to_ir(Net())
-        mutators = process_inline_mutation(model)
+        model, mutators = self._get_model_with_mutators(Net())
         self.assertEqual(len(mutators), 1)
         mutator = mutators[0].bind_sampler(EnumerateSampler())
         model1 = mutator.apply(model)
@@ -235,6 +249,7 @@ def forward(self, x):
                          torch.Size([1, 8, 5, 5]))
 
     def test_value_choice_in_functional(self):
+        @self.get_serializer()
         class Net(nn.Module):
             def __init__(self):
                 super().__init__()
@@ -243,8 +258,7 @@ def __init__(self):
             def forward(self, x):
                 return F.dropout(x, self.dropout_rate())
 
-        model = self._convert_to_ir(Net())
-        mutators = process_inline_mutation(model)
+        model, mutators = self._get_model_with_mutators(Net())
         self.assertEqual(len(mutators), 1)
         mutator = mutators[0].bind_sampler(EnumerateSampler())
         model1 = mutator.apply(model)
@@ -254,6 +268,7 @@ def forward(self, x):
         self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
 
     def test_value_choice_in_layer_choice(self):
+        @self.get_serializer()
         class Net(nn.Module):
             def __init__(self):
                 super().__init__()
@@ -265,8 +280,7 @@ def __init__(self):
             def forward(self, x):
                 return self.linear(x)
 
-        model = self._convert_to_ir(Net())
-        mutators = process_inline_mutation(model)
+        model, mutators = self._get_model_with_mutators(Net())
         self.assertEqual(len(mutators), 3)
         sz_counter = Counter()
         sampler = RandomSampler()
@@ -278,6 +292,7 @@ def forward(self, x):
         self.assertEqual(len(sz_counter), 4)
 
     def test_shared(self):
+        @self.get_serializer()
         class Net(nn.Module):
             def __init__(self, shared=True):
                 super().__init__()
@@ -294,16 +309,14 @@ def __init__(self, shared=True):
             def forward(self, x):
                 return self.module1(x) + self.module2(x)
 
-        model = self._convert_to_ir(Net())
-        mutators = process_inline_mutation(model)
+        model, mutators = self._get_model_with_mutators(Net())
         self.assertEqual(len(mutators), 1)
         sampler = RandomSampler()
         mutator = mutators[0].bind_sampler(sampler)
         self.assertEqual(self._get_converted_pytorch_model(mutator.apply(model))(torch.randn(1, 3, 3, 3)).size(0), 1)
         self.assertEqual(sampler.counter, 1)
 
-        model = self._convert_to_ir(Net(shared=False))
-        mutators = process_inline_mutation(model)
+        model, mutators = self._get_model_with_mutators(Net(shared=False))
         self.assertEqual(len(mutators), 2)
         sampler = RandomSampler()
         # repeat test. Expectation: sometimes succeeds, sometimes fails.
@@ -321,6 +334,7 @@ def forward(self, x):
         self.assertLess(failed_count, 30)
 
     def test_valuechoice_access(self):
+        @self.get_serializer()
         class Net(nn.Module):
             def __init__(self):
                 super().__init__()
@@ -330,8 +344,7 @@ def __init__(self):
             def forward(self, x):
                 return self.conv(x)
 
-        model = self._convert_to_ir(Net())
-        mutators = process_inline_mutation(model)
+        model, mutators = self._get_model_with_mutators(Net())
         self.assertEqual(len(mutators), 1)
         mutators[0].bind_sampler(EnumerateSampler())
         input = torch.randn(1, 3, 5, 5)
@@ -340,6 +353,7 @@ def forward(self, x):
         self.assertEqual(self._get_converted_pytorch_model(mutators[0].apply(model))(input).size(),
                          torch.Size([1, 8, 1, 1]))
 
+        @self.get_serializer()
         class Net2(nn.Module):
             def __init__(self):
                 super().__init__()
@@ -354,24 +368,23 @@ def forward(self, x):
                 x = self.conv(x)
                 return self.conv1(torch.cat((x, x), 1))
 
-        model = self._convert_to_ir(Net2())
-        mutators = process_inline_mutation(model)
+        model, mutators = self._get_model_with_mutators(Net2())
         self.assertEqual(len(mutators), 1)
         mutators[0].bind_sampler(EnumerateSampler())
         input = torch.randn(1, 3, 5, 5)
         self._get_converted_pytorch_model(mutators[0].apply(model))(input)
 
     def test_valuechoice_access_functional(self):
+        @self.get_serializer()
         class Net(nn.Module):
             def __init__(self):
                 super().__init__()
-                self.dropout_rate = nn.ValueChoice([[0.,], [1.,]])
+                self.dropout_rate = nn.ValueChoice([[0., ], [1., ]])
 
             def forward(self, x):
                 return F.dropout(x, self.dropout_rate()[0])
 
-        model = self._convert_to_ir(Net())
-        mutators = process_inline_mutation(model)
+        model, mutators = self._get_model_with_mutators(Net())
         self.assertEqual(len(mutators), 1)
         mutator = mutators[0].bind_sampler(EnumerateSampler())
         model1 = mutator.apply(model)
@@ -381,18 +394,18 @@ def forward(self, x):
         self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
 
     def test_valuechoice_access_functional_expression(self):
+        @self.get_serializer()
         class Net(nn.Module):
             def __init__(self):
                 super().__init__()
-                self.dropout_rate = nn.ValueChoice([[1.05,], [1.1,]])
+                self.dropout_rate = nn.ValueChoice([[1.05, ], [1.1, ]])
 
             def forward(self, x):
                 # if expression failed, the exception would be:
                 # ValueError: dropout probability has to be between 0 and 1, but got 1.05
                 return F.dropout(x, self.dropout_rate()[0] - .1)
 
-        model = self._convert_to_ir(Net())
-        mutators = process_inline_mutation(model)
+        model, mutators = self._get_model_with_mutators(Net())
         self.assertEqual(len(mutators), 1)
         mutator = mutators[0].bind_sampler(EnumerateSampler())
         model1 = mutator.apply(model)
@@ -400,3 +413,90 @@ def forward(self, x):
         self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3))
         self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3]))
         self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
+
+    def test_repeat(self):
+        class AddOne(nn.Module):
+            def forward(self, x):
+                return x + 1
+
+        @self.get_serializer()
+        class Net(nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.block = nn.Repeat(AddOne(), (3, 5))
+
+            def forward(self, x):
+                return self.block(x)
+
+        model, mutators = self._get_model_with_mutators(Net())
+        self.assertEqual(len(mutators), 1)
+        mutator = mutators[0].bind_sampler(EnumerateSampler())
+        model1 = mutator.apply(model)
+        model2 = mutator.apply(model)
+        model3 = mutator.apply(model)
+        self.assertTrue((self._get_converted_pytorch_model(model1)(torch.zeros(1, 16)) == 3).all())
+        self.assertTrue((self._get_converted_pytorch_model(model2)(torch.zeros(1, 16)) == 4).all())
+        self.assertTrue((self._get_converted_pytorch_model(model3)(torch.zeros(1, 16)) == 5).all())
+
+    def test_cell(self):
+        @self.get_serializer()
+        class Net(nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)],
+                                    num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='all')
+
+            def forward(self, x, y):
+                return self.cell([x, y])
+
+        raw_model, mutators = self._get_model_with_mutators(Net())
+        for _ in range(10):
+            sampler = EnumerateSampler()
+            model = raw_model
+            for mutator in mutators:
+                model = mutator.bind_sampler(sampler).apply(model)
+            self.assertTrue(self._get_converted_pytorch_model(model)(
+                torch.randn(1, 16), torch.randn(1, 16)).size() == torch.Size([1, 64]))
+
+        @self.get_serializer()
+        class Net2(nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)], num_nodes=4)
+
+            def forward(self, x):
+                return self.cell([x])
+
+        raw_model, mutators = self._get_model_with_mutators(Net2())
+        for _ in range(10):
+            sampler = EnumerateSampler()
+            model = raw_model
+            for mutator in mutators:
+                model = mutator.bind_sampler(sampler).apply(model)
+            self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(1, 16)).size() == torch.Size([1, 64]))
+
+
+class Python(GraphIR):
+    def _get_converted_pytorch_model(self, model_ir):
+        mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model_ir.history}
+        with ContextStack('fixed', mutation):
+            model = model_ir.python_class(**model_ir.python_init_params)
+            return model
+
+    def _get_model_with_mutators(self, pytorch_model):
+        return extract_mutation_from_pt_module(pytorch_model)
+
+    def get_serializer(self):
+        return model_wrapper
+
+    @unittest.skip
+    def test_value_choice(self): ...
+
+    @unittest.skip
+    def test_value_choice_in_functional(self): ...
+
+    @unittest.skip
+    def test_valuechoice_access_functional(self): ...
+
+    @unittest.skip
+    def test_valuechoice_access_functional_expression(self): ...
diff --git a/test/ut/retiarii/test_mutator.py b/test/ut/retiarii/test_mutator.py
index 0c4cfd404b..a0cd05296d 100644
--- a/test/ut/retiarii/test_mutator.py
+++ b/test/ut/retiarii/test_mutator.py
@@ -60,7 +60,14 @@ def test_mutation():
     model2 = mutator.apply(model1)
     assert _get_pools(model2) == (global_pool, max_pool)
 
-    assert model2.history == [model0, model1]
+    assert len(model2.history) == 2
+    assert model2.history[0].from_ == model0
+    assert model2.history[0].to == model1
+    assert model2.history[1].from_ == model1
+    assert model2.history[1].to == model2
+    assert model2.history[0].mutator == mutator
+    assert model2.history[1].mutator == mutator
+
     assert _get_pools(model0) == (max_pool, max_pool)
     assert _get_pools(model1) == (avg_pool, global_pool)
 
diff --git a/test/ut/retiarii/test_strategy.py b/test/ut/retiarii/test_strategy.py
index 0333b4cfed..02fefdd3a9 100644
--- a/test/ut/retiarii/test_strategy.py
+++ b/test/ut/retiarii/test_strategy.py
@@ -1,4 +1,5 @@
 import random
+import sys
 import time
 import threading
 from typing import *
@@ -6,6 +7,7 @@
 import nni.retiarii.execution.api
 import nni.retiarii.nn.pytorch as nn
 import nni.retiarii.strategy as strategy
+import pytest
 import torch
 import torch.nn.functional as F
 from nni.retiarii import Model
@@ -58,7 +60,7 @@ def _reset_execution_engine(engine=None):
 
 
 class Net(nn.Module):
-    def __init__(self, hidden_size=32):
+    def __init__(self, hidden_size=32, diff_size=False):
         super(Net, self).__init__()
         self.conv1 = nn.Conv2d(1, 20, 5, 1)
         self.conv2 = nn.Conv2d(20, 50, 5, 1)
@@ -69,7 +71,7 @@ def __init__(self, hidden_size=32):
         self.fc2 = nn.LayerChoice([
             nn.Linear(hidden_size, 10, bias=False),
             nn.Linear(hidden_size, 10, bias=True)
-        ], label='fc2')
+        ] + ([] if not diff_size else [nn.Linear(hidden_size, 10, bias=False)]), label='fc2')
 
     def forward(self, x):
         x = F.relu(self.conv1(x))
@@ -82,8 +84,8 @@ def forward(self, x):
         return F.log_softmax(x, dim=1)
 
 
-def _get_model_and_mutators():
-    base_model = Net()
+def _get_model_and_mutators(**kwargs):
+    base_model = Net(**kwargs)
     script_module = torch.jit.script(base_model)
     base_model_ir = convert_to_graph(script_module, base_model)
     base_model_ir.evaluator = DebugEvaluator()
@@ -139,7 +141,25 @@ def test_evolution():
     _reset_execution_engine()
 
 
+@pytest.mark.skipif(sys.platform in ('win32', 'darwin'), reason='Does not run on Windows and MacOS')
+def test_rl():
+    rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10)
+    engine = MockExecutionEngine(failure_prob=0.2)
+    _reset_execution_engine(engine)
+    rl.run(*_get_model_and_mutators(diff_size=True))
+    wait_models(*engine.models)
+    _reset_execution_engine()
+
+    rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10, asynchronous=False)
+    engine = MockExecutionEngine(failure_prob=0.2)
+    _reset_execution_engine(engine)
+    rl.run(*_get_model_and_mutators())
+    wait_models(*engine.models)
+    _reset_execution_engine()
+
+
 if __name__ == '__main__':
     test_grid_search()
     test_random_search()
     test_evolution()
+    test_rl()
diff --git a/test/ut/sdk/test_compressor_torch.py b/test/ut/sdk/test_compressor_torch.py
index b7b0b2019e..350e46025e 100644
--- a/test/ut/sdk/test_compressor_torch.py
+++ b/test/ut/sdk/test_compressor_torch.py
@@ -61,9 +61,8 @@ def test_torch_quantizer_modules_detection(self):
 
     def test_torch_level_pruner(self):
         model = TorchModel()
-        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
         configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
-        torch_pruner.LevelPruner(model, configure_list, optimizer).compress()
+        torch_pruner.LevelPruner(model, configure_list).compress()
 
     def test_torch_naive_quantizer(self):
         model = TorchModel()
@@ -93,7 +92,7 @@ def test_torch_fpgm_pruner(self):
 
         model = TorchModel()
         config_list = [{'sparsity': 0.6, 'op_types': ['Conv2d']}, {'sparsity': 0.2, 'op_types': ['Conv2d']}]
-        pruner = torch_pruner.FPGMPruner(model, config_list, torch.optim.SGD(model.parameters(), lr=0.01))
+        pruner = torch_pruner.FPGMPruner(model, config_list)
 
         model.conv2.module.weight.data = torch.tensor(w).float()
         masks = pruner.calc_mask(model.conv2)
@@ -152,7 +151,7 @@ def test_torch_slim_pruner(self):
         config_list = [{'sparsity': 0.2, 'op_types': ['BatchNorm2d']}]
         model.bn1.weight.data = torch.tensor(w).float()
         model.bn2.weight.data = torch.tensor(-w).float()
-        pruner = torch_pruner.SlimPruner(model, config_list)
+        pruner = torch_pruner.SlimPruner(model, config_list, optimizer=None, trainer=None, criterion=None)
 
         mask1 = pruner.calc_mask(model.bn1)
         mask2 = pruner.calc_mask(model.bn2)
@@ -165,7 +164,7 @@ def test_torch_slim_pruner(self):
         config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}]
         model.bn1.weight.data = torch.tensor(w).float()
         model.bn2.weight.data = torch.tensor(w).float()
-        pruner = torch_pruner.SlimPruner(model, config_list)
+        pruner = torch_pruner.SlimPruner(model, config_list, optimizer=None, trainer=None, criterion=None)
 
         mask1 = pruner.calc_mask(model.bn1)
         mask2 = pruner.calc_mask(model.bn2)
@@ -202,8 +201,8 @@ def test_torch_taylorFOweight_pruner(self):
 
         model = TorchModel()
         optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
-        pruner = torch_pruner.TaylorFOWeightFilterPruner(model, config_list, optimizer, statistics_batch_num=1)
-        
+        pruner = torch_pruner.TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer=None, criterion=None, sparsity_training_epochs=1)
+
         x = torch.rand((1, 1, 28, 28), requires_grad=True)
         model.conv1.module.weight.data = torch.tensor(w1).float()
         model.conv2.module.weight.data = torch.tensor(w2).float()
@@ -345,7 +344,7 @@ def test_torch_pruner_validation(self):
             ],
             [
                 {'sparsity': 0.2 },
-                {'sparsity': 0.6, 'op_names': 'abc' }
+                {'sparsity': 0.6, 'op_names': 'abc'}
             ]
         ]
         model = TorchModel()
@@ -353,7 +352,13 @@ def test_torch_pruner_validation(self):
         for pruner_class in pruner_classes:
             for config_list in bad_configs:
                 try:
-                    pruner_class(model, config_list, optimizer)
+                    kwargs = {}
+                    if pruner_class in (torch_pruner.SlimPruner, torch_pruner.AGPPruner, torch_pruner.ActivationMeanRankFilterPruner, torch_pruner.ActivationAPoZRankFilterPruner):
+                        kwargs = {'optimizer': None, 'trainer': None, 'criterion': None}
+
+                    print('kwargs', kwargs)
+                    pruner_class(model, config_list, **kwargs)      
+
                     print(config_list)
                     assert False, 'Validation error should be raised for bad configuration'
                 except schema.SchemaError:
diff --git a/test/ut/sdk/test_dependecy_aware.py b/test/ut/sdk/test_dependecy_aware.py
index 5918f502bf..5823d1a408 100644
--- a/test/ut/sdk/test_dependecy_aware.py
+++ b/test/ut/sdk/test_dependecy_aware.py
@@ -46,6 +46,24 @@ def generate_random_sparsity_v2(model):
                              'sparsity': sparsity})
     return cfg_list
 
+def train(model, criterion, optimizer, callback=None):
+    model.train()
+    device = next(model.parameters()).device
+    data = torch.randn(2, 3, 224, 224).to(device)
+    target = torch.tensor([0, 1]).long().to(device)
+    optimizer.zero_grad()
+    output = model(data)
+    loss = criterion(output, target)
+    loss.backward()
+
+    # callback should be inserted between loss.backward() and optimizer.step()
+    if callback:
+        callback()
+
+    optimizer.step()
+
+def trainer(model, optimizer, criterion, epoch, callback=None):
+    return train(model, criterion, optimizer, callback=callback)
 
 class DependencyawareTest(TestCase):
     @unittest.skipIf(torch.__version__ < "1.3.0", "not supported")
@@ -55,6 +73,7 @@ def test_dependency_aware_pruning(self):
         sparsity = 0.7
         cfg_list = [{'op_types': ['Conv2d'], 'sparsity':sparsity}]
         dummy_input = torch.ones(1, 3, 224, 224)
+
         for model_name in model_zoo:
             for pruner in pruners:
                 print('Testing on ', pruner)
@@ -72,16 +91,12 @@ def test_dependency_aware_pruning(self):
                                  momentum=0.9,
                                  weight_decay=4e-5)
                 criterion = torch.nn.CrossEntropyLoss()
-                tmp_pruner = pruner(
-                    net, cfg_list, optimizer, dependency_aware=True, dummy_input=dummy_input)
-                # train one single batch so that the the pruner can collect the
-                # statistic
-                optimizer.zero_grad()
-                out = net(dummy_input)
-                batchsize = dummy_input.size(0)
-                loss = criterion(out, torch.zeros(batchsize, dtype=torch.int64))
-                loss.backward()
-                optimizer.step()
+                if pruner == TaylorFOWeightFilterPruner:
+                    tmp_pruner = pruner(
+                        net, cfg_list, optimizer, trainer=trainer, criterion=criterion, dependency_aware=True, dummy_input=dummy_input)
+                else:
+                    tmp_pruner = pruner(
+                        net, cfg_list, dependency_aware=True, dummy_input=dummy_input)
 
                 tmp_pruner.compress()
                 tmp_pruner.export_model(MODEL_FILE, MASK_FILE)
@@ -91,7 +106,7 @@ def test_dependency_aware_pruning(self):
                 ms.speedup_model()
                 for name, module in net.named_modules():
                     if isinstance(module, nn.Conv2d):
-                        expected = int(ori_filters[name] * (1-sparsity))
+                        expected = int(ori_filters[name] * (1 - sparsity))
                         filter_diff = abs(expected - module.out_channels)
                         errmsg = '%s Ori: %d, Expected: %d, Real: %d' % (
                             name, ori_filters[name], expected, module.out_channels)
@@ -124,16 +139,13 @@ def test_dependency_aware_random_config(self):
                                     momentum=0.9,
                                     weight_decay=4e-5)
                     criterion = torch.nn.CrossEntropyLoss()
-                    tmp_pruner = pruner(
-                        net, cfg_list, optimizer, dependency_aware=True, dummy_input=dummy_input)
-                    # train one single batch so that the the pruner can collect the
-                    # statistic
-                    optimizer.zero_grad()
-                    out = net(dummy_input)
-                    batchsize = dummy_input.size(0)
-                    loss = criterion(out, torch.zeros(batchsize, dtype=torch.int64))
-                    loss.backward()
-                    optimizer.step()
+
+                    if pruner in (TaylorFOWeightFilterPruner, ActivationMeanRankFilterPruner, ActivationAPoZRankFilterPruner):
+                        tmp_pruner = pruner(
+                            net, cfg_list, optimizer, trainer=trainer, criterion=criterion, dependency_aware=True, dummy_input=dummy_input)
+                    else:
+                        tmp_pruner = pruner(
+                            net, cfg_list, dependency_aware=True, dummy_input=dummy_input)
 
                     tmp_pruner.compress()
                     tmp_pruner.export_model(MODEL_FILE, MASK_FILE)
diff --git a/test/ut/sdk/test_model_speedup.py b/test/ut/sdk/test_model_speedup.py
index ecbdb89e6d..9ce7a7cba9 100644
--- a/test/ut/sdk/test_model_speedup.py
+++ b/test/ut/sdk/test_model_speedup.py
@@ -17,7 +17,7 @@
 from nni.compression.pytorch import ModelSpeedup, apply_compression_results
 from nni.algorithms.compression.pytorch.pruning import L1FilterPruner
 from nni.algorithms.compression.pytorch.pruning.weight_masker import WeightMasker
-from nni.algorithms.compression.pytorch.pruning.one_shot import _StructuredFilterPruner
+from nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner import DependencyAwarePruner
 
 torch.manual_seed(0)
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -205,7 +205,7 @@ def calc_mask(self, sparsity, wrapper, wrapper_idx=None):
             return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias}
 
 
-class L1ChannelPruner(_StructuredFilterPruner):
+class L1ChannelPruner(DependencyAwarePruner):
     def __init__(self, model, config_list, optimizer=None, dependency_aware=False, dummy_input=None):
         super().__init__(model, config_list, pruning_algorithm='l1', optimizer=optimizer,
                          dependency_aware=dependency_aware, dummy_input=dummy_input)
diff --git a/test/ut/sdk/test_pruners.py b/test/ut/sdk/test_pruners.py
index d9b10d63a9..e6948f2677 100644
--- a/test/ut/sdk/test_pruners.py
+++ b/test/ut/sdk/test_pruners.py
@@ -42,13 +42,10 @@ def validate_sparsity(wrapper, sparsity, bias=False):
     'agp': {
         'pruner_class': AGPPruner,
         'config_list': [{
-            'initial_sparsity': 0.,
-            'final_sparsity': 0.8,
-            'start_epoch': 0,
-            'end_epoch': 10,
-            'frequency': 1,
+            'sparsity': 0.8,
             'op_types': ['Conv2d']
         }],
+        'trainer': lambda model, optimizer, criterion, epoch: model,
         'validators': []
     },
     'slim': {
@@ -57,6 +54,7 @@ def validate_sparsity(wrapper, sparsity, bias=False):
             'sparsity': 0.7,
             'op_types': ['BatchNorm2d']
         }],
+        'trainer': lambda model, optimizer, criterion, epoch: model,
         'validators': [
             lambda model: validate_sparsity(model.bn1, 0.7, model.bias)
         ]
@@ -97,6 +95,7 @@ def validate_sparsity(wrapper, sparsity, bias=False):
             'sparsity': 0.5,
             'op_types': ['Conv2d'],
         }],
+        'trainer': lambda model, optimizer, criterion, epoch: model,
         'validators': [
             lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
         ]
@@ -107,6 +106,7 @@ def validate_sparsity(wrapper, sparsity, bias=False):
             'sparsity': 0.5,
             'op_types': ['Conv2d'],
         }],
+        'trainer': lambda model, optimizer, criterion, epoch: model,
         'validators': [
             lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
         ]
@@ -117,6 +117,7 @@ def validate_sparsity(wrapper, sparsity, bias=False):
             'sparsity': 0.5,
             'op_types': ['Conv2d'],
         }],
+        'trainer': lambda model, optimizer, criterion, epoch: model,
         'validators': [
             lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
         ]
@@ -127,7 +128,7 @@ def validate_sparsity(wrapper, sparsity, bias=False):
             'sparsity': 0.5,
             'op_types': ['Conv2d']
         }],
-        'short_term_fine_tuner': lambda model:model, 
+        'short_term_fine_tuner': lambda model: model, 
         'evaluator':lambda model: 0.9,
         'validators': []
     },
@@ -146,7 +147,7 @@ def validate_sparsity(wrapper, sparsity, bias=False):
             'sparsity': 0.5,
             'op_types': ['Conv2d'],
         }],
-        'trainer': lambda model, optimizer, criterion, epoch, callback : model, 
+        'trainer': lambda model, optimizer, criterion, epoch : model, 
         'validators': [
             lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
         ]
@@ -158,7 +159,7 @@ def validate_sparsity(wrapper, sparsity, bias=False):
             'op_types': ['Conv2d'],
         }],
         'base_algo': 'l1',
-        'trainer': lambda model, optimizer, criterion, epoch, callback : model,
+        'trainer': lambda model, optimizer, criterion, epoch : model,
         'evaluator': lambda model: 0.9,
         'dummy_input': torch.randn([64, 1, 28, 28]),
         'validators': []
@@ -170,7 +171,7 @@ def validate_sparsity(wrapper, sparsity, bias=False):
             'op_types': ['Conv2d'],
         }],
         'base_algo': 'l2',
-        'trainer': lambda model, optimizer, criterion, epoch, callback : model,
+        'trainer': lambda model, optimizer, criterion, epoch : model,
         'evaluator': lambda model: 0.9,
         'dummy_input': torch.randn([64, 1, 28, 28]),
         'validators': []
@@ -182,7 +183,7 @@ def validate_sparsity(wrapper, sparsity, bias=False):
             'op_types': ['Conv2d'],
         }],
         'base_algo': 'fpgm',
-        'trainer': lambda model, optimizer, criterion, epoch, callback : model,
+        'trainer': lambda model, optimizer, criterion, epoch : model,
         'evaluator': lambda model: 0.9,
         'dummy_input': torch.randn([64, 1, 28, 28]),
         'validators': []
@@ -206,88 +207,87 @@ def __init__(self, bias=True):
     def forward(self, x):
         return self.fc(self.pool(self.bn1(self.conv1(x))).view(x.size(0), -1))
 
+class SimpleDataset:
+    def __getitem__(self, index):
+        return torch.randn(3, 32, 32), 1.
+
+    def __len__(self):
+        return 1000
+
+def train(model, train_loader, criterion, optimizer):
+    model.train()
+    device = next(model.parameters()).device
+    x = torch.randn(2, 1, 28, 28).to(device)
+    y = torch.tensor([0, 1]).long().to(device)
+    # print('hello...')
+
+    for _ in range(2):
+        out = model(x)
+        loss = criterion(out, y)
+        optimizer.zero_grad()
+        loss.backward()
+
+        optimizer.step()
+
 def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'taylorfo', 'mean_activation', 'apoz', 'netadapt', 'simulatedannealing', 'admm', 'autocompress_l1', 'autocompress_l2', 'autocompress_fpgm',], bias=True):
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    dummy_input = torch.randn(2, 1, 28, 28).to(device)
+
+    criterion = torch.nn.CrossEntropyLoss()
+    train_loader = torch.utils.data.DataLoader(SimpleDataset(), batch_size=16, shuffle=False, drop_last=True)
+
+    def trainer(model, optimizer, criterion, epoch):
+        return train(model, train_loader, criterion, optimizer)
+
     for pruner_name in pruner_names:
         print('testing {}...'.format(pruner_name))
-        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
         model = Model(bias=bias).to(device)
         optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
         config_list = prune_config[pruner_name]['config_list']
 
-        x = torch.randn(2, 1, 28, 28).to(device)
-        y = torch.tensor([0, 1]).long().to(device)
-        out = model(x)
-        loss = F.cross_entropy(out, y)
-        optimizer.zero_grad()
-        loss.backward()
-        optimizer.step()
-
         if pruner_name == 'netadapt':
             pruner = prune_config[pruner_name]['pruner_class'](model, config_list, short_term_fine_tuner=prune_config[pruner_name]['short_term_fine_tuner'], evaluator=prune_config[pruner_name]['evaluator'])
         elif pruner_name == 'simulatedannealing':
             pruner = prune_config[pruner_name]['pruner_class'](model, config_list, evaluator=prune_config[pruner_name]['evaluator'])
+        elif pruner_name in ('agp', 'slim', 'taylorfo', 'apoz', 'mean_activation'):
+            pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=trainer, optimizer=optimizer, criterion=criterion)
         elif pruner_name == 'admm':
-            pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'])
+            pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=trainer)
         elif pruner_name.startswith('autocompress'):
-            pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'], evaluator=prune_config[pruner_name]['evaluator'], dummy_input=x, base_algo=prune_config[pruner_name]['base_algo'])
+            pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'], evaluator=prune_config[pruner_name]['evaluator'], criterion=torch.nn.CrossEntropyLoss(), dummy_input=dummy_input, base_algo=prune_config[pruner_name]['base_algo'])
         else:
-            pruner = prune_config[pruner_name]['pruner_class'](model, config_list, optimizer)
-        pruner.compress()
-
-        x = torch.randn(2, 1, 28, 28).to(device)
-        y = torch.tensor([0, 1]).long().to(device)
-        out = model(x)
-        loss = F.cross_entropy(out, y)
-        optimizer.zero_grad()
-        loss.backward()
-        optimizer.step()
-
-        if pruner_name == 'taylorfo':
-            # taylorfo algorithm calculate contributions at first iteration(step), and do pruning
-            # when iteration >= statistics_batch_num (default 1)
-            optimizer.step()
+            pruner = prune_config[pruner_name]['pruner_class'](model, config_list)
 
+        pruner.compress()
         pruner.export_model('./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth', input_shape=(2,1,28,28), device=device)
 
         for v in prune_config[pruner_name]['validators']:
             v(model)
 
-    
     filePaths = ['./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth', './search_history.csv', './search_result.json']
     for f in filePaths:
         if os.path.exists(f):
             os.remove(f)
 
-def _test_agp(pruning_algorithm):
-        model = Model()
-        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
-        config_list = prune_config['agp']['config_list']
 
-        pruner = AGPPruner(model, config_list, optimizer, pruning_algorithm=pruning_algorithm)
-        pruner.compress()
-
-        x = torch.randn(2, 1, 28, 28)
-        y = torch.tensor([0, 1]).long()
+def _test_agp(pruning_algorithm):
+    train_loader = torch.utils.data.DataLoader(SimpleDataset(), batch_size=16, shuffle=False, drop_last=True)
+    model = Model()
+    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
 
-        for epoch in range(config_list[0]['start_epoch'], config_list[0]['end_epoch']+1):
-            pruner.update_epoch(epoch)
-            out = model(x)
-            loss = F.cross_entropy(out, y)
-            optimizer.zero_grad()
-            loss.backward()
-            optimizer.step()
+    def trainer(model, optimizer, criterion, epoch):
+        return train(model, train_loader, criterion, optimizer)
 
-            target_sparsity = pruner.compute_target_sparsity(config_list[0])
-            actual_sparsity = (model.conv1.weight_mask == 0).sum().item() / model.conv1.weight_mask.numel()
-            # set abs_tol = 0.2, considering the sparsity error for channel pruning when number of channels is small.
-            assert math.isclose(actual_sparsity, target_sparsity, abs_tol=0.2)
+    config_list = prune_config['agp']['config_list']
+    pruner = AGPPruner(model, config_list, optimizer=optimizer, trainer=trainer, criterion=torch.nn.CrossEntropyLoss(), pruning_algorithm=pruning_algorithm)
+    pruner.compress()
 
-class SimpleDataset:
-    def __getitem__(self, index):
-        return torch.randn(3, 32, 32), 1.
+    target_sparsity = pruner.compute_target_sparsity(config_list[0])
+    actual_sparsity = (model.conv1.weight_mask == 0).sum().item() / model.conv1.weight_mask.numel()
+    # set abs_tol = 0.2, considering the sparsity error for channel pruning when number of channels is small.
+    assert math.isclose(actual_sparsity, target_sparsity, abs_tol=0.2)
 
-    def __len__(self):
-        return 1000
 
 class PrunerTestCase(TestCase):
     def test_pruners(self):
diff --git a/ts/webui/scripts/start.js b/ts/webui/scripts/start.js
index e6236bd6a2..5daa8c24e7 100644
--- a/ts/webui/scripts/start.js
+++ b/ts/webui/scripts/start.js
@@ -41,7 +41,7 @@ if (!checkRequiredFiles([paths.appHtml, paths.appIndexJs])) {
 }
 
 // Tools like Cloud9 rely on this.
-const DEFAULT_PORT = parseInt(process.env.PORT, 10) || 3000;
+const DEFAULT_PORT = parseInt(process.env.PORT, 10) || 8000;
 const HOST = process.env.HOST || '0.0.0.0';
 
 if (process.env.HOST) {
diff --git a/ts/webui/src/App.tsx b/ts/webui/src/App.tsx
index fde6d54de9..ce0ced6d3b 100644
--- a/ts/webui/src/App.tsx
+++ b/ts/webui/src/App.tsx
@@ -48,6 +48,8 @@ export const AppContext = React.createContext({
     // eslint-disable-next-line @typescript-eslint/no-empty-function
     updateOverviewPage: () => {},
     // eslint-disable-next-line @typescript-eslint/no-empty-function
+    updateDetailPage: () => {},
+    // eslint-disable-next-line @typescript-eslint/no-empty-function
     changeExpandRowIDs: (_val: string, _type?: string): void => {}
 });
 
@@ -133,6 +135,12 @@ class App extends React.Component<{}, AppState> {
         }));
     };
 
+    updateDetailPage = (): void => {
+        this.setState(state => ({
+            trialsUpdateBroadcast: state.trialsUpdateBroadcast + 1
+        }));
+    };
+
     shouldComponentUpdate(nextProps: any, nextState: AppState): boolean {
         if (!(nextState.isUpdate || nextState.isUpdate === undefined)) {
             nextState.isUpdate = true;
@@ -207,6 +215,7 @@ class App extends React.Component<{}, AppState> {
                                         bestTrialEntries,
                                         changeEntries: this.changeEntries,
                                         updateOverviewPage: this.updateOverviewPage,
+                                        updateDetailPage: this.updateDetailPage,
                                         expandRowIDs,
                                         changeExpandRowIDs: this.changeExpandRowIDs
                                     }}
diff --git a/ts/webui/src/components/TrialsDetail.tsx b/ts/webui/src/components/TrialsDetail.tsx
index 4d8d119d72..2665bbe662 100644
--- a/ts/webui/src/components/TrialsDetail.tsx
+++ b/ts/webui/src/components/TrialsDetail.tsx
@@ -83,10 +83,7 @@ class TrialsDetail extends React.Component<{}, TrialDetailState> {
                         </div>
                         {/* trial table list */}
                         <div className='detailTable' style={{ marginTop: 10 }}>
-                            <TableList
-                                tableSource={source}
-                                trialsUpdateBroadcast={this.context.trialsUpdateBroadcast}
-                            />
+                            <TableList tableSource={source} updateDetailPage={this.context.updateDetailPage} />
                         </div>
                     </React.Fragment>
                 )}
diff --git a/ts/webui/src/components/modals/ChildrenGap.ts b/ts/webui/src/components/modals/ChildrenGap.ts
new file mode 100644
index 0000000000..2471bbe58b
--- /dev/null
+++ b/ts/webui/src/components/modals/ChildrenGap.ts
@@ -0,0 +1,7 @@
+import { IStackTokens } from '@fluentui/react';
+
+const searchConditonsGap: IStackTokens = {
+    childrenGap: 10
+};
+
+export { searchConditonsGap };
diff --git a/ts/webui/src/components/trial-detail/TableList.tsx b/ts/webui/src/components/trial-detail/TableList.tsx
index a4405e63c4..e25da074fb 100644
--- a/ts/webui/src/components/trial-detail/TableList.tsx
+++ b/ts/webui/src/components/trial-detail/TableList.tsx
@@ -1,10 +1,8 @@
 import React from 'react';
 import {
     DefaultButton,
-    Dropdown,
     IColumn,
     Icon,
-    IDropdownOption,
     PrimaryButton,
     Stack,
     StackItem,
@@ -14,13 +12,15 @@ import {
 } from '@fluentui/react';
 import { EXPERIMENT, TRIALS } from '../../static/datamodel';
 import { TOOLTIP_BACKGROUND_COLOR } from '../../static/const';
-import { convertDuration, formatTimestamp, copyAndSort } from '../../static/function';
-import { TableObj, SortInfo } from '../../static/interface';
+import { convertDuration, formatTimestamp, copyAndSort, parametersType } from '../../static/function';
+import { TableObj, SortInfo, SearchItems } from '../../static/interface';
+import { getTrialsBySearchFilters } from './search/searchFunction';
 import { blocked, copy, LineChart, tableListIcon } from '../buttons/Icon';
 import ChangeColumnComponent from '../modals/ChangeColumnComponent';
 import Compare from '../modals/Compare';
 import Customize from '../modals/CustomizedTrial';
 import TensorboardUI from '../modals/tensorboard/TensorboardUI';
+import Search from './search/Search';
 import KillJob from '../modals/Killjob';
 import ExpandableDetails from '../public-child/ExpandableDetails';
 import PaginationTable from '../public-child/PaginationTable';
@@ -41,12 +41,6 @@ require('echarts/lib/component/tooltip');
 require('echarts/lib/component/title');
 
 type SearchOptionType = 'id' | 'trialnum' | 'status' | 'parameters';
-const searchOptionLiterals = {
-    id: 'ID',
-    trialnum: 'Trial No.',
-    status: 'Status',
-    parameters: 'Parameters'
-};
 
 const defaultDisplayedColumns = ['sequenceId', 'id', 'duration', 'status', 'latestAccuracy'];
 
@@ -76,7 +70,7 @@ function _inferColumnTitle(columnKey: string): string {
 
 interface TableListProps {
     tableSource: TableObj[];
-    trialsUpdateBroadcast: number;
+    updateDetailPage: () => void;
 }
 
 interface TableListState {
@@ -91,6 +85,8 @@ interface TableListState {
     intermediateDialogTrial: TableObj | undefined;
     copiedTrialId: string | undefined;
     sortInfo: SortInfo;
+    searchItems: Array<SearchItems>;
+    relation: Map<string, string>;
 }
 
 class TableList extends React.Component<TableListProps, TableListState> {
@@ -114,47 +110,14 @@ class TableList extends React.Component<TableListProps, TableListState> {
             selectedRowIds: [],
             intermediateDialogTrial: undefined,
             copiedTrialId: undefined,
-            sortInfo: { field: '', isDescend: true }
+            sortInfo: { field: '', isDescend: true },
+            searchItems: [],
+            relation: parametersType()
         };
 
         this._expandedTrialIds = new Set<string>();
     }
 
-    /* Search related methods */
-
-    // This functions as the filter for the final trials displayed in the current table
-    private _filterTrials(trials: TableObj[]): TableObj[] {
-        const { searchText, searchType } = this.state;
-        // search a trial by Trial No. | Trial ID | Parameters | Status
-        let searchFilter = (_: TableObj): boolean => true; // eslint-disable-line no-unused-vars
-        if (searchText.trim()) {
-            if (searchType === 'id') {
-                searchFilter = (trial): boolean => trial.id.toUpperCase().includes(searchText.toUpperCase());
-            } else if (searchType === 'trialnum') {
-                searchFilter = (trial): boolean => trial.sequenceId.toString() === searchText;
-            } else if (searchType === 'status') {
-                searchFilter = (trial): boolean => trial.status.toUpperCase().includes(searchText.toUpperCase());
-            } else if (searchType === 'parameters') {
-                // TODO: support filters like `x: 2` (instead of `'x': 2`)
-                searchFilter = (trial): boolean => JSON.stringify(trial.description.parameters).includes(searchText);
-            }
-        }
-        return trials.filter(searchFilter);
-    }
-
-    private _updateSearchFilterType(_event: React.FormEvent<HTMLDivElement>, item: IDropdownOption | undefined): void {
-        if (item !== undefined) {
-            const value = item.key.toString();
-            if (searchOptionLiterals.hasOwnProperty(value)) {
-                this.setState({ searchType: value as SearchOptionType }, this._updateTableSource);
-            }
-        }
-    }
-
-    private _updateSearchText(ev: React.ChangeEvent<HTMLInputElement>): void {
-        this.setState({ searchText: ev.target.value }, this._updateTableSource);
-    }
-
     /* Table basic function related methods */
 
     private _onColumnClick(ev: React.MouseEvent<HTMLElement>, column: IColumn): void {
@@ -180,7 +143,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
             const ret = {
                 sequenceId: trial.sequenceId,
                 id: trial.id,
-                checked: selectedRowIds.includes(trial.id) ? true : false,
+                _checked: selectedRowIds.includes(trial.id) ? true : false,
                 startTime: (trial as Trial).info.startTime, // FIXME: why do we need info here?
                 endTime: (trial as Trial).info.endTime,
                 duration: trial.duration,
@@ -221,7 +184,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
         }
         items.forEach(item => {
             if (item.id === id) {
-                item.checked = !!checked;
+                item._checked = !!checked;
             }
         });
         this.setState(() => ({ displayedItems: items, selectedRowIds: temp }));
@@ -231,7 +194,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
         const { displayedItems } = this.state;
         const newDisplayedItems = displayedItems;
         newDisplayedItems.forEach(item => {
-            item.checked = false;
+            item._checked = false;
         });
         this.setState(() => ({
             selectedRowIds: [],
@@ -253,7 +216,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
                 onRender: (record): React.ReactNode => (
                     <Checkbox
                         label={undefined}
-                        checked={record.checked}
+                        checked={record._checked}
                         className='detail-check'
                         onChange={this.selectedTrialOnChangeEvent.bind(this, record.id)}
                     />
@@ -438,7 +401,11 @@ class TableList extends React.Component<TableListProps, TableListState> {
 
     private _updateTableSource(): void {
         // call this method when trials or the computation of trial filter has changed
-        const items = this._trialsToTableItems(this._filterTrials(this.props.tableSource));
+        const { searchItems, relation } = this.state;
+        let items = this._trialsToTableItems(this.props.tableSource);
+        if (searchItems.length > 0) {
+            items = getTrialsBySearchFilters(items, searchItems, relation); // use search filter to filter data
+        }
         if (items.length > 0) {
             const columns = this._buildColumnsFromTableItems(items);
             this.setState({
@@ -496,6 +463,12 @@ class TableList extends React.Component<TableListProps, TableListState> {
         );
     }
 
+    private changeSearchFilterList = (arr: Array<SearchItems>): void => {
+        this.setState(() => ({
+            searchItems: arr
+        }));
+    };
+
     componentDidUpdate(prevProps: TableListProps): void {
         if (this.props.tableSource !== prevProps.tableSource) {
             this._updateTableSource();
@@ -510,13 +483,13 @@ class TableList extends React.Component<TableListProps, TableListState> {
         const {
             displayedItems,
             columns,
-            searchType,
             customizeColumnsDialogVisible,
             compareDialogVisible,
             displayedColumns,
             selectedRowIds,
             intermediateDialogTrial,
-            copiedTrialId
+            copiedTrialId,
+            searchItems
         } = this.state;
 
         return (
@@ -526,7 +499,24 @@ class TableList extends React.Component<TableListProps, TableListState> {
                     <span>Trial jobs</span>
                 </Stack>
                 <Stack horizontal className='allList'>
-                    <StackItem grow={50}>
+                    <StackItem>
+                        <Stack horizontal horizontalAlign='end' className='allList'>
+                            <Search
+                                searchFilter={searchItems} // search filter list
+                                changeSearchFilterList={this.changeSearchFilterList}
+                                updatePage={this.props.updateDetailPage}
+                            />
+                        </Stack>
+                    </StackItem>
+
+                    <StackItem styles={{ root: { position: 'absolute', right: '0' } }}>
+                        <DefaultButton
+                            className='allList-button-gap'
+                            text='Add/Remove columns'
+                            onClick={(): void => {
+                                this.setState({ customizeColumnsDialogVisible: true });
+                            }}
+                        />
                         <DefaultButton
                             text='Compare'
                             className='allList-compare'
@@ -540,37 +530,6 @@ class TableList extends React.Component<TableListProps, TableListState> {
                             changeSelectTrialIds={this.changeSelectTrialIds}
                         />
                     </StackItem>
-                    <StackItem grow={50}>
-                        <Stack horizontal horizontalAlign='end' className='allList'>
-                            <DefaultButton
-                                className='allList-button-gap'
-                                text='Add/Remove columns'
-                                onClick={(): void => {
-                                    this.setState({ customizeColumnsDialogVisible: true });
-                                }}
-                            />
-                            <Dropdown
-                                selectedKey={searchType}
-                                options={Object.entries(searchOptionLiterals).map(([k, v]) => ({
-                                    key: k,
-                                    text: v
-                                }))}
-                                onChange={this._updateSearchFilterType.bind(this)}
-                                styles={{ root: { width: 150 } }}
-                            />
-                            <input
-                                type='text'
-                                className='allList-search-input'
-                                placeholder={`Search by ${
-                                    ['id', 'trialnum'].includes(searchType)
-                                        ? searchOptionLiterals[searchType]
-                                        : searchType
-                                }`}
-                                onChange={this._updateSearchText.bind(this)}
-                                style={{ width: 230 }}
-                            />
-                        </Stack>
-                    </StackItem>
                 </Stack>
                 {columns && displayedItems && (
                     <PaginationTable
diff --git a/ts/webui/src/components/trial-detail/search/GeneralSearch.tsx b/ts/webui/src/components/trial-detail/search/GeneralSearch.tsx
new file mode 100644
index 0000000000..86501ab20c
--- /dev/null
+++ b/ts/webui/src/components/trial-detail/search/GeneralSearch.tsx
@@ -0,0 +1,82 @@
+import React, { useState } from 'react';
+import PropTypes from 'prop-types';
+import { Stack, PrimaryButton } from '@fluentui/react';
+import { searchConditonsGap } from '../../modals/ChildrenGap';
+import { getSearchInputValueBySearchList } from './searchFunction';
+
+// This file is for search trial ['Trial id', 'Trial No.']
+
+function GeneralSearch(props): any {
+    // searchName val: Trial No. | Trial id
+    const { searchName, searchFilter, dismiss, changeSearchFilterList, setSearchInputVal, updatePage } = props;
+    const [firstInputVal, setFirstInputVal] = useState(getSearchNameInit());
+
+    function updateFirstInputVal(ev: React.ChangeEvent<HTMLInputElement>): void {
+        setFirstInputVal(ev.target.value);
+    }
+
+    function getSearchNameInit(): string {
+        let str = ''; // init ''
+        const find = searchFilter.find(item => item.name === searchName);
+
+        if (find !== undefined) {
+            str = find.value1; // init by filter value
+        }
+
+        return str;
+    }
+
+    function startFilterTrial(): void {
+        const { searchFilter } = props;
+        const searchFilterConditions = JSON.parse(JSON.stringify(searchFilter));
+        const find = searchFilterConditions.filter(item => item.name === searchName);
+
+        if (firstInputVal === '') {
+            alert('Please input related value!');
+            return;
+        }
+
+        if (find.length > 0) {
+            // change this record
+            // Trial id | Trial No. only need {search name, search value} these message
+            searchFilterConditions.forEach(item => {
+                if (item.name === searchName) {
+                    item.value1 = firstInputVal;
+                    // item.operator = '';
+                    item.isChoice = false;
+                }
+            });
+        } else {
+            searchFilterConditions.push({
+                name: searchName,
+                // operator: '',
+                value1: firstInputVal,
+                isChoice: false
+            });
+        }
+        setSearchInputVal(getSearchInputValueBySearchList(searchFilterConditions));
+        changeSearchFilterList(searchFilterConditions);
+        updatePage();
+        dismiss(); // close menu
+    }
+
+    return (
+        // Trial id & Trial No.
+        <Stack horizontal className='filterConditions' tokens={searchConditonsGap}>
+            <span>{searchName === 'Trial id' ? 'Includes' : 'Equals to'}</span>
+            <input type='text' className='input input-padding' onChange={updateFirstInputVal} value={firstInputVal} />
+            <PrimaryButton text='Apply' className='btn-vertical-middle' onClick={startFilterTrial} />
+        </Stack>
+    );
+}
+
+GeneralSearch.propTypes = {
+    searchName: PropTypes.string,
+    searchFilter: PropTypes.array,
+    dismiss: PropTypes.func,
+    setSearchInputVal: PropTypes.func,
+    changeSearchFilterList: PropTypes.func,
+    updatePage: PropTypes.func
+};
+
+export default GeneralSearch;
diff --git a/ts/webui/src/components/trial-detail/search/Search.tsx b/ts/webui/src/components/trial-detail/search/Search.tsx
new file mode 100644
index 0000000000..ab39ee919b
--- /dev/null
+++ b/ts/webui/src/components/trial-detail/search/Search.tsx
@@ -0,0 +1,259 @@
+import React, { useState } from 'react';
+import PropTypes from 'prop-types';
+import {
+    Stack,
+    DefaultButton,
+    IContextualMenuProps,
+    IContextualMenuItem,
+    DirectionalHint,
+    SearchBox
+} from '@fluentui/react';
+import { EXPERIMENT } from '../../../static/datamodel';
+import { SearchItems } from '../../../static/interface';
+import SearchParameterConditions from './SearchParameterConditions';
+import GeneralSearch from './GeneralSearch';
+import { classNames, isChoiceType } from './searchFunction';
+
+// TableList search layout
+
+function Search(props): any {
+    const { searchFilter, changeSearchFilterList, updatePage } = props;
+    const [searchInputVal, setSearchInputVal] = useState('');
+
+    function getSearchMenu(parameterList): IContextualMenuProps {
+        const menu: Array<object> = [];
+
+        parameterList.unshift('StatusNNI');
+
+        ['Trial id', 'Trial No.'].forEach(item => {
+            menu.push({
+                key: item,
+                text: item,
+                subMenuProps: {
+                    items: [
+                        {
+                            key: item,
+                            text: item,
+                            // component: GeneralSearch.tsx
+                            onRender: renderIdAndNoComponent.bind(item)
+                        }
+                    ]
+                }
+            });
+        });
+
+        parameterList.forEach(item => {
+            menu.push({
+                key: item,
+                text: item === 'StatusNNI' ? 'Status' : item,
+                subMenuProps: {
+                    items: [
+                        {
+                            key: item,
+                            text: item,
+                            // component: SearchParameterConditions.tsx
+                            onRender: renderParametersSearchComponent.bind(item)
+                        }
+                    ]
+                }
+            });
+        });
+
+        const filterMenu: IContextualMenuProps = {
+            shouldFocusOnMount: true,
+            directionalHint: DirectionalHint.bottomLeftEdge,
+            className: classNames.menu,
+            items: menu as any
+        };
+
+        return filterMenu;
+    }
+
+    // Avoid nested experiments, nested experiments do not support hyperparameter search
+    const searchMenuProps: IContextualMenuProps = getSearchMenu(
+        EXPERIMENT.isNestedExp() ? [] : Object.keys(EXPERIMENT.searchSpace)
+    );
+
+    function renderParametersSearchComponent(item: IContextualMenuItem, dismissMenu: () => void): JSX.Element {
+        return (
+            <SearchParameterConditions
+                parameter={item.text}
+                searchFilter={searchFilter} // search filter list
+                changeSearchFilterList={changeSearchFilterList}
+                updatePage={updatePage}
+                setSearchInputVal={setSearchInputVal}
+                dismiss={dismissMenu} // close menu
+            />
+        );
+    }
+
+    function renderIdAndNoComponent(item: IContextualMenuItem, dismissMenu: () => void): JSX.Element {
+        return (
+            <GeneralSearch
+                searchName={item.text}
+                searchFilter={searchFilter} // search fliter list
+                changeSearchFilterList={changeSearchFilterList}
+                setSearchInputVal={setSearchInputVal}
+                updatePage={updatePage}
+                dismiss={dismissMenu} // after click Apply button to close menu
+            />
+        );
+    }
+
+    function updateSearchText(_, newValue): void {
+        setSearchInputVal(newValue);
+    }
+
+    // update TableList page
+    function changeTableListPage(searchFilterList: Array<SearchItems>): void {
+        changeSearchFilterList(searchFilterList);
+        updatePage();
+    }
+
+    // "[hello, world]", JSON.parse(it) doesn't work so write this function
+    function convertStringArrToList(str: string): string[] {
+        const value = str.slice(1, str.length - 1); // delete []
+        // delete ""
+        const result: string[] = [];
+
+        if (value.includes(',')) {
+            const arr = value.split(',');
+            arr.forEach(item => {
+                if (item !== '') {
+                    result.push(item);
+                }
+            });
+            return result;
+        } else {
+            if (value === '') {
+                return result;
+            } else {
+                return [value];
+            }
+        }
+    }
+
+    // SearchBox onSearch event: Filter based on the filter criteria entered by the user
+    function startFilter(): void {
+        const regEn = /`~!@#$%^&*()+?"{}.'/im;
+        const regCn = /·!#¥(——):;“”‘、,|《。》?、【】[\]]/im;
+        if (regEn.test(searchInputVal) || regCn.test(searchInputVal)) {
+            alert('Please delete special characters in the conditions!');
+            return;
+        }
+        // according [input val] to change searchFilter list
+        const allFilterConditions = searchInputVal.trim().split(';');
+        const newSearchFilter: any = [];
+
+        // delete '' in filter list
+        if (allFilterConditions.includes('')) {
+            allFilterConditions.splice(
+                allFilterConditions.findIndex(item => item === ''),
+                1
+            );
+        }
+
+        allFilterConditions.forEach(eachFilterConditionStr => {
+            let eachFilterConditionArr: string[] = [];
+
+            // EXPERIMENT.searchSpace[parameter]._type === 'choice'
+            if (eachFilterConditionStr.includes('>' || '<')) {
+                const operator = eachFilterConditionStr.includes('>') === true ? '>' : '<';
+                eachFilterConditionArr = eachFilterConditionStr.trim().split(operator);
+                newSearchFilter.push({
+                    name: eachFilterConditionArr[0],
+                    operator: operator,
+                    value1: eachFilterConditionArr[1],
+                    value2: '',
+                    choice: [],
+                    isChoice: false
+                });
+            } else if (eachFilterConditionStr.includes('≠')) {
+                // drop_rate≠6; status≠[x,xx,xxx]; conv_size≠[3,7]
+                eachFilterConditionArr = eachFilterConditionStr.trim().split('≠');
+                const filterName = eachFilterConditionArr[0] === 'Status' ? 'StatusNNI' : eachFilterConditionArr[0];
+                const isChoicesType = isChoiceType(filterName);
+                newSearchFilter.push({
+                    name: filterName,
+                    operator: '≠',
+                    value1: isChoicesType ? '' : JSON.parse(eachFilterConditionArr[1]),
+                    value2: '',
+                    choice: isChoicesType ? convertStringArrToList(eachFilterConditionArr[1]) : [],
+                    isChoice: isChoicesType ? true : false
+                });
+            } else {
+                // = : conv_size:[1,2,3,4]; Trial id:3; hidden_size:[1,2], status:[val1,val2,val3]
+                eachFilterConditionArr = eachFilterConditionStr.trim().split(':');
+                const filterName = eachFilterConditionArr[0] === 'Status' ? 'StatusNNI' : eachFilterConditionArr[0];
+                const isChoicesType = isChoiceType(filterName);
+                const isArray =
+                    eachFilterConditionArr.length > 1 && eachFilterConditionArr[1].includes('[' || ']') ? true : false;
+                if (isArray === true) {
+                    if (isChoicesType === true) {
+                        // status:[SUCCEEDED]
+                        newSearchFilter.push({
+                            name: filterName,
+                            operator: '=',
+                            value1: '',
+                            value2: '',
+                            choice: convertStringArrToList(eachFilterConditionArr[1]),
+                            isChoice: true
+                        });
+                    } else {
+                        // drop_rate:[1,10]
+                        newSearchFilter.push({
+                            name: eachFilterConditionArr[0],
+                            operator: 'between',
+                            value1: JSON.parse(eachFilterConditionArr[1])[0],
+                            value2: JSON.parse(eachFilterConditionArr[1])[1],
+                            choice: [],
+                            isChoice: false
+                        });
+                    }
+                } else {
+                    newSearchFilter.push({
+                        name: eachFilterConditionArr[0],
+                        operator: '=',
+                        value1: eachFilterConditionArr[1],
+                        value2: '',
+                        choice: [],
+                        isChoice: false
+                    });
+                }
+            }
+        });
+
+        changeTableListPage(newSearchFilter);
+    }
+
+    // clear search input all value, clear all search filter
+    function clearFliter(): void {
+        changeTableListPage([]);
+    }
+
+    return (
+        <div>
+            <Stack horizontal>
+                <DefaultButton text='Filter' menuProps={searchMenuProps} />
+                {/* search input: store filter conditons, also, user could input filter conditions, could search */}
+                <SearchBox
+                    styles={{ root: { width: 530 } }}
+                    placeholder='Search'
+                    onChange={updateSearchText}
+                    value={searchInputVal}
+                    onSearch={startFilter}
+                    onEscape={clearFliter}
+                    onClear={clearFliter}
+                />
+            </Stack>
+        </div>
+    );
+}
+
+Search.propTypes = {
+    searchFilter: PropTypes.array,
+    changeSearchFilterList: PropTypes.func,
+    updatePage: PropTypes.func
+};
+
+export default Search;
diff --git a/ts/webui/src/components/trial-detail/search/SearchParameterConditions.tsx b/ts/webui/src/components/trial-detail/search/SearchParameterConditions.tsx
new file mode 100644
index 0000000000..7d3b946507
--- /dev/null
+++ b/ts/webui/src/components/trial-detail/search/SearchParameterConditions.tsx
@@ -0,0 +1,197 @@
+import React, { useState } from 'react';
+import PropTypes from 'prop-types';
+import { Stack, PrimaryButton, Dropdown, IDropdownOption } from '@fluentui/react';
+import { EXPERIMENT } from '../../../static/datamodel';
+import { getDropdownOptions, getSearchInputValueBySearchList } from './searchFunction';
+import { searchConditonsGap } from '../../modals/ChildrenGap';
+
+// This file is for filtering trial parameters and trial status
+
+function SearchParameterConditions(props): any {
+    const { parameter, searchFilter, dismiss, changeSearchFilterList, updatePage, setSearchInputVal } = props;
+    const isChoiceTypeSearchFilter = parameter === 'StatusNNI' || EXPERIMENT.searchSpace[parameter]._type === 'choice';
+    const operatorList = isChoiceTypeSearchFilter ? ['=', '≠'] : ['between', '>', '<', '=', '≠'];
+
+    const initValueList = getInitVal();
+    const [operatorVal, setOperatorVal] = useState(initValueList[0]);
+    const [firstInputVal, setFirstInputVal] = useState(initValueList[1] as string);
+    const [secondInputVal, setSecondInputVal] = useState(initValueList[2] as string);
+    // status or choice parameter dropdown selected value list
+    const [choiceList, setChoiceList] = useState(initValueList[3] as string[]);
+
+    function getInitVal(): Array<string | string[]> {
+        // push value: operator, firstInputVal(value1), secondInputVal(value2), choiceValue
+        const str: Array<string | string[]> = [];
+
+        if (searchFilter.length > 0) {
+            const filterElement = searchFilter.find(ele => ele.name === parameter);
+            if (filterElement !== undefined) {
+                str.push(
+                    filterElement.operator,
+                    filterElement.value1.toString(),
+                    filterElement.value2.toString(),
+                    filterElement.choice.toString().split(',')
+                );
+            } else {
+                // set init value
+                str.push(`${isChoiceTypeSearchFilter ? '=' : 'between'}`, '', '', [] as string[]);
+            }
+        } else {
+            str.push(`${isChoiceTypeSearchFilter ? '=' : 'between'}`, '', '', [] as string[]);
+        }
+
+        return str;
+    }
+
+    function updateOperatorDropdown(_event: React.FormEvent<HTMLDivElement>, item: IDropdownOption | undefined): void {
+        if (item !== undefined) {
+            setOperatorVal(item.key.toString());
+        }
+    }
+
+    // get [status | parameters that type is choice] list
+    function updateChoiceDropdown(_event: React.FormEvent<HTMLDivElement>, item: IDropdownOption | undefined): void {
+        if (item !== undefined) {
+            const result = item.selected
+                ? [...choiceList, item.key as string]
+                : choiceList.filter(key => key !== item.key);
+            setChoiceList(result);
+        }
+    }
+
+    function updateFirstInputVal(ev: React.ChangeEvent<HTMLInputElement>): void {
+        setFirstInputVal(ev.target.value);
+    }
+
+    function updateSecondInputVal(ev: React.ChangeEvent<HTMLInputElement>): void {
+        setSecondInputVal(ev.target.value);
+    }
+
+    function getSecondInputVal(): string {
+        if (secondInputVal === '' && operatorVal === 'between') {
+            // if user uses 'between' operator and doesn't write the second input value,
+            // help to set second value as this parameter max value
+            return EXPERIMENT.searchSpace[parameter]._value[1].toString();
+        }
+
+        return secondInputVal as string;
+    }
+
+    // click Apply button
+    function startFilterTrials(): void {
+        if (isChoiceTypeSearchFilter === false) {
+            if (firstInputVal === '') {
+                alert('Please input related value!');
+                return;
+            }
+        }
+
+        if (firstInputVal.match(/[a-zA-Z]/) || secondInputVal.match(/[a-zA-Z]/)) {
+            alert('Please input a number!');
+            return;
+        }
+
+        let newSearchFilters = JSON.parse(JSON.stringify(searchFilter));
+        const find = newSearchFilters.filter(ele => ele.name === parameter);
+
+        if (find.length > 0) {
+            // if user clear all selected options, will clear this filter condition on the searchFilter list
+            // eg: conv_size -> choiceList = [], searchFilter will remove (name === 'conv_size')
+            if ((isChoiceTypeSearchFilter && choiceList.length !== 0) || isChoiceTypeSearchFilter === false) {
+                newSearchFilters.forEach(item => {
+                    if (item.name === parameter) {
+                        item.operator = operatorVal;
+                        item.value1 = firstInputVal;
+                        item.value2 = getSecondInputVal();
+                        item.choice = choiceList;
+                        item.isChoice = isChoiceTypeSearchFilter ? true : false;
+                    }
+                });
+            } else {
+                newSearchFilters = newSearchFilters.filter(item => item.name !== parameter);
+            }
+        } else {
+            if ((isChoiceTypeSearchFilter && choiceList.length !== 0) || isChoiceTypeSearchFilter === false) {
+                newSearchFilters.push({
+                    name: parameter,
+                    operator: operatorVal,
+                    value1: firstInputVal,
+                    value2: getSecondInputVal(),
+                    choice: choiceList,
+                    isChoice: isChoiceTypeSearchFilter ? true : false
+                });
+            }
+        }
+
+        setSearchInputVal(getSearchInputValueBySearchList(newSearchFilters));
+        changeSearchFilterList(newSearchFilters);
+        updatePage();
+        dismiss(); // close menu
+    }
+
+    return (
+        // for trial parameters & Status
+        <Stack horizontal className='filterConditions' tokens={searchConditonsGap}>
+            <Dropdown
+                selectedKey={operatorVal}
+                options={operatorList.map(item => ({
+                    key: item,
+                    text: item
+                }))}
+                onChange={updateOperatorDropdown}
+                className='btn-vertical-middle'
+                styles={{ root: { width: 100 } }}
+            />
+            {isChoiceTypeSearchFilter ? (
+                <Dropdown
+                    // selectedKeys:[] multiy, selectedKey: string
+                    selectedKeys={choiceList}
+                    multiSelect
+                    options={getDropdownOptions(parameter)}
+                    onChange={updateChoiceDropdown}
+                    className='btn-vertical-middle'
+                    styles={{ root: { width: 190 } }}
+                />
+            ) : (
+                <React.Fragment>
+                    {operatorVal === 'between' ? (
+                        <div>
+                            <input
+                                type='text'
+                                className='input input-padding'
+                                onChange={updateFirstInputVal}
+                                value={firstInputVal}
+                            />
+                            <span className='and'>and</span>
+                            <input
+                                type='text'
+                                className='input input-padding'
+                                onChange={updateSecondInputVal}
+                                value={secondInputVal}
+                            />
+                        </div>
+                    ) : (
+                        <input
+                            type='text'
+                            className='input input-padding'
+                            onChange={updateFirstInputVal}
+                            value={firstInputVal}
+                        />
+                    )}
+                </React.Fragment>
+            )}
+            <PrimaryButton text='Apply' className='btn-vertical-middle' onClick={startFilterTrials} />
+        </Stack>
+    );
+}
+
+SearchParameterConditions.propTypes = {
+    parameter: PropTypes.string,
+    searchFilter: PropTypes.array,
+    dismiss: PropTypes.func,
+    setSearchInputVal: PropTypes.func,
+    changeSearchFilterList: PropTypes.func,
+    updatePage: PropTypes.func
+};
+
+export default SearchParameterConditions;
diff --git a/ts/webui/src/components/trial-detail/search/searchFunction.ts b/ts/webui/src/components/trial-detail/search/searchFunction.ts
new file mode 100644
index 0000000000..2bc4bf4bc7
--- /dev/null
+++ b/ts/webui/src/components/trial-detail/search/searchFunction.ts
@@ -0,0 +1,203 @@
+import { mergeStyleSets } from '@fluentui/react';
+import { trialJobStatus } from '../../../static/const';
+import { EXPERIMENT } from '../../../static/datamodel';
+import { TableObj, SearchItems } from '../../../static/interface';
+
+const classNames = mergeStyleSets({
+    menu: {
+        textAlign: 'center',
+        maxWidth: 600,
+        selectors: {
+            '.ms-ContextualMenu-item': {
+                height: 'auto'
+            }
+        }
+    },
+    item: {
+        display: 'inline-block',
+        width: 40,
+        height: 40,
+        lineHeight: 40,
+        textAlign: 'center',
+        verticalAlign: 'middle',
+        marginBottom: 8,
+        cursor: 'pointer',
+        selectors: {
+            '&:hover': {
+                backgroundColor: '#eaeaea'
+            }
+        }
+    },
+    categoriesList: {
+        margin: 0,
+        padding: 0,
+        listStyleType: 'none'
+    },
+    button: {
+        width: '40%',
+        margin: '2%'
+    }
+});
+
+function getDropdownOptions(parameter): any {
+    if (parameter === 'StatusNNI') {
+        return trialJobStatus.map(item => ({
+            key: item,
+            text: item
+        }));
+    } else {
+        return EXPERIMENT.searchSpace[parameter]._value.map(item => ({
+            key: item.toString(),
+            text: item.toString()
+        }));
+    }
+}
+
+// change origin data according to parameter type, string -> number
+const convertParametersValue = (searchItems: SearchItems[], relation: Map<string, string>): SearchItems[] => {
+    const choice: any[] = [];
+
+    searchItems.forEach(item => {
+        if (relation.get(item.name) === 'number') {
+            if (item.isChoice === true) {
+                item.choice.forEach(ele => {
+                    choice.push(JSON.parse(ele));
+                });
+                item.choice = choice;
+            } else {
+                item.value1 = JSON.parse(item.value1);
+                if (item.value2 !== '') {
+                    item.value2 = JSON.parse(item.value2);
+                }
+            }
+        }
+    });
+
+    return searchItems;
+};
+// relation: trial parameter -> type {conv_size -> number}
+const getTrialsBySearchFilters = (
+    arr: TableObj[],
+    searchItems: SearchItems[],
+    relation: Map<string, string>
+): TableObj[] => {
+    const que = convertParametersValue(searchItems, relation);
+    // start to filter data by ['Trial id', 'Trial No.', 'Status'] [...parameters]...
+    que.forEach(element => {
+        if (element.name === 'Trial id') {
+            arr = arr.filter(trial => trial.id.toUpperCase().includes(element.value1.toUpperCase()));
+        } else if (element.name === 'Trial No.') {
+            arr = arr.filter(trial => trial.sequenceId.toString() === element.value1);
+        } else if (element.name === 'StatusNNI') {
+            arr = searchChoiceFilter(arr, element, 'status');
+        } else {
+            const parameter = `space/${element.name}`;
+
+            if (element.isChoice === true) {
+                arr = searchChoiceFilter(arr, element, element.name);
+            } else {
+                if (element.operator === '=') {
+                    arr = arr.filter(trial => trial[parameter] === element.value1);
+                } else if (element.operator === '>') {
+                    arr = arr.filter(trial => trial[parameter] > element.value1);
+                } else if (element.operator === '<') {
+                    arr = arr.filter(trial => trial[parameter] < element.value1);
+                } else if (element.operator === 'between') {
+                    arr = arr.filter(trial => trial[parameter] > element.value1 && trial[parameter] < element.value2);
+                } else {
+                    // operator is '≠'
+                    arr = arr.filter(trial => trial[parameter] !== element.value1);
+                }
+            }
+        }
+    });
+
+    return arr;
+};
+
+// isChoice = true: status and trial parameters
+function findTrials(arr: TableObj[], choice: string[], filed: string): TableObj[] {
+    const newResult: TableObj[] = [];
+    const parameter = filed === 'status' ? 'status' : `space/${filed}`;
+    arr.forEach(trial => {
+        choice.forEach(item => {
+            if (trial[parameter] === item) {
+                newResult.push(trial);
+            }
+        });
+    });
+
+    return newResult;
+}
+
+function searchChoiceFilter(arr: TableObj[], element: SearchItems, field: string): TableObj[] {
+    if (element.operator === '=') {
+        return findTrials(arr, element.choice, field);
+    } else {
+        let choice;
+        if (field === 'status') {
+            choice = trialJobStatus.filter(index => !new Set(element.choice).has(index));
+        } else {
+            choice = EXPERIMENT.searchSpace[field]._value.filter(index => !new Set(element.choice).has(index));
+        }
+        return findTrials(arr, choice, field);
+    }
+}
+
+// click Apply btn: set searchBox value now
+function getSearchInputValueBySearchList(searchFilter): string {
+    let str = ''; // store search input value
+
+    searchFilter.forEach(item => {
+        const filterName = item.name === 'StatusNNI' ? 'Status' : item.name;
+
+        if (item.isChoice === false) {
+            // id, No, !choice parameter
+            if (item.name === 'Trial id' || item.name === 'Trial No.') {
+                str = str + `${item.name}:${item.value1}; `;
+            } else {
+                // !choice parameter
+                if (['=', '≠', '>', '<'].includes(item.operator)) {
+                    str = str + `${filterName}${item.operator === '=' ? ':' : item.operator}${item.value1}; `;
+                } else {
+                    // between
+                    str = str + `${filterName}:[${item.value1},${item.value2}]; `;
+                }
+            }
+        } else {
+            // status, choice parameter
+            str = str + `${filterName}${item.operator === '=' ? ':' : '≠'}[${[...item.choice]}]; `;
+        }
+    });
+
+    return str;
+}
+
+/***
+ * from experiment search space
+* "conv_size": {
+        "_type": "choice", // is choice type
+        "_value": [
+            2,
+            3,
+            5,
+            7
+        ]
+    },
+ */
+function isChoiceType(parameterName): boolean {
+    // 判断是 [choice, status] 还是普通的类型
+    let flag = false; // normal type
+
+    if (parameterName === 'StatusNNI') {
+        flag = true;
+    }
+
+    if (parameterName in EXPERIMENT.searchSpace) {
+        flag = EXPERIMENT.searchSpace[parameterName]._type === 'choice' ? true : false;
+    }
+
+    return flag;
+}
+
+export { classNames, getDropdownOptions, getTrialsBySearchFilters, getSearchInputValueBySearchList, isChoiceType };
diff --git a/ts/webui/src/static/function.ts b/ts/webui/src/static/function.ts
index cb1c9275df..4c43be8503 100644
--- a/ts/webui/src/static/function.ts
+++ b/ts/webui/src/static/function.ts
@@ -2,6 +2,7 @@ import * as JSON5 from 'json5';
 import axios from 'axios';
 import { IContextualMenuProps } from '@fluentui/react';
 import { MANAGER_IP } from './const';
+import { EXPERIMENT } from './datamodel';
 import { MetricDataRecord, FinalType, TableObj, Tensorboard } from './interface';
 
 function getPrefix(): string | undefined {
@@ -356,6 +357,19 @@ function getTensorboardMenu(queryTensorboardList: Tensorboard[], stopFunc, seeDe
 
     return tensorboardMenu;
 }
+
+// search space type map list: now get type from search space
+const parametersType = (): Map<string, string> => {
+    const parametersTypeMap = new Map();
+    const trialParameterlist = Object.keys(EXPERIMENT.searchSpace);
+
+    trialParameterlist.forEach(item => {
+        parametersTypeMap.set(item, typeof EXPERIMENT.searchSpace[item]._value[0]);
+    });
+
+    return parametersTypeMap;
+};
+
 export {
     getPrefix,
     convertTime,
@@ -381,5 +395,6 @@ export {
     caclMonacoEditorHeight,
     copyAndSort,
     disableTensorboard,
-    getTensorboardMenu
+    getTensorboardMenu,
+    parametersType
 };
diff --git a/ts/webui/src/static/interface.ts b/ts/webui/src/static/interface.ts
index 1c25510b68..89b345e683 100644
--- a/ts/webui/src/static/interface.ts
+++ b/ts/webui/src/static/interface.ts
@@ -203,6 +203,16 @@ interface Tensorboard {
     port: string;
 }
 
+// for TableList search
+interface SearchItems {
+    name: string;
+    operator: string;
+    value1: string; // first input value
+    value2: string; // second input value
+    choice: string[]; // use select multiy value list
+    isChoice: boolean; // for parameters: type = choice and status also as choice type
+}
+
 export {
     TableObj,
     TableRecord,
@@ -226,5 +236,6 @@ export {
     MultipleAxes,
     SortInfo,
     AllExperimentList,
-    Tensorboard
+    Tensorboard,
+    SearchItems
 };
diff --git a/ts/webui/src/static/style/common.scss b/ts/webui/src/static/style/common.scss
index 95af28b5d6..a869bf0cdf 100644
--- a/ts/webui/src/static/style/common.scss
+++ b/ts/webui/src/static/style/common.scss
@@ -73,3 +73,7 @@ $themeBlue: #0071bc;
 .bold {
     font-weight: bold;
 }
+
+.input-padding {
+    padding-left: 10px;
+}
diff --git a/ts/webui/src/static/style/search.scss b/ts/webui/src/static/style/search.scss
index b302acea70..9ad10b3473 100644
--- a/ts/webui/src/static/style/search.scss
+++ b/ts/webui/src/static/style/search.scss
@@ -3,23 +3,17 @@
     width: 96%;
     margin: 0 auto;
     margin-top: 15px;
+    position: relative;
 
     &-compare {
         margin-top: 15px;
     }
 
-    &-entry {
-        line-height: 32px;
-    }
-
     /* compare button style */
     &-button-gap {
         margin-right: 10px;
     }
 
-    &-search-input {
-        padding-left: 10px;
-    }
 }
 
 /* each row's Intermediate btn -> Modal */
@@ -33,3 +27,32 @@
         width: 120px;
     }
 }
+
+$filterConditionsHeight: 54px;
+
+.filterConditions {
+    height: $filterConditionsHeight;
+    line-height: $filterConditionsHeight;
+    padding: 0 15px;
+
+    .btn-vertical-middle {
+        margin-top: 11px;
+    }
+
+    .input {
+        width: 100px;
+        height: 24px;
+        margin-top: 13px;
+        border: 1px solid #333;
+        border-radius: 20px;
+        outline: none;
+    }
+
+    .and {
+        margin: 0 4px;
+    }
+}
+
+.ms-ContextualMenu-Callout {
+    display: block;
+}