Skip to content

Commit 78a7e3d

Browse files
authored
[PLT-1797] Vb/member upload plt 1797 (#1924)
1 parent 3af4500 commit 78a7e3d

File tree

2 files changed

+233
-0
lines changed

2 files changed

+233
-0
lines changed

docs/labelbox/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,5 @@ Labelbox Python SDK Documentation
5252
task
5353
task-queue
5454
user
55+
user-group-upload
5556
webhook
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
import json
2+
import warnings
3+
from dataclasses import dataclass
4+
from io import BytesIO
5+
from typing import List, Optional
6+
7+
import requests
8+
from lbox.exceptions import (
9+
InternalServerError,
10+
LabelboxError,
11+
ResourceNotFoundError,
12+
)
13+
14+
from labelbox import Client
15+
from labelbox.pagination import PaginatedCollection
16+
17+
18+
@dataclass
19+
class UploadReportLine:
20+
"""A single line in the CSV report of the upload members mutation.
21+
Both errors and successes are reported here.
22+
23+
Example output when using dataclasses.asdict():
24+
>>> {
25+
>>> 'lines': [
26+
>>> {
27+
>>> 'email': '...',
28+
>>> 'result': 'Not added',
29+
>>> 'error': 'User not found in the current organization'
30+
>>> },
31+
>>> {
32+
>>> 'email': '...',
33+
>>> 'result': 'Not added',
34+
>>> 'error': 'Member already exists in group'
35+
>>> },
36+
>>> {
37+
>>> 'email': '...',
38+
>>> 'result': 'Added',
39+
>>> 'error': ''
40+
>>> }
41+
>>> ]
42+
>>> }
43+
"""
44+
45+
email: str
46+
result: str
47+
error: Optional[str] = None
48+
49+
50+
@dataclass
51+
class UploadReport:
52+
"""The report of the upload members mutation."""
53+
54+
lines: List[UploadReportLine]
55+
56+
57+
class UserGroupUpload:
58+
"""Upload members to a user group."""
59+
60+
def __init__(self, client: Client):
61+
self.client = client
62+
63+
def upload_members(
64+
self, group_id: str, role: str, emails: List[str]
65+
) -> Optional[UploadReport]:
66+
"""Upload members to a user group.
67+
68+
Args:
69+
group_id: A valid ID of the user group.
70+
role: The name of the role to assign to the uploaded members as it appears in the UI on the Import Members popup.
71+
emails: The list of emails of the members to upload.
72+
73+
Returns:
74+
UploadReport: The report of the upload members mutation.
75+
76+
Raises:
77+
ResourceNotFoundError: If the role is not found.
78+
LabelboxError: If the upload fails.
79+
80+
For indicvidual email errors, the error message is available in the UploadReport.
81+
"""
82+
warnings.warn(
83+
"The upload_members for UserGroupUpload is in beta. The method name and signature may change in the future.”",
84+
)
85+
86+
if len(emails) == 0:
87+
print("No emails to upload.")
88+
return None
89+
90+
role_id = self._get_role_id(role)
91+
if role_id is None:
92+
raise ResourceNotFoundError(
93+
message="Could not find a valid role with the name provided. Please make sure the role name is correct."
94+
)
95+
96+
buffer = BytesIO()
97+
buffer.write(b"email\n") # Header row
98+
for email in emails:
99+
buffer.write(f"{email}\n".encode("utf-8"))
100+
# Reset pointer to start of stream
101+
buffer.seek(0)
102+
103+
multipart_file_field = "1"
104+
gql_file_field = "file"
105+
files = {
106+
multipart_file_field: (
107+
f"{multipart_file_field}.csv",
108+
buffer,
109+
"text/csv",
110+
)
111+
}
112+
query = """mutation ImportMembersToGroup(
113+
$roleId: ID!
114+
$file: Upload!
115+
$where: WhereUniqueIdInput!
116+
) {
117+
importUsersAsCsvToGroup(roleId: $roleId, file: $file, where: $where) {
118+
csvReport
119+
addedCount
120+
count
121+
}
122+
}
123+
"""
124+
params = {
125+
"roleId": role_id,
126+
gql_file_field: None,
127+
"where": {"id": group_id},
128+
}
129+
130+
request_data = {
131+
"operations": json.dumps(
132+
{
133+
"variables": params,
134+
"query": query,
135+
}
136+
),
137+
"map": (
138+
None,
139+
json.dumps(
140+
{multipart_file_field: [f"variables.{gql_file_field}"]}
141+
),
142+
),
143+
}
144+
145+
client = self.client
146+
headers = dict(client.connection.headers)
147+
headers.pop("Content-Type", None)
148+
request = requests.Request(
149+
"POST",
150+
client.endpoint,
151+
headers=headers,
152+
data=request_data,
153+
files=files,
154+
)
155+
156+
prepped: requests.PreparedRequest = request.prepare()
157+
158+
response = client.connection.send(prepped)
159+
160+
if response.status_code == 502:
161+
error_502 = "502 Bad Gateway"
162+
raise InternalServerError(error_502)
163+
elif response.status_code == 503:
164+
raise InternalServerError(response.text)
165+
elif response.status_code == 520:
166+
raise InternalServerError(response.text)
167+
168+
try:
169+
file_data = response.json().get("data", None)
170+
except ValueError as e: # response is not valid JSON
171+
raise LabelboxError("Failed to upload, unknown cause", e)
172+
173+
if not file_data or not file_data.get("importUsersAsCsvToGroup", None):
174+
try:
175+
errors = response.json().get("errors", [])
176+
error_msg = "Unknown error"
177+
if errors:
178+
error_msg = errors[0].get("message", "Unknown error")
179+
except Exception:
180+
error_msg = "Unknown error"
181+
raise LabelboxError("Failed to upload, message: %s" % error_msg)
182+
183+
csv_report = file_data["importUsersAsCsvToGroup"]["csvReport"]
184+
return self._parse_csv_report(csv_report)
185+
186+
def _get_role_id(self, role_name: str) -> Optional[str]:
187+
role_id = None
188+
query = """query GetAvailableUserRolesPyPi {
189+
roles(skip: %d, first: %d) {
190+
id
191+
organizationId
192+
name
193+
description
194+
}
195+
}
196+
"""
197+
198+
result = PaginatedCollection(
199+
client=self.client,
200+
query=query,
201+
params={},
202+
dereferencing=["roles"],
203+
obj_class=lambda _, data: data, # type: ignore
204+
)
205+
if result is None:
206+
raise ResourceNotFoundError(
207+
message="Could not find any valid roles."
208+
)
209+
for role in result:
210+
if role["name"].strip() == role_name.strip():
211+
role_id = role["id"]
212+
break
213+
214+
return role_id
215+
216+
def _parse_csv_report(self, csv_report: str) -> UploadReport:
217+
lines = csv_report.strip().split("\n")
218+
headers = lines[0].split(",")
219+
report_lines = []
220+
for line in lines[1:]:
221+
values = line.split(",")
222+
row = dict(zip(headers, values))
223+
report_lines.append(
224+
UploadReportLine(
225+
email=row["Email"],
226+
result=row["Result"],
227+
error=row.get(
228+
"Error"
229+
), # Using get() since error is optional
230+
)
231+
)
232+
return UploadReport(lines=report_lines)

0 commit comments

Comments
 (0)