Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNT Support numpy2 #429

Merged
merged 18 commits into from
Jul 3, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add missing file
adrinjalali committed Jun 25, 2024
commit 70349bab6534e2d8235a57e5c09feede18b5c8fe
44 changes: 44 additions & 0 deletions skops/io/old/_numpy_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

from typing import Any, Optional, Sequence

import numpy as np

from .._audit import Node, get_tree
from .._utils import LoadContext, gettype

PROTOCOL = 1


class RandomGeneratorNode(Node):
def __init__(
self,
state: dict[str, Any],
load_context: LoadContext,
trusted: Optional[Sequence[str]] = None,
) -> None:
super().__init__(state, load_context, trusted)
self.children = {
"bit_generator_state": get_tree(
state["content"]["bit_generator"], load_context, trusted=trusted
)
}
self.trusted = self._get_trusted(trusted, [np.random.Generator])

def _construct(self):
# first restore the state of the bit generator
bit_generator_state = self.children["bit_generator_state"].construct()
bit_generator_cls = gettype(
"numpy.random", bit_generator_state["bit_generator"]
)
bit_generator = bit_generator_cls()
bit_generator.state = bit_generator_state

# next create the generator instance
return gettype(self.module_name, self.class_name)(bit_generator=bit_generator)


# tuples of type and function that creates the instance of that type
NODE_TYPE_MAPPING = {
("RandomGeneratorNode", PROTOCOL): RandomGeneratorNode,
}