diff --git a/libs/labelbox/src/labelbox/schema/project.py b/libs/labelbox/src/labelbox/schema/project.py index 8c89b9ae4..48b678675 100644 --- a/libs/labelbox/src/labelbox/schema/project.py +++ b/libs/labelbox/src/labelbox/schema/project.py @@ -1118,18 +1118,19 @@ def get_label_count(self) -> int: res = self.client.execute(query_str, {"projectId": self.uid}) return res["project"]["labelCount"] - def add_model_config(self, model_config_id: str) -> str: + def add_model_config(self, model_config_id: str, response_count: Optional[int] = None) -> str: """Adds a model config to this project. Args: model_config_id (str): ID of a model config to add to this project. + response_count (Optional[int]): Number of responses to generate. If not provided, uses the default. Returns: str, ID of the project model config association. This is needed for updating and deleting associations. """ - query = """mutation CreateProjectModelConfigPyApi($projectId: ID!, $modelConfigId: ID!) { - createProjectModelConfig(input: {projectId: $projectId, modelConfigId: $modelConfigId}) { + query = """mutation CreateProjectModelConfigPyApi($projectId: ID!, $modelConfigId: ID!, $responseCount: Int) { + createProjectModelConfig(input: {projectId: $projectId, modelConfigId: $modelConfigId, responseCount: $responseCount}) { projectModelConfigId } }""" @@ -1137,6 +1138,7 @@ def add_model_config(self, model_config_id: str) -> str: params = { "projectId": self.uid, "modelConfigId": model_config_id, + "responseCount": response_count, } try: result = self.client.execute(query, params)