diff --git a/Framework/AnalysisSupport/src/AODJAlienReaderHelpers.cxx b/Framework/AnalysisSupport/src/AODJAlienReaderHelpers.cxx index b532c51b8d307..cde6c85f2c624 100644 --- a/Framework/AnalysisSupport/src/AODJAlienReaderHelpers.cxx +++ b/Framework/AnalysisSupport/src/AODJAlienReaderHelpers.cxx @@ -98,29 +98,6 @@ using o2::monitoring::tags::Value; namespace o2::framework::readers { -auto setEOSCallback(InitContext& ic) -{ - ic.services().get().set( - [](EndOfStreamContext& eosc) { - auto& control = eosc.services().get(); - control.endOfStream(); - control.readyToQuit(QuitRequest::Me); - }); -} - -template -static inline auto extractTypedOriginal(ProcessingContext& pc) -{ - /// FIXME: this should be done in invokeProcess() as some of the originals may be compound tables - return O{pc.inputs().get(aod::MetadataTrait::metadata::tableLabel())->asArrowTable()}; -} - -template -static inline auto extractOriginalsTuple(framework::pack, ProcessingContext& pc) -{ - return std::make_tuple(extractTypedOriginal(pc)...); -} - AlgorithmSpec AODJAlienReaderHelpers::rootFileReaderCallback(ConfigContext const& ctx) { // aod-parent-base-path-replacement is now a workflow option, so it needs to be diff --git a/Framework/AnalysisSupport/src/AODReaderHelpers.cxx b/Framework/AnalysisSupport/src/AODReaderHelpers.cxx index 7f08dd0b36a64..485f3fa69edad 100644 --- a/Framework/AnalysisSupport/src/AODReaderHelpers.cxx +++ b/Framework/AnalysisSupport/src/AODReaderHelpers.cxx @@ -18,6 +18,7 @@ #include "Framework/DataProcessingHelpers.h" #include "Framework/AlgorithmSpec.h" #include "Framework/DataSpecUtils.h" +#include "Framework/DataSpecViews.h" #include "Framework/ConfigContext.h" #include "Framework/DanglingEdgesContext.h" @@ -29,6 +30,7 @@ struct Buildable { bool exclusive = false; std::string binding; std::vector labels; + std::vector matchers; header::DataOrigin origin; header::DataDescription description; header::DataHeader::SubSpecificationType version; @@ -52,6 +54,7 @@ struct Buildable { for (auto const& r : records) { labels.emplace_back(r.label); + matchers.emplace_back(r.matcher); } outputSchema = std::make_shared([](std::vector const& recs) { std::vector> fields; @@ -68,6 +71,7 @@ struct Buildable { return { exclusive, labels, + matchers, records, outputSchema, origin, @@ -105,6 +109,7 @@ namespace struct Spawnable { std::string binding; std::vector labels; + std::vector matchers; std::vector projectors; std::vector> expressions; std::shared_ptr outputSchema; @@ -132,14 +137,17 @@ struct Spawnable { o2::framework::addLabelToSchema(outputSchema, binding.c_str()); std::vector> schemas; - for (auto& i : spec.metadata) { - if (i.name.starts_with("input-schema:")) { - labels.emplace_back(i.name.substr(13)); - iws.clear(); - auto json = i.defaultValue.get(); - iws.str(json); - schemas.emplace_back(ArrowJSONHelpers::read(iws)); - } + for (auto const& i : spec.metadata | views::filter_string_params_starts_with("input-schema:")) { + labels.emplace_back(i.name.substr(13)); + iws.clear(); + auto json = i.defaultValue.get(); + iws.str(json); + schemas.emplace_back(ArrowJSONHelpers::read(iws)); + } + for (auto const& i : spec.metadata | views::filter_string_params_starts_with("input:") | std::ranges::views::transform([](auto const& param) { + return DataSpecUtils::fromMetadataString(param.defaultValue.template get()); + })) { + matchers.emplace_back(std::get(i.matcher)); } std::vector> fields; @@ -169,6 +177,7 @@ struct Spawnable { return { binding, labels, + matchers, expressions, makeProjector(), outputSchema, diff --git a/Framework/AnalysisSupport/src/AODWriterHelpers.cxx b/Framework/AnalysisSupport/src/AODWriterHelpers.cxx index 5a43683afd364..d868b7498fb76 100644 --- a/Framework/AnalysisSupport/src/AODWriterHelpers.cxx +++ b/Framework/AnalysisSupport/src/AODWriterHelpers.cxx @@ -185,13 +185,12 @@ AlgorithmSpec AODWriterHelpers::getOutputTTreeWriter(ConfigContext const& ctx) } // get the TableConsumer and corresponding arrow table - auto msg = pc.inputs().get(ref.spec->binding); - if (msg.header == nullptr) { + if (ref.header == nullptr) { LOGP(error, "No header for message {}:{}", ref.spec->binding, DataSpecUtils::describe(*ref.spec)); continue; } - auto s = pc.inputs().get(ref.spec->binding); - auto table = s->asArrowTable(); + + auto table = pc.inputs().get(std::get(ref.spec->matcher))->asArrowTable(); if (!table->Validate().ok()) { LOGP(warning, "The table \"{}\" is not valid and will not be saved!", tableName); continue; diff --git a/Framework/CCDBSupport/src/AnalysisCCDBHelpers.cxx b/Framework/CCDBSupport/src/AnalysisCCDBHelpers.cxx index fcc856669cd92..9ec911518f754 100644 --- a/Framework/CCDBSupport/src/AnalysisCCDBHelpers.cxx +++ b/Framework/CCDBSupport/src/AnalysisCCDBHelpers.cxx @@ -83,6 +83,7 @@ AlgorithmSpec AnalysisCCDBHelpers::fetchFromCCDB(ConfigContext const& ctx) if (m.name.starts_with("input:")) { auto name = m.name.substr(6); schemaMetadata->Append("sourceTable", name); + schemaMetadata->Append("sourceMatcher", DataSpecUtils::describe(std::get(DataSpecUtils::fromMetadataString(m.defaultValue.get()).matcher))); continue; } // Ignore the non ccdb: entries @@ -109,13 +110,13 @@ AlgorithmSpec AnalysisCCDBHelpers::fetchFromCCDB(ConfigContext const& ctx) for (auto& schema : schemas) { std::vector ops; auto inputBinding = *schema->metadata()->Get("sourceTable"); + auto inputMatcher = DataSpecUtils::fromString(*schema->metadata()->Get("sourceMatcher")); auto outRouteDesc = *schema->metadata()->Get("outputRoute"); std::string outBinding = *schema->metadata()->Get("outputBinding"); O2_SIGNPOST_EVENT_EMIT_INFO(ccdb, sid, "fetchFromAnalysisCCDB", "Fetching CCDB objects for %{public}s's columns with timestamps from %{public}s and putting them in route %{public}s", outBinding.c_str(), inputBinding.c_str(), outRouteDesc.c_str()); - auto ref = inputs.get(inputBinding); - auto table = ref->asArrowTable(); + auto table = inputs.get(inputMatcher)->asArrowTable(); // FIXME: make the fTimestamp column configurable. auto timestampColumn = table->GetColumnByName("fTimestamp"); O2_SIGNPOST_EVENT_EMIT_INFO(ccdb, sid, "fetchFromAnalysisCCDB", diff --git a/Framework/Core/include/Framework/ASoA.h b/Framework/Core/include/Framework/ASoA.h index 43079a4634e97..ec02c7e47132b 100644 --- a/Framework/Core/include/Framework/ASoA.h +++ b/Framework/Core/include/Framework/ASoA.h @@ -12,6 +12,7 @@ #ifndef O2_FRAMEWORK_ASOA_H_ #define O2_FRAMEWORK_ASOA_H_ +#include "Framework/ConcreteDataMatcher.h" #include "Framework/Pack.h" // IWYU pragma: export #include "Framework/FunctionalHelpers.h" // IWYU pragma: export #include "Headers/DataHeader.h" // IWYU pragma: export @@ -375,6 +376,12 @@ consteval const char* signature() return o2::aod::Hash::str; } +template +constexpr framework::ConcreteDataMatcher matcher() +{ + return {origin(), description(signature()), R.version}; +} + /// hash identification concepts template concept is_aod_hash = requires(T t) { t.hash; t.str; }; @@ -1393,6 +1400,12 @@ static constexpr std::pair hasKey(std::string const& key) return {hasColumnForKey(typename aod::MetadataTrait>::metadata::columns{}, key), aod::label()}; } +template +static constexpr std::pair hasKeyM(std::string const& key) +{ + return {hasColumnForKey(typename aod::MetadataTrait>::metadata::columns{}, key), aod::matcher()}; +} + template static constexpr auto haveKey(framework::pack, std::string const& key) { @@ -1427,6 +1440,31 @@ static constexpr std::string getLabelFromTypeForKey(std::string const& key) O2_BUILTIN_UNREACHABLE(); } +template +static constexpr framework::ConcreteDataMatcher getMatcherFromTypeForKey(std::string const& key) +{ + if constexpr (T::originals.size() == 1) { + auto locate = hasKeyM(key); + if (locate.first) { + return locate.second; + } + } else { + auto locate = [&](std::index_sequence) { + return std::vector{hasKeyM(key)...}; + }(std::make_index_sequence{}); + auto it = std::find_if(locate.begin(), locate.end(), [](auto const& x) { return x.first; }); + if (it != locate.end()) { + return it->second; + } + } + if constexpr (!OPT) { + notFoundColumn(getLabelFromType>().data(), key.data()); + } else { + return framework::ConcreteDataMatcher{header::DataOrigin{"AOD"}, header::DataDescription{"[MISSING]"}, 0}; + } + O2_BUILTIN_UNREACHABLE(); +} + template consteval static bool hasIndexTo(framework::pack&&) { @@ -1477,7 +1515,10 @@ struct PreslicePolicyGeneral : public PreslicePolicyBase { std::span getSliceFor(int value) const; }; -template +template +concept is_preslice_policy = std::derived_from; + +template struct PresliceBase : public Policy { constexpr static bool optional = OPT; using target_t = T; @@ -1485,7 +1526,7 @@ struct PresliceBase : public Policy { const std::string binding; PresliceBase(expressions::BindingNode index_) - : Policy{PreslicePolicyBase{{o2::soa::getLabelFromTypeForKey(std::string{index_.name})}, Entry(o2::soa::getLabelFromTypeForKey(std::string{index_.name}), std::string{index_.name})}, {}} + : Policy{PreslicePolicyBase{{o2::soa::getLabelFromTypeForKey(std::string{index_.name})}, Entry(o2::soa::getLabelFromTypeForKey(std::string{index_.name}), o2::soa::getMatcherFromTypeForKey(std::string{index_.name}), std::string{index_.name})}, {}} { } @@ -1520,7 +1561,11 @@ template using PresliceOptional = PresliceBase; template -concept is_preslice = std::derived_from; +concept is_preslice = std::derived_from&& + requires(T) +{ + T::optional; +}; /// Can be user to group together a number of Preslice declaration /// to avoid the limit of 100 data members per task @@ -1667,10 +1712,10 @@ auto doFilteredSliceBy(T const* table, o2::framework::PresliceBase +template auto doSliceByCached(T const* table, framework::expressions::BindingNode const& node, int value, o2::framework::SliceCache& cache) { - auto localCache = cache.ptr->getCacheFor({o2::soa::getLabelFromTypeForKey(node.name), node.name}); + auto localCache = cache.ptr->getCacheFor({"", o2::soa::getMatcherFromTypeForKey(node.name), node.name}); auto [offset, count] = localCache.getSliceFor(value); auto t = typename T::self_t({table->asArrowTable()->Slice(static_cast(offset), count)}, static_cast(offset)); if (t.tableSize() != 0) { @@ -1679,19 +1724,19 @@ auto doSliceByCached(T const* table, framework::expressions::BindingNode const& return t; } -template +template auto doFilteredSliceByCached(T const* table, framework::expressions::BindingNode const& node, int value, o2::framework::SliceCache& cache) { - auto localCache = cache.ptr->getCacheFor({o2::soa::getLabelFromTypeForKey(node.name), node.name}); + auto localCache = cache.ptr->getCacheFor({"", o2::soa::getMatcherFromTypeForKey(node.name), node.name}); auto [offset, count] = localCache.getSliceFor(value); auto slice = table->asArrowTable()->Slice(static_cast(offset), count); return prepareFilteredSlice(table, slice, offset); } -template +template auto doSliceByCachedUnsorted(T const* table, framework::expressions::BindingNode const& node, int value, o2::framework::SliceCache& cache) { - auto localCache = cache.ptr->getCacheUnsortedFor({o2::soa::getLabelFromTypeForKey(node.name), node.name}); + auto localCache = cache.ptr->getCacheUnsortedFor({"", o2::soa::getMatcherFromTypeForKey(node.name), node.name}); if constexpr (soa::is_filtered_table) { auto t = typename T::self_t({table->asArrowTable()}, localCache.getSliceFor(value)); if (t.tableSize() != 0) { diff --git a/Framework/Core/include/Framework/AnalysisHelpers.h b/Framework/Core/include/Framework/AnalysisHelpers.h index 3666fe1299489..a01d14b6632a9 100644 --- a/Framework/Core/include/Framework/AnalysisHelpers.h +++ b/Framework/Core/include/Framework/AnalysisHelpers.h @@ -30,6 +30,7 @@ namespace o2::soa { struct IndexRecord { std::string label; + framework::ConcreteDataMatcher matcher; std::string columnLabel; IndexKind kind; int pos; @@ -142,6 +143,7 @@ std::vector> extractSources(ProcessingContext& pc, struct Spawner { std::string binding; std::vector labels; + std::vector matchers; std::vector> expressions; std::shared_ptr projector = nullptr; std::shared_ptr schema = nullptr; @@ -157,6 +159,7 @@ struct Spawner { struct Builder { bool exclusive; std::vector labels; + std::vector matchers; std::vector records; std::shared_ptr outputSchema; header::DataOrigin origin; @@ -258,9 +261,9 @@ inline constexpr auto getIndexMapping() ([&idx]() mutable { constexpr auto pos = o2::aod::MetadataTrait>::metadata::template getIndexPosToKey(); if constexpr (pos == -1) { - idx.emplace_back(o2::aod::label(), C::columnLabel(), IndexKind::IdxSelf, pos); + idx.emplace_back(o2::aod::label(), o2::aod::matcher(), C::columnLabel(), IndexKind::IdxSelf, pos); } else { - idx.emplace_back(o2::aod::label(), C::columnLabel(), getIndexKind(), pos); + idx.emplace_back(o2::aod::label(), o2::aod::matcher(), C::columnLabel(), getIndexKind(), pos); } }.template operator()>(), ...); diff --git a/Framework/Core/include/Framework/AnalysisManagers.h b/Framework/Core/include/Framework/AnalysisManagers.h index fbb499940b9b9..5112e3659f4aa 100644 --- a/Framework/Core/include/Framework/AnalysisManagers.h +++ b/Framework/Core/include/Framework/AnalysisManagers.h @@ -38,7 +38,7 @@ template refs> static inline auto extractOriginals(ProcessingContext& pc) { return [&](std::index_sequence) -> std::vector> { - return {pc.inputs().get(o2::aod::label())->asArrowTable()...}; + return {pc.inputs().get(o2::aod::matcher())->asArrowTable()...}; }(std::make_index_sequence()); } } // namespace @@ -151,7 +151,7 @@ template concept with_base_table = requires { T::base_specs(); }; template -bool requestInputs(std::vector& inputs, T const& entity) +bool requestInputs(std::vector& inputs, T const& /*entity*/) { auto base_specs = T::base_specs(); for (auto base_spec : base_specs) { @@ -586,7 +586,7 @@ bool registerCache(T& preslice, Cache& bsks, Cache&) return true; } } - auto locate = std::find_if(bsks.begin(), bsks.end(), [&](auto const& entry) { return (entry.binding == preslice.bindingKey.binding) && (entry.key == preslice.bindingKey.key); }); + auto locate = std::find(bsks.begin(), bsks.end(), preslice.getBindingKey()); if (locate == bsks.end()) { bsks.emplace_back(preslice.getBindingKey()); } else if (locate->enabled == false) { @@ -604,7 +604,7 @@ bool registerCache(T& preslice, Cache&, Cache& bsksU) return true; } } - auto locate = std::find_if(bsksU.begin(), bsksU.end(), [&](auto const& entry) { return (entry.binding == preslice.bindingKey.binding) && (entry.key == preslice.bindingKey.key); }); + auto locate = std::find(bsksU.begin(), bsksU.end(), preslice.getBindingKey()); if (locate == bsksU.end()) { bsksU.emplace_back(preslice.getBindingKey()); } else if (locate->enabled == false) { diff --git a/Framework/Core/include/Framework/AnalysisTask.h b/Framework/Core/include/Framework/AnalysisTask.h index 53f6bc0f862d6..c50b5358990de 100644 --- a/Framework/Core/include/Framework/AnalysisTask.h +++ b/Framework/Core/include/Framework/AnalysisTask.h @@ -75,11 +75,11 @@ struct AnalysisDataProcessorBuilder { auto key = std::string{"fIndex"} + o2::framework::cutString(soa::getLabelFromType>()); ([&bk, &bku, &key, enabled]() mutable { if constexpr (soa::relatedByIndex, std::decay_t>()) { - auto binding = soa::getLabelFromTypeForKey>(key); + Entry e{soa::getLabelFromTypeForKey>(key), soa::getMatcherFromTypeForKey>(key), key, enabled}; if constexpr (o2::soa::is_smallgroups>) { - framework::updatePairList(bku, binding, key, enabled); + framework::updatePairList(bku, e); } else { - framework::updatePairList(bk, binding, key, enabled); + framework::updatePairList(bk, e); } } }(), @@ -214,7 +214,7 @@ struct AnalysisDataProcessorBuilder { template static auto extractTableFromRecord(InputRecord& record) { - auto table = record.get(o2::aod::label())->asArrowTable(); + auto table = record.get(o2::aod::matcher())->asArrowTable(); if (table->num_rows() == 0) { table = makeEmptyTable(); } diff --git a/Framework/Core/include/Framework/ArrowTableSlicingCache.h b/Framework/Core/include/Framework/ArrowTableSlicingCache.h index a6117ec3e01bc..073eadc22d72c 100644 --- a/Framework/Core/include/Framework/ArrowTableSlicingCache.h +++ b/Framework/Core/include/Framework/ArrowTableSlicingCache.h @@ -12,6 +12,7 @@ #ifndef ARROWTABLESLICINGCACHE_H #define ARROWTABLESLICINGCACHE_H +#include "Framework/ConcreteDataMatcher.h" #include "Framework/ServiceHandle.h" #include #include @@ -36,20 +37,28 @@ struct SliceInfoUnsortedPtr { struct Entry { std::string binding; + ConcreteDataMatcher matcher; std::string key; bool enabled; - Entry(std::string b, std::string k, bool e = true) + Entry(std::string b, ConcreteDataMatcher m, std::string k, bool e = true) : binding{b}, + matcher{m}, key{k}, enabled{e} { } + + friend bool operator==(Entry const& lhs, Entry const& rhs) + { + return (lhs.matcher == rhs.matcher) && + (lhs.key == rhs.key); + } }; using Cache = std::vector; -void updatePairList(Cache& list, std::string const& binding, std::string const& key, bool enabled); +void updatePairList(Cache& list, Entry& entry); struct ArrowTableSlicingCacheDef { constexpr static ServiceKind service_kind = ServiceKind::Global; diff --git a/Framework/Core/include/Framework/ConcreteDataMatcher.h b/Framework/Core/include/Framework/ConcreteDataMatcher.h index 247e3cd6ed8b9..bfbd2a05a8709 100644 --- a/Framework/Core/include/Framework/ConcreteDataMatcher.h +++ b/Framework/Core/include/Framework/ConcreteDataMatcher.h @@ -56,9 +56,9 @@ struct ConcreteDataMatcher { header::DataDescription description; header::DataHeader::SubSpecificationType subSpec; - ConcreteDataMatcher(header::DataOrigin origin_, - header::DataDescription description_, - header::DataHeader::SubSpecificationType subSpec_) + constexpr ConcreteDataMatcher(header::DataOrigin origin_, + header::DataDescription description_, + header::DataHeader::SubSpecificationType subSpec_) : origin(origin_), description(description_), subSpec(subSpec_) diff --git a/Framework/Core/include/Framework/DataSpecUtils.h b/Framework/Core/include/Framework/DataSpecUtils.h index 588aa30da7e08..fe322334a8edb 100644 --- a/Framework/Core/include/Framework/DataSpecUtils.h +++ b/Framework/Core/include/Framework/DataSpecUtils.h @@ -127,6 +127,9 @@ struct DataSpecUtils { /// unique way a description should be done, so we keep this outside. static std::string describe(OutputSpec const& spec); + /// Describes a ConcreteDataMatcher + static std::string describe(ConcreteDataMatcher const& matcher); + /// Provide a unique label for the input spec. Again this is outside because there /// is no standard way of doing it, so better not to pollute the API. static std::string label(InputSpec const& spec); @@ -211,6 +214,9 @@ struct DataSpecUtils { /// Create an InputSpec from metadata string static InputSpec fromMetadataString(std::string s); + /// Create a concrete data matcher from serialized string + static ConcreteDataMatcher fromString(std::string s); + /// Get the origin, if available static std::optional getOptionalOrigin(InputSpec const& spec); diff --git a/Framework/Core/include/Framework/DataSpecViews.h b/Framework/Core/include/Framework/DataSpecViews.h index 162a12419594e..b38866d8aa6fd 100644 --- a/Framework/Core/include/Framework/DataSpecViews.h +++ b/Framework/Core/include/Framework/DataSpecViews.h @@ -43,6 +43,13 @@ static auto filter_string_params_with(std::string match) }); } +static auto filter_string_params_starts_with(std::string match) +{ + return std::views::filter([match](auto const& param) { + return (param.type == VariantType::String) && (param.name.starts_with(match)); + }); +} + static auto input_to_output_specs() { return std::views::transform([](auto const& input) { diff --git a/Framework/Core/include/Framework/GroupSlicer.h b/Framework/Core/include/Framework/GroupSlicer.h index 4cfbb8c440fd3..596e68d8cdd4c 100644 --- a/Framework/Core/include/Framework/GroupSlicer.h +++ b/Framework/Core/include/Framework/GroupSlicer.h @@ -55,7 +55,7 @@ struct GroupSlicer { { constexpr auto index = framework::has_type_at_v>(associated_pack_t{}); auto binding = o2::soa::getLabelFromTypeForKey>(mIndexColumnName); - auto bk = Entry(binding, mIndexColumnName); + auto bk = Entry(binding, o2::soa::getMatcherFromTypeForKey>(mIndexColumnName), mIndexColumnName); if constexpr (!o2::soa::is_smallgroups>) { if (table.size() == 0) { return; diff --git a/Framework/Core/include/Framework/InputRecord.h b/Framework/Core/include/Framework/InputRecord.h index 0c9f36d00c634..96963f88524be 100644 --- a/Framework/Core/include/Framework/InputRecord.h +++ b/Framework/Core/include/Framework/InputRecord.h @@ -189,6 +189,7 @@ class InputRecord }; int getPos(const char* name) const; + int getPos(ConcreteDataMatcher matcher) const; [[nodiscard]] static InputPos getPos(std::vector const& routes, ConcreteDataMatcher matcher); [[nodiscard]] static DataRef getByPos(std::vector const& routes, InputSpan const& span, int pos, int part = 0); @@ -511,6 +512,27 @@ class InputRecord return cache.idToMetadata[id]; } + template + requires(std::same_as) + decltype(auto) get(ConcreteDataMatcher matcher, int part = 0) + { + auto pos = getPos(matcher); + if (pos < 0) { + auto msg = describeAvailableInputs(); + throw runtime_error_f("InputRecord::get: no input with binding %s found. %s", DataSpecUtils::describe(matcher).c_str(), msg.c_str()); + } + return getByPos(pos, part); + } + + template + requires(std::same_as) + decltype(auto) get(ConcreteDataMatcher matcher, int part = 0) + { + auto ref = get(matcher, part); + auto data = reinterpret_cast(ref.payload); + return std::make_unique(data, DataRefUtils::getPayloadSize(ref)); + } + /// Helper method to be used to check if a given part of the InputRecord is present. [[nodiscard]] bool isValid(std::string const& s) const { diff --git a/Framework/Core/src/AnalysisHelpers.cxx b/Framework/Core/src/AnalysisHelpers.cxx index b8e0348d5df9c..f2ecb2d68ce28 100644 --- a/Framework/Core/src/AnalysisHelpers.cxx +++ b/Framework/Core/src/AnalysisHelpers.cxx @@ -185,18 +185,18 @@ std::string serializeIndexRecords(std::vector& irs) return osm.str(); } -std::vector> extractSources(ProcessingContext& pc, std::vector const& labels) +std::vector> extractSources(ProcessingContext& pc, std::vector const& matchers) { std::vector> tables; - for (auto const& label : labels) { - tables.emplace_back(pc.inputs().get(label.c_str())->asArrowTable()); + for (auto const& matcher : matchers) { + tables.emplace_back(pc.inputs().get(matcher)->asArrowTable()); } return tables; } std::shared_ptr Spawner::materialize(ProcessingContext& pc) const { - auto tables = extractSources(pc, labels); + auto tables = extractSources(pc, matchers); auto fullTable = soa::ArrowHelpers::joinTables(std::move(tables), std::span{labels.begin(), labels.size()}); if (fullTable->num_rows() == 0) { return arrow::Table::MakeEmpty(schema).ValueOrDie(); @@ -212,7 +212,7 @@ std::shared_ptr Builder::materialize(ProcessingContext& pc) builders->reserve(records.size()); } std::shared_ptr result; - auto tables = extractSources(pc, labels); + auto tables = extractSources(pc, matchers); result = o2::soa::IndexBuilder::materialize(*builders.get(), std::move(tables), records, outputSchema, exclusive); return result; } diff --git a/Framework/Core/src/ArrowSupport.cxx b/Framework/Core/src/ArrowSupport.cxx index c0280b144e146..0da06c7bc8d7b 100644 --- a/Framework/Core/src/ArrowSupport.cxx +++ b/Framework/Core/src/ArrowSupport.cxx @@ -751,7 +751,7 @@ o2::framework::ServiceSpec ArrowSupport::arrowTableSlicingCacheSpec() auto& caches = service->bindingsKeys; for (auto i = 0u; i < caches.size(); ++i) { if (caches[i].enabled && pc.inputs().getPos(caches[i].binding.c_str()) >= 0) { - auto status = service->updateCacheEntry(i, pc.inputs().get(caches[i].binding.c_str())->asArrowTable()); + auto status = service->updateCacheEntry(i, pc.inputs().get(caches[i].matcher)->asArrowTable()); if (!status.ok()) { throw runtime_error_f("Failed to update slice cache for %s/%s", caches[i].binding.c_str(), caches[i].key.c_str()); } @@ -760,7 +760,7 @@ o2::framework::ServiceSpec ArrowSupport::arrowTableSlicingCacheSpec() auto& unsortedCaches = service->bindingsKeysUnsorted; for (auto i = 0u; i < unsortedCaches.size(); ++i) { if (unsortedCaches[i].enabled && pc.inputs().getPos(unsortedCaches[i].binding.c_str()) >= 0) { - auto status = service->updateCacheEntryUnsorted(i, pc.inputs().get(unsortedCaches[i].binding.c_str())->asArrowTable()); + auto status = service->updateCacheEntryUnsorted(i, pc.inputs().get(unsortedCaches[i].matcher)->asArrowTable()); if (!status.ok()) { throw runtime_error_f("failed to update slice cache (unsorted) for %s/%s", unsortedCaches[i].binding.c_str(), unsortedCaches[i].key.c_str()); } diff --git a/Framework/Core/src/ArrowTableSlicingCache.cxx b/Framework/Core/src/ArrowTableSlicingCache.cxx index 75b4bbfac701d..634c51f71f5a6 100644 --- a/Framework/Core/src/ArrowTableSlicingCache.cxx +++ b/Framework/Core/src/ArrowTableSlicingCache.cxx @@ -37,12 +37,12 @@ std::shared_ptr GetColumnByNameCI(std::shared_ptrenabled && enabled) { + list.emplace_back(entry); + } else if (!locate->enabled && entry.enabled) { locate->enabled = true; } } @@ -110,7 +110,7 @@ arrow::Status ArrowTableSlicingCache::updateCacheEntry(int pos, std::shared_ptr< if (table->num_rows() == 0) { return arrow::Status::OK(); } - auto& [b, k, e] = bindingsKeys[pos]; + auto& [b, m, k, e] = bindingsKeys[pos]; if (!e) { throw runtime_error_f("Disabled cache %s/%s update requested", b.c_str(), k.c_str()); } @@ -169,7 +169,7 @@ arrow::Status ArrowTableSlicingCache::updateCacheEntryUnsorted(int pos, const st if (table->num_rows() == 0) { return arrow::Status::OK(); } - auto& [b, k, e] = bindingsKeysUnsorted[pos]; + auto& [b, m, k, e] = bindingsKeysUnsorted[pos]; if (!e) { throw runtime_error_f("Disabled unsorted cache %s/%s update requested", b.c_str(), k.c_str()); } @@ -210,7 +210,7 @@ std::pair ArrowTableSlicingCache::getCachePos(const Entry& bindingKey int ArrowTableSlicingCache::getCachePosSortedFor(Entry const& bindingKey) const { - auto locate = std::find_if(bindingsKeys.begin(), bindingsKeys.end(), [&](Entry const& bk) { return (bindingKey.binding == bk.binding) && (bindingKey.key == bk.key); }); + auto locate = std::find(bindingsKeys.begin(), bindingsKeys.end(), bindingKey); if (locate != bindingsKeys.end()) { return std::distance(bindingsKeys.begin(), locate); } @@ -219,7 +219,7 @@ int ArrowTableSlicingCache::getCachePosSortedFor(Entry const& bindingKey) const int ArrowTableSlicingCache::getCachePosUnsortedFor(Entry const& bindingKey) const { - auto locate_unsorted = std::find_if(bindingsKeysUnsorted.begin(), bindingsKeysUnsorted.end(), [&](Entry const& bk) { return (bindingKey.binding == bk.binding) && (bindingKey.key == bk.key); }); + auto locate_unsorted = std::find(bindingsKeysUnsorted.begin(), bindingsKeysUnsorted.end(), bindingKey); if (locate_unsorted != bindingsKeysUnsorted.end()) { return std::distance(bindingsKeysUnsorted.begin(), locate_unsorted); } @@ -269,7 +269,10 @@ SliceInfoUnsortedPtr ArrowTableSlicingCache::getCacheUnsortedForPos(int pos) con void ArrowTableSlicingCache::validateOrder(Entry const& bindingKey, const std::shared_ptr& input) { - auto const& [target, key, enabled] = bindingKey; + auto const& [target, matcher, key, enabled] = bindingKey; + if (!enabled) { + return; + } auto column = o2::framework::GetColumnByNameCI(input, key); auto array0 = static_cast>(column->chunk(0)->data()); int32_t prev = 0; diff --git a/Framework/Core/src/DataSpecUtils.cxx b/Framework/Core/src/DataSpecUtils.cxx index 48f5e6abcad5b..bc1fcd180ed76 100644 --- a/Framework/Core/src/DataSpecUtils.cxx +++ b/Framework/Core/src/DataSpecUtils.cxx @@ -89,6 +89,11 @@ std::string DataSpecUtils::describe(OutputSpec const& spec) spec.matcher); } +std::string DataSpecUtils::describe(ConcreteDataMatcher const& matcher) +{ + return join(matcher, "/"); +} + template size_t DataSpecUtils::describe(char* buffer, size_t size, T const& spec) { @@ -664,16 +669,39 @@ InputSpec DataSpecUtils::fromMetadataString(std::string s) if (std::distance(words, std::sregex_iterator()) != 4) { throw runtime_error_f("Malformed input spec metadata: %s", s.c_str()); } - std::vector data; + std::array data; + auto pos = 0; for (auto i = words; i != std::sregex_iterator(); ++i) { - data.emplace_back(i->str()); + data[pos] = i->str(); + ++pos; } char origin[4]; char description[16]; std::memcpy(&origin, data[1].c_str(), 4); std::memcpy(&description, data[2].c_str(), 16); auto version = static_cast(std::atoi(data[3].c_str())); - return InputSpec{data[0], header::DataOrigin{origin}, header::DataDescription{description}, version, Lifetime::Timeframe}; + return {data[0], header::DataOrigin{origin}, header::DataDescription{description}, version, Lifetime::Timeframe}; +} + +ConcreteDataMatcher DataSpecUtils::fromString(std::string s) +{ + std::regex word_regex("(\\w+)"); + auto words = std::sregex_iterator(s.begin(), s.end(), word_regex); + if (std::distance(words, std::sregex_iterator()) != 3) { + throw runtime_error_f("Malformed serialized matcher: %s", s.c_str()); + } + std::array data; + auto pos = 0; + for (auto i = words; i != std::sregex_iterator(); ++i) { + data[pos] = i->str(); + ++pos; + } + char origin[4]; + char description[16]; + std::memcpy(&origin, data[0].c_str(), 4); + std::memcpy(&description, data[1].c_str(), 16); + auto version = static_cast(std::atoi(data[2].c_str())); + return {header::DataOrigin{origin}, header::DataDescription{description}, version}; } std::optional DataSpecUtils::getOptionalOrigin(InputSpec const& spec) diff --git a/Framework/Core/src/IndexJSONHelpers.cxx b/Framework/Core/src/IndexJSONHelpers.cxx index 19ae94a4bcd4c..a5c6c70579599 100644 --- a/Framework/Core/src/IndexJSONHelpers.cxx +++ b/Framework/Core/src/IndexJSONHelpers.cxx @@ -41,6 +41,7 @@ struct IndexRecordsReader : public rapidjson::BaseReaderHandler& w, std::vector const& schema, ConcreteDataMatcher concrete) { size_t inputIndex = 0; diff --git a/Framework/Core/test/benchmark_EventMixing.cxx b/Framework/Core/test/benchmark_EventMixing.cxx index 99a7d0d4b1cb9..0e7e6839ee35e 100644 --- a/Framework/Core/test/benchmark_EventMixing.cxx +++ b/Framework/Core/test/benchmark_EventMixing.cxx @@ -78,7 +78,8 @@ static void BM_EventMixingTraditional(benchmark::State& state) auto tableTrack = trackBuilder.finalize(); o2::aod::StoredTracks tracks{tableTrack}; - ArrowTableSlicingCache atscache({{getLabelFromType(), "fIndex" + cutString(getLabelFromType())}}); + std::string key = "fIndex" + cutString(getLabelFromType()); + ArrowTableSlicingCache atscache({{getLabelFromType(), getMatcherFromTypeForKey(key), key}}); auto s = atscache.updateCacheEntry(0, tableTrack); SliceCache cache{&atscache}; @@ -171,7 +172,8 @@ static void BM_EventMixingCombinations(benchmark::State& state) int64_t count = 0; int64_t colCount = 0; - ArrowTableSlicingCache atscache{{{getLabelFromType(), "fIndex" + getLabelFromType()}}}; + std::string key = "fIndex" + getLabelFromType(); + ArrowTableSlicingCache atscache{{{getLabelFromType(), getMatcherFromTypeForKey(key), key}}}; auto s = atscache.updateCacheEntry(0, tableTrack); SliceCache cache{&atscache}; diff --git a/Framework/Core/test/test_ASoA.cxx b/Framework/Core/test/test_ASoA.cxx index 80519aebc9ee7..117dddff4c548 100644 --- a/Framework/Core/test/test_ASoA.cxx +++ b/Framework/Core/test/test_ASoA.cxx @@ -1187,7 +1187,8 @@ TEST_CASE("TestSliceByCached") auto refs = w.finalize(); o2::aod::References r{refs}; - ArrowTableSlicingCache atscache({{o2::soa::getLabelFromType(), "fIndex" + o2::framework::cutString(o2::soa::getLabelFromType())}}); + std::string key = "fIndex" + o2::framework::cutString(o2::soa::getLabelFromType()); + ArrowTableSlicingCache atscache({{o2::soa::getLabelFromType(), o2::soa::getMatcherFromTypeForKey(key), key}}); auto s = atscache.updateCacheEntry(0, refs); SliceCache cache{&atscache}; @@ -1238,7 +1239,7 @@ TEST_CASE("TestSliceByCachedMismatched") J rr{{refs, refs2}}; auto key = "fIndex" + o2::framework::cutString(o2::soa::getLabelFromType()) + "_alt"; - ArrowTableSlicingCache atscache({{o2::soa::getLabelFromTypeForKey(key), key}}); + ArrowTableSlicingCache atscache({{o2::soa::getLabelFromTypeForKey(key), o2::soa::getMatcherFromTypeForKey(key), key}}); auto s = atscache.updateCacheEntry(0, refs2); SliceCache cache{&atscache}; diff --git a/Framework/Core/test/test_GroupSlicer.cxx b/Framework/Core/test/test_GroupSlicer.cxx index 2f21d7dd17975..71360f736c3fb 100644 --- a/Framework/Core/test/test_GroupSlicer.cxx +++ b/Framework/Core/test/test_GroupSlicer.cxx @@ -117,7 +117,8 @@ TEST_CASE("GroupSlicerOneAssociated") REQUIRE(t.size() == 10 * 20); auto tt = std::make_tuple(t); - ArrowTableSlicingCache slices({{soa::getLabelFromType(), "fIndex" + o2::framework::cutString(soa::getLabelFromType())}}); + std::string key = "fIndex" + o2::framework::cutString(soa::getLabelFromType()); + ArrowTableSlicingCache slices({{soa::getLabelFromType(), soa::getMatcherFromTypeForKey(key), key}}); auto s = slices.updateCacheEntry(0, trkTable); o2::framework::GroupSlicer g(e, tt, slices); @@ -191,9 +192,9 @@ TEST_CASE("GroupSlicerSeveralAssociated") auto tt = std::make_tuple(tx, ty, tz, tu); auto key = "fIndex" + o2::framework::cutString(soa::getLabelFromType()); - ArrowTableSlicingCache slices({{soa::getLabelFromType(), key}, - {soa::getLabelFromType(), key}, - {soa::getLabelFromType(), key}}); + ArrowTableSlicingCache slices({{soa::getLabelFromType(), soa::getMatcherFromTypeForKey(key), key}, + {soa::getLabelFromType(), soa::getMatcherFromTypeForKey(key), key}, + {soa::getLabelFromType(), soa::getMatcherFromTypeForKey(key), key}}); auto s = slices.updateCacheEntry(0, {trkTableX}); s = slices.updateCacheEntry(1, {trkTableY}); s = slices.updateCacheEntry(2, {trkTableZ}); @@ -256,7 +257,8 @@ TEST_CASE("GroupSlicerMismatchedGroups") REQUIRE(t.size() == 10 * (20 - 5)); auto tt = std::make_tuple(t); - ArrowTableSlicingCache slices({{soa::getLabelFromType(), "fIndex" + o2::framework::cutString(soa::getLabelFromType())}}); + std::string key = "fIndex" + o2::framework::cutString(soa::getLabelFromType()); + ArrowTableSlicingCache slices({{soa::getLabelFromType(), soa::getMatcherFromTypeForKey(key), key}}); auto s = slices.updateCacheEntry(0, trkTable); o2::framework::GroupSlicer g(e, tt, slices); @@ -312,7 +314,8 @@ TEST_CASE("GroupSlicerMismatchedUnassignedGroups") REQUIRE(t.size() == (30 + 10 * (20 - 5))); auto tt = std::make_tuple(t); - ArrowTableSlicingCache slices({{soa::getLabelFromType(), "fIndex" + o2::framework::cutString(soa::getLabelFromType())}}); + std::string key = "fIndex" + o2::framework::cutString(soa::getLabelFromType()); + ArrowTableSlicingCache slices({{soa::getLabelFromType(), soa::getMatcherFromTypeForKey(key), key}}); auto s = slices.updateCacheEntry(0, trkTable); o2::framework::GroupSlicer g(e, tt, slices); @@ -362,7 +365,8 @@ TEST_CASE("GroupSlicerMismatchedFilteredGroups") REQUIRE(t.size() == 10 * (20 - 4)); auto tt = std::make_tuple(t); - ArrowTableSlicingCache slices({{soa::getLabelFromType(), "fIndex" + o2::framework::cutString(soa::getLabelFromType())}}); + std::string key = "fIndex" + o2::framework::cutString(soa::getLabelFromType()); + ArrowTableSlicingCache slices({{soa::getLabelFromType(), soa::getMatcherFromTypeForKey(key), key}}); auto s = slices.updateCacheEntry(0, trkTable); o2::framework::GroupSlicer g(e, tt, slices); @@ -423,7 +427,8 @@ TEST_CASE("GroupSlicerMismatchedUnsortedFilteredGroups") REQUIRE(t.size() == 10 * (20 - 4)); auto tt = std::make_tuple(t); - ArrowTableSlicingCache slices({}, {{soa::getLabelFromType(), "fIndex" + o2::framework::cutString(soa::getLabelFromType())}}); + std::string key = "fIndex" + o2::framework::cutString(soa::getLabelFromType()); + ArrowTableSlicingCache slices({}, {{soa::getLabelFromType(), soa::getMatcherFromTypeForKey(key), key}}); auto s = slices.updateCacheEntryUnsorted(0, trkTable); o2::framework::GroupSlicer g(e, tt, slices); @@ -547,8 +552,9 @@ TEST_CASE("GroupSlicerMismatchedUnsortedFilteredGroupsWithSelfIndex") } FilteredParts fp{{partsTable}, rows}; auto associatedTuple = std::make_tuple(fp, t); - ArrowTableSlicingCache slices({{soa::getLabelFromType(), "fIndex" + o2::framework::cutString(soa::getLabelFromType())}, - {soa::getLabelFromType(), "fIndex" + o2::framework::cutString(soa::getLabelFromType())}}); + std::string key = "fIndex" + o2::framework::cutString(soa::getLabelFromType()); + ArrowTableSlicingCache slices({{soa::getLabelFromType(), soa::getMatcherFromTypeForKey(key), key}, + {soa::getLabelFromType(), soa::getMatcherFromTypeForKey(key), key}}); auto s0 = slices.updateCacheEntry(0, partsTable); auto s1 = slices.updateCacheEntry(1, thingsTable); o2::framework::GroupSlicer g(e, associatedTuple, slices); @@ -607,7 +613,8 @@ TEST_CASE("EmptySliceables") REQUIRE(t.size() == 0); auto tt = std::make_tuple(t); - ArrowTableSlicingCache slices({{soa::getLabelFromType(), "fIndex" + o2::framework::cutString(soa::getLabelFromType())}}); + std::string key = "fIndex" + o2::framework::cutString(soa::getLabelFromType()); + ArrowTableSlicingCache slices({{soa::getLabelFromType(), soa::getMatcherFromTypeForKey(key), key}}); auto s = slices.updateCacheEntry(0, trkTable); o2::framework::GroupSlicer g(e, tt, slices); @@ -679,7 +686,7 @@ TEST_CASE("ArrowDirectSlicing") std::vector slices; std::vector offsts; - auto bk = Entry(soa::getLabelFromType(), "fID"); + auto bk = Entry(soa::getLabelFromType(), soa::getMatcherFromTypeForKey("fID"), "fID"); ArrowTableSlicingCache cache({bk}); auto s = cache.updateCacheEntry(0, {evtTable}); auto lcache = cache.getCacheFor(bk); @@ -737,7 +744,7 @@ TEST_CASE("TestSlicingException") } auto evtTable = builderE.finalize(); - auto bk = Entry(soa::getLabelFromType(), "fID"); + auto bk = Entry(soa::getLabelFromType(), soa::getMatcherFromTypeForKey("fID"), "fID"); ArrowTableSlicingCache cache({bk}); try {