Skip to content

Commit

Permalink
Add feature group api and hive and spark engine (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
moritzmeister authored Feb 19, 2020
1 parent 77a1935 commit 4d8647c
Show file tree
Hide file tree
Showing 14 changed files with 381 additions and 60 deletions.
58 changes: 24 additions & 34 deletions python/hopsworks/connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import functools
from requests.exceptions import ConnectionError

from hopsworks import util, engine
from hopsworks.core import client, feature_store_api


Expand Down Expand Up @@ -31,6 +32,7 @@ def __init__(
self._hostname_verification = (
hostname_verification or self.HOSTNAME_VERIFICATION_DEFAULT
)
# what's the difference between trust store path and cert folder
self._trust_store_path = trust_store_path
self._cert_folder = cert_folder or self.CERT_FOLDER_DEFAULT
self._api_key_file = api_key_file
Expand Down Expand Up @@ -77,40 +79,26 @@ def connect(self):
self._cert_folder,
self._api_key_file,
)
engine.init(
"hive", self._host, self._cert_folder, self._client._cert_key
)
else:
self._client = client.HopsworksClient()
except TypeError:
engine.init("spark")
self._feature_store_api = feature_store_api.FeatureStoreApi(self._client)
except (TypeError, ConnectionError):
self._connected = False
raise
self._feature_store_api = feature_store_api.FeatureStoreApi(self._client)
print("CONNECTED")

def close(self):
self._client._close()
self._feature_store_api = None
self._client = None
engine.stop()
self._connected = False
print("CONNECTION CLOSED")

def not_connected(fn):
@functools.wraps(fn)
def if_not_connected(inst, *args, **kwargs):
if inst._connected:
raise ConnectionError
return fn(inst, *args, **kwargs)

return if_not_connected

def connected(fn):
@functools.wraps(fn)
def if_connected(inst, *args, **kwargs):
if not inst._connected:
raise NoConnectionError
return fn(inst, *args, **kwargs)

return if_connected

@connected
@util.connected
def get_feature_store(self, name=None):
"""Get a reference to a feature store, to perform operations on.
Expand All @@ -124,14 +112,16 @@ def get_feature_store(self, name=None):
"""
if not name:
name = self._client._project_name + "_featurestore"
# TODO: this won't work with multiple feature stores
engine.get_instance()._feature_store = name
return self._feature_store_api.get(name)

@property
def host(self):
return self._host

@host.setter
@not_connected
@util.not_connected
def host(self, host):
self._host = host

Expand All @@ -140,7 +130,7 @@ def port(self):
return self._port

@port.setter
@not_connected
@util.not_connected
def port(self, port):
self._port = port

Expand All @@ -149,7 +139,7 @@ def project(self):
return self._project

@project.setter
@not_connected
@util.not_connected
def project(self, project):
self._project = project

Expand All @@ -158,7 +148,7 @@ def region_name(self):
return self._region_name

@region_name.setter
@not_connected
@util.not_connected
def region_name(self, region_name):
self._region_name = region_name

Expand All @@ -167,7 +157,7 @@ def secrets_store(self):
return self._secrets_store

@secrets_store.setter
@not_connected
@util.not_connected
def secrets_store(self, secrets_store):
self._secrets_store = secrets_store

Expand All @@ -176,7 +166,7 @@ def hostname_verification(self):
return self._hostname_verification

@hostname_verification.setter
@not_connected
@util.not_connected
def hostname_verification(self, hostname_verification):
self._hostname_verification = hostname_verification

Expand All @@ -185,7 +175,7 @@ def trust_store_path(self):
return self._trust_store_path

@trust_store_path.setter
@not_connected
@util.not_connected
def trust_store_path(self, trust_store_path):
self._trust_store_path = trust_store_path

Expand All @@ -194,7 +184,7 @@ def cert_folder(self):
return self._cert_folder

@cert_folder.setter
@not_connected
@util.not_connected
def cert_folder(self, cert_folder):
self._cert_folder = cert_folder

Expand All @@ -203,7 +193,7 @@ def api_key_file(self):
return self._api_key_file

@api_key_file.setter
@not_connected
@util.not_connected
def api_key_file(self, api_key_file):
self._api_key_file = api_key_file

Expand All @@ -215,7 +205,7 @@ def __exit__(self, type, value, traceback):
self.close()


class ConnectionError(Exception):
class HopsworksConnectionError(Exception):
"""Thrown when attempted to change connection attributes while connected."""

def __init__(self):
Expand All @@ -224,7 +214,7 @@ def __init__(self):
)


class NoConnectionError(Exception):
class NoHopsworksConnectionError(Exception):
"""Thrown when attempted to perform operation on connection while not connected."""

def __init__(self):
Expand Down
25 changes: 18 additions & 7 deletions python/hopsworks/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import requests
import urllib3

from hopsworks import util


urllib3.disable_warnings(urllib3.exceptions.SecurityWarning)
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
Expand Down Expand Up @@ -102,6 +104,7 @@ def _read_jwt(self):
with open(self.TOKEN_FILE, "r") as jwt:
return jwt.read()

@util.connected
def _send_request(
self, method, path_params, query_params=None, headers=None, data=None
):
Expand Down Expand Up @@ -158,7 +161,7 @@ def _send_request(

def _close(self):
"""Closes a client. Can be implemented for clean up purposes, not mandatory."""
pass
self._connected = False


class HopsworksClient(BaseClient):
Expand All @@ -172,7 +175,7 @@ class HopsworksClient(BaseClient):
def __init__(self):
"""Initializes a client being run from a job/notebook directly on Hopsworks."""
self._base_url = self._get_hopsworks_rest_endpoint()
host, port = self._get_host_port_pair()
self._host, self._port = self._get_host_port_pair()
trust_store_path = (
os.environ[self.DOMAIN_CA_TRUSTSTORE_PEM]
if self.DOMAIN_CA_TRUSTSTORE_PEM in os.environ
Expand All @@ -187,10 +190,12 @@ def __init__(self):
self._project_name = self._project_name()
self._auth = BearerAuth(self._read_jwt())
self._verify = self._get_verify(
host, port, hostname_verification, trust_store_path
self._host, self._port, hostname_verification, trust_store_path
)
self._session = requests.session()

self._connected = True

def _get_hopsworks_rest_endpoint(self):
"""Get the hopsworks REST endpoint for making requests to the REST API."""
return os.environ[self.REST_ENDPOINT]
Expand Down Expand Up @@ -240,7 +245,9 @@ def __init__(
if not project:
raise ExternalClientError("project")

self._base_url = "https://" + host + ":" + str(port)
self._host = host
self._port = port
self._base_url = "https://" + self._host + ":" + str(self._port)
self._project_name = project
self._region_name = region_name
self._cert_folder = cert_folder
Expand All @@ -250,19 +257,22 @@ def __init__(
)

self._session = requests.session()
self._connected = True
self._verify = self._get_verify(
host, port, hostname_verification, trust_store_path
self._host, self._port, hostname_verification, trust_store_path
)

project_info = self._get_project_info(self._project_name)
self._project_id = str(project_info["projectId"])

credentials = self._get_credentials(self._project_id)
self._write_b64_cert_to_bytes(
str(credentials["kStore"]), path=os.path.join(cert_folder, "keyStore.jks")
str(credentials["kStore"]),
path=os.path.join(self._cert_folder, "keyStore.jks"),
)
self._write_b64_cert_to_bytes(
str(credentials["tStore"]), path=os.path.join(cert_folder, "trustStore.jks")
str(credentials["tStore"]),
path=os.path.join(self._cert_folder, "trustStore.jks"),
)

self._cert_key = str(credentials["password"])
Expand All @@ -271,6 +281,7 @@ def _close(self):
"""Closes a client and deletes certificates."""
self._cleanup_file(os.path.join(self._cert_folder, "keyStore.jks"))
self._cleanup_file(os.path.join(self._cert_folder, "trustStore.jks"))
self._connected = False

def _get_secret(self, secrets_store, secret_key=None, api_key_file=None):
"""Returns secret value from the AWS Secrets Manager or Parameter Store.
Expand Down
30 changes: 30 additions & 0 deletions python/hopsworks/core/feature_group_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from hopsworks import feature_group


class FeatureGroupApi:
def __init__(self, client, feature_store_id):
# or should the client be passed with every call to the API
self._client = client
self._feature_store_id = feature_store_id

def get(self, name, version):
"""Get feature store with specific id or name.
:param identifier: id or name of the feature store
:type identifier: int, str
:return: the featurestore metadata
:rtype: FeatureStore
"""
path_params = [
"project",
self._client._project_id,
"featurestores",
self._feature_store_id,
"featuregroups",
name,
]
query_params = {"version": version}
return feature_group.FeatureGroup.from_response_json(
self._client,
self._client._send_request("GET", path_params, query_params)[0],
)
2 changes: 1 addition & 1 deletion python/hopsworks/core/feature_store_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ def get(self, identifier):
"""
path_params = ["project", self._client._project_id, "featurestores", identifier]
return FeatureStore.from_response_json(
self._client._send_request("GET", path_params)
self._client, self._client._send_request("GET", path_params)
)
24 changes: 24 additions & 0 deletions python/hopsworks/core/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import json

from hopsworks import util, engine


class Query:
def __init__(
self, query_constructor_api, left_feature_group, left_features, joins=None,
):
self._left_feature_group = left_feature_group
self._left_features = left_features
self._joins = joins
self._query_constructor_api = query_constructor_api

def read(self, dataframe_type="default"):
sql_query = self._query_constructor_api.construct_query(self)["query"]
return engine.get_instance().sql(sql_query, dataframe_type)

def show(self, n):
sql_query = self._query_constructor_api.construct_query(self)["query"]
return engine.get_instance().show(sql_query, n)

def json(self):
return json.dumps(self, cls=util.QueryEncoder)
10 changes: 10 additions & 0 deletions python/hopsworks/core/query_constructor_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class QueryConstructorApi:
def __init__(self, client):
self._client = client

def construct_query(self, query):
path_params = ["project", self._client._project_id, "featurestores", "query"]
headers = {"content-type": "application/json"}
return self._client._send_request(
"PUT", path_params, headers=headers, data=query.json()
)
24 changes: 24 additions & 0 deletions python/hopsworks/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from hopsworks.engine import spark, hive

_engine = None


def init(engine_type, host=None, cert_folder=None, cert_key=None):
global _engine
if not _engine:
if engine_type == "spark":
_engine = spark.Engine()
elif engine_type == "hive":
_engine = hive.Engine(host, cert_folder, cert_key)


def get_instance():
global _engine
if _engine:
return _engine
raise Exception("Couldn't find execution engine. Try reconnecting to Hopsworks.")


def stop():
global _engine
_engine = None
Loading

0 comments on commit 4d8647c

Please sign in to comment.