Skip to content

Commit

Permalink
Add simple pad on view
Browse files Browse the repository at this point in the history
  • Loading branch information
kvkenyon committed Jul 12, 2024
1 parent 9f5c1ba commit 94d2f8e
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 4 deletions.
209 changes: 209 additions & 0 deletions shrimpgrad/examples/conv2d.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"id": "17a3775a-01c7-48d6-b898-d5e71251738e",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn.functional as F"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "5dbf1936-a2d9-4b60-8377-e69d7c73ade3",
"metadata": {},
"outputs": [],
"source": [
"x = torch.ones(2,2,2,2)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8f8f07d9-e857-43c9-bdf9-ee643c6db6fc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[1., 1.],\n",
" [1., 1.]],\n",
"\n",
" [[1., 1.],\n",
" [1., 1.]]],\n",
"\n",
"\n",
" [[[1., 1.],\n",
" [1., 1.]],\n",
"\n",
" [[1., 1.],\n",
" [1., 1.]]]])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "6f4886b3-4943-4f8b-95e3-6451bdc6a860",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[0., 0.],\n",
" [0., 0.]],\n",
"\n",
" [[0., 0.],\n",
" [0., 0.]],\n",
"\n",
" [[0., 0.],\n",
" [0., 0.]],\n",
"\n",
" [[0., 0.],\n",
" [0., 0.]]],\n",
"\n",
"\n",
" [[[0., 0.],\n",
" [0., 0.]],\n",
"\n",
" [[1., 1.],\n",
" [1., 1.]],\n",
"\n",
" [[1., 1.],\n",
" [1., 1.]],\n",
"\n",
" [[0., 0.],\n",
" [0., 0.]]],\n",
"\n",
"\n",
" [[[0., 0.],\n",
" [0., 0.]],\n",
"\n",
" [[1., 1.],\n",
" [1., 1.]],\n",
"\n",
" [[1., 1.],\n",
" [1., 1.]],\n",
"\n",
" [[0., 0.],\n",
" [0., 0.]]],\n",
"\n",
"\n",
" [[[0., 0.],\n",
" [0., 0.]],\n",
"\n",
" [[0., 0.],\n",
" [0., 0.]],\n",
"\n",
" [[0., 0.],\n",
" [0., 0.]],\n",
"\n",
" [[0., 0.],\n",
" [0., 0.]]]])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"F.pad(x, (0,0,0,0,1,1,1,1), 'constant', 0)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "c670712f-5929-412f-8d0c-9c156c421be7",
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "Padding length should be less than or equal to two times the input dimension but got padding length 12 and input of dimension 4",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[17], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m y \u001b[38;5;241m=\u001b[39m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpad\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m y\n",
"File \u001b[0;32m/nix/store/7rlyhybbpanqaibk6682gq63zqljsc61-python3.11-torch-2.3.0/lib/python3.11/site-packages/torch/nn/functional.py:4522\u001b[0m, in \u001b[0;36mpad\u001b[0;34m(input, pad, mode, value)\u001b[0m\n\u001b[1;32m 4515\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m mode \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mreplicate\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 4516\u001b[0m \u001b[38;5;66;03m# Use slow decomp whose backward will be in terms of index_put.\u001b[39;00m\n\u001b[1;32m 4517\u001b[0m \u001b[38;5;66;03m# importlib is required because the import cannot be top level\u001b[39;00m\n\u001b[1;32m 4518\u001b[0m \u001b[38;5;66;03m# (cycle) and cannot be nested (TS doesn't support)\u001b[39;00m\n\u001b[1;32m 4519\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m importlib\u001b[38;5;241m.\u001b[39mimport_module(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtorch._decomp.decompositions\u001b[39m\u001b[38;5;124m'\u001b[39m)\u001b[38;5;241m.\u001b[39m_replication_pad(\n\u001b[1;32m 4520\u001b[0m \u001b[38;5;28minput\u001b[39m, pad\n\u001b[1;32m 4521\u001b[0m )\n\u001b[0;32m-> 4522\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_nn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpad\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpad\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mRuntimeError\u001b[0m: Padding length should be less than or equal to two times the input dimension but got padding length 12 and input of dimension 4"
]
}
],
"source": [
"y = F.pad(x, (1,1,1,1,1,1,1,1,1,1,1,1))\n",
"y"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "b45a340d-2e81-4832-869c-1787aa374733",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 4, 4, 4])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y.size()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "adae1784-1684-4d59-877b-f84d28320b25",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "6bfcf63f-448f-4e9a-a659-f8c5aeb3ffb9",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
20 changes: 19 additions & 1 deletion shrimpgrad/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def expand(self, new_shape: Tuple[int,...]) -> ViewTracker:
def permute(self, order: Tuple[int,...]) -> ViewTracker:
return ViewTracker.from_views(self.views[0:-1] + [self.view.permute(order)] )

def pad(self, pad_width: Tuple[Tuple[int,int], ...]) -> ViewTracker:
return ViewTracker.from_views(self.views[0:-1] + [self.view.pad(pad_width)])

@staticmethod
def from_views(views: List[View]) -> ViewTracker:
return ViewTracker(views)
Expand All @@ -62,9 +65,10 @@ def __repr__(self) -> str:
class View:
"""A description of how a thunk's data is interpreted
"""
def __init__(self, shape: Tuple[int,...]):
def __init__(self, shape: Tuple[int,...], mask=None):
self.shape = shape
self._strides = tuple(accumulate(self.shape[-1:0:-1], func=operator.mul, initial=(1 if len(self.shape)else None)))[::-1]
self.mask = mask

@property
def strides(self) -> Tuple[int,...]: return self._strides
Expand Down Expand Up @@ -147,6 +151,20 @@ def expand(self, shape: Tuple[int,...]) -> View:
out._strides = tuple(strd)
return out

def pad(self, pad_width: Tuple[Tuple[int,int],...]):
assert all(s >= 0 and e >= 0 for s,e in pad_width), "pad_width must all be >= 0"
assert len(pad_width) == self.ndim, f'pad_width length must equal view ndim: {len(pad_width) != self.ndim}'

# No padding needed
if all(s == 0 and e == 0 for s,e in pad_width): return self
new_shape = list(self.shape)
mask = [None]*self.ndim
for i, ((pad_start, pad_end), shp) in enumerate(zip(pad_width, self.shape)):
new_shape[i] += pad_start + pad_end
# start index of non-padded values, end value of non-padded values
mask[i] = (pad_start, shp + pad_start)
return View(tuple(new_shape), tuple(mask))

@staticmethod
def from_view(view: View):
return View(view.shape)
Expand Down
8 changes: 5 additions & 3 deletions test/test_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def test_reshape_permute_reshape_expand(self):
assert vt.shape == (1,1,2,2)
assert vt.strides == (0,0,1,2)




def test_pad(self):
vt = ViewTracker.from_shape((2,2))
vt = vt.pad(((1,1),(0,0)))
assert vt.shape == (4,2)
assert vt.view.mask == ((1,3),(0,2))

0 comments on commit 94d2f8e

Please sign in to comment.