@@ -39,38 +39,19 @@ pub fn init_table(
39
39
init:: init_pgmq ( ) ?;
40
40
41
41
let guc_configs = get_guc_configs ( & transformer. source ) ;
42
- let provider = get_provider (
43
- & transformer. source ,
44
- guc_configs. api_key . clone ( ) ,
45
- guc_configs. service_url ,
46
- None ,
47
- ) ?;
48
-
49
- //synchronous
50
- let runtime = tokio:: runtime:: Builder :: new_current_thread ( )
51
- . enable_io ( )
52
- . enable_time ( )
53
- . build ( )
54
- . unwrap_or_else ( |e| error ! ( "failed to initialize tokio runtime: {}" , e) ) ;
55
- let model_dim =
56
- match runtime. block_on ( async { provider. model_dim ( & transformer. api_name ( ) ) . await } ) {
57
- Ok ( e) => e,
58
- Err ( e) => {
59
- error ! ( "error getting model dim: {}" , e) ;
60
- }
61
- } ;
62
-
63
- // validate API key where necessary
42
+ info ! ( "guc_configs: {:?}" , guc_configs) ;
43
+ // validate API key where necessary and collect any optional arguments
64
44
// certain embedding services require an API key, e.g. openAI
65
45
// key can be set in a GUC, so if its required but not provided in args, and not in GUC, error
66
- match transformer. source {
46
+ let optional_args = match transformer. source {
67
47
ModelSource :: OpenAI => {
68
48
openai:: validate_api_key (
69
49
& guc_configs
70
50
. api_key
71
51
. clone ( )
72
52
. context ( "OpenAI key is required" ) ?,
73
53
) ?;
54
+ None
74
55
}
75
56
ModelSource :: Tembo => {
76
57
error ! ( "Tembo not implemented for search yet" ) ;
@@ -85,15 +66,40 @@ pub fn init_table(
85
66
let res = check_model_host ( & url) ;
86
67
match res {
87
68
Ok ( _) => {
88
- info ! ( "Model host active!" )
69
+ info ! ( "Model host active!" ) ;
70
+ None
89
71
}
90
72
Err ( e) => {
91
73
error ! ( "Error with model host: {:?}" , e)
92
74
}
93
75
}
94
76
}
95
- _ => ( ) ,
96
- }
77
+ ModelSource :: Portkey => Some ( serde_json:: json!( {
78
+ "virtual_key" : guc_configs. virtual_key. clone( ) . expect( "Portkey virtual key is required" )
79
+ } ) ) ,
80
+ _ => None ,
81
+ } ;
82
+
83
+ let provider = get_provider (
84
+ & transformer. source ,
85
+ guc_configs. api_key . clone ( ) ,
86
+ guc_configs. service_url . clone ( ) ,
87
+ guc_configs. virtual_key . clone ( ) ,
88
+ ) ?;
89
+
90
+ // synchronous
91
+ let runtime = tokio:: runtime:: Builder :: new_current_thread ( )
92
+ . enable_io ( )
93
+ . enable_time ( )
94
+ . build ( )
95
+ . unwrap_or_else ( |e| error ! ( "failed to initialize tokio runtime: {}" , e) ) ;
96
+ let model_dim =
97
+ match runtime. block_on ( async { provider. model_dim ( & transformer. api_name ( ) ) . await } ) {
98
+ Ok ( e) => e,
99
+ Err ( e) => {
100
+ error ! ( "error getting model dim: {}" , e) ;
101
+ }
102
+ } ;
97
103
98
104
let valid_params = types:: JobParams {
99
105
schema : schema. to_string ( ) ,
@@ -105,6 +111,7 @@ pub fn init_table(
105
111
pkey_type,
106
112
api_key : guc_configs. api_key . clone ( ) ,
107
113
schedule : schedule. to_string ( ) ,
114
+ args : optional_args,
108
115
} ;
109
116
let params =
110
117
pgrx:: JsonB ( serde_json:: to_value ( valid_params. clone ( ) ) . expect ( "error serializing params" ) ) ;
0 commit comments