Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions evals/wheat_from_chaff_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,18 @@ static const char* kQuestions =
"Which people first proposed the quark model of hadrons, and when?";

// All phrases in kAnswers must appear in the response in the order given for
// the test to pass.
static const char* kAnswers[] = {
"a ship's anchor", "a dark forest", "an hour",
"enormous sand", "castles", "limpet shells",
"Murray Gell-Mann", "George Zweig", "1964"};
// the test to pass. Multiple acceptable answers can be provided for each
// expected phrase.
static const std::vector<std::vector<const char*>> kAnswers = {
{"rusty metal", "ship's anchor"},
{"dark forest"},
{"an hour"},
{"enormous sand"},
{"castles"},
{"limpet shells"},
{"Murray Gell-Mann"},
{"George Zweig"},
{"1964"}};

std::string LoadPromptFile(const std::string& filename) {
// If the filename is empty, return an empty string.
Expand Down Expand Up @@ -108,12 +115,22 @@ class GemmaTest : public ::testing::Test {
void TestExpectations(const std::string& response) {
fprintf(stderr, "Response: '%s'\n", response.c_str());
size_t pos = 0;
for (const char* answer : kAnswers) {
auto found = response.find(answer, pos);
EXPECT_NE(found, std::string::npos)
<< "Response does not contain " << answer;
if (found != std::string::npos) {
pos = found + strlen(answer);
for (const auto& answer_group : kAnswers) {
size_t earliest_pos = std::string::npos;
const char* matched_answer = nullptr;
for (const char* answer : answer_group) {
auto found = response.find(answer, pos);
if (found != std::string::npos &&
(earliest_pos == std::string::npos || found < earliest_pos)) {
earliest_pos = found;
matched_answer = answer;
}
}
EXPECT_NE(earliest_pos, std::string::npos)
<< "Response does not contain acceptable answers, e.g., "
<< answer_group[0];
if (earliest_pos != std::string::npos) {
pos = earliest_pos + strlen(matched_answer);
}
}
s_env->PrintProfileResults();
Expand Down
12 changes: 7 additions & 5 deletions gemma/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ HWY_INLINE float SingleFlashAttentionRowVector(DF df, size_t start_pos,
}
float m = hn::ReduceMax(df, x);
m = std::max(m, old_max);
x = hn::Exp(df, hn::Sub(x, hn::Set(df, m)));
x = hn::FastExpMinusOrZero(df, hn::Sub(x, hn::Set(df, m)));
float scale = old_d * std::exp(old_max - m);
old_d = hn::ReduceSum(df, x) + scale;
old_max = m;
Expand Down Expand Up @@ -538,8 +538,8 @@ HWY_INLINE float DoubleFlashAttentionRowVector(DF df, size_t start_pos,
float m = hn::ReduceMax(df, x_max);
m = std::max(m, old_max);
VF m_vec = hn::Set(df, m);
x0 = hn::Exp(df, hn::Sub(x0, m_vec));
x1 = hn::Exp(df, hn::Sub(x1, m_vec));
x0 = hn::FastExpMinusOrZero(df, hn::Sub(x0, m_vec));
x1 = hn::FastExpMinusOrZero(df, hn::Sub(x1, m_vec));
float scale = old_d * std::exp(old_max - m);
VF x_sum = hn::Add(x0, x1);
old_d = hn::ReduceSum(df, x_sum) + scale;
Expand Down Expand Up @@ -672,7 +672,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
x_sum = Reduce4(df, x_0_sum, x_1_sum, x_2_sum, x_3_sum,
[](auto a, auto b) HWY_ATTR { return hn::Add(a, b); });
}
VF4 scale = hn::Mul(old_d_vf, hn::Exp(df4, hn::Sub(old_max_vf, new_max)));
VF4 scale = hn::Mul(
old_d_vf, hn::FastExpMinusOrZero(df4, hn::Sub(old_max_vf, new_max)));
old_d_vf = hn::Add(scale, x_sum);
auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df4, 0.0f));
const VF zero = hn::Zero(df);
Expand Down Expand Up @@ -810,7 +811,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
x_6_sum, x_7_sum,
[](auto a, auto b) HWY_ATTR { return hn::Add(a, b); });
}
VF8 scale = hn::Mul(old_d_vf, hn::Exp(df8, hn::Sub(old_max_vf, new_max)));
VF8 scale = hn::Mul(
old_d_vf, hn::FastExpMinusOrZero(df8, hn::Sub(old_max_vf, new_max)));
old_d_vf = hn::Add(scale, x_sum);
auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df8, 0.0f));
const VF zero = hn::Zero(df);
Expand Down
Loading