-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-120] Float16 support for distributed training #10183
Conversation
This reverts commit 05ffa1b.
@solin319 This allows us to send some keys in fp16 and some in fp32. By using parameter server with type char, we also can avoid interface changes. Now the user doesn't have to do anything special to use fp16. Could you also review this? |
@marcoabreu I see something weird in the CI build for this. |
In how far did it break compilation locally? what's your build environment? |
Locally I noticed it failing for (GPU GCC Openblas) as well as (CPU Clang). The CI status page says it used the right commit 8152ca4 |
src/kvstore/kvstore_dist.h
Outdated
pull_pskv.size *= num_bytes; | ||
CHECK_EQ(static_cast<size_t>(push_pskv.size), compr_size * num_bytes); | ||
CHECK_EQ(static_cast<size_t>(pull_pskv.size), original_size * num_bytes); | ||
CHECK_EQ(push_pskv.lens.size(), num_servers * 2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we not do something like:
CHECK_EQ(static_cast<size_t>(push_pskv.size), compr_size);
CHECK_EQ(static_cast<size_t>(pull_pskv.size), original_size);
push_pskv.size *= num_bytes;
pull_pskv.size *= num_bytes;
instead so that we perform 2 less multiplications?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, will do
Huh? How come you closed the PR? |
Thanks, Do you know why the integration test doesn't show up in the CI steps? |
The CI steps are cached and only re-evaluated if the stage is actually reached. It will show up once the job reaches the integration stage. |
I added the nightly tests we had for distributed kvstore as integration tests to the CI. And the build has passed. @everyone Please review and let me know if there are any more concerns |
Thanks a lot, Rahul! |
@rahul003 can you fix the conflicts ? |
@eric-haibin-lin @piiswrong Is this good to be merged? |
amalgamation/Makefile
Outdated
@@ -23,6 +23,11 @@ ifndef OPENBLAS_ROOT | |||
export OPENBLAS_ROOT=/usr/local/opt/openblas | |||
endif | |||
|
|||
# use F16C if the architecture supports it, turned off by default | |||
ifndef USE_F16C |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does 16C stand for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added better comment. F16C is an instruction set extension supported by newer x86 CPUs. It provides intrinsics for faster fp16 compute.
BTW build is failing |
Amalgamation build failed after I updated mshadow. I've updated the makefile for amalgamation to correct the build. |
@eric-haibin-lin can we merge this ? |
@piiswrong can this be merged ? |
* send as char * fix bug on pull response, and rowsparse on worker side * three modes * default to mode 0 and add support for row sparse * refactor sparse * rowsparse numbytes fixes * WIP tests * update test sync * remove prints * refactoring * Revert "refactoring" This reverts commit 05ffa1b. * undo refactoring to keep PR simple * add wait to stored in pull default * lint fixes * undo static cast for recvblob * lint fixes * mode 1 changes * sparse bug fix dtype * mshadow default * remove unused var * remove debug statements * clearer variables, reduced multiplication, const vars * add const for more vars, comments * comment syntax, code watcher, test default val * remove unnecessary print in test * trigger ci * multi precision mode (debugging race condition) * working rsp pushes * finish multiprecision for row sparse * rename num-bytes * fix bug due to rename of numbytes, and remove debug logs * address comments * add integration test * trigger ci * integration test * integration test * fix path of script * update mshadow * disable f16c for amalgamation * fix amalgamation build * trigger ci * disable f16c for jetson
* send as char * fix bug on pull response, and rowsparse on worker side * three modes * default to mode 0 and add support for row sparse * refactor sparse * rowsparse numbytes fixes * WIP tests * update test sync * remove prints * refactoring * Revert "refactoring" This reverts commit 05ffa1b. * undo refactoring to keep PR simple * add wait to stored in pull default * lint fixes * undo static cast for recvblob * lint fixes * mode 1 changes * sparse bug fix dtype * mshadow default * remove unused var * remove debug statements * clearer variables, reduced multiplication, const vars * add const for more vars, comments * comment syntax, code watcher, test default val * remove unnecessary print in test * trigger ci * multi precision mode (debugging race condition) * working rsp pushes * finish multiprecision for row sparse * rename num-bytes * fix bug due to rename of numbytes, and remove debug logs * address comments * add integration test * trigger ci * integration test * integration test * fix path of script * update mshadow * disable f16c for amalgamation * fix amalgamation build * trigger ci * disable f16c for jetson
* send as char * fix bug on pull response, and rowsparse on worker side * three modes * default to mode 0 and add support for row sparse * refactor sparse * rowsparse numbytes fixes * WIP tests * update test sync * remove prints * refactoring * Revert "refactoring" This reverts commit 05ffa1b. * undo refactoring to keep PR simple * add wait to stored in pull default * lint fixes * undo static cast for recvblob * lint fixes * mode 1 changes * sparse bug fix dtype * mshadow default * remove unused var * remove debug statements * clearer variables, reduced multiplication, const vars * add const for more vars, comments * comment syntax, code watcher, test default val * remove unnecessary print in test * trigger ci * multi precision mode (debugging race condition) * working rsp pushes * finish multiprecision for row sparse * rename num-bytes * fix bug due to rename of numbytes, and remove debug logs * address comments * add integration test * trigger ci * integration test * integration test * fix path of script * update mshadow * disable f16c for amalgamation * fix amalgamation build * trigger ci * disable f16c for jetson
* Remove Fermi from cmake (#10486) * updated R docs (#10473) * [MXNET-120] Float16 support for distributed training (#10183) * send as char * fix bug on pull response, and rowsparse on worker side * three modes * default to mode 0 and add support for row sparse * refactor sparse * rowsparse numbytes fixes * WIP tests * update test sync * remove prints * refactoring * Revert "refactoring" This reverts commit 05ffa1b. * undo refactoring to keep PR simple * add wait to stored in pull default * lint fixes * undo static cast for recvblob * lint fixes * mode 1 changes * sparse bug fix dtype * mshadow default * remove unused var * remove debug statements * clearer variables, reduced multiplication, const vars * add const for more vars, comments * comment syntax, code watcher, test default val * remove unnecessary print in test * trigger ci * multi precision mode (debugging race condition) * working rsp pushes * finish multiprecision for row sparse * rename num-bytes * fix bug due to rename of numbytes, and remove debug logs * address comments * add integration test * trigger ci * integration test * integration test * fix path of script * update mshadow * disable f16c for amalgamation * fix amalgamation build * trigger ci * disable f16c for jetson * Fix rat excludes (#10499) * MXNET-308 added missing license (#10497) * refactored example (#10484) * [MXNET-298] Scala Infer API docs landing page (#10474) * changed url references from dmlc to apache/incubator-mxnet * prepping scala landing pages * infer api info added * Fix infer storage type (#10507) * Fix infer_storage_type * Add test * Fix lint * Trigger CI * [MXNET-306] Add slice_like operator (#10491) * add slice_like and doc * pass unittest and lint * Minor simplifications in ci/build.py (#10496) * [MXNET-305] Scala tutorial table fix (#10488) * initial update on setting up scala ide with mxnet * moving images to web-data project * updated links to images; added readme for root folder * scala hello world feature added * workaround for make transitive error * fixed systempath * minor updates * table fix * added some spacing * more spacing * added ability to set search path for Accelerate library * [MXNET-311] change test needs a docker with sudo, hence image changed (#10510) * added new metric * changed back from branch to upstream * lint changed * Fixed typo * Clarified interpretation * Changes for variable names * fixed variable names * replay the unit tests * changed comment
…e#10524) * Remove Fermi from cmake (apache#10486) * updated R docs (apache#10473) * [MXNET-120] Float16 support for distributed training (apache#10183) * send as char * fix bug on pull response, and rowsparse on worker side * three modes * default to mode 0 and add support for row sparse * refactor sparse * rowsparse numbytes fixes * WIP tests * update test sync * remove prints * refactoring * Revert "refactoring" This reverts commit 05ffa1b. * undo refactoring to keep PR simple * add wait to stored in pull default * lint fixes * undo static cast for recvblob * lint fixes * mode 1 changes * sparse bug fix dtype * mshadow default * remove unused var * remove debug statements * clearer variables, reduced multiplication, const vars * add const for more vars, comments * comment syntax, code watcher, test default val * remove unnecessary print in test * trigger ci * multi precision mode (debugging race condition) * working rsp pushes * finish multiprecision for row sparse * rename num-bytes * fix bug due to rename of numbytes, and remove debug logs * address comments * add integration test * trigger ci * integration test * integration test * fix path of script * update mshadow * disable f16c for amalgamation * fix amalgamation build * trigger ci * disable f16c for jetson * Fix rat excludes (apache#10499) * MXNET-308 added missing license (apache#10497) * refactored example (apache#10484) * [MXNET-298] Scala Infer API docs landing page (apache#10474) * changed url references from dmlc to apache/incubator-mxnet * prepping scala landing pages * infer api info added * Fix infer storage type (apache#10507) * Fix infer_storage_type * Add test * Fix lint * Trigger CI * [MXNET-306] Add slice_like operator (apache#10491) * add slice_like and doc * pass unittest and lint * Minor simplifications in ci/build.py (apache#10496) * [MXNET-305] Scala tutorial table fix (apache#10488) * initial update on setting up scala ide with mxnet * moving images to web-data project * updated links to images; added readme for root folder * scala hello world feature added * workaround for make transitive error * fixed systempath * minor updates * table fix * added some spacing * more spacing * added ability to set search path for Accelerate library * [MXNET-311] change test needs a docker with sudo, hence image changed (apache#10510) * added new metric * changed back from branch to upstream * lint changed * Fixed typo * Clarified interpretation * Changes for variable names * fixed variable names * replay the unit tests * changed comment
Description
Supports distributed training with float16 (and technically float64 as well).
Modified parameter server to work at byte level, so that we can send both fp16 and fp32 gradients while training in fp16 and mixed precision modes. The user does not have to do anything special to handle fp16 gradients. KVStore handles it internally. Even originally, ps-lite later converted data to char, doing this at an earlier step gives mxnet more flexibility.
The type of request and data type received by server are encoded in a command which is computed using the Cantor pairing function. This generates a unique number
z
given two numbersx
andy
, and allows us to invert this numberz
to give the unique original pair of numbersx
andy
.For keys with fp16 gradients, they are sent as fp16. All computation on server is done in fp16, and the server responds to worker's pull responses as fp16.
FP16 support depends heavily on CPU computation with fp16. This is currently very slow in MXNet.
There are two things which can greatly improve this speed. I have made PRs for both. So for best performance and to get the below numbers we would need to merge these two PRs
Results
8 p3.8x machines, with all 4 gpus being used. Throughput in samples/sec for resnet-50 with imagenet synthetic data
The effect of PR 9029 is not much for fp32 end to end, but on profiling we do see that the addition operations is faster with Kernel launch. But it is hidden by other factors. In the case of fp16 and gradient compression, the effect is very visible.
Checklist
Essentials
Changes
Comments
@eric-haibin-lin @reminisce @cjolivier01 @anirudh2290 @haojin2 @ptrendx please review