9
9
from typing import List , Optional , Pattern
10
10
11
11
RELEASE_PATTERN = re .compile (r"release_[0-9]+(_docs)*" )
12
+ # This matches the various ways to invoke pip: "pip", "pip3", "python -m pip"
13
+ # It matches "mlagents" and "mlagents_envs", accessible as group "package"
14
+ # and optionally matches the version, e.g. "==1.2.3"
15
+ PIP_INSTALL_PATTERN = re .compile (
16
+ r"(python -m )?pip3* install (?P<package>mlagents(_envs)?)(==[0-9]\.[0-9]\.[0-9](\.dev[0-9]+)?)?"
17
+ )
12
18
TRAINER_INIT_FILE = "ml-agents/mlagents/trainers/__init__.py"
13
19
14
20
MATCH_ANY = re .compile (r"(?s).*" )
15
21
# Filename -> regex list to allow specific lines.
16
- # To allow everything in the file, use None for the value
22
+ # To allow everything in the file (effectively skipping it) , use MATCH_ANY for the value
17
23
ALLOW_LIST = {
18
24
# Previous release table
19
25
"README.md" : re .compile (r"\*\*(Verified Package ([0-9]\.?)*|Release [0-9]+)\*\*" ),
24
30
}
25
31
26
32
27
- def test_pattern ():
33
+ def test_release_pattern ():
28
34
# Just some sanity check that the regex works as expected.
29
- assert RELEASE_PATTERN .search (
30
- "https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/Food.md"
31
- )
32
- assert RELEASE_PATTERN .search (
33
- "https://github.com/Unity-Technologies/ml-agents/blob/release_4/Foo.md"
34
- )
35
- assert RELEASE_PATTERN .search (
36
- "git clone --branch release_4 https://github.com/Unity-Technologies/ml-agents.git"
37
- )
38
- assert RELEASE_PATTERN .search (
39
- "https://github.com/Unity-Technologies/ml-agents/blob/release_123_docs/Foo.md"
40
- )
41
- assert RELEASE_PATTERN .search (
42
- "https://github.com/Unity-Technologies/ml-agents/blob/release_123/Foo.md"
43
- )
44
- assert not RELEASE_PATTERN .search (
45
- "https://github.com/Unity-Technologies/ml-agents/blob/latest_release/docs/Foo.md"
35
+ for s , expected in [
36
+ (
37
+ "https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/Food.md" ,
38
+ True ,
39
+ ),
40
+ ("https://github.com/Unity-Technologies/ml-agents/blob/release_4/Foo.md" , True ),
41
+ (
42
+ "git clone --branch release_4 https://github.com/Unity-Technologies/ml-agents.git" ,
43
+ True ,
44
+ ),
45
+ (
46
+ "https://github.com/Unity-Technologies/ml-agents/blob/release_123_docs/Foo.md" ,
47
+ True ,
48
+ ),
49
+ (
50
+ "https://github.com/Unity-Technologies/ml-agents/blob/release_123/Foo.md" ,
51
+ True ,
52
+ ),
53
+ (
54
+ "https://github.com/Unity-Technologies/ml-agents/blob/latest_release/docs/Foo.md" ,
55
+ False ,
56
+ ),
57
+ ]:
58
+ assert bool (RELEASE_PATTERN .search (s )) is expected
59
+
60
+ print ("release tests OK!" )
61
+
62
+
63
+ def test_pip_pattern ():
64
+ # Just some sanity check that the regex works as expected.
65
+ for s , expected in [
66
+ ("pip install mlagents" , True ),
67
+ ("pip3 install mlagents" , True ),
68
+ ("python -m pip install mlagents" , True ),
69
+ ("python -m pip install mlagents==1.2.3" , True ),
70
+ ("python -m pip install mlagents_envs==1.2.3" , True ),
71
+ ]:
72
+ assert bool (PIP_INSTALL_PATTERN .search (s )) is expected
73
+
74
+ sub_expected = "Try running rm -rf / to install"
75
+ assert sub_expected == PIP_INSTALL_PATTERN .sub (
76
+ "rm -rf /" , "Try running python -m pip install mlagents==1.2.3 to install"
46
77
)
47
- print ("tests OK!" )
78
+
79
+ print ("pip tests OK!" )
80
+
81
+
82
+ def update_pip_install_line (line , package_verion ):
83
+ match = PIP_INSTALL_PATTERN .search (line )
84
+ package_name = match .group ("package" )
85
+ replacement_version = f"python -m pip install { package_name } =={ package_verion } "
86
+ updated = PIP_INSTALL_PATTERN .sub (replacement_version , line )
87
+ return updated
48
88
49
89
50
90
def git_ls_files () -> List [str ]:
@@ -74,8 +114,28 @@ def get_release_tag() -> Optional[str]:
74
114
raise RuntimeError ("Can't determine release tag" )
75
115
76
116
117
+ def get_python_package_version () -> str :
118
+ """
119
+ Returns the mlagents python package.
120
+ :return:
121
+ """
122
+ with open (TRAINER_INIT_FILE ) as f :
123
+ for line in f :
124
+ if "__version__" in line :
125
+ lhs , equals_string , rhs = line .strip ().partition (" = " )
126
+ # Evaluate the right hand side of the expression
127
+ return ast .literal_eval (rhs )
128
+ # If we couldn't find the release tag, raise an exception
129
+ # (since we can't return None here)
130
+ raise RuntimeError ("Can't determine python package version" )
131
+
132
+
77
133
def check_file (
78
- filename : str , global_allow_pattern : Pattern , release_tag : str
134
+ filename : str ,
135
+ release_tag_pattern : Pattern ,
136
+ release_tag : str ,
137
+ pip_allow_pattern : Pattern ,
138
+ package_version : str ,
79
139
) -> List [str ]:
80
140
"""
81
141
Validate a single file and return any offending lines.
@@ -90,21 +150,37 @@ def check_file(
90
150
allow_list_pattern = ALLOW_LIST .get (filename , None )
91
151
with open (filename ) as f :
92
152
for line in f :
93
- keep_line = True
94
- keep_line = not RELEASE_PATTERN .search (line )
95
- keep_line |= global_allow_pattern .search (line ) is not None
96
- keep_line |= (
97
- allow_list_pattern is not None
153
+ # Does it contain anything of the form release_123
154
+ has_release_pattern = RELEASE_PATTERN .search (line ) is not None
155
+ # Does it contain this particular release, e.g. release_42 or release_42_docs
156
+ has_release_tag_pattern = (
157
+ release_tag_pattern .search (line ) is not None
158
+ )
159
+ # Does it contain the allow list pattern for the file (if there is one)
160
+ has_allow_list_pattern = (
161
+ allow_list_pattern
98
162
and allow_list_pattern .search (line ) is not None
99
163
)
100
164
101
- if keep_line :
165
+ pip_install_ok = (
166
+ has_allow_list_pattern
167
+ or PIP_INSTALL_PATTERN .search (line ) is None
168
+ or pip_allow_pattern .search (line ) is not None
169
+ )
170
+
171
+ release_tag_ok = (
172
+ not has_release_pattern
173
+ or has_release_tag_pattern
174
+ or has_allow_list_pattern
175
+ )
176
+
177
+ if release_tag_ok and pip_install_ok :
102
178
new_file .write (line )
103
179
else :
104
180
bad_lines .append (f"{ filename } : { line } " )
105
- new_file . write (
106
- re . sub ( r"release_[0-9]+" , fr" { release_tag } " , line )
107
- )
181
+ new_line = re . sub ( r"release_[0-9]+" , fr" { release_tag } " , line )
182
+ new_line = update_pip_install_line ( new_line , package_version )
183
+ new_file . write ( new_line )
108
184
if bad_lines :
109
185
if os .path .exists (filename ):
110
186
os .remove (filename )
@@ -113,17 +189,28 @@ def check_file(
113
189
return bad_lines
114
190
115
191
116
- def check_all_files (allow_pattern : Pattern , release_tag : str ) -> List [str ]:
192
+ def check_all_files (
193
+ release_allow_pattern : Pattern ,
194
+ release_tag : str ,
195
+ pip_allow_pattern : Pattern ,
196
+ package_version : str ,
197
+ ) -> List [str ]:
117
198
"""
118
199
Validate all files tracked by git.
119
- :param allow_pattern :
200
+ :param release_allow_pattern :
120
201
"""
121
202
bad_lines = []
122
203
file_types = {".py" , ".md" , ".cs" }
123
204
for file_name in git_ls_files ():
124
205
if "localized" in file_name or os .path .splitext (file_name )[1 ] not in file_types :
125
206
continue
126
- bad_lines += check_file (file_name , allow_pattern , release_tag )
207
+ bad_lines += check_file (
208
+ file_name ,
209
+ release_allow_pattern ,
210
+ release_tag ,
211
+ pip_allow_pattern ,
212
+ package_version ,
213
+ )
127
214
return bad_lines
128
215
129
216
@@ -133,9 +220,16 @@ def main():
133
220
print ("Release tag is None, exiting" )
134
221
sys .exit (0 )
135
222
223
+ package_version = get_python_package_version ()
136
224
print (f"Release tag: { release_tag } " )
137
- allow_pattern = re .compile (f"{ release_tag } (_docs)*" )
138
- bad_lines = check_all_files (allow_pattern , release_tag )
225
+ print (f"Python package version: { package_version } " )
226
+ release_allow_pattern = re .compile (f"{ release_tag } (_docs)?" )
227
+ pip_allow_pattern = re .compile (
228
+ f"python -m pip install mlagents(_envs)?=={ package_version } "
229
+ )
230
+ bad_lines = check_all_files (
231
+ release_allow_pattern , release_tag , pip_allow_pattern , package_version
232
+ )
139
233
if bad_lines :
140
234
for line in bad_lines :
141
235
print (line )
@@ -151,5 +245,6 @@ def main():
151
245
152
246
if __name__ == "__main__" :
153
247
if "--test" in sys .argv :
154
- test_pattern ()
248
+ test_release_pattern ()
249
+ test_pip_pattern ()
155
250
main ()
0 commit comments