Skip to content

Latest commit

 

History

History
37 lines (25 loc) · 1.09 KB

README.md

File metadata and controls

37 lines (25 loc) · 1.09 KB

eincheck

CI Documentation Status PyPI version

Tensor shape checks inspired by einstein notation

Overview

This library has three main functions:

  • check_shapes takes tuples of (Tensor, shape) and checks that all the Tensors match the shapes
check_shapes((x, "i 3"), (y, "i 3"))
  • check_func is a function decorator to check the input and output shapes of a function
@check_func("*i x, *i y -> *i (x + y)")
def concat(a, b):
    return np.concatenate([a, b], -1)
  • check_data is a class decorator to check the fields of a data class
@check_data(start="i 2", end="i 2")
class LineSegment2D(NamedTuple):
    start: torch.Tensor
    end: torch.Tensor

For more info, read the docs!