diff --git a/evals/wheat_from_chaff_test.cc b/evals/wheat_from_chaff_test.cc index 0e031e6c..ad84821a 100644 --- a/evals/wheat_from_chaff_test.cc +++ b/evals/wheat_from_chaff_test.cc @@ -46,11 +46,18 @@ static const char* kQuestions = "Which people first proposed the quark model of hadrons, and when?"; // All phrases in kAnswers must appear in the response in the order given for -// the test to pass. -static const char* kAnswers[] = { - "a ship's anchor", "a dark forest", "an hour", - "enormous sand", "castles", "limpet shells", - "Murray Gell-Mann", "George Zweig", "1964"}; +// the test to pass. Multiple acceptable answers can be provided for each +// expected phrase. +static const std::vector> kAnswers = { + {"rusty metal", "ship's anchor"}, + {"dark forest"}, + {"an hour"}, + {"enormous sand"}, + {"castles"}, + {"limpet shells"}, + {"Murray Gell-Mann"}, + {"George Zweig"}, + {"1964"}}; std::string LoadPromptFile(const std::string& filename) { // If the filename is empty, return an empty string. @@ -108,12 +115,22 @@ class GemmaTest : public ::testing::Test { void TestExpectations(const std::string& response) { fprintf(stderr, "Response: '%s'\n", response.c_str()); size_t pos = 0; - for (const char* answer : kAnswers) { - auto found = response.find(answer, pos); - EXPECT_NE(found, std::string::npos) - << "Response does not contain " << answer; - if (found != std::string::npos) { - pos = found + strlen(answer); + for (const auto& answer_group : kAnswers) { + size_t earliest_pos = std::string::npos; + const char* matched_answer = nullptr; + for (const char* answer : answer_group) { + auto found = response.find(answer, pos); + if (found != std::string::npos && + (earliest_pos == std::string::npos || found < earliest_pos)) { + earliest_pos = found; + matched_answer = answer; + } + } + EXPECT_NE(earliest_pos, std::string::npos) + << "Response does not contain acceptable answers, e.g., " + << answer_group[0]; + if (earliest_pos != std::string::npos) { + pos = earliest_pos + strlen(matched_answer); } } s_env->PrintProfileResults(); diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 5922d8cc..fbb39a66 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -489,7 +489,7 @@ HWY_INLINE float SingleFlashAttentionRowVector(DF df, size_t start_pos, } float m = hn::ReduceMax(df, x); m = std::max(m, old_max); - x = hn::Exp(df, hn::Sub(x, hn::Set(df, m))); + x = hn::FastExpMinusOrZero(df, hn::Sub(x, hn::Set(df, m))); float scale = old_d * std::exp(old_max - m); old_d = hn::ReduceSum(df, x) + scale; old_max = m; @@ -538,8 +538,8 @@ HWY_INLINE float DoubleFlashAttentionRowVector(DF df, size_t start_pos, float m = hn::ReduceMax(df, x_max); m = std::max(m, old_max); VF m_vec = hn::Set(df, m); - x0 = hn::Exp(df, hn::Sub(x0, m_vec)); - x1 = hn::Exp(df, hn::Sub(x1, m_vec)); + x0 = hn::FastExpMinusOrZero(df, hn::Sub(x0, m_vec)); + x1 = hn::FastExpMinusOrZero(df, hn::Sub(x1, m_vec)); float scale = old_d * std::exp(old_max - m); VF x_sum = hn::Add(x0, x1); old_d = hn::ReduceSum(df, x_sum) + scale; @@ -672,7 +672,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4( x_sum = Reduce4(df, x_0_sum, x_1_sum, x_2_sum, x_3_sum, [](auto a, auto b) HWY_ATTR { return hn::Add(a, b); }); } - VF4 scale = hn::Mul(old_d_vf, hn::Exp(df4, hn::Sub(old_max_vf, new_max))); + VF4 scale = hn::Mul( + old_d_vf, hn::FastExpMinusOrZero(df4, hn::Sub(old_max_vf, new_max))); old_d_vf = hn::Add(scale, x_sum); auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df4, 0.0f)); const VF zero = hn::Zero(df); @@ -810,7 +811,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8( x_6_sum, x_7_sum, [](auto a, auto b) HWY_ATTR { return hn::Add(a, b); }); } - VF8 scale = hn::Mul(old_d_vf, hn::Exp(df8, hn::Sub(old_max_vf, new_max))); + VF8 scale = hn::Mul( + old_d_vf, hn::FastExpMinusOrZero(df8, hn::Sub(old_max_vf, new_max))); old_d_vf = hn::Add(scale, x_sum); auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df8, 0.0f)); const VF zero = hn::Zero(df);