TPU выдает ошибку при попытке инициализации и настройки, а затем при построении модели tf.keras

Я запустил этот блок кода, используя TF 2.2.0, Keras и некоторую конфигурацию TPU:

    TPU_WORKER = os.environ["TPU_NAME"]
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print(f"Running on TPU: {tpu.cluster_spec().as_dict()['worker']}")
    print(f"TPU_WORKER: {TPU_WORKER}")
except ValueError: 
    tpu = None
    gpus = tf.config.experimental.list_logical_devices("GPU")

if tpu:
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
elif len(gpus) > 1: # multiple GPUs on the VM
    strategy = tf.distribute.MirroredStrategy(gpus)
    strategy = tf.distribute.get_strategy()

и получил это сообщение об ошибке:

InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-27-a49335a43189> in <module>
     16 if tpu:
---> 17     tf.config.experimental_connect_to_cluster(tpu)
     18     tf.tpu.experimental.initialize_tpu_system(tpu)
     19     strategy = tf.distribute.experimental.TPUStrategy(tpu)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/ in connect_to_cluster(cluster_spec_or_resolver, job_name, task_index, protocol, make_master_device_default, cluster_device_filters)
    181     context.set_server_def(server_def)
    182   else:
--> 183     context.update_server_def(server_def)
    185   if make_master_device_default and isinstance(

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/ in update_server_def(server_def)
   2138 def update_server_def(server_def):
-> 2139   context().update_server_def(server_def)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/ in update_server_def(self, server_def, keep_alive_secs)
    596       # Current executor might have pending nodes that involves updated remote
    597       # devices. Wait for them to finish before updating.
--> 598       self.executor.wait()
    599       self.executor.clear_error()
    600       pywrap_tfe.TFE_ContextUpdateServerDef(self._context_handle,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/ in wait(self)
     65   def wait(self):
     66     """Waits for ops dispatched in this executor to finish."""
---> 67     pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)
     69   def clear_error(self):

InvalidArgumentError: {{function_node __inference_train_function_75067}} Compilation failure: XLA can't deduce compile time constant output shape for strided slice: [4,?], output shape must be a compile-time constant
     [[{{node model/tf_op_layer_strided_slice/strided_slice}}]]
    TPU compilation failed

Эта ошибка:

InvalidArgumentError: {{function_node __inference_train_function_75067}} Compilation failure: XLA can't deduce compile time constant output shape for strided slice: [4,?], output shape must be a compile-time constant
     [[{{node model/tf_op_layer_strided_slice/strided_slice}}]]
    TPU compilation failed

произошло во время предыдущего запуска, и с тех пор я не могу повторно запустить свой код.

Вместо этого можно было бы перезапустить ноутбук и повторно запустить его.

Но затем я получаю ту же ошибку в другом месте:

InvalidArgumentError                      Traceback (most recent call last)
<timed exec> in <module>

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/ in _method_wrapper(self, *args, **kwargs)
     64   def _method_wrapper(self, *args, **kwargs):
     65     if not self._in_multi_worker_mode():  # pylint: disable=protected-access
---> 66       return method(self, *args, **kwargs)
     68     # Running inside `run_distribute_coordinator` already.

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/ in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
    853                 context.async_wait()
    854               logs = tmp_logs  # No error, now safe to assign to logs.
--> 855               callbacks.on_train_batch_end(step, logs)
    856         epoch_logs = copy.copy(logs)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/ in on_train_batch_end(self, batch, logs)
    387     """
    388     if self._should_call_train_batch_hooks:
--> 389       logs = self._process_logs(logs)
    390       self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/ in _process_logs(self, logs)
    263     """Turns tensors into numpy arrays or Python scalars."""
    264     if logs:
--> 265       return tf_utils.to_numpy_or_python_type(logs)
    266     return {}

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/ in to_numpy_or_python_type(tensors)
    521     return t  # Don't turn ragged or sparse tensors to NumPy.
--> 523   return nest.map_structure(_to_single_numpy_or_python_type, tensors)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/ in map_structure(func, *structure, **kwargs)
    616   return pack_sequence_as(
--> 617       structure[0], [func(*x) for x in entries],
    618       expand_composites=expand_composites)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/ in <listcomp>(.0)
    616   return pack_sequence_as(
--> 617       structure[0], [func(*x) for x in entries],
    618       expand_composites=expand_composites)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/ in _to_single_numpy_or_python_type(t)
    517   def _to_single_numpy_or_python_type(t):
    518     if isinstance(t, ops.Tensor):
--> 519       x = t.numpy()
    520       return x.item() if np.ndim(x) == 0 else x
    521     return t  # Don't turn ragged or sparse tensors to NumPy.

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ in numpy(self)
    959     """
    960     # TODO(slebedev): Consider avoiding a copy for non-CPU or remote tensors.
--> 961     maybe_arr = self._numpy()  # pylint: disable=protected-access
    962     return maybe_arr.copy() if isinstance(maybe_arr, np.ndarray) else maybe_arr

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ in _numpy(self)
    927       return self._numpy_internal()
    928     except core._NotOkStatusException as e:
--> 929       six.raise_from(core._status_to_exception(e.code, e.message), None)
    931   @property

/opt/conda/lib/python3.7/site-packages/ in raise_from(value, from_value)

InvalidArgumentError: {{function_node __inference_train_function_78422}} Compilation failure: XLA can't deduce compile time constant output shape for strided slice: [16,?], output shape must be a compile-time constant
     [[{{node model/tf_op_layer_strided_slice/strided_slice}}]]
    TPU compilation failed

при попытке обучить/подогнать многоуровневую модель keras, хотя из приведенного выше стека вызовов неясно, в какой момент произошла эта ошибка.

Еще один вопрос: как очистить кеш или буфер, в котором хранится эта ошибка, чтобы мы могли сбросить TPU и снова запустить наш код после внесения изменений. И не нужно перезапускать сеансы или ядра?

Когда я запускаю тот же код инициализации TPU в Colab (установлено время выполнения TPU):

tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
print(f"Running on TPU: {tpu.cluster_spec().as_dict()['worker']}")  
strategy = tf.distribute.experimental.TPUStrategy(tpu)

Он работает без ошибок, повторно инициализирует TPU и очищает нетерпеливый кеш, вот логи:

Running on TPU: ['']
WARNING:tensorflow:TPU system grpc:// has already been initialized. Reinitializing the TPU can cause previously created variables on TPU to be lost.
WARNING:tensorflow:TPU system grpc:// has already been initialized. Reinitializing the TPU can cause previously created variables on TPU to be lost.
INFO:tensorflow:Initializing the TPU system: grpc://
INFO:tensorflow:Initializing the TPU system: grpc://
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Clearing out eager caches

Для вашего второго выпуска «[4,?] Выходная форма должна быть константой времени компиляции», пожалуйста, дайте вашей модели входные выходные формы при построении вашей модели.

