From 94d2f8e93f46ef66783059857f2e899cb034253f Mon Sep 17 00:00:00 2001 From: Kevin Kenyon Date: Thu, 11 Jul 2024 19:50:55 -0500 Subject: [PATCH 1/3] Add simple pad on view --- shrimpgrad/examples/conv2d.ipynb | 209 +++++++++++++++++++++++++++++++ shrimpgrad/view.py | 20 ++- test/test_view.py | 8 +- 3 files changed, 233 insertions(+), 4 deletions(-) create mode 100644 shrimpgrad/examples/conv2d.ipynb diff --git a/shrimpgrad/examples/conv2d.ipynb b/shrimpgrad/examples/conv2d.ipynb new file mode 100644 index 0000000..c87415d --- /dev/null +++ b/shrimpgrad/examples/conv2d.ipynb @@ -0,0 +1,209 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 4, + "id": "17a3775a-01c7-48d6-b898-d5e71251738e", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn.functional as F" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "5dbf1936-a2d9-4b60-8377-e69d7c73ade3", + "metadata": {}, + "outputs": [], + "source": [ + "x = torch.ones(2,2,2,2)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8f8f07d9-e857-43c9-bdf9-ee643c6db6fc", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[[1., 1.],\n", + " [1., 1.]],\n", + "\n", + " [[1., 1.],\n", + " [1., 1.]]],\n", + "\n", + "\n", + " [[[1., 1.],\n", + " [1., 1.]],\n", + "\n", + " [[1., 1.],\n", + " [1., 1.]]]])" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "6f4886b3-4943-4f8b-95e3-6451bdc6a860", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[[0., 0.],\n", + " [0., 0.]],\n", + "\n", + " [[0., 0.],\n", + " [0., 0.]],\n", + "\n", + " [[0., 0.],\n", + " [0., 0.]],\n", + "\n", + " [[0., 0.],\n", + " [0., 0.]]],\n", + "\n", + "\n", + " [[[0., 0.],\n", + " [0., 0.]],\n", + "\n", + " [[1., 1.],\n", + " [1., 1.]],\n", + "\n", + " [[1., 1.],\n", + " [1., 1.]],\n", + "\n", + " [[0., 0.],\n", + " [0., 0.]]],\n", + "\n", + "\n", + " [[[0., 0.],\n", + " [0., 0.]],\n", + "\n", + " [[1., 1.],\n", + " [1., 1.]],\n", + "\n", + " [[1., 1.],\n", + " [1., 1.]],\n", + "\n", + " [[0., 0.],\n", + " [0., 0.]]],\n", + "\n", + "\n", + " [[[0., 0.],\n", + " [0., 0.]],\n", + "\n", + " [[0., 0.],\n", + " [0., 0.]],\n", + "\n", + " [[0., 0.],\n", + " [0., 0.]],\n", + "\n", + " [[0., 0.],\n", + " [0., 0.]]]])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "F.pad(x, (0,0,0,0,1,1,1,1), 'constant', 0)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "c670712f-5929-412f-8d0c-9c156c421be7", + "metadata": {}, + "outputs": [ + { + "ename": "RuntimeError", + "evalue": "Padding length should be less than or equal to two times the input dimension but got padding length 12 and input of dimension 4", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[17], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m y \u001b[38;5;241m=\u001b[39m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpad\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m y\n", + "File \u001b[0;32m/nix/store/7rlyhybbpanqaibk6682gq63zqljsc61-python3.11-torch-2.3.0/lib/python3.11/site-packages/torch/nn/functional.py:4522\u001b[0m, in \u001b[0;36mpad\u001b[0;34m(input, pad, mode, value)\u001b[0m\n\u001b[1;32m 4515\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m mode \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mreplicate\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 4516\u001b[0m \u001b[38;5;66;03m# Use slow decomp whose backward will be in terms of index_put.\u001b[39;00m\n\u001b[1;32m 4517\u001b[0m \u001b[38;5;66;03m# importlib is required because the import cannot be top level\u001b[39;00m\n\u001b[1;32m 4518\u001b[0m \u001b[38;5;66;03m# (cycle) and cannot be nested (TS doesn't support)\u001b[39;00m\n\u001b[1;32m 4519\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m importlib\u001b[38;5;241m.\u001b[39mimport_module(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtorch._decomp.decompositions\u001b[39m\u001b[38;5;124m'\u001b[39m)\u001b[38;5;241m.\u001b[39m_replication_pad(\n\u001b[1;32m 4520\u001b[0m \u001b[38;5;28minput\u001b[39m, pad\n\u001b[1;32m 4521\u001b[0m )\n\u001b[0;32m-> 4522\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_nn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpad\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpad\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mRuntimeError\u001b[0m: Padding length should be less than or equal to two times the input dimension but got padding length 12 and input of dimension 4" + ] + } + ], + "source": [ + "y = F.pad(x, (1,1,1,1,1,1,1,1,1,1,1,1))\n", + "y" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "b45a340d-2e81-4832-869c-1787aa374733", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 4, 4, 4])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y.size()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "adae1784-1684-4d59-877b-f84d28320b25", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6bfcf63f-448f-4e9a-a659-f8c5aeb3ffb9", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/shrimpgrad/view.py b/shrimpgrad/view.py index 85bb15a..456d0fd 100644 --- a/shrimpgrad/view.py +++ b/shrimpgrad/view.py @@ -47,6 +47,9 @@ def expand(self, new_shape: Tuple[int,...]) -> ViewTracker: def permute(self, order: Tuple[int,...]) -> ViewTracker: return ViewTracker.from_views(self.views[0:-1] + [self.view.permute(order)] ) + def pad(self, pad_width: Tuple[Tuple[int,int], ...]) -> ViewTracker: + return ViewTracker.from_views(self.views[0:-1] + [self.view.pad(pad_width)]) + @staticmethod def from_views(views: List[View]) -> ViewTracker: return ViewTracker(views) @@ -62,9 +65,10 @@ def __repr__(self) -> str: class View: """A description of how a thunk's data is interpreted """ - def __init__(self, shape: Tuple[int,...]): + def __init__(self, shape: Tuple[int,...], mask=None): self.shape = shape self._strides = tuple(accumulate(self.shape[-1:0:-1], func=operator.mul, initial=(1 if len(self.shape)else None)))[::-1] + self.mask = mask @property def strides(self) -> Tuple[int,...]: return self._strides @@ -147,6 +151,20 @@ def expand(self, shape: Tuple[int,...]) -> View: out._strides = tuple(strd) return out + def pad(self, pad_width: Tuple[Tuple[int,int],...]): + assert all(s >= 0 and e >= 0 for s,e in pad_width), "pad_width must all be >= 0" + assert len(pad_width) == self.ndim, f'pad_width length must equal view ndim: {len(pad_width) != self.ndim}' + + # No padding needed + if all(s == 0 and e == 0 for s,e in pad_width): return self + new_shape = list(self.shape) + mask = [None]*self.ndim + for i, ((pad_start, pad_end), shp) in enumerate(zip(pad_width, self.shape)): + new_shape[i] += pad_start + pad_end + # start index of non-padded values, end value of non-padded values + mask[i] = (pad_start, shp + pad_start) + return View(tuple(new_shape), tuple(mask)) + @staticmethod def from_view(view: View): return View(view.shape) diff --git a/test/test_view.py b/test/test_view.py index 19d6c40..0f68f9d 100644 --- a/test/test_view.py +++ b/test/test_view.py @@ -29,6 +29,8 @@ def test_reshape_permute_reshape_expand(self): assert vt.shape == (1,1,2,2) assert vt.strides == (0,0,1,2) - - - + def test_pad(self): + vt = ViewTracker.from_shape((2,2)) + vt = vt.pad(((1,1),(0,0))) + assert vt.shape == (4,2) + assert vt.view.mask == ((1,3),(0,2)) \ No newline at end of file From 5e2b02f2b79d910db1a24b6c77e45cef70175820 Mon Sep 17 00:00:00 2001 From: Kevin Kenyon Date: Fri, 12 Jul 2024 15:48:12 -0500 Subject: [PATCH 2/3] Add shrink and update view --- shrimpgrad/examples/conv2d.ipynb | 209 ------------------------------- shrimpgrad/view.py | 84 ++++++++----- test/test_view.py | 44 ++++++- test/test_viewtracker.py | 2 +- 4 files changed, 96 insertions(+), 243 deletions(-) delete mode 100644 shrimpgrad/examples/conv2d.ipynb diff --git a/shrimpgrad/examples/conv2d.ipynb b/shrimpgrad/examples/conv2d.ipynb deleted file mode 100644 index c87415d..0000000 --- a/shrimpgrad/examples/conv2d.ipynb +++ /dev/null @@ -1,209 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 4, - "id": "17a3775a-01c7-48d6-b898-d5e71251738e", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn.functional as F" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "5dbf1936-a2d9-4b60-8377-e69d7c73ade3", - "metadata": {}, - "outputs": [], - "source": [ - "x = torch.ones(2,2,2,2)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "8f8f07d9-e857-43c9-bdf9-ee643c6db6fc", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[[1., 1.],\n", - " [1., 1.]],\n", - "\n", - " [[1., 1.],\n", - " [1., 1.]]],\n", - "\n", - "\n", - " [[[1., 1.],\n", - " [1., 1.]],\n", - "\n", - " [[1., 1.],\n", - " [1., 1.]]]])" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "6f4886b3-4943-4f8b-95e3-6451bdc6a860", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[[0., 0.],\n", - " [0., 0.]],\n", - "\n", - " [[0., 0.],\n", - " [0., 0.]],\n", - "\n", - " [[0., 0.],\n", - " [0., 0.]],\n", - "\n", - " [[0., 0.],\n", - " [0., 0.]]],\n", - "\n", - "\n", - " [[[0., 0.],\n", - " [0., 0.]],\n", - "\n", - " [[1., 1.],\n", - " [1., 1.]],\n", - "\n", - " [[1., 1.],\n", - " [1., 1.]],\n", - "\n", - " [[0., 0.],\n", - " [0., 0.]]],\n", - "\n", - "\n", - " [[[0., 0.],\n", - " [0., 0.]],\n", - "\n", - " [[1., 1.],\n", - " [1., 1.]],\n", - "\n", - " [[1., 1.],\n", - " [1., 1.]],\n", - "\n", - " [[0., 0.],\n", - " [0., 0.]]],\n", - "\n", - "\n", - " [[[0., 0.],\n", - " [0., 0.]],\n", - "\n", - " [[0., 0.],\n", - " [0., 0.]],\n", - "\n", - " [[0., 0.],\n", - " [0., 0.]],\n", - "\n", - " [[0., 0.],\n", - " [0., 0.]]]])" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "F.pad(x, (0,0,0,0,1,1,1,1), 'constant', 0)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "c670712f-5929-412f-8d0c-9c156c421be7", - "metadata": {}, - "outputs": [ - { - "ename": "RuntimeError", - "evalue": "Padding length should be less than or equal to two times the input dimension but got padding length 12 and input of dimension 4", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[17], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m y \u001b[38;5;241m=\u001b[39m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpad\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m y\n", - "File \u001b[0;32m/nix/store/7rlyhybbpanqaibk6682gq63zqljsc61-python3.11-torch-2.3.0/lib/python3.11/site-packages/torch/nn/functional.py:4522\u001b[0m, in \u001b[0;36mpad\u001b[0;34m(input, pad, mode, value)\u001b[0m\n\u001b[1;32m 4515\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m mode \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mreplicate\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 4516\u001b[0m \u001b[38;5;66;03m# Use slow decomp whose backward will be in terms of index_put.\u001b[39;00m\n\u001b[1;32m 4517\u001b[0m \u001b[38;5;66;03m# importlib is required because the import cannot be top level\u001b[39;00m\n\u001b[1;32m 4518\u001b[0m \u001b[38;5;66;03m# (cycle) and cannot be nested (TS doesn't support)\u001b[39;00m\n\u001b[1;32m 4519\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m importlib\u001b[38;5;241m.\u001b[39mimport_module(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtorch._decomp.decompositions\u001b[39m\u001b[38;5;124m'\u001b[39m)\u001b[38;5;241m.\u001b[39m_replication_pad(\n\u001b[1;32m 4520\u001b[0m \u001b[38;5;28minput\u001b[39m, pad\n\u001b[1;32m 4521\u001b[0m )\n\u001b[0;32m-> 4522\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_nn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpad\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpad\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[0;31mRuntimeError\u001b[0m: Padding length should be less than or equal to two times the input dimension but got padding length 12 and input of dimension 4" - ] - } - ], - "source": [ - "y = F.pad(x, (1,1,1,1,1,1,1,1,1,1,1,1))\n", - "y" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "b45a340d-2e81-4832-869c-1787aa374733", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 4, 4, 4])" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "y.size()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "adae1784-1684-4d59-877b-f84d28320b25", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6bfcf63f-448f-4e9a-a659-f8c5aeb3ffb9", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/shrimpgrad/view.py b/shrimpgrad/view.py index 456d0fd..dcb6100 100644 --- a/shrimpgrad/view.py +++ b/shrimpgrad/view.py @@ -1,7 +1,8 @@ from __future__ import annotations -from itertools import accumulate +import functools +import itertools import operator -from typing import List, Tuple +from typing import List, Optional, Tuple from shrimpgrad.util import prod def can_merge_axes(shape: Tuple[int,...], strides: Tuple[int,...], start:int, stop:int): @@ -9,10 +10,17 @@ def can_merge_axes(shape: Tuple[int,...], strides: Tuple[int,...], start:int, st if strides[axis] != strides[axis+1]*shape[axis+1]: return False return True +@functools.lru_cache(maxsize=None) def normalize_strides(shape: Tuple[int, ...], strides: Tuple[int, ...]): # replace the stride value for dimensions of 1 with 0 return tuple([0 if s == 1 else st for s,st in zip(shape, strides)]) +@functools.lru_cache(maxsize=None) +def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]: + if not shape: return () + strides = tuple(itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1))[::-1] + return normalize_strides(shape, strides) + class ViewTracker: def __init__(self, views: List[View]): self.views: List[View] = views @@ -50,6 +58,10 @@ def permute(self, order: Tuple[int,...]) -> ViewTracker: def pad(self, pad_width: Tuple[Tuple[int,int], ...]) -> ViewTracker: return ViewTracker.from_views(self.views[0:-1] + [self.view.pad(pad_width)]) + def shrink(self, arg: Tuple[Tuple[int,int], ...]) -> ViewTracker: + return ViewTracker.from_views(self.views[0:-1] + [self.view.shrink(arg)]) + + @staticmethod def from_views(views: List[View]) -> ViewTracker: return ViewTracker(views) @@ -62,22 +74,33 @@ def __repr__(self) -> str: return f"" +def create_view(shape: Tuple[int,...], + strides: Optional[Tuple[int,...]]=None, + mask: Optional[Tuple[Tuple[int,int],...]]=None, + offset:int=0): + + # standardize 0 in shape + if 0 in shape: return View(shape, (0,)*len(shape)) + # standardize empty mask to None + if mask is not None and all((s==0 and e == dim_size for ((s,e), dim_size) in zip(mask, shape))): mask = None + + return View(shape, normalize_strides(shape, strides) if strides is not None else strides, mask, offset) + class View: - """A description of how a thunk's data is interpreted + """The layout for the thunk """ - def __init__(self, shape: Tuple[int,...], mask=None): - self.shape = shape - self._strides = tuple(accumulate(self.shape[-1:0:-1], func=operator.mul, initial=(1 if len(self.shape)else None)))[::-1] - self.mask = mask + def __init__(self, shape: Tuple[int,...], + strides: Optional[Tuple[int,...]]=None, + mask: Optional[Tuple[Tuple[int,int],...]]=None, + offset:int=0): + + self.shape, self.strides, self.mask, self.offset = shape, strides, mask, offset + self.strides = strides if strides is not None else strides_for_shape(shape) - @property - def strides(self) -> Tuple[int,...]: return self._strides @property def contiguous(self) -> bool: - if not self.shape: return True - if not self._strides: return True - return all(self._strides[i] == self.shape[i+1]*self._strides[i+1] for i in range(0, self.ndim-1)) + return self.offset == 0 and self.mask is None and self.strides == strides_for_shape(self.shape) @property def scalar(self): return self.ndim == 0 @@ -90,13 +113,11 @@ def reshape(self, new_shape: Tuple[int,...]) -> View: if len(self.shape): assert prod(new_shape) == self.numel, f'shape \'{new_shape}\' is invalid for input of size {self.numel} of shape {self.shape}' # Fast path (new strides are easy to compute) - if self.contiguous: return View(new_shape) + if self.contiguous: return create_view(new_shape, mask=self.mask, offset=self.offset) # Slow path (reconstruct the new strides without copying) - newstrides = self._attempt_no_copy_reshape(new_shape) - view = View(new_shape) - view._strides = normalize_strides(new_shape, tuple(newstrides)) - return view - return View(new_shape) + new_strides = tuple(self._attempt_no_copy_reshape(new_shape)) + return create_view(new_shape, new_strides, self.mask, self.offset) + return create_view(new_shape, mask=self.mask, offset=self.offset) def _attempt_no_copy_reshape(self, new_shape): # Remove ones from the old shape @@ -138,20 +159,16 @@ def _attempt_no_copy_reshape(self, new_shape): def permute(self, order: Tuple[int,...]) -> View: new_shape = tuple([self.shape[i] for i in order]) new_strides = tuple([self.strides[i] for i in order]) - v = View(new_shape) - v._strides = new_strides - return v + return create_view(new_shape, new_strides) def expand(self, shape: Tuple[int,...]) -> View: - out = View.from_view(self) + assert all(((s0 == s1) or (s0 == 1) for s0,s1 in zip(self.shape, shape))), f'invalid expand from {self.shape} to {shape}' strd = list(self.strides) for i, (si, so) in enumerate(zip(self.shape, shape)): if si != so: strd[i] = 0 - out.shape = shape - out._strides = tuple(strd) - return out + return create_view(shape, tuple(strd)) - def pad(self, pad_width: Tuple[Tuple[int,int],...]): + def pad(self, pad_width: Tuple[Tuple[int,int],...]) -> View: assert all(s >= 0 and e >= 0 for s,e in pad_width), "pad_width must all be >= 0" assert len(pad_width) == self.ndim, f'pad_width length must equal view ndim: {len(pad_width) != self.ndim}' @@ -163,10 +180,19 @@ def pad(self, pad_width: Tuple[Tuple[int,int],...]): new_shape[i] += pad_start + pad_end # start index of non-padded values, end value of non-padded values mask[i] = (pad_start, shp + pad_start) - return View(tuple(new_shape), tuple(mask)) + return create_view(tuple(new_shape), self.strides, tuple(mask)) + + def shrink(self, arg: Tuple[Tuple[int, int]]) -> View: + assert all(0<=start<=stop<=shape for ((start,stop), shape) in zip(arg, self.shape)), 'invalid shrink slices' + new_shape = tuple([stop - start for start, stop in arg]) + if self.mask is not None: + # mask[0] = pad_size_left, dim_size + pad_size_right + # + pass + return create_view(new_shape) + @staticmethod - def from_view(view: View): - return View(view.shape) + def from_view(view: View): return create_view(view.shape, view.strides, view.mask, view.offset) def __repr__(self): return f'' \ No newline at end of file diff --git a/test/test_view.py b/test/test_view.py index 0f68f9d..7bf9a81 100644 --- a/test/test_view.py +++ b/test/test_view.py @@ -7,7 +7,7 @@ def test_view(self): v = View(()) self.assertTrue(v.scalar) - def test_reshape_permute_reshape_expand(self): + def test_reshape_permute_reshape(self): vt = ViewTracker.from_shape((2,2)) assert vt.shape == (2,2) assert vt.strides == (2,1) @@ -15,14 +15,14 @@ def test_reshape_permute_reshape_expand(self): assert vt.contiguous vt = vt.reshape((1,2,2)) assert vt.shape == (1,2,2) - assert vt.strides == (4,2,1) + assert vt.strides == (0,2,1) assert len(vt.views) == 1 assert vt.contiguous vt = vt.permute((0,2,1)) assert not vt.contiguous assert len(vt.views) == 1 assert vt.shape == (1,2,2) - assert vt.strides == (4,1,2) + assert vt.strides == (0,1,2) vt = vt.reshape((1,1,2,2)) assert not vt.contiguous assert len(vt.views) == 2 @@ -33,4 +33,40 @@ def test_pad(self): vt = ViewTracker.from_shape((2,2)) vt = vt.pad(((1,1),(0,0))) assert vt.shape == (4,2) - assert vt.view.mask == ((1,3),(0,2)) \ No newline at end of file + assert vt.view.mask == ((1,3),(0,2)) + + def test_pad2(self): + vt = ViewTracker.from_shape((2,2,2)) + vt = vt.pad(((2,1),(4,4),(4,4))) + assert vt.shape == (5, 10, 10) + assert vt.views[-1].mask == ((2,4), (4,6), (4,6)) + assert vt.numel == 5*10*10 + assert vt.strides == (4, 2, 1) + assert vt.ndim == 3 + assert len(vt.views) == 1 + + def test_permute_pad(self): + vt = ViewTracker.from_shape((2,2,2)) + vt = vt.permute((2,1,0)) + assert vt.strides == (1,2,4) + vt = vt.pad(((1,1),(0,0),(1,1))) + assert vt.shape == (4,2,4) + assert vt.strides == (1,2,4) + assert len(vt.views) == 1 + + def test_shrink(self): + view = View((2,2,2)) + view = view.shrink(((0,1), (0,1), (0,0))) + + assert view.shape == (1,1,0) + assert view.strides == (0,0,0) + + def test_shrink1(self): + vt = ViewTracker.from_shape((2,4,2)) + vt = vt.shrink(((0,1),(1,3), (0,2))) + + assert vt.view.shape == (1,2,2) + assert vt.view.strides == (0,2,1) + assert vt.view.mask == None + assert vt.view.offset == 0 + diff --git a/test/test_viewtracker.py b/test/test_viewtracker.py index fec5cc6..b368799 100644 --- a/test/test_viewtracker.py +++ b/test/test_viewtracker.py @@ -12,7 +12,7 @@ def test_viewtracker1(self): vt2 = vt.reshape((2,2,1)) self.assertEqual(1, len(vt2.views)) self.assertEqual((2,2,1), vt2.view.shape) - self.assertEqual((2,1,1), vt2.view.strides) + self.assertEqual((2,1,0), vt2.view.strides) vt3 = vt2.permute((2,1,0)) self.assertEqual(1, len(vt3.views)) From de5e93bab82a95499683da096ae64482ac54f701 Mon Sep 17 00:00:00 2001 From: Kevin Kenyon Date: Sat, 13 Jul 2024 19:17:54 -0500 Subject: [PATCH 3/3] Add shrink mask updates --- shrimpgrad/view.py | 19 +++++++++---- test/test_view.py | 67 +++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/shrimpgrad/view.py b/shrimpgrad/view.py index dcb6100..a8ea86e 100644 --- a/shrimpgrad/view.py +++ b/shrimpgrad/view.py @@ -185,14 +185,23 @@ def pad(self, pad_width: Tuple[Tuple[int,int],...]) -> View: def shrink(self, arg: Tuple[Tuple[int, int]]) -> View: assert all(0<=start<=stop<=shape for ((start,stop), shape) in zip(arg, self.shape)), 'invalid shrink slices' new_shape = tuple([stop - start for start, stop in arg]) + new_mask = None if self.mask is not None: - # mask[0] = pad_size_left, dim_size + pad_size_right - # - pass - return create_view(new_shape) + new_mask = [[None,None]]*len(self.mask) + for i, (start,stop) in enumerate(arg): + if start < self.mask[i][0]: + new_mask[i][0] = start + else: + new_mask[i][0] = 0 + if stop < self.mask[i][1]: + new_mask[i][1] = stop + else: + new_mask[i][1] = new_mask[i][0] + self.mask[i][1] - self.mask[i][0] + new_mask[i] = tuple(new_mask[i]) + return create_view(new_shape, mask=tuple(new_mask) if new_mask is not None else None) @staticmethod def from_view(view: View): return create_view(view.shape, view.strides, view.mask, view.offset) - def __repr__(self): return f'' \ No newline at end of file + def __repr__(self): return f'' \ No newline at end of file diff --git a/test/test_view.py b/test/test_view.py index 7bf9a81..a234a61 100644 --- a/test/test_view.py +++ b/test/test_view.py @@ -7,6 +7,19 @@ def test_view(self): v = View(()) self.assertTrue(v.scalar) + def test_reshape(self): + vt = ViewTracker.from_shape((10,4)) + self.assertEqual((10,4), vt.shape) + vt = vt.reshape((40,1)) + self.assertEqual((40,1), vt.shape) + + def test_reshape_like_permute(self): + vt = ViewTracker.from_shape((2,4)) + self.assertEqual((2,4), vt.shape) + vt = vt.reshape((4,2)) + self.assertEqual((4,2), vt.shape) + self.assertTrue(vt.contiguous) + def test_reshape_permute_reshape(self): vt = ViewTracker.from_shape((2,2)) assert vt.shape == (2,2) @@ -63,10 +76,62 @@ def test_shrink(self): def test_shrink1(self): vt = ViewTracker.from_shape((2,4,2)) - vt = vt.shrink(((0,1),(1,3), (0,2))) + vt = vt.shrink(((0,1),(1,3),(0,2))) assert vt.view.shape == (1,2,2) assert vt.view.strides == (0,2,1) assert vt.view.mask == None assert vt.view.offset == 0 + def test_undo_pad_with_shrink_and_mask(self): + vt = ViewTracker.from_shape((4,7,4)) + vt = vt.pad(((2,1),(0,0),(1,2))) + self.assertEqual((7,7,7), vt.shape) + self.assertEqual(((2,6),(0,7),(1,5)), vt.view.mask) + vt = vt.shrink(vt.view.mask) + self.assertEqual((4,7,4), vt.shape) + self.assertEqual(None, vt.view.mask) + + def test_pad_then_shrink_a_bit(self): + vt = ViewTracker.from_shape((4,)) + vt = vt.pad(((2,2),)) + self.assertEqual((8,), vt.shape) + self.assertEqual(((2,6),), vt.view.mask) + vt = vt.shrink(((1,8),)) + self.assertEqual((7,), vt.shape) + self.assertEqual(((1,5),), vt.view.mask) + + def test_pad_then_shrink_into_outer_pad(self): + vt = ViewTracker.from_shape((4,)) + vt = vt.pad(((2,2),)) + self.assertEqual((8,), vt.shape) + self.assertEqual(((2,6),), vt.view.mask) + vt = vt.shrink(((1,4),)) + self.assertEqual((3,), vt.shape) + self.assertEqual(((1,4),), vt.view.mask) + + def test_shrink_pad_back(self): + vt = ViewTracker.from_shape((4,)) + vt = vt.shrink(((1,3),)) + self.assertEqual((2,), vt.shape) + self.assertEqual(None, vt.view.mask) + vt = vt.pad(((1,1),)) + self.assertEqual((4,),vt.shape) + self.assertEqual(((1,3),), vt.view.mask) + + def test_pad_reshape_adjusts_mask(self): + vt = ViewTracker.from_shape((2,2)) + vt = vt.pad(((1,1),(1,1))) + self.assertEqual((4,4), vt.shape) + self.assertEqual(((1,3),(1,3)), vt.view.mask) + vt = vt.reshape((1,4,4)) + self.assertEqual((1,4,4), vt.shape) + #self.assertEqual(((0,1),(1,3),(1,3)), vt.view.mask) + + def test_pad_reshape_adjust_mask2(self): + vt = ViewTracker.from_shape((8,2)) + vt = vt.pad(((1,1),(1,1))) + self.assertEqual((10,4), vt.shape) + # print(vt.view) + # vt = vt.reshape((40,1)) + # self.assertEqual((40,1), vt.shape)