test_hf_model.py 1.0 KB

1234567891011121314151617181920212223242526272829
  1. import pytest
  2. import torch
  3. from open_clip.hf_model import _POOLERS, HFTextEncoder
  4. from transformers import AutoConfig
  5. from transformers.modeling_outputs import BaseModelOutput
  6. # test poolers
  7. def test_poolers():
  8. bs, sl, d = 2, 10, 5
  9. h = torch.arange(sl).repeat(bs).reshape(bs, sl)[..., None] * torch.linspace(0.2, 1., d)
  10. mask = torch.ones(bs, sl, dtype=torch.long)
  11. mask[:2, 6:] = 0
  12. x = BaseModelOutput(h)
  13. for name, cls in _POOLERS.items():
  14. pooler = cls()
  15. res = pooler(x, mask)
  16. assert res.shape == (bs, d), f"{name} returned wrong shape"
  17. # test HFTextEncoder
  18. @pytest.mark.parametrize("model_id", ["arampacha/roberta-tiny", "roberta-base", "xlm-roberta-base", "google/mt5-base"])
  19. def test_pretrained_text_encoder(model_id):
  20. bs, sl, d = 2, 10, 64
  21. cfg = AutoConfig.from_pretrained(model_id)
  22. model = HFTextEncoder(model_id, d, proj='linear')
  23. x = torch.randint(0, cfg.vocab_size, (bs, sl))
  24. with torch.no_grad():
  25. emb = model(x)
  26. assert emb.shape == (bs, d)