This repository has been archived by the owner on Jan 18, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #42 from milvus-io/split-files
Refine utils file structure
- Loading branch information
Showing
5 changed files
with
371 additions
and
361 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from Types import ParameterException | ||
import os | ||
|
||
|
||
def readCsvFile(path='', withCol=True): | ||
if not path or not path[-4:] == '.csv': | ||
raise ParameterException('Path is empty or target file is not .csv') | ||
fileSize = os.stat(path).st_size | ||
if fileSize >= 512000000: | ||
raise ParameterException( | ||
'File is too large! Only allow csv files less than 512MB.') | ||
from csv import reader | ||
from json import JSONDecodeError | ||
import click | ||
try: | ||
result = {'columns': [], 'data': []} | ||
with click.open_file(path, 'r') as csv_file: | ||
click.echo(f'Opening csv file({fileSize} bytes)...') | ||
csv_reader = reader(csv_file, delimiter=',') | ||
# For progressbar, transform it to list. | ||
rows = list(csv_reader) | ||
line_count = 0 | ||
with click.progressbar(rows, label='Reading csv rows...', show_percent=True) as bar: | ||
# for row in csv_reader: | ||
for row in bar: | ||
if withCol and line_count == 0: | ||
result['columns'] = row | ||
line_count += 1 | ||
else: | ||
formatRowForData(row, result['data']) | ||
line_count += 1 | ||
click.echo(f'''Column names are {result['columns']}''') | ||
click.echo(f'Processed {line_count} lines.') | ||
except FileNotFoundError as fe: | ||
raise ParameterException(f'FileNotFoundError {str(fe)}') | ||
except UnicodeDecodeError as ue: | ||
raise ParameterException(f'UnicodeDecodeError {str(ue)}') | ||
except JSONDecodeError as je: | ||
raise ParameterException(f'JSONDecodeError {str(je)}') | ||
else: | ||
return result | ||
|
||
|
||
# For readCsvFile formatting data. | ||
def formatRowForData(row=[], data=[]): | ||
from json import loads | ||
# init data with empty list | ||
if not data: | ||
for _in in range(len(row)): | ||
data.append([]) | ||
for idx, val in enumerate(row): | ||
formattedVal = loads(val) | ||
data[idx].append(formattedVal) | ||
|
||
|
||
def writeCsvFile(path, rows, headers=[]): | ||
if not path: | ||
raise ParameterException(f'Path should not be empty') | ||
from csv import writer | ||
import click | ||
try: | ||
with click.open_file(path, 'w+') as csv_file: | ||
csv_writer = writer(csv_file, delimiter=',') | ||
if headers: | ||
csv_writer.writerow(headers) | ||
line_count = 0 | ||
with click.progressbar(rows, label='Writing csv rows...', show_percent=True) as bar: | ||
for row in bar: | ||
csv_writer.writerow(row) | ||
line_count += 1 | ||
click.echo(f'Processed {line_count} lines.') | ||
except Exception as e: | ||
raise ParameterException(f'Export csv file error! {str(e)}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from functools import reduce | ||
|
||
|
||
class ParameterException(Exception): | ||
"Custom Exception for parameters checking." | ||
|
||
def __init__(self, msg): | ||
self.msg = msg | ||
|
||
def __str__(self): | ||
return str(self.msg) | ||
|
||
|
||
class ConnectException(Exception): | ||
"Custom Exception for milvus connection." | ||
|
||
def __init__(self, msg): | ||
self.msg = msg | ||
|
||
def __str__(self): | ||
return str(self.msg) | ||
|
||
|
||
FiledDataTypes = [ | ||
"BOOL", | ||
"INT8", | ||
"INT16", | ||
"INT32", | ||
"INT64", | ||
"FLOAT", | ||
"DOUBLE", | ||
"STRING", | ||
"BINARY_VECTOR", | ||
"FLOAT_VECTOR" | ||
] | ||
|
||
IndexTypes = [ | ||
"FLAT", | ||
"IVF_FLAT", | ||
"IVF_SQ8", | ||
"IVF_PQ", | ||
"RNSG", | ||
"HNSW", | ||
# "NSG", | ||
"ANNOY", | ||
# "RHNSW_FLAT", | ||
# "RHNSW_PQ", | ||
# "RHNSW_SQ", | ||
# "BIN_FLAT", | ||
# "BIN_IVF_FLAT" | ||
] | ||
|
||
IndexParams = [ | ||
"nlist", | ||
"m", | ||
"nbits", | ||
"M", | ||
"efConstruction", | ||
"n_trees", | ||
"PQM", | ||
] | ||
|
||
IndexTypesMap = { | ||
"FLAT": { | ||
"index_building_parameters": [], | ||
"search_parameters": ["metric_type"], | ||
}, | ||
"IVF_FLAT": { | ||
"index_building_parameters": ["nlist"], | ||
"search_parameters": ["nprobe"], | ||
}, | ||
"IVF_SQ8": { | ||
"index_building_parameters": ["nlist"], | ||
"search_parameters": ["nprobe"], | ||
}, | ||
"IVF_PQ": { | ||
"index_building_parameters": ["nlist", "m", "nbits"], | ||
"search_parameters": ["nprobe"], | ||
}, | ||
"RNSG": { | ||
"index_building_parameters": ["out_degree", "candidate_pool_size", "search_length", "knng"], | ||
"search_parameters": ["search_length"], | ||
}, | ||
"HNSW": { | ||
"index_building_parameters": ["M", "efConstruction"], | ||
"search_parameters": ["ef"], | ||
}, | ||
"ANNOY": { | ||
"index_building_parameters": ["n_trees"], | ||
"search_parameters": ["search_k"], | ||
}, | ||
} | ||
|
||
DupSearchParams = reduce( | ||
lambda x, y: x+IndexTypesMap[y]['search_parameters'], IndexTypesMap.keys(), []) | ||
SearchParams = list(dict.fromkeys(DupSearchParams)) | ||
|
||
MetricTypes = [ | ||
"L2", | ||
"IP", | ||
"HAMMING", | ||
"TANIMOTO" | ||
] | ||
|
||
DataTypeByNum = { | ||
0: 'NONE', | ||
1: 'BOOL', | ||
2: 'INT8', | ||
3: 'INT16', | ||
4: 'INT32', | ||
5: 'INT64', | ||
10: 'FLOAT', | ||
11: 'DOUBLE', | ||
20: 'STRING', | ||
100: 'BINARY_VECTOR', | ||
101: 'FLOAT_VECTOR', | ||
999: 'UNKNOWN', | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
from Types import ParameterException | ||
from Types import FiledDataTypes, IndexTypes, IndexTypesMap, SearchParams, MetricTypes | ||
from Fs import readCsvFile | ||
|
||
|
||
def validateParamsByCustomFunc(customFunc, errMsg, *params): | ||
try: | ||
customFunc(*params) | ||
except Exception as e: | ||
raise ParameterException(f"{errMsg}") | ||
|
||
|
||
def validateCollectionParameter(collectionName, primaryField, fields): | ||
if not collectionName: | ||
raise ParameterException('Missing collection name.') | ||
if not primaryField: | ||
raise ParameterException('Missing primary field.') | ||
if not fields: | ||
raise ParameterException('Missing fields.') | ||
fieldNames = [] | ||
for field in fields: | ||
fieldList = field.split(':') | ||
if not (len(fieldList) == 3): | ||
raise ParameterException( | ||
'Field should contain three paremeters and concat by ":".') | ||
[fieldName, fieldType, fieldData] = fieldList | ||
fieldNames.append(fieldName) | ||
if fieldType not in FiledDataTypes: | ||
raise ParameterException( | ||
'Invalid field data type, should be one of {}'.format(str(FiledDataTypes))) | ||
if fieldType in ['BINARY_VECTOR', 'FLOAT_VECTOR']: | ||
try: | ||
int(fieldData) | ||
except ValueError as e: | ||
raise ParameterException("""Vector's dim should be int.""") | ||
# Dedup field name. | ||
newNames = list(set(fieldNames)) | ||
if not (len(newNames) == len(fieldNames)): | ||
raise ParameterException('Field names are duplicated.') | ||
if primaryField not in fieldNames: | ||
raise ParameterException( | ||
"""Primary field name doesn't exist in input fields.""") | ||
|
||
|
||
def validateIndexParameter(indexType, metricType, params): | ||
if indexType not in IndexTypes: | ||
raise ParameterException( | ||
'Invalid index type, should be one of {}'.format(str(IndexTypes))) | ||
if metricType not in MetricTypes: | ||
raise ParameterException( | ||
'Invalid index metric type, should be one of {}'.format(str(MetricTypes))) | ||
# if not params: | ||
# raise ParameterException('Missing params') | ||
paramNames = [] | ||
buildingParameters = IndexTypesMap[indexType]['index_building_parameters'] | ||
for param in params: | ||
paramList = param.split(':') | ||
if not (len(paramList) == 2): | ||
raise ParameterException( | ||
'Params should contain two paremeters and concat by ":".') | ||
[paramName, paramValue] = paramList | ||
paramNames.append(paramName) | ||
if paramName not in buildingParameters: | ||
raise ParameterException( | ||
'Invalid index param, should be one of {}'.format(str(buildingParameters))) | ||
try: | ||
int(paramValue) | ||
except ValueError as e: | ||
raise ParameterException("""Index param's value should be int.""") | ||
# Dedup field name. | ||
newNames = list(set(paramNames)) | ||
if not (len(newNames) == len(paramNames)): | ||
raise ParameterException('Index params are duplicated.') | ||
|
||
|
||
def validateSearchParams(data, annsField, metricType, params, limit, expr, partitionNames, timeout, roundDecimal, hasIndex=True): | ||
import json | ||
result = {} | ||
# Validate data | ||
try: | ||
if '.csv' in data: | ||
csvData = readCsvFile(data, withCol=False) | ||
result['data'] = csvData['data'][0] | ||
else: | ||
result['data'] = json.loads( | ||
data.replace('\'', '').replace('\"', '')) | ||
except Exception as e: | ||
raise ParameterException( | ||
'Format(list[list[float]]) "Data" error! {}'.format(str(e))) | ||
# Validate annsField | ||
if not annsField: | ||
raise ParameterException('annsField is empty!') | ||
result['anns_field'] = annsField | ||
if hasIndex: | ||
# Validate metricType | ||
if metricType not in MetricTypes: | ||
raise ParameterException( | ||
'Invalid index metric type, should be one of {}'.format(str(MetricTypes))) | ||
# Validate params | ||
paramDict = {} | ||
if type(params) == str: | ||
paramsList = params.replace(' ', '').split(',') | ||
else: | ||
paramsList = params | ||
for param in paramsList: | ||
if not param: | ||
continue | ||
paramList = param.split(':') | ||
if not (len(paramList) == 2): | ||
raise ParameterException( | ||
'Params should contain two paremeters and concat by ":".') | ||
[paramName, paramValue] = paramList | ||
if paramName not in SearchParams: | ||
raise ParameterException( | ||
'Invalid search parameter, should be one of {}'.format(str(SearchParams))) | ||
try: | ||
paramDict[paramName] = int(paramValue) | ||
except ValueError as e: | ||
raise ParameterException( | ||
"""Search parameter's value should be int.""") | ||
result['param'] = {"metric_type": metricType} | ||
if paramDict.keys(): | ||
result['param']['params'] = paramDict | ||
else: | ||
result['param'] = {} | ||
# Validate limit | ||
try: | ||
result['limit'] = int(limit) | ||
except Exception as e: | ||
raise ParameterException( | ||
'Format(int) "limit" error! {}'.format(str(e))) | ||
# Validate expr | ||
result['expr'] = expr | ||
# Validate partitionNames | ||
if partitionNames: | ||
try: | ||
result['partition_names'] = partitionNames.replace( | ||
' ', '').split(',') | ||
except Exception as e: | ||
raise ParameterException( | ||
'Format(list[str]) "partitionNames" error! {}'.format(str(e))) | ||
# Validate timeout | ||
if timeout: | ||
result['timeout'] = float(timeout) | ||
if roundDecimal: | ||
result['round_decimal'] = int(roundDecimal) | ||
return result | ||
|
||
|
||
def validateQueryParams(expr, partitionNames, outputFields, timeout): | ||
result = {} | ||
if not expr: | ||
raise ParameterException('expr is empty!') | ||
if ' in ' not in expr: | ||
raise ParameterException( | ||
'expr only accepts "<field_name> in [<min>,<max>]"!') | ||
result['expr'] = expr | ||
if not outputFields: | ||
result['output_fields'] = None | ||
else: | ||
nameList = outputFields.replace(' ', '').split(',') | ||
result['output_fields'] = nameList | ||
if not partitionNames: | ||
result['partition_names'] = None | ||
else: | ||
nameList = partitionNames.replace(' ', '').split(',') | ||
result['partition_names'] = nameList | ||
result['timeout'] = float(timeout) if timeout else None | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.