Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hitting 2GB protobuf limit in sparse transforms #177

Closed
atait opened this issue Sep 2, 2020 · 3 comments
Closed

Hitting 2GB protobuf limit in sparse transforms #177

atait opened this issue Sep 2, 2020 · 3 comments

Comments

@atait
Copy link

atait commented Sep 2, 2020

This is related to #160, which was solved, so I'm hoping the fix will be manageable. Before, it was a 10k x 10k Dense transform; now, it is a Sparse transform with 500M nonzeros. Is it possible to use the approach of #163 and apply it to these lines?

self.sparse_indices = tf.constant(

Here is the command, the trace, and some debug lines ("ipdb">) pointing out the offending variable

(tf) atait@renobuntu:examplesipython --pdb benchmark_wattsstrogatz.py dl 50_000_000
This autorun file is: /home/atait/.ipython/profile_default/startup/0-autoreload.py
autoreload activated
INFO:nengo_dl.simulator:Running on CPU/GPU
Build finished in 0:00:06                                                                                                                                              
|#                                                                     Optimizing graph                                                                       | 0:00:00INFO:nengo_dl.tensor_graph:Initial plan length: 13
INFO:nengo_dl.tensor_graph:Optimized plan length: 7:02                                                                                                                 
INFO:nengo_dl.tensor_graph:Number of base arrays: (trainable, 0), (non_trainable, 1), (state, 9)
Optimization finished in 0:00:08                                                                                                                                       
|                                                                     Constructing graph                  #                                                   | 0:00:10
OrderedDict([('benchmark', 'watts-strogatz'), ('name', 'dl'), ('n_neurons', 50000000), ('simtime', 1.0), ('status', 'exception'), ('exception', 'Cannot create a tensor proto whose content is larger than 2GB.')])
dl, n_neurons=50000000 exception
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
~/Documents/git-research/nengo-suite/nengo-ocl/examples/benchmark_wattsstrogatz.py in <module>
    135 
    136         # -- build
--> 137         with sim_class(model, **sim_kwargs) as sim:
    138             # sim._probe_outputs[E_p] = EventBasedSignal(channels=n_neurons, numba_step=True)
    139             # profiler.add_function(sim._probe_outputs[E_p].step)

~/Documents/git-research/nengo-suite/nengo-dl/nengo_dl/simulator.py in __init__(self, network, dt, seed, model, device, unroll_simulation, minibatch_size, progress_bar)
    530             "Constructing graph", "Construction", max_value=None
    531         ) as progress:
--> 532             self._build_keras(progress)
    533 
    534         # initialize sim attributes

~/miniconda3/envs/tf/lib/python3.7/site-packages/nengo/utils/magic.py in __call__(self, *args, **kwargs)
    179                 return self.wrapper(wrapped, instance, args, kwargs)
    180             else:
--> 181                 return self.wrapper(self.__wrapped__, self.instance, args, kwargs)
    182         else:
    183             instance = getattr(self.__wrapped__, "__self__", None)

~/Documents/git-research/nengo-suite/nengo-dl/nengo_dl/simulator.py in with_self(wrapped, instance, args, kwargs)
     56     try:
     57         with tf.device(instance.tensor_graph.device):
---> 58             output = wrapped(*args, **kwargs)
     59     finally:
     60         tf.keras.backend.set_floatx(keras_dtype)

~/Documents/git-research/nengo-suite/nengo-dl/nengo_dl/simulator.py in _build_keras(self, progress)
    553         inputs = list(self.node_inputs.values()) + [n_steps]
    554 
--> 555         outputs = self.tensor_graph(inputs, stateful=self.stateful, progress=progress)
    556 
    557         self.keras_model = tf.keras.Model(

~/miniconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
    924     if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
    925       return self._functional_construction_call(inputs, args, kwargs,
--> 926                                                 input_list)
    927 
    928     # Maintains info about the `Layer.call` stack.

~/miniconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py in _functional_construction_call(self, inputs, args, kwargs, input_list)
   1115           try:
   1116             with ops.enable_auto_cast_variables(self._compute_dtype_object):
-> 1117               outputs = call_fn(cast_inputs, *args, **kwargs)
   1118 
   1119           except errors.OperatorNotAllowedInGraphError as e:

~/miniconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    300   def wrapper(*args, **kwargs):
    301     with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
--> 302       return func(*args, **kwargs)
    303 
    304   if inspect.isfunction(func) or inspect.ismethod(func):

~/Documents/git-research/nengo-suite/nengo-dl/nengo_dl/tensor_graph.py in call(self, inputs, training, progress, stateful)
    472         # pre-build stage
    473         with progress.sub("pre-build stage", max_value=len(self.plan)) as sub:
--> 474             self.op_builder.build_pre(self.signals, build_config, sub)
    475 
    476         # build stage

~/Documents/git-research/nengo-suite/nengo-dl/nengo_dl/builder.py in build_pre(self, signals, config, progress)
     68 
     69             with self.name_scope(ops):
---> 70                 self.op_builds[ops].build_pre(signals, config)
     71 
     72             if progress is not None:

~/Documents/git-research/nengo-suite/nengo-dl/nengo_dl/op_builders.py in build_pre(self, signals, config)
    508                 tf.int32
    509                 if np.all(sparse_indices < np.iinfo(np.int32).max)
--> 510                 else tf.int64
    511             ),
    512         )

~/miniconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/framework/constant_op.py in constant(value, dtype, shape, name)
    262   """
    263   return _constant_impl(value, dtype, shape, name, verify_shape=False,
--> 264                         allow_broadcast=True)
    265 
    266 

~/miniconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/framework/constant_op.py in _constant_impl(value, dtype, shape, name, verify_shape, allow_broadcast)
    280       tensor_util.make_tensor_proto(
    281           value, dtype=dtype, shape=shape, verify_shape=verify_shape,
--> 282           allow_broadcast=allow_broadcast))
    283   dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
    284   attrs = {"value": tensor_value, "dtype": dtype_value}

~/miniconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape, allow_broadcast)
    525     if nparray.size * nparray.itemsize >= (1 << 31):
    526       raise ValueError(
--> 527           "Cannot create a tensor proto whose content is larger than 2GB.")
    528     tensor_proto.tensor_content = nparray.tobytes()
    529     return tensor_proto

ValueError: Cannot create a tensor proto whose content is larger than 2GB.
> /home/atait/miniconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/framework/tensor_util.py(527)make_tensor_proto()
    525     if nparray.size * nparray.itemsize >= (1 << 31):
    526       raise ValueError(
--> 527           "Cannot create a tensor proto whose content is larger than 2GB.")
    528     tensor_proto.tensor_content = nparray.tobytes()
    529     return tensor_proto

ipdb> up
> /home/atait/miniconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/framework/constant_op.py(282)_constant_impl()
    280       tensor_util.make_tensor_proto(
    281           value, dtype=dtype, shape=shape, verify_shape=verify_shape,
--> 282           allow_broadcast=allow_broadcast))
    283   dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
    284   attrs = {"value": tensor_value, "dtype": dtype_value}

ipdb> up
> /home/atait/miniconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/framework/constant_op.py(264)constant()
    262   """
    263   return _constant_impl(value, dtype, shape, name, verify_shape=False,
--> 264                         allow_broadcast=True)
    265 
    266 

ipdb> up
> /home/atait/Documents/git-research/nengo-suite/nengo-dl/nengo_dl/op_builders.py(510)build_pre()
    508                 tf.int32
    509                 if np.all(sparse_indices < np.iinfo(np.int32).max)
--> 510                 else tf.int64
    511             ),
    512         )

ipdb> sparse_indices.nbytes / 1e9
4.0
ipdb> 

TF 2.3.0,
Nengo DL master,
Nengo 3.0.0 PyPI

@drasmuss
Copy link
Member

drasmuss commented Sep 2, 2020

I looked into this a bit, but there's not an obvious solution. The issue is that during the Keras model construction it's building a symbolic graph, which happens in TensorFlow's graph mode (which has the 2GB limit). The bottleneck in #160 occurs in a different (eager) part of the process, which is why we could resolve that by switching to eager mode.

You could work around this by manually splitting up your matrix into smaller pieces. Something like

        weimat = wattsstrogatz_adjacencies(n_neurons)
        n_split = 5
        split_neurons = n_neurons // n_split
        for i in range(n_split):
            split_weimat = weimat[i * split_neurons : (i + 1) * split_neurons]
            if sparse:
                transform = nengo.transforms.Sparse(
                    (split_neurons, n_neurons),
                    init=split_weimat,
                )
            else:
                transform = split_weimat.toarray()
            nengo.Connection(
                ens.neurons,
                ens.neurons[i * split_neurons : (i + 1) * split_neurons],
                transform=transform,
                synapse=0.1,
            )

(caveat: I haven't tested this thoroughly)

Note that when you do this you will also need to disable the operator merging (nengo_dl.configure_settings(planner=nengo_dl.graph_optimizer.noop_planner), otherwise NengoDL will helpfully re-combine those split up sparse matrices into one big one 😉 .

You will pay some performance penalty when splitting up connections like that, but hopefully not too bad. There isn't too much else we can do until TensorFlow does something about that underlying 2GB limit I don't think.

@atait
Copy link
Author

atait commented Sep 2, 2020

That's a good idea. It will work for me. I think the penalty will come at build time, not run time, and I can live with that.

I have read a bunch of Stack Overflows and developer forums about this protobuf limit, and I still don't understand why this data structure is being used by TF. What are the advantages? Why aren't they outweighed by the obvious disadvantage? Are there no 2GB constants used in machine learning? That's not a Nengo DL thing, so we don't have to answer it here, but any insight from a neuromorphic-getting-into-ML perspective would be appreciated.

For Nengo DL, all I would suggest is putting that allocation within a try/except to give the user more information.

Thanks

@drasmuss
Copy link
Member

The plan is to add some documentation about the 2GB protobuf limit and possible workarounds as part of the same memory "tips and tricks" documentation discussed in #178. Going to close this so that we have a single place to track any updates there, but feel free to reopen if there is anything else that isn't addressed!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

No branches or pull requests

2 participants