-
Notifications
You must be signed in to change notification settings - Fork 469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor tensor data #1916
Refactor tensor data #1916
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1916 +/- ##
==========================================
- Coverage 85.02% 84.72% -0.31%
==========================================
Files 790 791 +1
Lines 93275 93591 +316
==========================================
- Hits 79310 79291 -19
- Misses 13965 14300 +335 ☔ View full report in Codecov by Sentry. |
One of the drawbacks of this change is that creating a tensor from data would previously give a compile error for something like the following: let tensor = Tensor::<B, 1>::from_floats([[1.0], [2.0]], &device)
But now it gives a runtime error instead since we removed the const generic
Also, this let tensor = Tensor::<B, 2>::from_floats([1.0, 2.0], &device) wouldn't work previously (similar compilation error) but now the data is coerced into shape We could discuss how we want to handle this. |
I think we should create a custom parnic message that indicates a rank mismatch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some general comments, but thanks for the huge refactor!
With the latest changes it would look like this:
|
@nathanielsimard one thing we missed is documentation. Where should we document these?
In the book? And also README? Just so it's visible to users and they know what to do for the transition. |
Checklist
run-checks all
script has been executed.Changes
This is a big refactor to move away from the
Data<E, D>
andDataSerialize<E>
structures, which are marked as deprecated.It introduces the new data format
TensorData
, which stores the elements as bytes to preserve the flexibility to switch between data types.Breaking⚠️
The new serialization format is not compatible with the previous format, but we can easily load records saved in previous versions with the
record-backward-compat
feature flag (except for binary formats, but this is expected). So any other self-describing format that we support can be loaded and saved into the new format easily.Important Bits
This PR impacts a lot of the codebase for minor stuff, but make sure you take a look at the most important bits.
New
TensorData
type & implementationburn-tensor/src/tensor/data.rs
burn-tensor/src/tensor/element/base.rs
De/serialization with backward compat
burn-core/src/record/tensor.rs
De/serialization of NestedValue for new u8 type
burn-core/src/record/serde/data.rs
burn-core/src/record/serde/de.rs
burn-core/src/record/serde/ser.rs
from_data
implementationburn-candle/src/tensor.rs
burn-jit/src/ops/base.rs
burn-ndarray/src/tensor.rs
burn-tch/src/tensor.rs
Deprecations
burn-tensor/src/libs.rs (allow deprecated usage internally)
burn-tensor/src/tensor/data.rs (
Data
andDataSerialize
deprecation message)Testing
All tests passed.