Skip to content

Commit

Permalink
Merge pull request #12 from zilliztech/nameczz/dev
Browse files Browse the repository at this point in the history
Support user related apis and connect with auth
  • Loading branch information
shanghaikid authored Jul 29, 2022
2 parents b26f400 + a5ee6cd commit 70f829d
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 31 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
__pycache__/
*.py[cod]
*$py.class
.vscode

# C extensions
*.so
Expand Down
1 change: 1 addition & 0 deletions Dockerfile.test
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
From python:3.9-slim
84 changes: 82 additions & 2 deletions milvus_cli/scripts/milvus_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,30 @@ def help():
default=19530,
type=int,
)
@click.option(
"-s",
"--secure",
"secure",
help="[Optional] - Secure, default is `False`.",
default=False,
type=bool,
)
@click.option(
"-u",
"--username",
"username",
help="[Optional] - Username , default is `None`.",
default=None,
type=str,
)
@click.option(
"-pwd",
"--password",
"password",
help="[Optional] - Password , default is `None`.",
default=None,
type=str,
)
@click.option(
"-D",
"--disconnect",
Expand All @@ -75,7 +99,7 @@ def help():
is_flag=True,
)
@click.pass_obj
def connect(obj, alias, host, port, disconnect):
def connect(obj, alias, host, port,secure,username,password, disconnect):
"""
Connect to Milvus.
Expand All @@ -84,7 +108,7 @@ def connect(obj, alias, host, port, disconnect):
milvus_cli > connect -h 127.0.0.1 -p 19530 -a default
"""
try:
obj.connect(alias, host, port, disconnect)
obj.connect(alias, host, port, disconnect,secure,username,password)
except Exception as e:
click.echo(message=e, err=True)
else:
Expand Down Expand Up @@ -424,6 +448,16 @@ def indexes(obj, collection):
except Exception as e:
click.echo(message=e, err=True)

@listDetails.command()
@click.pass_obj
def users(obj):
"""List all users in Milvus"""
try:
obj.checkConnection()
click.echo(obj.listCredUsers())
except Exception as e:
click.echo(message=e, err=True)


@cli.group("describe", no_args_is_help=False)
@click.pass_obj
Expand Down Expand Up @@ -716,6 +750,27 @@ def createIndex(obj):
click.echo("Create index successfully!")


@createDetails.command("user")
@click.option("-u", "--username", "username", help="The username of milvus user.")
@click.option("-p", "--password", "password", help="The pawssord of milvus user.")
@click.pass_obj
def createUser(obj, username, password):
"""
Create user.
Example:
milvus_cli > create user -u zilliz -p zilliz
"""
try:
obj.checkConnection()
click.echo(obj.createCredUser(username,password))
click.echo("Create user successfully")
except Exception as e:
click.echo(message=e, err=True)



@cli.group("delete", no_args_is_help=False)
@click.pass_obj
def deleteObject(obj):
Expand Down Expand Up @@ -868,6 +923,31 @@ def deleteIndex(obj, collectionName, timeout):
)


@deleteObject.command("user")
@click.option("-u", "--username", "username", help="The username of milvus user")
@click.pass_obj
def deleteUser(obj, username):
"""
Drop user in milvus by username
Example:
milvus_cli > delete user -u zilliz
"""
click.echo(
"Warning!\nYou are trying to delete the user in milvus. This action cannot be undone!\n"
)
if not click.confirm("Do you want to continue?"):
return
try:
obj.checkConnection()
result = obj.deleteCredUser(username)
click.echo("Drop user successfully!")
click.echo(result)
except Exception as e:
click.echo(message=e, err=True)


@deleteObject.command("entities")
@click.option("-c", "--collection-name", "collectionName", help="Collection name.")
@click.option(
Expand Down
93 changes: 66 additions & 27 deletions milvus_cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from Types import DataTypeByNum
from Types import ParameterException, ConnectException
from time import time
from string import Template
from pymilvus import __version__


def getPackageVersion():
Expand Down Expand Up @@ -35,16 +37,19 @@ class PyOrm(object):
port = 19530
alias = "default"

def connect(self, alias=None, host=None, port=None, disconnect=False):
def connect(self, alias=None, host=None, port=None, disconnect=False, secure=False, username=None, password=None):
self.alias = alias
self.host = host
self.port = port
trimUsername = None if username is None else username.strip()
trimPwd = None if password is None else password.strip()

from pymilvus import connections

if disconnect:
connections.disconnect(alias)
return
connections.connect(self.alias, host=self.host, port=self.port)
connections.connect(self.alias, host=self.host, port=self.port,user=trimUsername,password=trimPwd,secure=secure)

def checkConnection(self):
from pymilvus import list_collections
Expand All @@ -65,10 +70,10 @@ def showConnection(self, alias="default", showAll=False):
)
aliasList = map(lambda x: x[0], allConnections)
if tempAlias in aliasList:
host, port = connections.get_connection_addr(tempAlias).values()
secure, host, port = connections.get_connection_addr(tempAlias).values()
# return """Host: {}\nPort: {}\nAlias: {}""".format(host, port, alias)
return tabulate(
[["Host", host], ["Port", port], ["Alias", tempAlias]],
[["Host", host], ["Port", port], ["Alias", tempAlias],["Secure",secure]],
tablefmt="pretty",
)
else:
Expand All @@ -89,7 +94,8 @@ def _list_field_names(self, collectionName, showVectorOnly=False):
result = target.schema.fields
if showVectorOnly:
return reduce(
lambda x, y: x + [y.name] if y.dtype in [100, 101] else x, result, []
lambda x, y: x + [y.name] if y.dtype in [100,
101] else x, result, []
)
return [i.name for i in result]

Expand All @@ -115,7 +121,7 @@ def listCollections(self, timeout=None, showLoadedOnly=False):
collectionNames = self._list_collection_names(timeout)
for name in collectionNames:
loadingProgress = self.showCollectionLoadingProgress(name)
loaded, total = loadingProgress.values()
loaded, total, alias = loadingProgress.values()
# isLoaded = (total > 0) and (loaded == total)
# shouldBeAdded = isLoaded if showLoadedOnly else True
# if shouldBeAdded:
Expand Down Expand Up @@ -194,34 +200,40 @@ def releaseCollection(self, collectionName):
)

def releasePartition(self, collectionName, partitionName):
targetPartition = self.getTargetPartition(collectionName, partitionName)
targetPartition = self.getTargetPartition(
collectionName, partitionName)
targetPartition.release()
result = self.showCollectionLoadingProgress(collectionName, [partitionName])
result = self.showCollectionLoadingProgress(
collectionName, [partitionName])
return result

def releasePartitions(self, collectionName, partitionNameList):
result = []
for name in partitionNameList:
tmp = self.releasePartition(collectionName, name)
result.append(
[name, tmp.get("num_loaded_entities"), tmp.get("num_total_entities")]
[name, tmp.get("num_loaded_entities"),
tmp.get("num_total_entities")]
)
return tabulate(
result, headers=["Partition Name", "Loaded", "Total"], tablefmt="grid"
)

def loadPartition(self, collectionName, partitionName):
targetPartition = self.getTargetPartition(collectionName, partitionName)
targetPartition = self.getTargetPartition(
collectionName, partitionName)
targetPartition.load()
result = self.showCollectionLoadingProgress(collectionName, [partitionName])
result = self.showCollectionLoadingProgress(
collectionName, [partitionName])
return result

def loadPartitions(self, collectionName, partitionNameList):
result = []
for name in partitionNameList:
tmp = self.loadPartition(collectionName, name)
result.append(
[name, tmp.get("num_loaded_entities"), tmp.get("num_total_entities")]
[name, tmp.get("num_loaded_entities"),
tmp.get("num_total_entities")]
)
return tabulate(
result, headers=["Partition Name", "Loaded", "Total"], tablefmt="grid"
Expand Down Expand Up @@ -280,8 +292,10 @@ def getCollectionDetails(self, collectionName="", collection=None):
schemaDetails = """Description: {}\n\nAuto ID: {}\n\nFields(* is the primary field):{}""".format(
schema.description, schema.auto_id, fieldSchemaDetails
)
partitionDetails = " - " + "\n- ".join(map(lambda x: x.name, partitions))
indexesDetails = " - " + "\n- ".join(map(lambda x: x.field_name, indexes))
partitionDetails = " - " + \
"\n- ".join(map(lambda x: x.name, partitions))
indexesDetails = " - " + \
"\n- ".join(map(lambda x: x.field_name, indexes))
rows.append(["Name", target.name])
rows.append(["Description", target.description])
rows.append(["Is Empty", target.is_empty])
Expand Down Expand Up @@ -313,7 +327,8 @@ def getIndexDetails(self, collection):
rows.append(["Index Type", index.params["index_type"]])
rows.append(["Metric Type", index.params["metric_type"]])
params = index.params["params"]
paramsDetails = "\n- ".join(map(lambda k: f"{k[0]}: {k[1]}", params.items()))
paramsDetails = "\n- ".join(
map(lambda k: f"{k[0]}: {k[1]}", params.items()))
rows.append(["Params", paramsDetails])
return tabulate(rows, tablefmt="grid")

Expand All @@ -329,7 +344,8 @@ def createCollection(
if fieldType in ["BINARY_VECTOR", "FLOAT_VECTOR"]:
fieldList.append(
FieldSchema(
name=fieldName, dtype=DataType[fieldType], dim=int(fieldData)
name=fieldName, dtype=DataType[fieldType], dim=int(
fieldData)
)
)
else:
Expand Down Expand Up @@ -417,7 +433,6 @@ def search(self, collectionName, searchParameters, prettierFormat=True):

def query(self, collectionName, queryParameters):
collection = self.getTargetCollection(collectionName)
collection.load()
print(queryParameters)
res = collection.query(**queryParameters)
# return f"- Query results: {res}"
Expand All @@ -433,7 +448,8 @@ def query(self, collectionName, queryParameters):

def insert(self, collectionName, data, partitionName=None, timeout=None):
collection = self.getTargetCollection(collectionName)
result = collection.insert(data, partition_name=partitionName, timeout=timeout)
result = collection.insert(
data, partition_name=partitionName, timeout=timeout)
entitiesNum = collection.num_entities
return [result, entitiesNum]

Expand All @@ -459,7 +475,8 @@ def calcDistance(self, vectors_left, vectors_right, params=None, timeout=None):

def deleteEntities(self, expr, collectionName, partition_name=None, timeout=None):
collection = self.getTargetCollection(collectionName)
result = collection.delete(expr, partition_name=partition_name, timeout=timeout)
result = collection.delete(
expr, partition_name=partition_name, timeout=timeout)
return result

def getQuerySegmentInfo(self, collectionName, timeout=None, prettierFormat=False):
Expand All @@ -471,7 +488,8 @@ def getQuerySegmentInfo(self, collectionName, timeout=None, prettierFormat=False
if not prettierFormat or not result:
return result
firstChild = result[0]
headers = ["segmentID", "collectionID", "partitionID", "mem_size", "num_rows"]
headers = ["segmentID", "collectionID",
"partitionID", "mem_size", "num_rows"]
return tabulate(
[[getattr(_, i) for i in headers] for _ in result],
headers=headers,
Expand Down Expand Up @@ -579,6 +597,21 @@ def getCompactCollectionPlans(self, collectionName, timeout=None):
collection = self.getTargetCollection(collectionName)
return collection.get_compaction_plans(timeout=timeout)

def listCredUsers(self):
from pymilvus import list_cred_users
users = list_cred_users(self.alias)
return users

def createCredUser(self, username=None, password=None):
from pymilvus import create_credential
create_credential(username, password, self.alias)
return self.listCredUsers()

def deleteCredUser(self, username=None):
from pymilvus import delete_credential
delete_credential(username, self.alias)
return self.listCredUsers()


class Completer(object):
# COMMANDS = ['clear', 'connect', 'create', 'delete', 'describe', 'exit',
Expand All @@ -594,13 +627,13 @@ class Completer(object):
"clear": [],
"compact": [],
"connect": [],
"create": ["alias", "collection", "partition", "index"],
"delete": ["alias", "collection", "entities", "partition", "index"],
"create": ["alias", "collection", "partition", "index","user"],
"delete": ["alias", "collection", "entities", "partition", "index","user"],
"describe": ["collection", "partition", "index"],
"exit": [],
"help": [],
"import": [],
"list": ["collections", "partitions", "indexes"],
"list": ["collections", "partitions", "indexes","users"],
"load_balance": [],
"load": [],
"query": [],
Expand Down Expand Up @@ -702,11 +735,12 @@ def complete(self, text, state):
if args:
return (impl(args) + [None])[state]
return [cmd + " "][state]
results = [c + " " for c in self.COMMANDS if c.startswith(cmd)] + [None]
results = [
c + " " for c in self.COMMANDS if c.startswith(cmd)] + [None]
return results[state]


WELCOME_MSG = """
msgTemp = Template("""
__ __ _ _ ____ _ ___
Expand All @@ -715,9 +749,14 @@ def complete(self, text, state):
| | | | | |\ V /| |_| \__ \ | |___| |___ | |
|_| |_|_|_| \_/ \__,_|___/ \____|_____|___|
Milvus cli version: ${cli}
Pymilvus version: ${py}
Learn more: https://github.com/zilliztech/milvus_cli.
"""
""")


WELCOME_MSG = msgTemp.safe_substitute(cli=getPackageVersion(), py=__version__)

EXIT_MSG = "\n\nThanks for using.\nWe hope your feedback: https://github.com/zilliztech/milvus_cli/issues/new.\n\n"
Loading

0 comments on commit 70f829d

Please sign in to comment.