-
Notifications
You must be signed in to change notification settings - Fork 383
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor RAIInsights into RAIInsightsBase class for basic functionality
- Loading branch information
1 parent
2f550f5
commit a6594b1
Showing
2 changed files
with
353 additions
and
215 deletions.
There are no files selected for viewing
318 changes: 318 additions & 0 deletions
318
responsibleai/responsibleai/rai_insights/rai_base_insights.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,318 @@ | ||
# Copyright (c) Microsoft Corporation | ||
# Licensed under the MIT License. | ||
|
||
"""Defines the RAIBaseInsights class.""" | ||
|
||
import json | ||
import pickle | ||
import warnings | ||
from pathlib import Path | ||
from typing import Any, Optional | ||
|
||
import pandas as pd | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
from responsibleai._internal.constants import Metadata | ||
|
||
_DATA = 'data' | ||
_DTYPES = 'dtypes' | ||
_TRAIN = 'train' | ||
_TEST = 'test' | ||
_MODEL = Metadata.MODEL | ||
_MODEL_PKL = _MODEL + '.pkl' | ||
_SERIALIZER = 'serializer' | ||
_MANAGERS = 'managers' | ||
_JSON_EXTENSION = '.json' | ||
|
||
|
||
class RAIBaseInsights(ABC): | ||
"""Defines the base class RAIBaseInsights for the top-level API. | ||
This class is abstract and should not be instantiated. | ||
""" | ||
|
||
def __init__(self, model: Any, train: pd.DataFrame, test: pd.DataFrame, | ||
target_column: str, task_type: str, | ||
serializer: Optional[Any] = None): | ||
"""Creates an RAIBaseInsights object. | ||
:param model: The model to compute RAI insights for. | ||
A model that implements sklearn.predict or sklearn.predict_proba | ||
or function that accepts a 2d ndarray. | ||
:type model: object | ||
:param train: The training dataset including the label column. | ||
:type train: pandas.DataFrame | ||
:param test: The test dataset including the label column. | ||
:type test: pandas.DataFrame | ||
:param target_column: The name of the label column. | ||
:type target_column: str | ||
:param task_type: The task to run. | ||
:type task_type: str | ||
:param classes: The class labels in the training dataset | ||
:type classes: numpy.ndarray | ||
:param serializer: Picklable custom serializer with save and load | ||
methods for custom model serialization. | ||
The save method writes the model to file given a parent directory. | ||
The load method returns the deserialized model from the same | ||
parent directory. | ||
:type serializer: object | ||
""" | ||
self.model = model | ||
self.train = train | ||
self.test = test | ||
self.target_column = target_column | ||
self.task_type = task_type | ||
self._serializer = serializer | ||
self._initialize_managers() | ||
|
||
@abstractmethod | ||
def _initialize_managers(self): | ||
"""Initializes the managers. | ||
This method is abstract and should not be called. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def _validate_model_analysis_input_parameters(self, *args): | ||
"""Abstract method to validate the inputs for the constructor.""" | ||
pass | ||
|
||
def compute(self): | ||
"""Calls compute on each of the managers.""" | ||
for manager in self._managers: | ||
manager.compute() | ||
|
||
def list(self): | ||
"""List information about each of the managers. | ||
:return: Information about each of the managers. | ||
:rtype: dict | ||
""" | ||
configs = {} | ||
for manager in self._managers: | ||
configs[manager.name] = manager.list() | ||
return configs | ||
|
||
def get(self): | ||
"""List information about each of the managers. | ||
:return: Information about each of the managers. | ||
:rtype: dict | ||
""" | ||
configs = {} | ||
for manager in self._managers: | ||
configs[manager.name] = manager.get() | ||
return configs | ||
|
||
@abstractmethod | ||
def get_data(self): | ||
"""Get all data as RAIInsightsData object | ||
:return: Model Analysis Data | ||
:rtype: RAIInsightsData | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def _get_dataset(self): | ||
pass | ||
|
||
def _write_to_file(self, file_path, content): | ||
"""Save the string content to the given file path. | ||
:param file_path: The file path to save the content to. | ||
:type file_path: str | ||
:param content: The string content to save. | ||
:type content: str | ||
""" | ||
with open(file_path, 'w') as file: | ||
file.write(content) | ||
|
||
@abstractmethod | ||
def _save_predictions(self, path): | ||
"""Save the predict() and predict_proba() output. | ||
:param path: The directory path to save the RAIInsights to. | ||
:type path: str | ||
""" | ||
pass | ||
|
||
def _save_data(self, path): | ||
"""Save the copy of raw data (train and test sets) and | ||
their related metadata. | ||
:param path: The directory path to save the RAIBaseInsights to. | ||
:type path: str | ||
""" | ||
data_directory = Path(path) / _DATA | ||
data_directory.mkdir(parents=True, exist_ok=True) | ||
dtypes = self.train.dtypes.astype(str).to_dict() | ||
self._write_to_file(data_directory / | ||
(_TRAIN + _DTYPES + _JSON_EXTENSION), | ||
json.dumps(dtypes)) | ||
self._write_to_file(data_directory / (_TRAIN + _JSON_EXTENSION), | ||
self.train.to_json(orient='split')) | ||
|
||
dtypes = self.test.dtypes.astype(str).to_dict() | ||
self._write_to_file(data_directory / | ||
(_TEST + _DTYPES + _JSON_EXTENSION), | ||
json.dumps(dtypes)) | ||
self._write_to_file(data_directory / (_TEST + _JSON_EXTENSION), | ||
self.test.to_json(orient='split')) | ||
|
||
@abstractmethod | ||
def _save_metadata(self, path): | ||
"""Save the metadata like target column, categorical features, | ||
task type and the classes (if any). | ||
:param path: The directory path to save the RAIBaseInsights to. | ||
:type path: str | ||
""" | ||
pass | ||
|
||
def _save_model(self, path): | ||
"""Save the model and the serializer (if any). | ||
:param path: The directory path to save the RAIInsights to. | ||
:type path: str | ||
""" | ||
top_dir = Path(path) | ||
if self._serializer is not None: | ||
# save the model | ||
self._serializer.save(self.model, top_dir) | ||
# save the serializer | ||
with open(top_dir / _SERIALIZER, 'wb') as file: | ||
pickle.dump(self._serializer, file) | ||
else: | ||
if self.model is not None: | ||
has_setstate = hasattr(self.model, '__setstate__') | ||
has_getstate = hasattr(self.model, '__getstate__') | ||
if not (has_setstate and has_getstate): | ||
raise ValueError( | ||
"Model must be picklable or a custom serializer must" | ||
" be specified") | ||
with open(top_dir / _MODEL_PKL, 'wb') as file: | ||
pickle.dump(self.model, file) | ||
|
||
def _save_managers(self, path): | ||
"""Save the state of individual managers. | ||
:param path: The directory path to save the RAIInsights to. | ||
:type path: str | ||
""" | ||
top_dir = Path(path) | ||
# save each of the individual managers | ||
for manager in self._managers: | ||
manager._save(top_dir / manager.name) | ||
|
||
def save(self, path): | ||
"""Save the RAIBaseInsights to the given path. | ||
:param path: The directory path to save the RAIInsights to. | ||
:type path: str | ||
""" | ||
self._save_managers(path) | ||
self._save_data(path) | ||
self._save_metadata(path) | ||
self._save_model(path) | ||
self._save_predictions(path) | ||
|
||
@staticmethod | ||
def _load_data(inst, path): | ||
"""Load the raw data (train and test sets). | ||
:param inst: RAIInsights object instance. | ||
:type inst: RAIInsights | ||
:param path: The directory path to data location. | ||
:type path: str | ||
""" | ||
data_directory = Path(path) / _DATA | ||
with open(data_directory / | ||
(_TRAIN + _DTYPES + _JSON_EXTENSION), 'r') as file: | ||
types = json.load(file) | ||
with open(data_directory / (_TRAIN + _JSON_EXTENSION), 'r') as file: | ||
train = pd.read_json(file, dtype=types, orient='split') | ||
inst.__dict__[_TRAIN] = train | ||
with open(data_directory / | ||
(_TEST + _DTYPES + _JSON_EXTENSION), 'r') as file: | ||
types = json.load(file) | ||
with open(data_directory / (_TEST + _JSON_EXTENSION), 'r') as file: | ||
test = pd.read_json(file, dtype=types, orient='split') | ||
inst.__dict__[_TEST] = test | ||
|
||
@staticmethod | ||
def _load_model(inst, path): | ||
"""Load the model. | ||
:param inst: RAIBaseInsights object instance. | ||
:type inst: RAIBaseInsights | ||
:param path: The directory path to model location. | ||
:type path: str | ||
""" | ||
top_dir = Path(path) | ||
serializer_path = top_dir / _SERIALIZER | ||
model_load_err = ('ERROR-LOADING-USER-MODEL: ' | ||
'There was an error loading the user model. ' | ||
'Some of RAI dashboard features may not work.') | ||
if serializer_path.exists(): | ||
try: | ||
with open(serializer_path, 'rb') as file: | ||
serializer = pickle.load(file) | ||
inst.__dict__['_' + _SERIALIZER] = serializer | ||
inst.__dict__[_MODEL] = serializer.load(top_dir) | ||
except Exception: | ||
warnings.warn(model_load_err) | ||
inst.__dict__[_MODEL] = None | ||
else: | ||
inst.__dict__['_' + _SERIALIZER] = None | ||
try: | ||
with open(top_dir / _MODEL_PKL, 'rb') as file: | ||
inst.__dict__[_MODEL] = pickle.load(file) | ||
except Exception: | ||
warnings.warn(model_load_err) | ||
inst.__dict__[_MODEL] = None | ||
|
||
@staticmethod | ||
def _load_managers(inst, path, manager_map): | ||
"""Load the model. | ||
:param inst: RAIInsights object instance. | ||
:type inst: RAIInsights | ||
:param path: The directory path to manager location. | ||
:type path: str | ||
:param manager_map: The map of manager names to manager classes. | ||
:type manager_map: dict | ||
""" | ||
top_dir = Path(path) | ||
managers = [] | ||
for manager_name, manager_class in manager_map.items(): | ||
full_name = f'_{manager_name}_manager' | ||
manager_dir = top_dir / manager_name | ||
manager = manager_class._load(manager_dir, inst) | ||
inst.__dict__[full_name] = manager | ||
managers.append(manager) | ||
|
||
inst.__dict__['_' + _MANAGERS] = managers | ||
|
||
@staticmethod | ||
def _load(path, inst, manager_map, load_metadata_func): | ||
"""Load the RAIInsights from the given path. | ||
:param path: The directory path to load the RAIInsights from. | ||
:type path: str | ||
:param inst: RAIInsights object instance. | ||
:type inst: RAIInsights | ||
:param manager_map: The map of manager names to manager classes. | ||
:type manager_map: dict | ||
:param load_metadata_func: The function to load the metadata. | ||
:type load_metadata_func: function | ||
:return: The RAIBaseInsights object after loading. | ||
:rtype: RAIBaseInsights | ||
""" | ||
# load current state | ||
RAIBaseInsights._load_data(inst, path) | ||
load_metadata_func(inst, path) | ||
RAIBaseInsights._load_model(inst, path) | ||
RAIBaseInsights._load_managers(inst, path, manager_map) | ||
|
||
return inst |
Oops, something went wrong.