1- from typing import Any , Dict , Optional , Tuple
1+ from typing import Any , Dict , Optional , Tuple , Union
22
33from neo4j import GraphDatabase
44from pandas import DataFrame
88from graphdatascience .call_builder import IndirectCallBuilder
99from graphdatascience .endpoints import AlphaEndpoints , BetaEndpoints , DirectEndpoints
1010from graphdatascience .error .uncallable_namespace import UncallableNamespace
11+ from graphdatascience .graph .graph_proc_runner import GraphRemoteProcRunner
1112from graphdatascience .query_runner .arrow_query_runner import ArrowQueryRunner
1213from graphdatascience .query_runner .aura_db_arrow_query_runner import (
1314 AuraDbArrowQueryRunner ,
1415 AuraDbConnectionInfo ,
1516)
17+ from graphdatascience .query_runner .query_runner import QueryRunner
1618
1719
1820class AuraGraphDataScience (DirectEndpoints , UncallableNamespace ):
@@ -23,46 +25,51 @@ class AuraGraphDataScience(DirectEndpoints, UncallableNamespace):
2325
2426 def __init__ (
2527 self ,
26- endpoint : str ,
28+ endpoint : Union [ str , QueryRunner ] ,
2729 auth : Tuple [str , str ],
2830 aura_db_connection_info : AuraDbConnectionInfo ,
2931 arrow_disable_server_verification : bool = True ,
3032 arrow_tls_root_certs : Optional [bytes ] = None ,
3133 bookmarks : Optional [Any ] = None ,
3234 ):
33- gds_query_runner = ArrowQueryRunner .create (
34- Neo4jQueryRunner .create (endpoint , auth , aura_ds = True ),
35- auth ,
36- True ,
37- arrow_disable_server_verification ,
38- arrow_tls_root_certs ,
39- )
40-
41- self ._server_version = gds_query_runner .server_version ()
42-
43- if self ._server_version < ServerVersion (2 , 6 , 0 ):
44- raise RuntimeError (
45- f"AuraDB connection info was provided but GDS version { self ._server_version } \
46- does not support connecting to AuraDB"
35+ if isinstance (endpoint , str ):
36+ gds_query_runner = ArrowQueryRunner .create (
37+ Neo4jQueryRunner .create (endpoint , auth , aura_ds = True ),
38+ auth ,
39+ True ,
40+ arrow_disable_server_verification ,
41+ arrow_tls_root_certs ,
4742 )
4843
49- self ._driver_config = gds_query_runner .driver_config ()
50- driver = GraphDatabase .driver (
51- aura_db_connection_info .uri , auth = aura_db_connection_info .auth , ** self ._driver_config
52- )
53- self ._db_query_runner = Neo4jQueryRunner (
54- driver , auto_close = True , bookmarks = bookmarks , server_version = self ._server_version
55- )
44+ self ._server_version = gds_query_runner .server_version ()
5645
57- # we need to explicitly set these as the default value is None
58- # which signals the driver to use the default configured database
59- # from the dbms.
60- gds_query_runner . set_database ( "neo4j" )
61- self . _db_query_runner . set_database ( "neo4j" )
46+ if self . _server_version < ServerVersion ( 2 , 6 , 0 ):
47+ raise RuntimeError (
48+ f"AuraDB connection info was provided but GDS version { self . _server_version } \
49+ does not support connecting to AuraDB"
50+ )
6251
63- self ._query_runner = AuraDbArrowQueryRunner (
64- gds_query_runner , self ._db_query_runner , driver .encrypted , aura_db_connection_info
65- )
52+ self ._driver_config = gds_query_runner .driver_config ()
53+ driver = GraphDatabase .driver (
54+ aura_db_connection_info .uri , auth = aura_db_connection_info .auth , ** self ._driver_config
55+ )
56+ self ._db_query_runner : QueryRunner = Neo4jQueryRunner (
57+ driver , auto_close = True , bookmarks = bookmarks , server_version = self ._server_version
58+ )
59+
60+ # we need to explicitly set these as the default value is None
61+ # which signals the driver to use the default configured database
62+ # from the dbms.
63+ gds_query_runner .set_database ("neo4j" )
64+ self ._db_query_runner .set_database ("neo4j" )
65+
66+ self ._query_runner = AuraDbArrowQueryRunner (
67+ gds_query_runner , self ._db_query_runner , driver .encrypted , aura_db_connection_info
68+ )
69+ else :
70+ self ._query_runner = endpoint
71+ self ._db_query_runner = endpoint
72+ self ._server_version = self ._query_runner .server_version ()
6673
6774 super ().__init__ (self ._query_runner , "gds" , self ._server_version )
6875
@@ -87,6 +94,10 @@ def run_cypher(
8794 # This will avoid calling valid gds procedures through a raw string
8895 return self ._db_query_runner .run_cypher (query , params , database , False )
8996
97+ @property
98+ def graph (self ) -> GraphRemoteProcRunner :
99+ return GraphRemoteProcRunner (self ._query_runner , f"{ self ._namespace } .graph" , self ._server_version )
100+
90101 @property
91102 def alpha (self ) -> AlphaEndpoints :
92103 return AlphaEndpoints (self ._query_runner , "gds.alpha" , self ._server_version )
0 commit comments