compatible_with_transformers.py 563 B

12345678910111213141516
  1. import transformers
  2. from packaging import version
  3. def compatible_position_ids(state_dict, position_id_key):
  4. """Transformers no longer expect position_ids after transformers==4.31
  5. https://github.com/huggingface/transformers/pull/24505
  6. Args:
  7. position_id_key (str): position_ids key,
  8. such as(encoder.embeddings.position_ids)
  9. """
  10. transformer_version = version.parse('.'.join(
  11. transformers.__version__.split('.')[:2]))
  12. if transformer_version >= version.parse('4.31.0'):
  13. del state_dict[position_id_key]