Skip to content

Commit ee9f35c

Browse files
committed
adds init from scalar to Dataset
1 parent 70997ef commit ee9f35c

File tree

4 files changed

+82
-8
lines changed

4 files changed

+82
-8
lines changed

xarray/core/merge.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,12 @@ def collect_variables_and_indexes(
333333
indexes = {}
334334

335335
grouped: dict[Hashable, list[MergeElement]] = defaultdict(list)
336+
sizes: dict[Hashable, int] = {
337+
k: v
338+
for i in list_of_mappings
339+
for j in i.values()
340+
for k, v in getattr(j, "sizes", {}).items()
341+
}
336342

337343
def append(name, variable, index):
338344
grouped[name].append((variable, index))
@@ -355,7 +361,7 @@ def append_all(variables, indexes):
355361
indexes_.pop(name, None)
356362
append_all(coords_, indexes_)
357363

358-
variable = as_variable(variable, name=name, auto_convert=False)
364+
variable = as_variable(variable, name=name, auto_convert=False, sizes=sizes)
359365
if name in indexes:
360366
append(name, variable, indexes[name])
361367
elif variable.dims == (name,):

xarray/core/variable.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@ class MissingDimensionsError(ValueError):
8686

8787

8888
def as_variable(
89-
obj: T_DuckArray | Any, name=None, auto_convert: bool = True
89+
obj: T_DuckArray | Any,
90+
name=None,
91+
auto_convert: bool = True,
92+
sizes: Mapping | None = None,
9093
) -> Variable | IndexVariable:
9194
"""Convert an object into a Variable.
9295
@@ -127,24 +130,44 @@ def as_variable(
127130
if isinstance(obj, Variable):
128131
obj = obj.copy(deep=False)
129132
elif isinstance(obj, tuple):
133+
if len(obj) < 2:
134+
obj += (np.nan,)
130135
try:
131-
dims_, data_, *attrs = obj
136+
dims_, data_, *attrs_ = obj
132137
except ValueError as err:
133138
raise ValueError(
134-
f"Tuple {obj} is not in the form (dims, data[, attrs])"
139+
f"Tuple {obj} is not in the form (dims, [data[, attrs[, encoding]]])"
135140
) from err
136141

137142
if isinstance(data_, DataArray):
138143
raise TypeError(
139144
f"Variable {name!r}: Using a DataArray object to construct a variable is"
140145
" ambiguous, please extract the data using the .data property."
141146
)
147+
148+
if utils.is_scalar(data_, include_0d=True):
149+
try:
150+
shape_ = tuple(sizes[i] for i in dims_)
151+
except TypeError as err:
152+
message = (
153+
f"Variable {name!r}: Could not convert tuple of form "
154+
f"(dims, [data, [attrs, [encoding]]]): {obj} to Variable."
155+
)
156+
raise ValueError(message) from err
157+
except KeyError as err:
158+
message = (
159+
f"Variable {name!r}: Provide `coords` with dimension(s) {dims_} to "
160+
f"initialize with `np.full({dims_}, {data_!r})`."
161+
)
162+
raise ValueError(message) from err
163+
data_ = np.full(shape_, data_)
164+
142165
try:
143-
obj = Variable(dims_, data_, *attrs)
166+
obj = Variable(dims_, data_, *attrs_)
144167
except (TypeError, ValueError) as error:
145168
raise error.__class__(
146169
f"Variable {name!r}: Could not convert tuple of form "
147-
f"(dims, data[, attrs, encoding]): {obj} to Variable."
170+
f"(dims, [data, [attrs, [encoding]]]): {obj} to Variable."
148171
) from error
149172
elif utils.is_scalar(obj):
150173
obj = Variable([], obj)

xarray/tests/test_dataset.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ def test_constructor(self) -> None:
476476

477477
with pytest.raises(ValueError, match=r"conflicting sizes"):
478478
Dataset({"a": x1, "b": x2})
479-
with pytest.raises(TypeError, match=r"tuple of form"):
479+
with pytest.raises(ValueError, match=r"tuple of form"):
480480
Dataset({"x": (1, 2, 3, 4, 5, 6, 7)})
481481
with pytest.raises(ValueError, match=r"already exists as a scalar"):
482482
Dataset({"x": 0, "y": ("x", [1, 2, 3])})
@@ -527,6 +527,51 @@ class Arbitrary:
527527
actual = Dataset({"x": arg})
528528
assert_identical(expected, actual)
529529

530+
def test_constructor_scalar(self) -> None:
531+
fill_value = np.nan
532+
x = np.arange(2)
533+
a = {"foo": "bar"}
534+
535+
# a suitable `coords`` argument is required
536+
with pytest.raises(ValueError):
537+
Dataset({"f": (["x"], fill_value), "x": x})
538+
539+
# 1d coordinates
540+
expected = Dataset(
541+
{
542+
"f": DataArray(fill_value, dims=["x"], coords={"x": x}),
543+
},
544+
)
545+
for actual in (
546+
Dataset({"f": (["x"], fill_value)}, coords=expected.coords),
547+
Dataset({"f": (["x"], fill_value)}, coords={"x": x}),
548+
Dataset({"f": (["x"],)}, coords=expected.coords),
549+
Dataset({"f": (["x"],)}, coords={"x": x}),
550+
):
551+
assert_identical(expected, actual)
552+
expected["f"].attrs.update(a)
553+
actual = Dataset({"f": (["x"], fill_value, a)}, coords={"x": x})
554+
assert_identical(expected, actual)
555+
556+
# 2d coordinates
557+
yx = np.arange(6).reshape(2, -1)
558+
try:
559+
# TODO(itcarroll): aux coords broken in DataArray from scalar
560+
array = DataArray(
561+
fill_value, dims=["y", "x"], coords={"lat": (["y", "x"], yx)}
562+
)
563+
expected = Dataset({"f": array})
564+
except ValueError:
565+
expected = Dataset(
566+
data_vars={"f": (["y", "x"], np.full(yx.shape, fill_value))},
567+
coords={"lat": (["y", "x"], yx)},
568+
)
569+
actual = Dataset(
570+
{"f": (["y", "x"], fill_value)},
571+
coords=expected.coords,
572+
)
573+
assert_identical(expected, actual)
574+
530575
def test_constructor_auto_align(self) -> None:
531576
a = DataArray([1, 2], [("x", [0, 1])])
532577
b = DataArray([3, 4], [("x", [1, 2])])

xarray/tests/test_variable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1220,7 +1220,7 @@ def test_as_variable(self):
12201220
)
12211221
assert_identical(expected_extra, as_variable(xarray_tuple))
12221222

1223-
with pytest.raises(TypeError, match=r"tuple of form"):
1223+
with pytest.raises(ValueError, match=r"tuple of form"):
12241224
as_variable(tuple(data))
12251225
with pytest.raises(ValueError, match=r"tuple of form"): # GH1016
12261226
as_variable(("five", "six", "seven"))

0 commit comments

Comments
 (0)