-
Notifications
You must be signed in to change notification settings - Fork 753
/
run_pretraining_google_fast.py
565 lines (460 loc) · 21.3 KB
/
run_pretraining_google_fast.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
# coding=utf-8
# Copyright 2019 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python2, python3
"""Run masked LM/next sentence masked_lm pre-training for ALBERT."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
from six.moves import range
import tensorflow as tf
import modeling_google_fast as modeling
import optimization_google as optimization
flags = tf.flags
FLAGS = flags.FLAGS
## Required parameters
flags.DEFINE_string(
"albert_config_file", None,
"The config json file corresponding to the pre-trained ALBERT model. "
"This specifies the model architecture.")
flags.DEFINE_string(
"input_file", None,
"Input TF example files (can be a glob or comma separated).")
flags.DEFINE_string(
"output_dir", None,
"The output directory where the model checkpoints will be written.")
flags.DEFINE_string(
"export_dir", None,
"The output directory where the saved models will be written.")
## Other parameters
flags.DEFINE_string(
"init_checkpoint", None,
"Initial checkpoint (usually from a pre-trained ALBERT model).")
flags.DEFINE_integer(
"max_seq_length", 512,
"The maximum total input sequence length after WordPiece tokenization. "
"Sequences longer than this will be truncated, and sequences shorter "
"than this will be padded. Must match data generation.")
flags.DEFINE_integer(
"max_predictions_per_seq", 20,
"Maximum number of masked LM predictions per sequence. "
"Must match data generation.")
flags.DEFINE_bool("do_train", True, "Whether to run training.")
flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
flags.DEFINE_integer("train_batch_size", 4096, "Total batch size for training.")
flags.DEFINE_integer("eval_batch_size", 64, "Total batch size for eval.")
flags.DEFINE_enum("optimizer", "lamb", ["adamw", "lamb"],
"The optimizer for training.")
flags.DEFINE_float("learning_rate", 0.00176, "The initial learning rate.")
flags.DEFINE_float("poly_power", 1.0, "The power of poly decay.")
flags.DEFINE_integer("num_train_steps", 125000, "Number of training steps.")
flags.DEFINE_integer("num_warmup_steps", 3125, "Number of warmup steps.")
flags.DEFINE_integer("start_warmup_step", 0, "The starting step of warmup.")
flags.DEFINE_integer("save_checkpoints_steps", 5000,
"How often to save the model checkpoint.")
flags.DEFINE_integer("iterations_per_loop", 1000,
"How many steps to make in each estimator call.")
flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.")
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
flags.DEFINE_bool("init_from_group0", False, "Whether to initialize"
"parameters of other groups from group 0")
tf.flags.DEFINE_string(
"tpu_name", None,
"The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
"url.")
tf.flags.DEFINE_string(
"tpu_zone", None,
"[Optional] GCE zone where the Cloud TPU is located in. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
tf.flags.DEFINE_string(
"gcp_project", None,
"[Optional] Project name for the Cloud TPU-enabled project. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
flags.DEFINE_integer(
"num_tpu_cores", 8,
"Only used if `use_tpu` is True. Total number of TPU cores to use.")
flags.DEFINE_float(
"masked_lm_budget", 0,
"If >0, the ratio of masked ngrams to unmasked ngrams. Default 0,"
"for offline masking")
def model_fn_builder(albert_config, init_checkpoint, learning_rate,
num_train_steps, num_warmup_steps, use_tpu,
use_one_hot_embeddings, optimizer, poly_power,
start_warmup_step):
"""Returns `model_fn` closure for TPUEstimator."""
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
"""The `model_fn` for TPUEstimator."""
tf.logging.info("*** Features ***")
for name in sorted(features.keys()):
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
input_ids = features["input_ids"]
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]
masked_lm_positions = features["masked_lm_positions"]
masked_lm_ids = features["masked_lm_ids"]
masked_lm_weights = features["masked_lm_weights"]
# Note: We keep this feature name `next_sentence_labels` to be compatible
# with the original data created by lanzhzh@. However, in the ALBERT case
# it does represent sentence_order_labels.
sentence_order_labels = features["next_sentence_labels"]
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
model = modeling.AlbertModel(
config=albert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=use_one_hot_embeddings)
(masked_lm_loss, masked_lm_example_loss,
masked_lm_log_probs) = get_masked_lm_output(albert_config,
model.get_sequence_output(),
model.get_embedding_table(),
masked_lm_positions,
masked_lm_ids,
masked_lm_weights)
(sentence_order_loss, sentence_order_example_loss,
sentence_order_log_probs) = get_sentence_order_output(
albert_config, model.get_pooled_output(), sentence_order_labels)
total_loss = masked_lm_loss + sentence_order_loss
tvars = tf.trainable_variables()
initialized_variable_names = {}
scaffold_fn = None
if init_checkpoint:
tf.logging.info("number of hidden group %d to initialize",
albert_config.num_hidden_groups)
num_of_initialize_group = 1
if FLAGS.init_from_group0:
num_of_initialize_group = albert_config.num_hidden_groups
if albert_config.net_structure_type > 0:
num_of_initialize_group = albert_config.num_hidden_layers
(assignment_map, initialized_variable_names
) = modeling.get_assignment_map_from_checkpoint(
tvars, init_checkpoint, num_of_initialize_group)
if use_tpu:
def tpu_scaffold():
for gid in range(num_of_initialize_group):
tf.logging.info("initialize the %dth layer", gid)
tf.logging.info(assignment_map[gid])
tf.train.init_from_checkpoint(init_checkpoint, assignment_map[gid])
return tf.train.Scaffold()
scaffold_fn = tpu_scaffold
else:
for gid in range(num_of_initialize_group):
tf.logging.info("initialize the %dth layer", gid)
tf.logging.info(assignment_map[gid])
tf.train.init_from_checkpoint(init_checkpoint, assignment_map[gid])
tf.logging.info("**** Trainable Variables ****")
for var in tvars:
init_string = ""
if var.name in initialized_variable_names:
init_string = ", *INIT_FROM_CKPT*"
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
init_string)
output_spec = None
if mode == tf.estimator.ModeKeys.TRAIN:
train_op = optimization.create_optimizer(
total_loss, learning_rate, num_train_steps, num_warmup_steps,
use_tpu, optimizer, poly_power, start_warmup_step)
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
mode=mode,
loss=total_loss,
train_op=train_op,
scaffold_fn=scaffold_fn)
elif mode == tf.estimator.ModeKeys.EVAL:
def metric_fn(*args):
"""Computes the loss and accuracy of the model."""
(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
masked_lm_weights, sentence_order_example_loss,
sentence_order_log_probs, sentence_order_labels) = args[:7]
masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
[-1, masked_lm_log_probs.shape[-1]])
masked_lm_predictions = tf.argmax(
masked_lm_log_probs, axis=-1, output_type=tf.int32)
masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
masked_lm_accuracy = tf.metrics.accuracy(
labels=masked_lm_ids,
predictions=masked_lm_predictions,
weights=masked_lm_weights)
masked_lm_mean_loss = tf.metrics.mean(
values=masked_lm_example_loss, weights=masked_lm_weights)
metrics = {
"masked_lm_accuracy": masked_lm_accuracy,
"masked_lm_loss": masked_lm_mean_loss,
}
sentence_order_log_probs = tf.reshape(
sentence_order_log_probs, [-1, sentence_order_log_probs.shape[-1]])
sentence_order_predictions = tf.argmax(
sentence_order_log_probs, axis=-1, output_type=tf.int32)
sentence_order_labels = tf.reshape(sentence_order_labels, [-1])
sentence_order_accuracy = tf.metrics.accuracy(
labels=sentence_order_labels,
predictions=sentence_order_predictions)
sentence_order_mean_loss = tf.metrics.mean(
values=sentence_order_example_loss)
metrics.update({
"sentence_order_accuracy": sentence_order_accuracy,
"sentence_order_loss": sentence_order_mean_loss
})
return metrics
metric_values = [
masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
masked_lm_weights, sentence_order_example_loss,
sentence_order_log_probs, sentence_order_labels
]
eval_metrics = (metric_fn, metric_values)
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
mode=mode,
loss=total_loss,
eval_metrics=eval_metrics,
scaffold_fn=scaffold_fn)
else:
raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
return output_spec
return model_fn
def get_masked_lm_output(albert_config, input_tensor, output_weights, positions,
label_ids, label_weights):
"""Get loss and log probs for the masked LM."""
input_tensor = gather_indexes(input_tensor, positions)
with tf.variable_scope("cls/predictions"):
# We apply one more non-linear transformation before the output layer.
# This matrix is not used after pre-training.
with tf.variable_scope("transform"):
input_tensor = tf.layers.dense(
input_tensor,
units=albert_config.embedding_size,
activation=modeling.get_activation(albert_config.hidden_act),
kernel_initializer=modeling.create_initializer(
albert_config.initializer_range))
input_tensor = modeling.layer_norm(input_tensor)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
output_bias = tf.get_variable(
"output_bias",
shape=[albert_config.vocab_size],
initializer=tf.zeros_initializer())
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
log_probs = tf.nn.log_softmax(logits, axis=-1)
label_ids = tf.reshape(label_ids, [-1])
label_weights = tf.reshape(label_weights, [-1])
one_hot_labels = tf.one_hot(
label_ids, depth=albert_config.vocab_size, dtype=tf.float32)
# The `positions` tensor might be zero-padded (if the sequence is too
# short to have the maximum number of predictions). The `label_weights`
# tensor has a value of 1.0 for every real prediction and 0.0 for the
# padding predictions.
per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
numerator = tf.reduce_sum(label_weights * per_example_loss)
denominator = tf.reduce_sum(label_weights) + 1e-5
loss = numerator / denominator
return (loss, per_example_loss, log_probs)
def get_sentence_order_output(albert_config, input_tensor, labels):
"""Get loss and log probs for the next sentence prediction."""
# Simple binary classification. Note that 0 is "next sentence" and 1 is
# "random sentence". This weight matrix is not used after pre-training.
with tf.variable_scope("cls/seq_relationship"):
output_weights = tf.get_variable(
"output_weights",
shape=[2, albert_config.hidden_size],
initializer=modeling.create_initializer(
albert_config.initializer_range))
output_bias = tf.get_variable(
"output_bias", shape=[2], initializer=tf.zeros_initializer())
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
log_probs = tf.nn.log_softmax(logits, axis=-1)
labels = tf.reshape(labels, [-1])
one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
loss = tf.reduce_mean(per_example_loss)
return (loss, per_example_loss, log_probs)
def gather_indexes(sequence_tensor, positions):
"""Gathers the vectors at the specific positions over a minibatch."""
sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
batch_size = sequence_shape[0]
seq_length = sequence_shape[1]
width = sequence_shape[2]
flat_offsets = tf.reshape(
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
flat_positions = tf.reshape(positions + flat_offsets, [-1])
flat_sequence_tensor = tf.reshape(sequence_tensor,
[batch_size * seq_length, width])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor
def input_fn_builder(input_files,
max_seq_length,
max_predictions_per_seq,
is_training,
num_cpu_threads=4):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
def input_fn(params):
"""The actual input function."""
batch_size = params["batch_size"]
name_to_features = {
"input_ids": tf.FixedLenFeature([max_seq_length], tf.int64),
"input_mask": tf.FixedLenFeature([max_seq_length], tf.int64),
"segment_ids": tf.FixedLenFeature([max_seq_length], tf.int64),
# Note: We keep this feature name `next_sentence_labels` to be
# compatible with the original data created by lanzhzh@. However, in
# the ALBERT case it does represent sentence_order_labels.
"next_sentence_labels": tf.FixedLenFeature([1], tf.int64),
}
if FLAGS.masked_lm_budget:
name_to_features.update({
"token_boundary":
tf.FixedLenFeature([max_seq_length], tf.int64)})
else:
name_to_features.update({
"masked_lm_positions":
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
"masked_lm_ids":
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
"masked_lm_weights":
tf.FixedLenFeature([max_predictions_per_seq], tf.float32)})
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
if is_training:
d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
d = d.repeat()
d = d.shuffle(buffer_size=len(input_files))
# `cycle_length` is the number of parallel files that get read.
cycle_length = min(num_cpu_threads, len(input_files))
# `sloppy` mode means that the interleaving is not exact. This adds
# even more randomness to the training pipeline.
d = d.apply(
tf.contrib.data.parallel_interleave(
tf.data.TFRecordDataset,
sloppy=is_training,
cycle_length=cycle_length))
d = d.shuffle(buffer_size=100)
else:
d = tf.data.TFRecordDataset(input_files)
# Since we evaluate for a fixed number of steps we don't want to encounter
# out-of-range exceptions.
d = d.repeat()
# We must `drop_remainder` on training because the TPU requires fixed
# size dimensions. For eval, we assume we are evaluating on the CPU or GPU
# and we *don't* want to drop the remainder, otherwise we wont cover
# every sample.
d = d.apply(
tf.data.experimental.map_and_batch_with_legacy_function(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
num_parallel_batches=num_cpu_threads,
drop_remainder=True))
tf.logging.info(d)
return d
return input_fn
def _decode_record(record, name_to_features):
"""Decodes a record to a TensorFlow example."""
example = tf.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.to_int32(t)
example[name] = t
return example
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
if not FLAGS.do_train and not FLAGS.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file)
tf.gfile.MakeDirs(FLAGS.output_dir)
input_files = []
for input_pattern in FLAGS.input_file.split(","):
input_files.extend(tf.gfile.Glob(input_pattern))
tf.logging.info("*** Input Files ***")
for input_file in input_files:
tf.logging.info(" %s" % input_file)
tpu_cluster_resolver = None
if FLAGS.use_tpu and FLAGS.tpu_name:
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
run_config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster_resolver,
master=FLAGS.master,
model_dir=FLAGS.output_dir,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
tpu_config=tf.contrib.tpu.TPUConfig(
iterations_per_loop=FLAGS.iterations_per_loop,
num_shards=FLAGS.num_tpu_cores,
per_host_input_for_training=is_per_host))
model_fn = model_fn_builder(
albert_config=albert_config,
init_checkpoint=FLAGS.init_checkpoint,
learning_rate=FLAGS.learning_rate,
num_train_steps=FLAGS.num_train_steps,
num_warmup_steps=FLAGS.num_warmup_steps,
use_tpu=FLAGS.use_tpu,
use_one_hot_embeddings=FLAGS.use_tpu,
optimizer=FLAGS.optimizer,
poly_power=FLAGS.poly_power,
start_warmup_step=FLAGS.start_warmup_step)
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
estimator = tf.contrib.tpu.TPUEstimator(
use_tpu=FLAGS.use_tpu,
model_fn=model_fn,
config=run_config,
train_batch_size=FLAGS.train_batch_size,
eval_batch_size=FLAGS.eval_batch_size)
if FLAGS.do_train:
tf.logging.info("***** Running training *****")
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
train_input_fn = input_fn_builder(
input_files=input_files,
max_seq_length=FLAGS.max_seq_length,
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
is_training=True)
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
if FLAGS.do_eval:
tf.logging.info("***** Running evaluation *****")
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
global_step = -1
output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
writer = tf.gfile.GFile(output_eval_file, "w")
tf.gfile.MakeDirs(FLAGS.export_dir)
eval_input_fn = input_fn_builder(
input_files=input_files,
max_seq_length=FLAGS.max_seq_length,
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
is_training=False)
while global_step < FLAGS.num_train_steps:
if estimator.latest_checkpoint() is None:
tf.logging.info("No checkpoint found yet. Sleeping.")
time.sleep(1)
else:
result = estimator.evaluate(
input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)
global_step = result["global_step"]
tf.logging.info("***** Eval results *****")
for key in sorted(result.keys()):
tf.logging.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
if __name__ == "__main__":
flags.mark_flag_as_required("input_file")
flags.mark_flag_as_required("albert_config_file")
flags.mark_flag_as_required("output_dir")
tf.app.run()