From 2bd1b862c61309921f35212c2de06f7ea0481308 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Wed, 30 Jun 2021 13:20:14 +0200 Subject: [PATCH] Update initializers for typing in numpy 1.21+ Update initializers to cast values to the thinc-specific types supported by the ops. --- thinc/initializers.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/thinc/initializers.py b/thinc/initializers.py index 7aac1e926..db16a8079 100644 --- a/thinc/initializers.py +++ b/thinc/initializers.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Callable, cast import numpy from .backends import Ops @@ -16,7 +16,7 @@ def lecun_normal_init(ops: Ops, shape: Shape) -> FloatsXd: scale = numpy.sqrt(1.0 / shape[1]) - return ops.asarray_f(numpy.random.normal(0, scale, shape)) + return ops.asarray_f(cast(FloatsXd, numpy.random.normal(0, scale, shape))) @registry.initializers("lecun_normal_init.v1") @@ -26,7 +26,7 @@ def configure_lecun_normal_init() -> Callable[[Shape], FloatsXd]: def he_normal_init(ops: Ops, shape: Shape) -> FloatsXd: scale = numpy.sqrt(2.0 / shape[1]) - return ops.asarray_f(numpy.random.normal(0, scale, shape)) + return ops.asarray_f(cast(FloatsXd, numpy.random.normal(0, scale, shape))) @registry.initializers("he_normal_init.v1") @@ -36,7 +36,7 @@ def configure_he_normal_init() -> Callable[[Shape], FloatsXd]: def glorot_normal_init(ops: Ops, shape: Shape) -> FloatsXd: scale = numpy.sqrt(2.0 / (shape[1] + shape[0])) - return ops.asarray_f(numpy.random.normal(0, scale, shape)) + return ops.asarray_f(cast(FloatsXd, numpy.random.normal(0, scale, shape))) @registry.initializers("glorot_normal_init.v1") @@ -46,7 +46,7 @@ def configure_glorot_normal_init() -> Callable[[Shape], FloatsXd]: def he_uniform_init(ops: Ops, shape: Shape) -> FloatsXd: scale = numpy.sqrt(6.0 / shape[1]) - return ops.asarray_f(numpy.random.uniform(-scale, scale, shape)) + return ops.asarray_f(cast(FloatsXd, numpy.random.uniform(-scale, scale, shape))) @registry.initializers("he_uniform_init.v1") @@ -56,7 +56,7 @@ def configure_he_uniform_init() -> Callable[[Shape], FloatsXd]: def lecun_uniform_init(ops: Ops, shape: Shape) -> FloatsXd: scale = numpy.sqrt(3.0 / shape[1]) - return ops.asarray_f(numpy.random.uniform(-scale, scale, shape)) + return ops.asarray_f(cast(FloatsXd, numpy.random.uniform(-scale, scale, shape))) @registry.initializers("lecun_uniform_init.v1") @@ -66,7 +66,7 @@ def configure_lecun_uniform_init() -> Callable[[Shape], FloatsXd]: def glorot_uniform_init(ops: Ops, shape: Shape) -> FloatsXd: scale = numpy.sqrt(6.0 / (shape[0] + shape[1])) - return ops.asarray_f(numpy.random.uniform(-scale, scale, shape)) + return ops.asarray_f(cast(FloatsXd, numpy.random.uniform(-scale, scale, shape))) @registry.initializers("glorot_uniform_init.v1") @@ -87,7 +87,7 @@ def uniform_init( ops: Ops, shape: Shape, *, lo: float = -0.1, hi: float = 0.1 ) -> FloatsXd: values = numpy.random.uniform(lo, hi, shape) - return ops.asarray_f(values.astype("float32")) + return ops.asarray_f(cast(FloatsXd, values.astype("float32"))) @registry.initializers("uniform_init.v1") @@ -99,7 +99,7 @@ def configure_uniform_init( def normal_init(ops: Ops, shape: Shape, *, mean: int = 0) -> FloatsXd: size = int(ops.xp.prod(ops.xp.asarray(shape))) - inits = numpy.random.normal(scale=mean, size=size).astype("float32") + inits = cast(FloatsXd, numpy.random.normal(scale=mean, size=size).astype("float32")) inits = ops.reshape_f(inits, shape) return ops.asarray_f(inits)