diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 6fd44b0382..ecfa5ff26d 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -596,13 +596,13 @@ def join_nonshared_inputs( raise ValueError("Empty list of input variables.") raveled_inputs = pt.concatenate([var.ravel() for var in inputs]) + size = sum(point[var_name].size for var_name in point) if not make_inputs_shared: - tensor_type = raveled_inputs.type - joined_inputs = tensor_type("joined_inputs") + joined_inputs = pt.tensor("joined_inputs", shape=(size,), dtype=raveled_inputs.dtype) else: joined_values = np.concatenate([point[var.name].ravel() for var in inputs]) - joined_inputs = pytensor.shared(joined_values, "joined_inputs") + joined_inputs = pytensor.shared(joined_values, "joined_inputs", shape=(size,)) if pytensor.config.compute_test_value != "off": joined_inputs.tag.test_value = raveled_inputs.tag.test_value