Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cloudpickle breaks local dataclass #386

Closed
froody opened this issue Jun 22, 2020 · 13 comments · Fixed by #513
Closed

Cloudpickle breaks local dataclass #386

froody opened this issue Jun 22, 2020 · 13 comments · Fixed by #513

Comments

@froody
Copy link

froody commented Jun 22, 2020

Consider the following test, the last assertEqual fails. This fails because the test in dataclasses.fields of f._field_type is _FIELD fails. See https://github.com/python/cpython/blob/3.7/Lib/dataclasses.py#L1028

This is because f._field_type points to a different object than dataclasses._FIELD.

def testCloudpickle(self):
    import cloudpickle
    import dataclasses

    @dataclass
    class potato:
        drink: int

    v = potato(-42)
    ref = {"drink" : -42}

    self.assertEqual(ref, dataclasses.asdict(v))

    t2 = cloudpickle.loads(cloudpickle.dumps(potato))

    v2 = potato(-42)

    self.assertEqual(ref, dataclasses.asdict(v2))
@pierreglaser
Copy link
Member

pierreglaser commented Jun 23, 2020

Hi, Thanks for the report. I see the problem. This may be worth a patch on cpython to define a custom reducer for the _FIELD class. I'm not sure yet about fixing this directly in cloudpickle because this requires manipulating private dataclasses attributes that are likely to change without notice and break things in new Python versions. WDYT @ogrisel?

@ogrisel
Copy link
Contributor

ogrisel commented Jul 2, 2020

The problem is that it's probably useless for CPython as if the dataclass is defined in an importable module, the above problem does not happen. _FIELD being a private class, I see no reason to make it pickleable so the CPython dev might not be interested in adding a reducer. But maybe I am wrong.

I think we will have to deal with a cloudpickle fix that depends on private API an rely on tests to make sure our code tracks the internal changes of the CPython standard library.

@avikchaudhuri
Copy link

@pierreglaser @ogrisel did you reach a resolution on a fix? We're using cloudpickle for a new project that relies pretty heavily on dataclasses for validation and not being able to use locally defined dataclasses is causing hard constraints on the design. Any update would be much appreciated! Thank you.

@tsiq-bertram
Copy link

+1 on this issue. Is there any known good workaround for this?

@jseppanen
Copy link

One simple workaround is to avoid asdict, but write an equivalent method instead, for example:

@dataclass
class potato:
    drink: int

    def asdict(self):
        return {k: getattr(self, k) for k in self.__dataclass_fields__}

@mickare
Copy link

mickare commented Feb 15, 2021

This is still a major issue.

So I have some questions:

In my case this bug destroys the original dataclasses when a ray worker returns a dataclass result.

@jakubkwiatkowski
Copy link

I haven't tested it extensively, but in my case this workaround is working. As @jseppanen suggested I've replaced asdict with equivalent function.

def as_dict(obj, *, dict_factory=dict):  
    if not _is_dataclass_instance(obj):
        raise TypeError("asdict() should be called on dataclass instances")
    return as_dict_inner(obj, dict_factory)

def as_dict_inner(obj, dict_factory=dict):
    if dataclasses.is_dataclass(obj):
        result = []
        for f in obj.__dict__:
            value = as_dict_inner(getattr(obj, f), dict_factory)
            result.append((f, value))
        return dict_factory(result)
    elif isinstance(obj, tuple) and hasattr(obj, '_fields'):

        return type(obj)(*[as_dict_inner(v, dict_factory) for v in obj])
    elif isinstance(obj, (list, tuple)):
        return type(obj)(as_dict_inner(v, dict_factory) for v in obj)
    elif isinstance(obj, dict):
        return type(obj)((as_dict_inner(k, dict_factory),
                          as_dict_inner(v, dict_factory))
                         for k, v in obj.items())
    else:
        return copy.deepcopy(obj)

@omry
Copy link

omry commented May 18, 2021

I am not 100% sure it's the same issue, but deserializing a dataclass right now is breaking the existing dataclass if it's defined in the same process (in the __main__ module)

import cloudpickle
from dataclasses import dataclass
import dataclasses


@dataclass
class Test:
    dim: int = 1

print(dataclasses.fields(Test))
_unused_deserialized_class = cloudpickle.loads(cloudpickle.dumps(Test))
print(dataclasses.fields(Test))

Output:

(Field(name='dim',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object at 0x7f2ecc1629d0>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),_field_type=_FIELD),)
()

@yodahuang
Copy link

In the meanwhile, an alternative solution to calling dataclasses.fields is to use __dataclass_fields__ field, though that returns a dict instead of tuple.

@salim7
Copy link

salim7 commented Jan 6, 2022

Here is another workaround, that works exactly like dataclasses.fields:

def dataclass_fields(class_or_instance):
    """This function is based on dataclasses.fields(), but contains a workaround
    for https://github.com/cloudpipe/cloudpickle/issues/386
    """

    try:
        fields = getattr(class_or_instance, dataclasses._FIELDS)
    except AttributeError:
        raise TypeError('must be called with a dataclass type or instance')

    return tuple(f for f in fields.values() if f._field_type.name == dataclasses._FIELD.name)

@simon-bachhuber
Copy link

Recently discovered and posted a connected issue here.

Any workarounds that does not involve changing the way the dataclass is defined (can not change source code) but can be used "after the fact"?

@rmorshea
Copy link
Contributor

rmorshea commented Aug 18, 2023

Here's a solution using a subclass of CloudPickler:

import cloudpickle
import io
import dataclasses
from dataclasses import fields, dataclass, _FIELD_BASE


def _get_dataclass_field_sentinel(name):
    """Return a sentinel object for a dataclass field."""
    return getattr(dataclasses, name)


class PatchedCloudPickler(cloudpickle.CloudPickler):
    def reducer_override(self, obj):
        """Custom reducer for MyClass."""
        if isinstance(obj, _FIELD_BASE):
            return _get_dataclass_field_sentinel, (obj.name,)
        return super().reducer_override(obj)


def dumps(value, protocol=None):
    with io.BytesIO() as file:
        PatchedCloudPickler(file, protocol).dump(value)
        return file.getvalue()


@dataclass
class InClass:
    a: int
    b: int


OutClass = cloudpickle.loads(dumps(InClass))
assert fields(OutClass)

If you need to you can monkey-patch cloudpickle:

cloudpickle.fast_cloudpickle.CloudPickler = PatchedCloudPickler

If the isinstance check has negative performance implications it might be more optimal to check is _FIELD, is _FIELD_INITVAR etc.

@rmorshea
Copy link
Contributor

I posted a potential fix here: #513

Would be great to get some feedback on it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.