diff --git a/tensorflow_text/core/kernels/string_vocab.cc b/tensorflow_text/core/kernels/string_vocab.cc index fadbe277d..0171f31fd 100644 --- a/tensorflow_text/core/kernels/string_vocab.cc +++ b/tensorflow_text/core/kernels/string_vocab.cc @@ -19,6 +19,7 @@ namespace text { StringVocab::StringVocab(const std::vector& vocab) : vocab_(vocab) { + index_map_.reserve(vocab.size()); for (int i = 0; i < vocab.size(); ++i) { index_map_[vocab_[i]] = i; } diff --git a/tensorflow_text/python/ops/fast_wordpiece_tokenizer_test.py b/tensorflow_text/python/ops/fast_wordpiece_tokenizer_test.py index 2e841f9f6..453b58d88 100644 --- a/tensorflow_text/python/ops/fast_wordpiece_tokenizer_test.py +++ b/tensorflow_text/python/ops/fast_wordpiece_tokenizer_test.py @@ -796,5 +796,86 @@ def detokenize(self, input_tensor): self.assertAllEqual(tf_detokenization_result, tflite_detokenization_result) +class StringVocabHashMapBugTest(test_util.TensorFlowTestCase): + """Test for StringVocab hash map rehashing bug (b/XXXXX). + + Regression test for issue where FastWordpieceTokenizer fails with + "Cannot find unk_token in the vocab!" when vocabulary size >= 7 and + unknown_token is not the last element. This is caused by a bug in + StringVocab constructor where hash map rehashing during construction + causes lookups to fail. + + See: tensorflow_text/core/kernels/string_vocab.cc + """ + + def test_vocab_size_6_unk_at_first_position(self): + """Vocab size 6 should always work (no rehashing).""" + vocab = [b'[UNK]', b'token1', b'token2', b'token3', b'token4', b'token5'] + # Should not raise + tokenizer = tf_text.FastWordpieceTokenizer( + vocab=vocab, unknown_token=b'[UNK]', no_pretokenization=True) + # Verify it works by tokenizing + result = tokenizer.tokenize([b'[UNK]', b'token1']) + self.assertIsNotNone(result) + + def test_vocab_size_7_unk_at_first_position(self): + """Vocab size 7 with unknown_token at first position. + + This is the critical test case. Without the fix (index_map_.reserve()), + this will fail with "Cannot find unk_token in the vocab!" due to hash map + rehashing during StringVocab construction. + """ + vocab = [ + b'[UNK]', b'token1', b'token2', b'token3', b'token4', b'token5', + b'token6' + ] + # Should not raise - this is the regression being tested + tokenizer = tf_text.FastWordpieceTokenizer( + vocab=vocab, unknown_token=b'[UNK]', no_pretokenization=True) + # Verify it works by tokenizing + result = tokenizer.tokenize([b'[UNK]', b'token1']) + self.assertIsNotNone(result) + + def test_vocab_size_7_unk_at_middle_position(self): + """Vocab size 7 with unknown_token at middle position.""" + vocab = [ + b'token1', b'token2', b'[UNK]', b'token3', b'token4', b'token5', + b'token6' + ] + # Should not raise with the fix + tokenizer = tf_text.FastWordpieceTokenizer( + vocab=vocab, unknown_token=b'[UNK]', no_pretokenization=True) + result = tokenizer.tokenize([b'[UNK]', b'token1']) + self.assertIsNotNone(result) + + def test_vocab_size_7_unk_at_last_position(self): + """Vocab size 7 with unknown_token at last position. + + This works even without the fix because hash map rehashing happens + before the unknown_token is inserted. + """ + vocab = [ + b'token1', b'token2', b'token3', b'token4', b'token5', b'token6', + b'[UNK]' + ] + # Should always work (control test) + tokenizer = tf_text.FastWordpieceTokenizer( + vocab=vocab, unknown_token=b'[UNK]', no_pretokenization=True) + result = tokenizer.tokenize([b'[UNK]', b'token1']) + self.assertIsNotNone(result) + + def test_vocab_size_8_unk_at_first_position(self): + """Larger vocabulary size to ensure fix works consistently.""" + vocab = [ + b'[UNK]', b'token1', b'token2', b'token3', b'token4', b'token5', + b'token6', b'token7' + ] + # Should not raise with the fix + tokenizer = tf_text.FastWordpieceTokenizer( + vocab=vocab, unknown_token=b'[UNK]', no_pretokenization=True) + result = tokenizer.tokenize([b'[UNK]', b'token1']) + self.assertIsNotNone(result) + + if __name__ == "__main__": test.main()