Skip to content

Commit

Permalink
Weights data type now also specifies window data type
Browse files Browse the repository at this point in the history
  • Loading branch information
the-lay committed Jan 6, 2022
1 parent a731aab commit 7d75102
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
12 changes: 6 additions & 6 deletions docs/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -2000,7 +2000,7 @@ <h6 id="return">Return</h6>
<span class="sd"> data_dtype (np.dtype): Specify data type for data buffer that stores cumulative result.</span>
<span class="sd"> Default is `np.float32`.</span>

<span class="sd"> weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights.</span>
<span class="sd"> weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights and window array.</span>
<span class="sd"> If you don&#39;t need precision but would rather save memory you can use `np.float16`.</span>
<span class="sd"> Likewise, on the opposite, you can use `np.float64`.</span>
<span class="sd"> Default is `np.float32`.</span>
Expand Down Expand Up @@ -2039,7 +2039,7 @@ <h6 id="return">Return</h6>
<span class="sd"> np.ndarray: n-dimensional window of the given shape and function</span>
<span class="sd"> &quot;&quot;&quot;</span>

<span class="n">w</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span>
<span class="n">w</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">weights_dtype</span><span class="p">)</span>
<span class="n">overlap</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tiler</span><span class="o">.</span><span class="n">_tile_overlap</span>
<span class="k">for</span> <span class="n">axis</span><span class="p">,</span> <span class="n">length</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">shape</span><span class="p">):</span>
<span class="k">if</span> <span class="n">axis</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">tiler</span><span class="o">.</span><span class="n">channel_dimension</span><span class="p">:</span>
Expand Down Expand Up @@ -2095,7 +2095,7 @@ <h6 id="return">Return</h6>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Window function must have the same shape as tile shape.&quot;</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">window</span> <span class="o">=</span> <span class="n">window</span>
<span class="bp">self</span><span class="o">.</span><span class="n">window</span> <span class="o">=</span> <span class="n">window</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">weights_dtype</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Unsupported type for window function (</span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">window</span><span class="p">)</span><span class="si">}</span><span class="s2">), expected str or np.ndarray.&quot;</span>
Expand Down Expand Up @@ -2363,7 +2363,7 @@ <h6 id="return">Return</h6>
<span class="sd"> data_dtype (np.dtype): Specify data type for data buffer that stores cumulative result.</span>
<span class="sd"> Default is `np.float32`.</span>

<span class="sd"> weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights.</span>
<span class="sd"> weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights and window array.</span>
<span class="sd"> If you don&#39;t need precision but would rather save memory you can use `np.float16`.</span>
<span class="sd"> Likewise, on the opposite, you can use `np.float64`.</span>
<span class="sd"> Default is `np.float32`.</span>
Expand Down Expand Up @@ -2424,7 +2424,7 @@ <h6 id="args">Args</h6>
<code>self.data_visits</code>. Can be disabled to save some memory. Default is <code>True</code>.</li>
<li><strong>data_dtype (np.dtype):</strong> Specify data type for data buffer that stores cumulative result.
Default is <code>np.float32</code>.</li>
<li><strong>weights_dtype (np.dtype):</strong> Specify data type for weights buffer that stores cumulative weights.
<li><strong>weights_dtype (np.dtype):</strong> Specify data type for weights buffer that stores cumulative weights and window array.
If you don't need precision but would rather save memory you can use <code>np.float16</code>.
Likewise, on the opposite, you can use <code>np.float64</code>.
Default is <code>np.float32</code>.</li>
Expand Down Expand Up @@ -2516,7 +2516,7 @@ <h6 id="args">Args</h6>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Window function must have the same shape as tile shape.&quot;</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">window</span> <span class="o">=</span> <span class="n">window</span>
<span class="bp">self</span><span class="o">.</span><span class="n">window</span> <span class="o">=</span> <span class="n">window</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">weights_dtype</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Unsupported type for window function (</span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">window</span><span class="p">)</span><span class="si">}</span><span class="s2">), expected str or np.ndarray.&quot;</span>
Expand Down
8 changes: 7 additions & 1 deletion tests/test_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,17 @@ def test_init(self):
self.assert_(merger3.data_visits is not None)

# Check data and weights dtypes
merger4 = Merger(tiler=tiler, data_dtype=np.float32, weights_dtype=np.float32)
merger4 = Merger(
tiler=tiler,
data_dtype=np.float32,
weights_dtype=np.float32,
window="boxcar",
)
self.assertEqual(merger4.data.dtype, np.float32)
self.assertEqual(merger4.data_dtype, np.float32)
self.assertEqual(merger4.weights_sum.dtype, np.float32)
self.assertEqual(merger4.weights_dtype, np.float32)
self.assertEqual(merger4.window.dtype, np.float32)

def test_add(self):
tiler = Tiler(data_shape=self.data.shape, tile_shape=(10,))
Expand Down
6 changes: 3 additions & 3 deletions tiler/merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
data_dtype (np.dtype): Specify data type for data buffer that stores cumulative result.
Default is `np.float32`.
weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights.
weights_dtype (np.dtype): Specify data type for weights buffer that stores cumulative weights and window array.
If you don't need precision but would rather save memory you can use `np.float16`.
Likewise, on the opposite, you can use `np.float64`.
Default is `np.float32`.
Expand Down Expand Up @@ -133,7 +133,7 @@ def _generate_window(self, window: str, shape: Union[Tuple, List]) -> np.ndarray
np.ndarray: n-dimensional window of the given shape and function
"""

w = np.ones(shape)
w = np.ones(shape, dtype=self.weights_dtype)
overlap = self.tiler._tile_overlap
for axis, length in enumerate(shape):
if axis == self.tiler.channel_dimension:
Expand Down Expand Up @@ -189,7 +189,7 @@ def set_window(self, window: Union[None, str, np.ndarray] = None) -> None:
raise ValueError(
f"Window function must have the same shape as tile shape."
)
self.window = window
self.window = window.astype(self.weights_dtype)
else:
raise ValueError(
f"Unsupported type for window function ({type(window)}), expected str or np.ndarray."
Expand Down

0 comments on commit 7d75102

Please sign in to comment.