-
Notifications
You must be signed in to change notification settings - Fork 4.3k
/
adlsgen2setup.py
198 lines (176 loc) · 8.9 KB
/
adlsgen2setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import argparse
import asyncio
import json
import logging
import os
from typing import Any, Optional
import aiohttp
from azure.core.credentials_async import AsyncTokenCredential
from azure.identity.aio import AzureDeveloperCliCredential
from azure.storage.filedatalake.aio import (
DataLakeDirectoryClient,
DataLakeServiceClient,
)
from load_azd_env import load_azd_env
logger = logging.getLogger("scripts")
class AdlsGen2Setup:
"""
Sets up a Data Lake Storage Gen 2 account with sample data and access control
"""
def __init__(
self,
data_directory: str,
storage_account_name: str,
filesystem_name: str,
security_enabled_groups: bool,
data_access_control_format: dict[str, Any],
credentials: AsyncTokenCredential,
):
"""
Initializes the command
Parameters
----------
data_directory
Directory where sample files are located
storage_account_name
Name of the Data Lake Storage Gen 2 account to use
filesystem_name
Name of the container / filesystem in the Data Lake Storage Gen 2 account to use
security_enabled_groups
When creating groups in Microsoft Entra, whether or not to make them security enabled
data_access_control_format
File describing how to create groups, upload files with access control. See the sampleacls.json for the format of this file
"""
self.data_directory = data_directory
self.storage_account_name = storage_account_name
self.filesystem_name = filesystem_name
self.credentials = credentials
self.security_enabled_groups = security_enabled_groups
self.data_access_control_format = data_access_control_format
self.graph_headers: Optional[dict[str, str]] = None
async def run(self):
async with self.create_service_client() as service_client:
logger.info(f"Ensuring {self.filesystem_name} exists...")
async with service_client.get_file_system_client(self.filesystem_name) as filesystem_client:
if not await filesystem_client.exists():
await filesystem_client.create_file_system()
logger.info("Creating groups...")
groups: dict[str, str] = {}
for group in self.data_access_control_format["groups"]:
group_id = await self.create_or_get_group(group)
groups[group] = group_id
logger.info("Ensuring directories exist...")
directories: dict[str, DataLakeDirectoryClient] = {}
try:
for directory in self.data_access_control_format["directories"].keys():
directory_client = (
await filesystem_client.create_directory(directory)
if directory != "/"
else filesystem_client._get_root_directory_client()
)
directories[directory] = directory_client
logger.info("Uploading files...")
for file, file_info in self.data_access_control_format["files"].items():
directory = file_info["directory"]
if directory not in directories:
logger.error(f"File {file} has unknown directory {directory}, exiting...")
return
await self.upload_file(
directory_client=directories[directory], file_path=os.path.join(self.data_directory, file)
)
logger.info("Setting access control...")
for directory, access_control in self.data_access_control_format["directories"].items():
directory_client = directories[directory]
if "groups" in access_control:
for group_name in access_control["groups"]:
if group_name not in groups:
logger.error(
f"Directory {directory} has unknown group {group_name} in access control list, exiting"
)
return
await directory_client.update_access_control_recursive(
acl=f"group:{groups[group_name]}:r-x"
)
if "oids" in access_control:
for oid in access_control["oids"]:
await directory_client.update_access_control_recursive(acl=f"user:{oid}:r-x")
finally:
for directory_client in directories.values():
await directory_client.close()
def create_service_client(self):
return DataLakeServiceClient(
account_url=f"https://{self.storage_account_name}.dfs.core.windows.net", credential=self.credentials
)
async def upload_file(self, directory_client: DataLakeDirectoryClient, file_path: str):
with open(file=file_path, mode="rb") as f:
file_client = directory_client.get_file_client(file=os.path.basename(file_path))
await file_client.upload_data(f, overwrite=True)
async def create_or_get_group(self, group_name: str):
group_id = None
if not self.graph_headers:
token_result = await self.credentials.get_token("https://graph.microsoft.com/.default")
self.graph_headers = {"Authorization": f"Bearer {token_result.token}"}
async with aiohttp.ClientSession(headers=self.graph_headers) as session:
logger.info(f"Searching for group {group_name}...")
async with session.get(
f"https://graph.microsoft.com/v1.0/groups?$select=id&$top=1&$filter=displayName eq '{group_name}'"
) as response:
content = await response.json()
if response.status != 200:
raise Exception(content)
if len(content["value"]) == 1:
group_id = content["value"][0]["id"]
if not group_id:
logger.info(f"Could not find group {group_name}, creating...")
group = {
"displayName": group_name,
"securityEnabled": self.security_enabled_groups,
"groupTypes": ["Unified"],
# If Unified does not work for you, then you may need the following settings instead:
# "mailEnabled": False,
# "mailNickname": group_name,
}
async with session.post("https://graph.microsoft.com/v1.0/groups", json=group) as response:
content = await response.json()
if response.status != 201:
raise Exception(content)
group_id = content["id"]
logger.info(f"Group {group_name} ID {group_id}")
return group_id
async def main(args: Any):
load_azd_env()
if not os.getenv("AZURE_ADLS_GEN2_STORAGE_ACCOUNT"):
raise Exception("AZURE_ADLS_GEN2_STORAGE_ACCOUNT must be set to continue")
async with AzureDeveloperCliCredential() as credentials:
with open(args.data_access_control) as f:
data_access_control_format = json.load(f)
command = AdlsGen2Setup(
data_directory=args.data_directory,
storage_account_name=os.environ["AZURE_ADLS_GEN2_STORAGE_ACCOUNT"],
filesystem_name="gptkbcontainer",
security_enabled_groups=args.create_security_enabled_groups,
credentials=credentials,
data_access_control_format=data_access_control_format,
)
await command.run()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Upload sample data to a Data Lake Storage Gen2 account and associate sample access control lists with it using sample groups",
epilog="Example: ./scripts/adlsgen2setup.py ./data --data-access-control ./scripts/sampleacls.json --create-security-enabled-groups <true|false>",
)
parser.add_argument("data_directory", help="Data directory that contains sample PDFs")
parser.add_argument(
"--create-security-enabled-groups",
required=False,
action="store_true",
help="Whether or not the sample groups created are security enabled in Microsoft Entra",
)
parser.add_argument(
"--data-access-control", required=True, help="JSON file describing access control for the sample data"
)
parser.add_argument("--verbose", "-v", required=False, action="store_true", help="Verbose output")
args = parser.parse_args()
if args.verbose:
logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)
asyncio.run(main(args))