Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add run_glue_tpu.py that trains models on TPUs #3702

Merged
merged 24 commits into from
Apr 10, 2020

Conversation

jysohn23
Copy link
Collaborator

@jysohn23 jysohn23 commented Apr 8, 2020

No description provided.

Initial commit to get GLUE (BERT) on TPU
TPU runner is currently implemented in:
https://github.com/pytorch-tpu/transformers/blob/tpu/examples/run_glue_tpu.py.

We plan to upstream this directly into `huggingface/transformers`
(either `master` or `tpu`) branch once it's been more thoroughly tested.
TPU runner is currently implemented in:
https://github.com/pytorch-tpu/transformers/blob/tpu/examples/run_glue_tpu.py.

We plan to upstream this directly into `huggingface/transformers`
(either `master` or `tpu`) branch once it's been more thoroughly tested.
Since for gradient accumulation we're accumulating on batches from
`ParallelLoader` instance which on next() marks the step itself.
* Shard eval dataset and aggregate eval metrics

Also, instead of calling `eval_loss.item()` every time do summation with
tensors on device.

* Change defaultdict to float

* Reduce the pred, label tensors instead of metrics

As brought up during review some metrics like f1 cannot be aggregated
via averaging. GLUE task metrics depends largely on the dataset, so
instead we sync the prediction and label tensors so that the metrics can
be computed accurately on those instead.
--task_name $TASK_NAME \
--do_train \
--do_eval \
--do_lower_case \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it's present in other example codes (and should be changed), but should we keep the --do_lower_case option with cased models?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Removed.

Comment on lines 323 to 329
def save_pretrained(self, save_directory, xla_device=False):
""" Save a model and its configuration file to a directory, so that it
can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.

Arguments:
save_directory: directory to which to save.
xla_device: True if saving after training on TPU/XLA.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes me a bit uncomfortable. I'm pretty sure the best option would be to save it in the model configuration instead of adding an argument to save_pretrained. I'll look into what can be done to have this clean.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your changes have now been merged and I like the way it looks now better 😄

jysohn23 and others added 4 commits April 8, 2020 21:50
This is needed for our testing framework which checks regressions
against key metrics writtern by the summary writer.
Using configuration for `xla_device`
@jysohn23 jysohn23 requested a review from LysandreJik April 9, 2020 17:03
Copy link
Contributor

@srush srush left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me with a couple of small changes.


import numpy as np
import torch
import torch_xla.core.xla_model as xm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this naming for non-standard modules. Can you refer to these by their full names.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In all our examples we use xm, xmp, met, pl consistently so I'd rather keep this way for the sake of consistency. I similarly do see np throughout Huggingface, but if you feel strongly, I'm happy to call them by full name instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. If that's the XLA standard I guess it's okay. (I feel like numpy is a special case).

set_seed(args.seed) # Added here for reproductibility (even between python 2 and 3)
for epoch in train_iterator:
# Get TPU parallel loader which sends data to TPU in background.
train_dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a "TPU" comment like this over each of the non-standard lines? I think that would be helpful for learning from this code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! Done.

loss.backward()

if (step + 1) % args.gradient_accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was under the impression that this line was very slow on TPU?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're referring to the comment made here: https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md I've sent our this PR to clarify.

tl;dr: we've patched torch.nn.utils.clip_grad_norm_ so that it's not slow.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like our comments crossed haha

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good to know.

)

# TPU Parameters
parser.add_argument("--num_cores", default=8, type=int, help="Number of TPU cores to use.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On colab you can only use 1 or 8 here right? Is that true generally?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's not only colab specific but you have to user either 1 or all TPU cores at the moment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, maybe we could put that comment in the doc string just so users don't get confused. (I found the error messages a bit cryptic when I did this wrong.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, clarified on the CLI flag.

@codecov-io
Copy link

codecov-io commented Apr 10, 2020

Codecov Report

Merging #3702 into master will decrease coverage by 0.02%.
The diff coverage is 41.66%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #3702      +/-   ##
==========================================
- Coverage   78.06%   78.03%   -0.03%     
==========================================
  Files         100      100              
  Lines       17134    17144      +10     
==========================================
+ Hits        13375    13378       +3     
- Misses       3759     3766       +7     
Impacted Files Coverage Δ
src/transformers/modeling_utils.py 91.30% <36.36%> (-0.81%) ⬇️
src/transformers/configuration_utils.py 97.01% <100.00%> (+0.02%) ⬆️
src/transformers/modeling_tf_utils.py 93.45% <0.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update f68d228...6e959fd. Read the comment docs.

jysohn23 and others added 2 commits April 10, 2020 06:16
Instead of under `args.data_dir`. This is needed as our test infra uses
data_dir with a read-only filesystem.
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Won't move the script to examples/glue/ yet as I know @julien-c is working on scripts and I don't want to create a complicated rebase.

@julien-c
Copy link
Member

Hmm don't worry about me, this shouldn't be a problem – feel free to merge if ready @LysandreJik !

@LysandreJik LysandreJik merged commit 551b450 into huggingface:master Apr 10, 2020
@akhileshgotmare
Copy link

@jysohn23 I'm trying to run a variant of run_glue_tpu.py on TPUs and am stuck at an oom error. The first iteration of the below for loop runs fine, but it breaks on the second one. Any pointers on how to fix this?

train_dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
epoch_iterator = tqdm(train_dataloader, desc="Iteration", total=len(dataloader), disable=disable_logging)
for step, batch in enumerate(epoch_iterator):

I tried reducing the batch-size to 1 and running on a single core, both led to the same error. I'm using this gcr.io/tpu-pytorch/xla:nightly_3.6 image for my experiments.

full log - shorturl.at/iswxR
few lines of the error log -

020-06-30 21:49:29.304998: E    6014 tensorflow/compiler/xla/xla_client/xla_util.cc:76] >>> Dumping Computation 0                     | 1/6136 [01:16<131:08:36, 76.95s/it]
2020-06-30 21:49:29.305126: E    6014 tensorflow/compiler/xla/xla_client/xla_util.cc:76] HloModule SyncTensorsGraph.33776, input_output_alias={ {0}: (250, {}), {1}: (249, {}), {2}: (265, {}), {3}: (248, {}), {4}: (247, {}), {5}: (246, {}), {6}: (245, {}), {7}: (244, {}), {8}: (269, {}), {9}: (243, {}), {10}: (242, {}), {11}: (241, {}), {12}: (240, {}), {13}: (239, {}), {14}: (271, {}), {15}: (238, {}), {16}: (237, {}), {17}: (236, {}), {18}: (235, {}), {19}: (234, {}), {20}: (273, {}), {21}: (233, {}), {22}: (232, {}), {23}: (231, {}), {24}: (230, {}), {25}: (229, {}), {26}: (274, {}), {27}: (228, {}), {28}: (227, {}), {29}: (226, {}), {30}: (225, {}), {31}: (224, {}), {32}: (276, {}), {33}: (223, {}), {34}: (222, {}), {35}: (221, {}), {36}: (220, {}), {37}: (219, {}), {38}: (277, {}), {39}: (218, {}), {40}: (217, {}), {41}: (216, {}), {42}: (215, {}), {43}: (214, {}), {44}: (279, {}), {45}: (213, {}), {46}: (212, {}), {47}: (211, {}), {48}: (210, {}), {49}: (209, {}), {50}: (280, {}), {51}: (208, {}), {52}: (207, {}), {53}: (206, {}), {54}: (205, {}), {55}: (204, {}), {56}: (282, {}), {57}: (203, {}), {58}: (202, {}), {59}: (201, {}), {60}: (200, {}), {61}: (199, {}), {62}: (283, {}), {63}: (198, {}), {64}: (197, {}), {65}: (196, {}), {66}: (195, {}), {67}: (194, {}), {68}: (285, {}), {69}: (193, {}), {70}: (192, {}), {71}: (191, {}), {72}: (190, {}), {73}: (189, {}), {74}: (286, {}), {75}: (188, {}), {76}: (187, {}), {77}: (186, {}), {78}: (185, {}), {79}: (184, {}), {80}: (288, {}), {81}: (183, {}), {82}: (182, {}), {83}: (181, {}), {84}: (180, {}), {85}: (179, {}), {86}: (289, {}), {87}: (178, {}), {88}: (177, {}), {89}: (176, {}), {90}: (175, {}), {91}: (174, {}), {92}: (291, {}), {93}: (173, {}), {94}: (172, {}), {95}: (171, {}), {96}: (170, {}), {97}: (169, {}), {98}: (292, {}), {99}: (168, {}), {100}: (167, {}), {101}: (166, {}), {102}: (165, {}), {103}: (164, {}), {104}: (294, {}), {105}: (163, {}), {106}: (162, {}), {107}: (161, {}), {108}: (160, {}), {109}: (159, {}), {110}: (295, {}), {111}: (158, {}), {112}: (157, {}), {113}: (156, {}), {114}: (155, {}), {115}: (154, {}), {116}: (297, {}), {117}: (153, {}), {118}: (152, {}), {119}: (151, {}), {120}: (150, {}), {121}: (149, {}), {122}: (298, {}), {123}: (148, {}), {124}: (147, {}), {125}: (146, {}), {126}: (145, {}), {127}: (144, {}), {128}: (300, {}), {129}: (143, {}), {130}: (142, {}), {131}: (141, {}), {132}: (140, {}), {133}: (139, {}), {134}: (301, {}), {135}: (138, {}), {136}: (137, {}), {137}: (136, {}), {138}: (135, {}), {139}: (134, {}), {140}: (303, {}), {141}: (133, {}), {142}: (132, {}), {143}: (131, {}), {144}: (130, {}), {145}: (129, {}), {146}: (304, {}), {147}: (128, {}), {148}: (127, {}), {149}: (126, {}), {150}: (125, {}), {151}: (124, {}), {152}: (306, {}), {153}: (123, {}), {154}: (122, {}), {155}: (121, {}), {156}: (120, {}), {157}: (119, {}), {158}: (307, {}), {159}: (118, {}), {160}: (117, {}), {161}: (116, {}), {162}: (115, {}), {163}: (114, {}), {164}: (309, {}), {165}: (113, {}), {166}: (112, {}), {167}: (111, {}), {168}: (110, {}), {169}: (109, {}), {170}: (310, {}), {171}: (108, {}), {172}: (107, {}), {173}: (106, {}), {174}: (105, {}), {175}: (104, {}), {176}: (312, {}), {177}: (103, {}), {178}: (102, {}), {179}: (101, {}), {180}: (100, {}), {181}: (99, {}), {182}: (313, {}), {183}: (98, {}), {184}: (97, {}), {185}: (96, {}), {186}: (95, {}), {187}: (94, {}), {188}: (315, {}), {189}: (93, {}), {190}: (92, {}), {191}: (91, {}), {192}: (90, {}), {193}: (89, {}), {194}: (316, {}), {195}: (88, {}), {196}: (87, {}), {197}: (86, {}), {198}: (85, {}), {199}: (84, {}), {200}: (318, {}), {201}: (83, {}), {202}: (82, {}), {203}: (81, {}), {204}: (80, {}), {205}: (79, {}), {206}: (319, {}), {207}: (78, {}), {208}: (77, {}), {209}: (76, {}), {210}: (75, {}), {211}: (74, {}), {212}: (321, {}), {213}: (73, {}), {214}: (72, {}), {215}: (71, {}), {216}: (70, {}), {217}: (69, {}), {218}: (322, {}), {219}: (68, {}), {220}: (67, {}), {221}: (66, {}), {222}: (65, {}), {223}: (64, {}), {224}: (324, {}), {225}: (63, {}), {226}: (62, {}), {227}: (61, {}), {228}: (60, {}), {229}: (59, {}), {230}: (325, {}), {231}: (58, {}), {232}: (57, {}), {233}: (56, {}), {234}: (55, {}), {235}: (54, {}), {236}: (327, {}), {237}: (53, {}), {238}: (52, {}), {239}: (51, {}), {240}: (50, {}), {241}: (49, {}), {242}: (328, {}), {243}: (48, {}), {244}: (47, {}), {245}: (46, {}), {246}: (45, {}), {247}: (44, {}), {248}: (330, {}), {249}: (43, {}), {250}: (42, {}), {251}: (41, {}), {252}: (40, {}), {253}: (39, {}), {254}: (331, {}), {255}: (38, {}), {256}: (37, {}), {257}: (36, {}), {258}: (35, {}), {259}: (34, {}), {260}: (333, {}), {261}: (33, {}), {262}: (32, {}), {263}: (31, {}), {264}: (30, {}), {265}: (29, {}), {266}: (334, {}), {267}: (28, {}), {268}: (27, {}), {269}: (26, {}), {270}: (25, {}), {271}: (24, {}), {272}: (336, {}), {273}: (23, {}), {274}: (22, {}), {275}: (21, {}), {276}: (20, {}), {277}: (19, {}), {278}: (337, {}), {279}: (18, {}), {280}: (17, {}), {281}: (16, {}), {282}: (15, {}), {283}: (14, {}), {284}: (339, {}), {285}: (13, {}), {286}: (12, {}), {287}: (8, {}), {288}: (7, {}), {289}: (5, {}), {290}: (340, {}), {291}: (346, {}), {292}: (4, {}), {377}: (342, {}) }
2020-06-30 21:49:29.305162: E    6014 tensorflow/compiler/xla/xla_client/xla_util.cc:76] 
2020-06-30 21:49:29.305173: E    6014 tensorflow/compiler/xla/xla_client/xla_util.cc:76] %MaxComputation.2092 (x.2093: f32[], y.2094: f32[]) -> f32[] {
2020-06-30 21:49:29.305181: E    6014 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %x.2093 = f32[] parameter(0)
2020-06-30 21:49:29.305196: E    6014 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %y.2094 = f32[] parameter(1)
2020-06-30 21:49:29.305204: E    6014 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   ROOT %maximum.2095 = f32[] maximum(f32[] %x.2093, f32[] %y.2094)
2020-06-30 21:49:29.305212: E    6014 tensorflow/compiler/xla/xla_client/xla_util.cc:76] }
2020-06-30 21:49:29.305221: E    6014 tensorflow/compiler/xla/xla_client/xla_util.cc:76] 
2020-06-30 21:49:29.305235: E    6014 tensorflow/compiler/xla/xla_client/xla_util.cc:76] %AddComputation.2101 (x.2102: f32[], y.2103: f32[]) -> f32[] {
2020-06-30 21:49:29.305244: E    6014 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %x.2102 = f32[] parameter(0)
2020-06-30 21:49:29.305254: E    6014 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %y.2103 = f32[] parameter(1)
2020-06-30 21:49:29.305264: E    6014 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   ROOT %add.2104 = f32[] add(f32[] %x.2102, f32[] %y.2103)
2020-06-30 21:49:29.305273: E    6014 tensorflow/compiler/xla/xla_client/xla_util.cc:76] }
2020-06-30 21:49:29.305283: E    6014 tensorflow/compiler/xla/xla_client/xla_util.cc:76] 
.
.
.
.
2020-06-30 21:49:29.568300: E    5603 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %subtract.5549 = f32[] subtract(f32[] %constant.5532, f32[] %constant.5533)
2020-06-30 21:49:29.568320: E    6014 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %constant.20745 = f32[] constant(0.125)
2020-06-30 21:49:29.568321: E    5603 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %broadcast.5550 = f32[1,16,128,128]{3,2,1,0} broadcast(f32[] %subtract.5549), dimensions={}
2020-06-30 21:49:29.568331: E    6014 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %broadcast.20746 = f32[1024,4096]{1,0} broadcast(f32[] %constant.20745), dimensions={}
2020-06-30 21:49:29.568332: E    5603 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %multiply.5551 = f32[1,16,128,128]{3,2,1,0} multiply(f32[1,16,128,128]{3,2,1,0} %multiply.5548, f32[1,16,128,128]{3,2,1,0} %broadcast.5550)
2020-06-30 21:49:29.568342: E    6014 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %multiply.20747 = f32[1024,4096]{1,0} multiply(f32[1024,4096]{1,0} %get-tuple-element.20744, f32[1024,4096]{1,0} %broadcast.20746)
2020-06-30 21:49:29.568344: E    5603 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %broadcast.5552 = f32[1,16,128,128]{3,2,1,0} broadcast(f32[] %constant.5533), dimensions={}
2020-06-30 21:49:29.568353: E    6014 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %reshape.25706 = f32[1,1]{1,0} reshape(f32[] %p263.1975)
2020-06-30 21:49:29.568354: E    5603 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %add.5553 = f32[1,16,128,128]{3,2,1,0} add(f32[1,16,128,128]{3,2,1,0} %multiply.5551, f32[1,16,128,128]{3,2,1,0} %broadcast.5552)
.
.
.
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

  (1) Resource exhausted: Ran out of memory in memory space vmem. It should not be possible to run out of vmem - please file a bug against XLA.

Largest program allocations in vmem:

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

	 [[{{node XRTCompile}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

	 [[XRTCompile_G6]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

0 successful operations.
0 derived errors ignored.
Traceback (most recent call last):
  File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 235, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 229, in _start_fn
    fn(gindex, *args)
  File "/export/share/akhilesh-gotmare/tpu_gedi/transformers/examples/run_tpu_glue.py", line 797, in _mp_fn
    main(args)
  File "/export/share/akhilesh-gotmare/tpu_gedi/transformers/examples/run_tpu_glue.py", line 607, in main
    global_step, tr_loss = train(args, train_dataset, model, tokenizer, disable_logging=disable_logging)
  File "/export/share/akhilesh-gotmare/tpu_gedi/transformers/examples/run_tpu_glue.py", line 186, in train
    for step, batch in enumerate(epoch_iterator):
  File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/tqdm/std.py", line 1107, in __iter__
    for obj in iterable:
  File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch_xla/distributed/parallel_loader.py", line 31, in __next__
    return self.next()
  File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch_xla/distributed/parallel_loader.py", line 37, in next
    xm.mark_step()
  File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch_xla/core/xla_model.py", line 549, in mark_step
    wait=xu.getenv_as('XLA_SYNC_WAIT', bool, False))
RuntimeError: Resource exhausted: From /job:tpu_worker/replica:0/task:0:
2 root error(s) found.
  (0) Resource exhausted: Ran out of memory in memory space vmem. It should not be possible to run out of vmem - please file a bug against XLA.

Largest program allocations in vmem:

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

	 [[{{node XRTCompile}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

  (1) Resource exhausted: Ran out of memory in memory space vmem. It should not be possible to run out of vmem - please file a bug against XLA.

Largest program allocations in vmem:

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

	 [[{{node XRTCompile}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

	 [[XRTCompile_G6]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

0 successful operations.
0 derived errors ignored.
Traceback (most recent call last):
  File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 235, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 229, in _start_fn
    fn(gindex, *args)
  File "/export/share/akhilesh-gotmare/tpu_gedi/transformers/examples/run_tpu_glue.py", line 797, in _mp_fn
    main(args)
  File "/export/share/akhilesh-gotmare/tpu_gedi/transformers/examples/run_tpu_glue.py", line 607, in main
    global_step, tr_loss = train(args, train_dataset, model, tokenizer, disable_logging=disable_logging)
  File "/export/share/akhilesh-gotmare/tpu_gedi/transformers/examples/run_tpu_glue.py", line 186, in train
    for step, batch in enumerate(epoch_iterator):
  File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/tqdm/std.py", line 1107, in __iter__
    for obj in iterable:
  File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch_xla/distributed/parallel_loader.py", line 31, in __next__
    return self.next()
  File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch_xla/distributed/parallel_loader.py", line 37, in next
    xm.mark_step()
  File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch_xla/core/xla_model.py", line 549, in mark_step
    wait=xu.getenv_as('XLA_SYNC_WAIT', bool, False))
RuntimeError: Resource exhausted: From /job:tpu_worker/replica:0/task:0:
2 root error(s) found.
  (0) Resource exhausted: Ran out of memory in memory space vmem. It should not be possible to run out of vmem - please file a bug against XLA.

Largest program allocations in vmem:

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

	 [[{{node XRTCompile}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

  (1) Resource exhausted: Ran out of memory in memory space vmem. It should not be possible to run out of vmem - please file a bug against XLA.

Largest program allocations in vmem:

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

  XLA label: %fusion.4431 = (f32[1024]{0:T(1024)}, f32[24,128]{1,0:T(8,128)}, f32[24,128]{1,0:T(8,128)}, f32[1024]{0:T(1024)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24,128,1024]{2,1,0:T(8,128)}, f32[24...
  Allocation type: scoped

	 [[{{node XRTCompile}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

	 [[XRTCompile_G6]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

0 successful operations.
0 derived errors ignored.
/root/anaconda3/envs/pytorch/lib/python3.6/multiprocessing/semaphore_tracker.py:143: UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores to clean up at shutdown
  len(cache))
/root/anaconda3/envs/pytorch/lib/python3.6/multiprocessing/semaphore_tracker.py:143: UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores to clean up at shutdown
  len(cache))
/root/anaconda3/envs/pytorch/lib/python3.6/multiprocessing/semaphore_tracker.py:143: UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores to clean up at shutdown
  len(cache))
/root/anaconda3/envs/pytorch/lib/python3.6/multiprocessing/semaphore_tracker.py:143: UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores to clean up at shutdown
  len(cache))
/root/anaconda3/envs/pytorch/lib/python3.6/multiprocessing/semaphore_tracker.py:143: UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores to clean up at shutdown
  len(cache))
/root/anaconda3/envs/pytorch/lib/python3.6/multiprocessing/semaphore_tracker.py:143: UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores to clean up at shutdown
  len(cache))
Traceback (most recent call last):
  File "run_tpu_glue.py", line 806, in <module>
    main_cli()
  File "run_tpu_glue.py", line 802, in main_cli
    xmp.spawn(_mp_fn, args=(args,), nprocs=args.num_cores)
  File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 300, in spawn
    start_method=start_method)
  File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 158, in start_processes
    while not context.join():
  File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 113, in join
    (error_index, exitcode)
Exception: process 6 terminated with exit code 17

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants