Skip to content
This repository has been archived by the owner on Jun 14, 2023. It is now read-only.

TypeError: 'NoneType' object is not callable #21

Open
putuoka opened this issue Jul 26, 2022 · 0 comments
Open

TypeError: 'NoneType' object is not callable #21

putuoka opened this issue Jul 26, 2022 · 0 comments

Comments

@putuoka
Copy link

putuoka commented Jul 26, 2022


class FairSlipLoaderBase(BaseMmcLoader):
    """
    SLIP models via https://github.com/facebookresearch/SLIP
    """
    def __init__(
        self,
        id,
        architecture,
    ):
        self.architecture = architecture
        self.publisher = 'facebookresearch'
        self.id = id
        self.modalities = (TEXT, IMAGE)
    def _napm_install(self):
        logger.debug('using napm to "install" facebookresearch/SLIP')
        url = "https://github.com/facebookresearch/SLIP"
        napm.pseudoinstall_git_repo(url, env_name='mmc', add_install_dir_to_path=True)
        napm.populate_pythonpaths('mmc')
        from SLIP.models import (
            SLIP_VITS16,
            SLIP_VITB16, 
            SLIP_VITL16
            )

    def load(self, device=DEVICE):
        """
        Returns the MMC associated with this loader.
        """
        self._napm_install()

        model_factory = model_factory_from_id(self.id)
        logger.debug(f"model_factory: {model_factory}")
        ckpt_url = url_from_id(self.id)
        ckpt = fetch_weights(
            url=ckpt_url, 
            namespace='fair_slip', 
            device=device,
            )
        d_args = vars(ckpt['args'])
        kwargs = {k:d_args[k] for k in ('ssl_emb_dim', 'ssl_mlp_dim') if k in d_args}
        logger.debug(kwargs)
        fix_param_names(ckpt)
        model = model_factory(**kwargs)
        model.load_state_dict(ckpt['state_dict'], strict=True)
        model = model.eval().to(device)

        from SLIP.tokenizer import SimpleTokenizer
        tokenizer = SimpleTokenizer()

        def preprocess_image_extended(*args, **kwargs):
            x = val_transform(*args, **kwargs)
            if x.ndim == 3:
                logger.debug("adding batch dimension")
                x = x.unsqueeze(0)
            return x.to(device)
        #logger.debug(model)
        mmc = MultiModalComparator(name=str(self), device=device)
        mmc.register_modality(modality=TEXT, projector=model.encode_text, preprocessor=tokenizer)
        mmc.register_modality(modality=IMAGE, projector=model.encode_image, preprocessor= preprocess_image_extended)
        mmc._model = model
        return mmc

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant