|
16 | 16 | from .array import *
|
17 | 17 | from .util import *
|
18 | 18 | from .util import _is_number
|
| 19 | +from .random import randu, randn, set_seed, get_seed |
19 | 20 |
|
20 | 21 | def constant(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
|
21 | 22 | """
|
@@ -186,105 +187,6 @@ def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=Dtype.f32)
|
186 | 187 | 4, ct.pointer(tdims), dtype.value))
|
187 | 188 | return out
|
188 | 189 |
|
189 |
| -def randu(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32): |
190 |
| - """ |
191 |
| - Create a multi dimensional array containing values from a uniform distribution. |
192 |
| -
|
193 |
| - Parameters |
194 |
| - ---------- |
195 |
| - d0 : int. |
196 |
| - Length of first dimension. |
197 |
| -
|
198 |
| - d1 : optional: int. default: None. |
199 |
| - Length of second dimension. |
200 |
| -
|
201 |
| - d2 : optional: int. default: None. |
202 |
| - Length of third dimension. |
203 |
| -
|
204 |
| - d3 : optional: int. default: None. |
205 |
| - Length of fourth dimension. |
206 |
| -
|
207 |
| - dtype : optional: af.Dtype. default: af.Dtype.f32. |
208 |
| - Data type of the array. |
209 |
| -
|
210 |
| - Returns |
211 |
| - ------- |
212 |
| -
|
213 |
| - out : af.Array |
214 |
| - Multi dimensional array whose elements are sampled uniformly between [0, 1]. |
215 |
| - - If d1 is None, `out` is 1D of size (d0,). |
216 |
| - - If d1 is not None and d2 is None, `out` is 2D of size (d0, d1). |
217 |
| - - If d1 and d2 are not None and d3 is None, `out` is 3D of size (d0, d1, d2). |
218 |
| - - If d1, d2, d3 are all not None, `out` is 4D of size (d0, d1, d2, d3). |
219 |
| - """ |
220 |
| - out = Array() |
221 |
| - dims = dim4(d0, d1, d2, d3) |
222 |
| - |
223 |
| - safe_call(backend.get().af_randu(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value)) |
224 |
| - return out |
225 |
| - |
226 |
| -def randn(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32): |
227 |
| - """ |
228 |
| - Create a multi dimensional array containing values from a normal distribution. |
229 |
| -
|
230 |
| - Parameters |
231 |
| - ---------- |
232 |
| - d0 : int. |
233 |
| - Length of first dimension. |
234 |
| -
|
235 |
| - d1 : optional: int. default: None. |
236 |
| - Length of second dimension. |
237 |
| -
|
238 |
| - d2 : optional: int. default: None. |
239 |
| - Length of third dimension. |
240 |
| -
|
241 |
| - d3 : optional: int. default: None. |
242 |
| - Length of fourth dimension. |
243 |
| -
|
244 |
| - dtype : optional: af.Dtype. default: af.Dtype.f32. |
245 |
| - Data type of the array. |
246 |
| -
|
247 |
| - Returns |
248 |
| - ------- |
249 |
| -
|
250 |
| - out : af.Array |
251 |
| - Multi dimensional array whose elements are sampled from a normal distribution with mean 0 and sigma of 1. |
252 |
| - - If d1 is None, `out` is 1D of size (d0,). |
253 |
| - - If d1 is not None and d2 is None, `out` is 2D of size (d0, d1). |
254 |
| - - If d1 and d2 are not None and d3 is None, `out` is 3D of size (d0, d1, d2). |
255 |
| - - If d1, d2, d3 are all not None, `out` is 4D of size (d0, d1, d2, d3). |
256 |
| - """ |
257 |
| - |
258 |
| - out = Array() |
259 |
| - dims = dim4(d0, d1, d2, d3) |
260 |
| - |
261 |
| - safe_call(backend.get().af_randn(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value)) |
262 |
| - return out |
263 |
| - |
264 |
| -def set_seed(seed=0): |
265 |
| - """ |
266 |
| - Set the seed for the random number generator. |
267 |
| -
|
268 |
| - Parameters |
269 |
| - ---------- |
270 |
| - seed: int. |
271 |
| - Seed for the random number generator |
272 |
| - """ |
273 |
| - safe_call(backend.get().af_set_seed(ct.c_ulonglong(seed))) |
274 |
| - |
275 |
| -def get_seed(): |
276 |
| - """ |
277 |
| - Get the seed for the random number generator. |
278 |
| -
|
279 |
| - Returns |
280 |
| - ---------- |
281 |
| - seed: int. |
282 |
| - Seed for the random number generator |
283 |
| - """ |
284 |
| - seed = ct.c_ulonglong(0) |
285 |
| - safe_call(backend.get().af_get_seed(ct.pointer(seed))) |
286 |
| - return seed.value |
287 |
| - |
288 | 190 | def identity(d0, d1, d2=None, d3=None, dtype=Dtype.f32):
|
289 | 191 | """
|
290 | 192 | Create an identity matrix or batch of identity matrices.
|
|
0 commit comments