Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop using ConstantFolding in other optimization rules #822

Merged
merged 2 commits into from
May 1, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from webdnn.backend.webgl.optimize_rules.split_texture.operators.split_concat import split_concat
from webdnn.backend.webgl.optimize_rules.split_texture.operators.split_im2col import split_im2col
from webdnn.backend.webgl.optimize_rules.split_texture.operators.split_partial_im2col import split_partial_im2col
from webdnn.backend.webgl.optimize_rules.split_texture.operators.split_pooling_2d import split_pooling_2d
from webdnn.backend.webgl.optimize_rules.split_texture.operators.split_reshape import split_reshape
from webdnn.backend.webgl.optimize_rules.split_texture.operators.split_splitaxis import split_splitaxis
from webdnn.backend.webgl.optimize_rules.split_texture.operators.split_tensordot import split_tensordot
from webdnn.backend.webgl.optimize_rules.split_texture.operators.split_tensorwise import split_tensorwise
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from typing import NamedTuple, List, Sequence

from webdnn.graph.axis import Axis
from webdnn.graph.graph import Graph
from webdnn.graph.operators.concat import Concat
from webdnn.graph.operators.split_axis import SplitAxis
from webdnn.graph.optimize_rule import OptimizeRule
from webdnn.graph.variable import Variable
from webdnn.util.assertion import UnexpectedAndPleaseReportError


class GraphVars(NamedTuple):
inputs: List[Variable]
hidden: List[Variable]
outputs: List[Variable]


def split_concat(graph: Graph, op: Concat, v: Variable, v_pair: Sequence[Variable], axis: Axis):
s1 = v_pair[0].shape_dict[axis]
xs = [op.inputs[key] for key in sorted([key for key in op.inputs.keys() if key.startswith("x")])]
y = op.outputs["y"]
op.remove_all()

if v in xs:
x_0, x_1 = v_pair

if axis == op.axis:
"""
before)
x1 -+
|
x2 -+-{concat}- y
|
x3 -+

after)
x1 ---+
|
x2_0 -+
+-{concat}- y
x2_1 -+
|
x3 ---+
"""
i = xs.index(v)
xs.pop(i)
xs.insert(i + 0, x_0)
xs.insert(i + 1, x_1)

y_new, = Concat(None, axis=axis)(*xs)
OptimizeRule.replace_variable(graph, y, y_new)

else:
"""
before)
x1 -+
|
x2 -+-{concat[op.axis]}- y
|
x3 -+

after)
+- x1_0 -+
x1 -{split[axis]}-+ |
+- x1_1 -|-+
| |
x2_0 ----------------------+---{concat[op.axis]}- y_0 -+
| | +-{concat[axis]}- y
x2_1 ----------------------|-+-{concat[op.axis]}- y_1 -+
| |
+- x3_0 -+ |
x3 -{split[axis]}-+ |
+- x3_1 ---+
"""
xs_0, xs_1 = zip(*[v_pair if x == v else SplitAxis(None, axis=axis, sections=[s1])(x) for x in xs])
y_0, = Concat(None, axis=op.axis)(*xs_0)
y_1, = Concat(None, axis=op.axis)(*xs_1)
y_new, = Concat(None, axis=axis)(y_0, y_1)
OptimizeRule.replace_variable(graph, y_new, y)

elif v == y:
y_0, y_1 = v_pair

if axis == op.axis:
"""
before)
x1 -+
|
x2 -+-{concat[axis=op.axis]}- y
|
x3 -+

after)
x1 ------------------------------+
+-{concat[axis=axis]}- y_0
+- y_0_1 -+
x2 -{split[axis=axis]}-+
+- y_1_0 -+
+-{concat[axis=axis]}- y_1
x3 ------------------------------+
"""
# find input variable which should be split

total_size = 0
xs_0 = [] # type: List[Variable]
xs_1 = list(xs) # type: List[Variable]
for x in xs:
xs_1.remove(x)
xs_0.append(x)
total_size += x.shape_dict[axis]

if total_size == s1:
# splitting is not needed
#
# x0, x1, ..., xn, | xn+1, ..., xs[-1]
# <--------------> | <--------------->
# y_0 | y_1
break

elif total_size > s1:
# this `x` must be split
#
# x0, x1, ..., xn, ..., xs[-1]
# <-------------><------------->
# y_0 y_1

xn_0, xn_1 = SplitAxis(None, axis=axis, sections=[s1 - (total_size - x.shape_dict[axis])])(x)
xs_0.remove(x)
xs_0.append(xn_0)
xs_1.insert(0, xn_1)
break

if len(xs_0) > 1:
y_0, = Concat(None, axis=axis)(*xs_0)
y_0.change_order(v_pair[0].order)

elif len(xs_0) == 1:
y_0 = xs_0[0]

else:
raise UnexpectedAndPleaseReportError

if len(xs_1) > 1:
y_1, = Concat(None, axis=axis)(*xs_1)
y_1.change_order(v_pair[1].order)

elif len(xs_1) == 1:
y_1 = xs_1[0]

else:
raise UnexpectedAndPleaseReportError

OptimizeRule.replace_variable(graph, y_0, v_pair[0])
OptimizeRule.replace_variable(graph, y_1, v_pair[1])

else:
"""
before)
x1 -+
|
x2 -+-{concat[op.axis]}- y
|
x3 -+

after)
+- x1_0 -+
x1 -{split[axis]}-+ |
+- x1_1 ---+
| |
+- x2_0 -+-|-{concat[op.axis]}- y_0
x2 -{split[axis]}-+ | |
+- x2_1 ---+-{concat[op.axis]}- y_1
| |
+- x3_0 -+ |
x3 -{split[axis]}-+ |
+- x3_1 ---+

"""
xs_0, xs_1 = zip(*[SplitAxis(None, axis=axis, sections=[s1])(x) for x in xs])

y_new_0, = Concat(None, axis=op.axis)(*xs_0)
y_new_1, = Concat(None, axis=op.axis)(*xs_1)

OptimizeRule.replace_variable(graph, y_new_0, y_0)
OptimizeRule.replace_variable(graph, y_new_1, y_1)

else:
raise UnexpectedAndPleaseReportError
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import NamedTuple, List, Sequence

import numpy as np

from webdnn.backend.webgl.attributes.channel_mode import ChannelMode
from webdnn.backend.webgl.attributes.texture_shape import TextureShape
from webdnn.backend.webgl.operators.partial_im2col import PartialIm2Col
from webdnn.backend.webgl.optimize_rules.split_texture.check_texture_size import SplitTarget
from webdnn.graph import traverse
from webdnn.graph.axis import Axis, AxisKeyDict
from webdnn.graph.graph import Graph
from webdnn.graph.operator import Operator
from webdnn.graph.operators.attributes.tensorwise import Tensorwise
from webdnn.graph.operators.concat import Concat
from webdnn.graph.operators.im2col import Im2Col
from webdnn.graph.operators.pooling_2d import Pooling2D
from webdnn.graph.operators.reshape import Reshape
from webdnn.graph.operators.slice import Slice
from webdnn.graph.operators.split_axis import SplitAxis
from webdnn.graph.operators.tensordot import Tensordot
from webdnn.graph.optimize_rule import OptimizeRule
from webdnn.graph.order import Order, OrderNHWC
from webdnn.graph.variable import Variable
from webdnn.graph.variables.constant_variable import ConstantVariable
from webdnn.util import console
from webdnn.util.assertion import UnexpectedAndPleaseReportError
from webdnn.util.misc import mul


class GraphVars(NamedTuple):
inputs: List[Variable]
hidden: List[Variable]
outputs: List[Variable]


def split_im2col(graph: Graph, op: Im2Col, v: Variable, v_pair: Sequence[Variable], axis: Axis):
s1 = v_pair[0].shape_dict[axis]
im = op.inputs["im"]
col = op.outputs["col"]

op.remove_all()

if v == col:
"""
before)

im -{Im2Col}- col

after)

+- col_0
im -{PartialIm2Col}-+
+- col_1
"""
col_0, col_1 = PartialIm2Col(None,
ksize=op.ksize, stride=op.stride, padding=op.padding,
dilation_rate=op.dilation_rate,
axis=axis, sections=[s1])(im)

OptimizeRule.replace_variable(graph, col_0.transpose(v_pair[0].order), v_pair[0])
OptimizeRule.replace_variable(graph, col_1.transpose(v_pair[1].order), v_pair[1])

elif v == im:
raise NotImplementedError(f"Variable is too large to handle in WebGL backend: {v}")

else:
raise UnexpectedAndPleaseReportError
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import NamedTuple, List, Sequence

from webdnn.backend.webgl.operators.partial_im2col import PartialIm2Col
from webdnn.graph.axis import Axis
from webdnn.graph.graph import Graph
from webdnn.graph.optimize_rule import OptimizeRule
from webdnn.graph.variable import Variable
from webdnn.util.assertion import UnexpectedAndPleaseReportError


class GraphVars(NamedTuple):
inputs: List[Variable]
hidden: List[Variable]
outputs: List[Variable]


def split_partial_im2col(graph: Graph, op: PartialIm2Col, v: Variable, v_pair: Sequence[Variable], axis: Axis):
s1 = v_pair[0].shape_dict[axis]
im = op.inputs["im"]
cols = [op.outputs[f"col{i}"] for i in range(len(op.outputs))]
sections = op.sections

if v == im:
raise NotImplementedError(f"Variable is too large to handle in WebGL backend: {v}")

elif v in cols:
op.remove_all()

if axis == op.axis:
"""
before)
+- col0
|
im -{PartialIm2Col}-+- col1
|
+- col2

after)
+- col0
|
+- col1_0
im -{PartialIm2Col}-+
+- col1_1
|
+- col2
"""
target_i = cols.index(v)

s_insert = (0 if target_i == 0 else sections[target_i - 1]) + s1
new_sections = list(sections)
new_sections.insert(target_i, s_insert)

cols.pop(target_i)
cols.insert(target_i + 0, v_pair[0])
cols.insert(target_i + 1, v_pair[1])

new_cols = PartialIm2Col(None,
ksize=op.ksize, stride=op.stride, padding=op.padding,
dilation_rate=op.dilation_rate,
axis=axis, sections=new_sections)(im)
for col, new_col in zip(cols, new_cols):
OptimizeRule.replace_variable(graph, new_col, col)

else:
raise NotImplementedError(f"Variable is too large to handle in WebGL backend: {v}")

else:
raise UnexpectedAndPleaseReportError
Loading