diff --git a/Framework/Core/include/Framework/DataModelViews.h b/Framework/Core/include/Framework/DataModelViews.h index 7c39a94950e9c..285f5ef15154e 100644 --- a/Framework/Core/include/Framework/DataModelViews.h +++ b/Framework/Core/include/Framework/DataModelViews.h @@ -16,7 +16,9 @@ #include "DomainInfoHeader.h" #include "SourceInfoHeader.h" #include "Headers/DataHeader.h" +#include "Framework/TimesliceSlot.h" #include +#include namespace o2::framework { @@ -213,13 +215,11 @@ struct get_num_payloads { } }; -struct MessageSet; - struct inputs_for_slot { TimesliceSlot slot; template requires requires(R r) { requires std::ranges::random_access_range; } - friend std::span operator|(R&& r, inputs_for_slot self) + friend auto operator|(R&& r, inputs_for_slot self) { return std::span(r.sets[self.slot.index * r.inputsPerSlot]); } @@ -231,7 +231,7 @@ struct messages_for_input { requires std::ranges::random_access_range friend std::span operator|(R&& r, messages_for_input self) { - return r[self.inputIdx].messages; + return std::span(r[self.inputIdx]); } }; diff --git a/Framework/Core/include/Framework/DataProcessingHelpers.h b/Framework/Core/include/Framework/DataProcessingHelpers.h index 87aeeb8922da3..f414e3aa4ae00 100644 --- a/Framework/Core/include/Framework/DataProcessingHelpers.h +++ b/Framework/Core/include/Framework/DataProcessingHelpers.h @@ -15,6 +15,7 @@ #include "Framework/TimesliceSlot.h" #include "Framework/TimesliceIndex.h" #include +#include #include #include @@ -29,7 +30,6 @@ struct OutputChannelState; struct ProcessingPolicies; struct DeviceSpec; struct FairMQDeviceProxy; -struct MessageSet; struct ChannelIndex; enum struct StreamingState; enum struct TransitionHandlingState; @@ -54,7 +54,7 @@ struct DataProcessingHelpers { /// starts the EoS timers and returns the new TransitionHandlingState in case as new state is requested static TransitionHandlingState updateStateTransition(ServiceRegistryRef const& ref, ProcessingPolicies const& policies); /// Helper to route messages for forwarding - static std::vector routeForwardedMessageSet(FairMQDeviceProxy& proxy, std::vector& currentSetOfInputs, + static std::vector routeForwardedMessageSet(FairMQDeviceProxy& proxy, std::vector>& currentSetOfInputs, bool copy, bool consume); /// Helper to route messages for forwarding static void routeForwardedMessages(FairMQDeviceProxy& proxy, std::span& currentSetOfInputs, std::vector& forwardedParts, diff --git a/Framework/Core/include/Framework/DataRelayer.h b/Framework/Core/include/Framework/DataRelayer.h index e5a2aecea1de4..b56a2cb59ff10 100644 --- a/Framework/Core/include/Framework/DataRelayer.h +++ b/Framework/Core/include/Framework/DataRelayer.h @@ -16,7 +16,7 @@ #include "Framework/DataDescriptorMatcher.h" #include "Framework/ForwardRoute.h" #include "Framework/CompletionPolicy.h" -#include "Framework/MessageSet.h" +#include #include "Framework/TimesliceIndex.h" #include "Framework/Tracing.h" #include "Framework/TimesliceSlot.h" @@ -113,7 +113,7 @@ class DataRelayer ActivityStats processDanglingInputs(std::vector const&, ServiceRegistryRef context, bool createNew); - using OnDropCallback = std::function&, TimesliceIndex::OldestOutputInfo info)>; + using OnDropCallback = std::function>&, TimesliceIndex::OldestOutputInfo info)>; // Callback for when some messages are about to be owned by the the DataRelayer using OnInsertionCallback = std::function&)>; @@ -156,8 +156,8 @@ class DataRelayer /// Returns an input registry associated to the given timeslice and gives /// ownership to the caller. This is because once the inputs are out of the /// DataRelayer they need to be deleted once the processing is concluded. - std::vector consumeAllInputsForTimeslice(TimesliceSlot id); - std::vector consumeExistingInputsForTimeslice(TimesliceSlot id); + std::vector> consumeAllInputsForTimeslice(TimesliceSlot id); + std::vector> consumeExistingInputsForTimeslice(TimesliceSlot id); /// Returns how many timeslices we can handle in parallel [[nodiscard]] size_t getParallelTimeslices() const; @@ -203,7 +203,7 @@ class DataRelayer /// Notice that we store them as a NxM sized vector, where /// N is the maximum number of inflight timeslices, while /// M is the number of inputs which are requested. - std::vector mCache; + std::vector> mCache; /// This is the index which maps a given timestamp to the associated /// cacheline. diff --git a/Framework/Core/include/Framework/MessageSet.h b/Framework/Core/include/Framework/MessageSet.h deleted file mode 100644 index 166934238d647..0000000000000 --- a/Framework/Core/include/Framework/MessageSet.h +++ /dev/null @@ -1,178 +0,0 @@ -// Copyright 2019-2020 CERN and copyright holders of ALICE O2. -// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. -// All rights not expressly granted are reserved. -// -// This software is distributed under the terms of the GNU General Public -// License v3 (GPL Version 3), copied verbatim in the file "COPYING". -// -// In applying this license CERN does not waive the privileges and immunities -// granted to it by virtue of its status as an Intergovernmental Organization -// or submit itself to any jurisdiction. -#ifndef FRAMEWORK_MESSAGESET_H -#define FRAMEWORK_MESSAGESET_H - -#include "Framework/PartRef.h" -#include -#include "Framework/DataModelViews.h" -#include -#include -#include - -namespace o2::framework -{ - -/// A set of inflight messages. -/// The messages are stored in a linear vector. Originally, an O2 message was -/// comprised of a header-payload pair which makes indexing of pairs in the -/// storage simple. To support O2 messages with multiple payloads in a future -/// update of the data model, a message index is needed to store position in the -/// linear storage and number of messages. -/// DPL InputRecord API is providing refs of header-payload pairs, the original -/// O2 message model. For this purpose, also the pair index is filled and can -/// be used to access header and payload associated with a pair -struct MessageSet { - struct Index { - Index(size_t p, size_t s) : position(p), size(s) {} - size_t position = 0; - size_t size = 0; - }; - // linear storage of messages - std::vector messages; - // message map describes O2 messages consisting of a header message and - // payload message(s), index describes position in the linear storage - std::vector messageMap; - // pair map describes all messages in one sequence of header-payload pairs and - // where in the message index the associated header and payload can be found - struct PairMapping { - PairMapping(size_t partId, size_t payloadId) : partIndex(partId), payloadIndex(payloadId) {} - // O2 message where the pair is located in - size_t partIndex = 0; - // payload index within the O2 message - size_t payloadIndex = 0; - }; - std::vector pairMap; - - MessageSet() - : messages(), messageMap(), pairMap() - { - } - - template - MessageSet(F getter, size_t size) - : messages(), messageMap(), pairMap() - { - add(std::forward(getter), size); - } - - MessageSet(MessageSet&& other) - : messages(std::move(other.messages)), messageMap(std::move(other.messageMap)), pairMap(std::move(other.pairMap)) - { - other.clear(); - } - - MessageSet& operator=(MessageSet&& other) - { - if (&other == this) { - return *this; - } - messages = std::move(other.messages); - messageMap = std::move(other.messageMap); - pairMap = std::move(other.pairMap); - other.clear(); - return *this; - } - - /// get number of header-payload pairs - [[nodiscard]] size_t getNumberOfPairs() const - { - return messages | count_payloads{}; - } - - /// get number of payloads for an in-flight message - [[nodiscard]] size_t getNumberOfPayloads(size_t mi) const - { - return messages | get_num_payloads{mi}; - } - - /// clear the set - void clear() - { - messages.clear(); - messageMap.clear(); - pairMap.clear(); - } - - // this is more or less legacy - // PartRef has been earlier used to store fixed header-payload pairs - // reset the set and store content of the part ref - void reset(PartRef&& ref) - { - clear(); - add(std::move(ref)); - } - - // this is more or less legacy - // PartRef has been earlier used to store fixed header-payload pairs - // add content of the part ref - void add(PartRef&& ref) - { - pairMap.emplace_back(messageMap.size(), 0); - messageMap.emplace_back(messages.size(), 1); - messages.emplace_back(std::move(ref.header)); - messages.emplace_back(std::move(ref.payload)); - } - - /// add an O2 message - template - void add(F getter, size_t size) - { - auto partid = messageMap.size(); - messageMap.emplace_back(messages.size(), size - 1); - for (size_t i = 0; i < size; ++i) { - if (i > 0) { - pairMap.emplace_back(partid, i - 1); - } - messages.emplace_back(std::move(getter(i))); - } - } - - fair::mq::MessagePtr& header(size_t partIndex) - { - return messages[messageMap[partIndex].position]; - } - - fair::mq::MessagePtr& payload(size_t partIndex, size_t payloadIndex = 0) - { - assert(partIndex < messageMap.size()); - assert(messageMap[partIndex].position + payloadIndex + 1 < messages.size()); - return messages[messageMap[partIndex].position + payloadIndex + 1]; - } - - fair::mq::MessagePtr const& header(size_t partIndex) const - { - return messages[messageMap[partIndex].position]; - } - - fair::mq::MessagePtr const& payload(size_t partIndex, size_t payloadIndex = 0) const - { - assert(partIndex < messageMap.size()); - assert(messageMap[partIndex].position + payloadIndex + 1 < messages.size()); - return messages[messageMap[partIndex].position + payloadIndex + 1]; - } - - fair::mq::MessagePtr const& associatedHeader(size_t pos) const - { - return messages[messageMap[pairMap[pos].partIndex].position]; - } - - fair::mq::MessagePtr const& associatedPayload(size_t pos) const - { - auto partIndex = pairMap[pos].partIndex; - auto payloadIndex = pairMap[pos].payloadIndex; - return messages[messageMap[partIndex].position + payloadIndex + 1]; - } -}; - -} // namespace o2::framework - -#endif // FRAMEWORK_MESSAGESET_H diff --git a/Framework/Core/src/DataProcessingDevice.cxx b/Framework/Core/src/DataProcessingDevice.cxx index 0fa70947bf18c..4f37a427abbea 100644 --- a/Framework/Core/src/DataProcessingDevice.cxx +++ b/Framework/Core/src/DataProcessingDevice.cxx @@ -50,6 +50,7 @@ #include "DecongestionService.h" #include "Framework/DataProcessingHelpers.h" +#include "Framework/DataModelViews.h" #include "DataRelayerHelpers.h" #include "Headers/DataHeader.h" #include "Headers/DataHeaderHelpers.h" @@ -585,7 +586,7 @@ auto decongestionCallbackLate = [](AsyncTask& task, size_t aid) -> void { // the inputs which are shared between this device and others // to the next one in the daisy chain. // FIXME: do it in a smarter way than O(N^2) -static auto forwardInputs = [](ServiceRegistryRef registry, TimesliceSlot slot, std::vector& currentSetOfInputs, +static auto forwardInputs = [](ServiceRegistryRef registry, TimesliceSlot slot, std::vector>& currentSetOfInputs, TimesliceIndex::OldestOutputInfo oldestTimeslice, bool copy, bool consume = true) { auto& proxy = registry.get(); @@ -617,7 +618,7 @@ static auto forwardInputs = [](ServiceRegistryRef registry, TimesliceSlot slot, O2_SIGNPOST_END(forwarding, sid, "forwardInputs", "Forwarding done"); }; -static auto cleanEarlyForward = [](ServiceRegistryRef registry, TimesliceSlot slot, std::vector& currentSetOfInputs, +static auto cleanEarlyForward = [](ServiceRegistryRef registry, TimesliceSlot slot, std::vector>& currentSetOfInputs, TimesliceIndex::OldestOutputInfo oldestTimeslice, bool copy, bool consume = true) { auto& proxy = registry.get(); @@ -627,7 +628,7 @@ static auto cleanEarlyForward = [](ServiceRegistryRef registry, TimesliceSlot sl // Always copy them, because we do not want to actually send them. // We merely need the side effect of the consume, if applicable. for (size_t ii = 0, ie = currentSetOfInputs.size(); ii < ie; ++ii) { - auto span = std::span(currentSetOfInputs[ii].messages); + auto span = std::span(currentSetOfInputs[ii]); DataProcessingHelpers::cleanForwardedMessages(span, consume); } @@ -1278,7 +1279,7 @@ void DataProcessingDevice::Run() // - we can trigger further events from the queue // - we can guarantee this is the last thing we do in the loop ( // assuming no one else is adding to the queue before this point). - auto onDrop = [®istry = mServiceRegistry, lid](TimesliceSlot slot, std::vector& dropped, TimesliceIndex::OldestOutputInfo oldestOutputInfo) { + auto onDrop = [®istry = mServiceRegistry, lid](TimesliceSlot slot, std::vector>& dropped, TimesliceIndex::OldestOutputInfo oldestOutputInfo) { O2_SIGNPOST_START(device, lid, "run_loop", "Dropping message from slot %" PRIu64 ". Forwarding as needed.", (uint64_t)slot.index); ServiceRegistryRef ref{registry}; ref.get(); @@ -1942,7 +1943,7 @@ void DataProcessingDevice::handleData(ServiceRegistryRef ref, InputChannelInfo& nPayloadsPerHeader = 1; ii += (nMessages / 2) - 1; } - auto onDrop = [ref](TimesliceSlot slot, std::vector& dropped, TimesliceIndex::OldestOutputInfo oldestOutputInfo) { + auto onDrop = [ref](TimesliceSlot slot, std::vector>& dropped, TimesliceIndex::OldestOutputInfo oldestOutputInfo) { O2_SIGNPOST_ID_GENERATE(cid, async_queue); O2_SIGNPOST_EVENT_EMIT(async_queue, cid, "onDrop", "Dropping message from slot %zu. Forwarding as needed. Timeslice %zu", slot.index, oldestOutputInfo.timeslice.value); @@ -2120,7 +2121,7 @@ bool DataProcessingDevice::tryDispatchComputation(ServiceRegistryRef ref, std::v // want to support multithreaded dispatching of operations, I can simply // move these to some thread local store and the rest of the lambdas // should work just fine. - std::vector currentSetOfInputs; + std::vector> currentSetOfInputs; // auto getInputSpan = [ref, ¤tSetOfInputs](TimesliceSlot slot, bool consume = true) { @@ -2131,7 +2132,7 @@ bool DataProcessingDevice::tryDispatchComputation(ServiceRegistryRef ref, std::v currentSetOfInputs = relayer.consumeExistingInputsForTimeslice(slot); } auto getter = [¤tSetOfInputs](size_t i, size_t partindex) -> DataRef { - if (currentSetOfInputs[i].getNumberOfPairs() > partindex) { + if ((currentSetOfInputs[i] | count_payloads{}) > partindex) { const char* headerptr = nullptr; const char* payloadptr = nullptr; size_t payloadSize = 0; @@ -2140,8 +2141,9 @@ bool DataProcessingDevice::tryDispatchComputation(ServiceRegistryRef ref, std::v // sequence is the header message // - each part has one or more payload messages // - InputRecord provides all payloads as header-payload pair - auto const& headerMsg = currentSetOfInputs[i].associatedHeader(partindex); - auto const& payloadMsg = currentSetOfInputs[i].associatedPayload(partindex); + auto const indices = currentSetOfInputs[i] | get_pair{partindex}; + auto const& headerMsg = currentSetOfInputs[i][indices.headerIdx]; + auto const& payloadMsg = currentSetOfInputs[i][indices.payloadIdx]; headerptr = static_cast(headerMsg->GetData()); payloadptr = payloadMsg ? static_cast(payloadMsg->GetData()) : nullptr; payloadSize = payloadMsg ? payloadMsg->GetSize() : 0; @@ -2150,10 +2152,10 @@ bool DataProcessingDevice::tryDispatchComputation(ServiceRegistryRef ref, std::v return DataRef{}; }; auto nofPartsGetter = [¤tSetOfInputs](size_t i) -> size_t { - return currentSetOfInputs[i].getNumberOfPairs(); + return (currentSetOfInputs[i] | count_payloads{}); }; auto refCountGetter = [¤tSetOfInputs](size_t idx) -> int { - auto& header = static_cast(*(currentSetOfInputs[idx].messages | get_header{0})); + auto& header = static_cast(*(currentSetOfInputs[idx] | get_header{0})); return header.GetRefCount(); }; return InputSpan{getter, nofPartsGetter, refCountGetter, currentSetOfInputs.size()}; diff --git a/Framework/Core/src/DataProcessingHelpers.cxx b/Framework/Core/src/DataProcessingHelpers.cxx index 334a0fc6045f6..b8399a4c591e7 100644 --- a/Framework/Core/src/DataProcessingHelpers.cxx +++ b/Framework/Core/src/DataProcessingHelpers.cxx @@ -393,14 +393,14 @@ void DataProcessingHelpers::cleanForwardedMessages(std::span& currentSetOfInputs, + std::vector>& currentSetOfInputs, const bool copyByDefault, bool consume) -> std::vector { // we collect all messages per forward in a map and send them together std::vector forwardedParts(proxy.getNumForwardChannels()); for (size_t ii = 0, ie = currentSetOfInputs.size(); ii < ie; ++ii) { - auto span = std::span(currentSetOfInputs[ii].messages); + auto span = std::span(currentSetOfInputs[ii]); routeForwardedMessages(proxy, span, forwardedParts, copyByDefault, consume); } return forwardedParts; diff --git a/Framework/Core/src/DataRelayer.cxx b/Framework/Core/src/DataRelayer.cxx index 7eb851e2aadd8..fc9966ffad643 100644 --- a/Framework/Core/src/DataRelayer.cxx +++ b/Framework/Core/src/DataRelayer.cxx @@ -38,6 +38,7 @@ #include "Framework/DataTakingContext.h" #include "Framework/DefaultsHelpers.h" #include "Framework/RawDeviceService.h" +#include "Framework/DataModelViews.h" #include "Headers/DataHeaderHelpers.h" #include "Framework/Formatters.h" @@ -184,11 +185,11 @@ DataRelayer::ActivityStats DataRelayer::processDanglingInputs(std::vector std::span { + auto getPartialRecord = [&cache = mCache, numInputTypes = mDistinctRoutesIndex.size()](int li) -> std::span const> { auto offset = li * numInputTypes; assert(cache.size() >= offset + numInputTypes); auto const start = cache.data() + offset; @@ -213,9 +214,9 @@ DataRelayer::ActivityStats DataRelayer::processDanglingInputs(std::vector(header->GetData()), reinterpret_cast(payload ? payload->GetData() : nullptr), @@ -224,10 +225,10 @@ DataRelayer::ActivityStats DataRelayer::processDanglingInputs(std::vector int { - auto& header = static_cast(*(partial[idx].messages | get_header{0})); + auto& header = static_cast(*(partial[idx] | get_header{0})); return header.GetRefCount(); }; InputSpan span{getter, nPartsGetter, refCountGetter, static_cast(partial.size())}; @@ -242,12 +243,14 @@ DataRelayer::ActivityStats DataRelayer::processDanglingInputs(std::vector(); @@ -353,7 +356,7 @@ void DataRelayer::setOldestPossibleInput(TimesliceId proposed, ChannelIndex chan continue; } auto& element = mCache[si * mInputs.size() + mi]; - if (element.messages.empty()) { + if (element.empty()) { auto& state = mContext.get(); if (state.transitionHandling != TransitionHandlingState::NoTransition && DefaultsHelpers::onlineDeploymentMode()) { if (state.allowedProcessing == DeviceState::CalibrationOnly) { @@ -405,17 +408,17 @@ void DataRelayer::pruneCache(TimesliceSlot slot, OnDropCallback onDrop) if (onDrop) { auto oldestPossibleTimeslice = index.getOldestPossibleOutput(); // State of the computation - std::vector dropped(numInputTypes); + std::vector> dropped(numInputTypes); for (size_t ai = 0, ae = numInputTypes; ai != ae; ++ai) { auto cacheId = slot.index * numInputTypes + ai; cachedStateMetrics[cacheId] = CacheEntryStatus::RUNNING; // TODO: in the original implementation of the cache, there have been only two messages per entry, // check if the 2 above corresponds to the number of messages. - if (!cache[cacheId].messages.empty()) { + if (!cache[cacheId].empty()) { dropped[ai] = std::move(cache[cacheId]); } } - bool anyDropped = std::any_of(dropped.begin(), dropped.end(), [](auto& m) { return !m.messages.empty(); }); + bool anyDropped = std::any_of(dropped.begin(), dropped.end(), [](auto& m) { return !m.empty(); }); if (anyDropped) { O2_SIGNPOST_ID_GENERATE(aid, data_relayer); O2_SIGNPOST_EVENT_EMIT(data_relayer, aid, "pruneCache", "Dropping stuff from slot %zu with timeslice %zu", slot.index, oldestPossibleTimeslice.timeslice.value); @@ -506,7 +509,7 @@ DataRelayer::RelayChoice timeslice.value, slot.index, info.index.value == ChannelIndex::INVALID ? "invalid" : services.get().getInputChannel(info.index)->GetName().c_str()); auto cacheIdx = numInputTypes * slot.index + input; - MessageSet& target = cache[cacheIdx]; + auto& target = cache[cacheIdx]; cachedStateMetrics[cacheIdx] = CacheEntryStatus::PENDING; // TODO: make sure that multiple parts can only be added within the same call of // DataRelayer::relay @@ -536,7 +539,9 @@ DataRelayer::RelayChoice auto span = std::span(messages + mi, messages + mi + nPayloads + 1); // Notice this will split [(header, payload), (header, payload)] multiparts // in N different subParts for the message spec. - target.add([&span](size_t i) -> fair::mq::MessagePtr& { return span[i]; }, nPayloads + 1); + for (size_t i = 0; i < nPayloads + 1; ++i) { + target.emplace_back(std::move(span[i])); + } mi += nPayloads; saved += nPayloads; } @@ -728,7 +733,7 @@ void DataRelayer::getReadyToProcess(std::vector& comp // // We use this to bail out early from the check as soon as we find something // which we know is not complete. - auto getPartialRecord = [&cache, &numInputTypes](int li) -> std::span { + auto getPartialRecord = [&cache, &numInputTypes](int li) -> std::span const> { auto offset = li * numInputTypes; assert(cache.size() >= offset + numInputTypes); auto const start = cache.data() + offset; @@ -786,9 +791,9 @@ void DataRelayer::getReadyToProcess(std::vector& comp auto partial = getPartialRecord(li); // TODO: get the data ref from message model auto getter = [&partial](size_t idx, size_t part) { - if (!partial[idx].messages.empty() && partial[idx].header(part).get()) { - auto header = partial[idx].header(part).get(); - auto payload = partial[idx].payload(part).get(); + if (!partial[idx].empty() && (partial[idx] | get_header{part}).get()) { + auto header = (partial[idx] | get_header{part}).get(); + auto payload = (partial[idx] | get_payload{part, 0}).get(); return DataRef{nullptr, reinterpret_cast(header->GetData()), reinterpret_cast(payload ? payload->GetData() : nullptr), @@ -797,10 +802,10 @@ void DataRelayer::getReadyToProcess(std::vector& comp return DataRef{}; }; auto nPartsGetter = [&partial](size_t idx) { - return partial[idx].messages | count_parts{}; + return partial[idx] | count_parts{}; }; auto refCountGetter = [&partial](size_t idx) -> int { - auto& header = static_cast(*(partial[idx].messages | get_header{0})); + auto& header = static_cast(*(partial[idx] | get_header{0})); return header.GetRefCount(); }; InputSpan span{getter, nPartsGetter, refCountGetter, static_cast(partial.size())}; @@ -871,13 +876,13 @@ void DataRelayer::updateCacheStatus(TimesliceSlot slot, CacheEntryStatus oldStat } } -std::vector DataRelayer::consumeAllInputsForTimeslice(TimesliceSlot slot) +std::vector> DataRelayer::consumeAllInputsForTimeslice(TimesliceSlot slot) { std::scoped_lock lock(mMutex); const auto numInputTypes = mDistinctRoutesIndex.size(); // State of the computation - std::vector messages(numInputTypes); + std::vector> messages(numInputTypes); auto& cache = mCache; auto& index = mTimesliceIndex; @@ -897,7 +902,7 @@ std::vector DataRelayer::consumeAllInputsForTimeslice cachedStateMetrics[cacheId] = CacheEntryStatus::RUNNING; // TODO: in the original implementation of the cache, there have been only two messages per entry, // check if the 2 above corresponds to the number of messages. - if (!cache[cacheId].messages.empty()) { + if (!cache[cacheId].empty()) { messages[arg] = std::move(cache[cacheId]); } index.markAsInvalid(s); @@ -909,7 +914,7 @@ std::vector DataRelayer::consumeAllInputsForTimeslice // FIXME: what happens when we have enough timeslices to hit the invalid one? auto invalidateCacheFor = [&numInputTypes, &index, &cache](TimesliceSlot s) { for (size_t ai = s.index * numInputTypes, ae = ai + numInputTypes; ai != ae; ++ai) { - assert(std::accumulate(cache[ai].messages.begin(), cache[ai].messages.end(), true, [](bool result, auto const& element) { return result && element.get() == nullptr; })); + assert(std::accumulate(cache[ai].begin(), cache[ai].end(), true, [](bool result, auto const& element) { return result && element.get() == nullptr; })); cache[ai].clear(); } index.markAsInvalid(s); @@ -925,13 +930,13 @@ std::vector DataRelayer::consumeAllInputsForTimeslice return messages; } -std::vector DataRelayer::consumeExistingInputsForTimeslice(TimesliceSlot slot) +std::vector> DataRelayer::consumeExistingInputsForTimeslice(TimesliceSlot slot) { std::scoped_lock lock(mMutex); const auto numInputTypes = mDistinctRoutesIndex.size(); // State of the computation - std::vector messages(numInputTypes); + std::vector> messages(numInputTypes); auto& cache = mCache; auto& index = mTimesliceIndex; @@ -951,11 +956,12 @@ std::vector DataRelayer::consumeExistingInputsForTime cachedStateMetrics[cacheId] = CacheEntryStatus::RUNNING; // TODO: in the original implementation of the cache, there have been only two messages per entry, // check if the 2 above corresponds to the number of messages. - for (size_t pi = 0; pi < (cache[cacheId].messages | count_parts{}); pi++) { - auto& header = cache[cacheId].header(pi); + for (size_t pi = 0; pi < (cache[cacheId] | count_parts{}); pi++) { + auto& header = cache[cacheId] | get_header{pi}; auto&& newHeader = header->GetTransport()->CreateMessage(); newHeader->Copy(*header); - messages[arg].add(PartRef{std::move(newHeader), std::move(cache[cacheId].payload(pi))}); + messages[arg].emplace_back(std::move(newHeader)); + messages[arg].emplace_back(std::move(cache[cacheId] | get_payload{pi, 0})); } }; diff --git a/Framework/Core/test/benchmark_DataRelayer.cxx b/Framework/Core/test/benchmark_DataRelayer.cxx index 312711d73e95e..e7df8fbb2fe9b 100644 --- a/Framework/Core/test/benchmark_DataRelayer.cxx +++ b/Framework/Core/test/benchmark_DataRelayer.cxx @@ -14,6 +14,7 @@ #include "Headers/Stack.h" #include "Framework/CompletionPolicyHelpers.h" #include "Framework/DataRelayer.h" +#include "Framework/DataModelViews.h" #include "Framework/DataProcessingHeader.h" #include "Framework/DataProcessingStates.h" #include "Framework/DataProcessingStats.h" @@ -138,8 +139,8 @@ static void BM_RelaySingleSlot(benchmark::State& state) assert(ready[0].op == CompletionPolicy::CompletionOp::Consume); auto result = relayer.consumeAllInputsForTimeslice(ready[0].slot); assert(result.size() == 1); - assert((result.at(0).messages | count_parts{}) == 1); - inflightMessages = std::move(result[0].messages); + assert((result.at(0) | count_parts{}) == 1); + inflightMessages = std::move(result[0]); } } @@ -194,8 +195,8 @@ static void BM_RelayMultipleSlots(benchmark::State& state) assert(ready[0].op == CompletionPolicy::CompletionOp::Consume); auto result = relayer.consumeAllInputsForTimeslice(ready[0].slot); assert(result.size() == 1); - assert((result.at(0).messages | count_parts{}) == 1); - inflightMessages = std::move(result[0].messages); + assert((result.at(0) | count_parts{}) == 1); + inflightMessages = std::move(result[0]); } } @@ -268,11 +269,11 @@ static void BM_RelayMultipleRoutes(benchmark::State& state) assert(ready[0].op == CompletionPolicy::CompletionOp::Consume); auto result = relayer.consumeAllInputsForTimeslice(ready[0].slot); assert(result.size() == 2); - assert((result.at(0).messages | count_parts{}) == 1); - assert((result.at(1).messages | count_parts{}) == 1); - inflightMessages = std::move(result[0].messages); - inflightMessages.emplace_back(std::move(result[1].messages[0])); - inflightMessages.emplace_back(std::move(result[1].messages[1])); + assert((result.at(0) | count_parts{}) == 1); + assert((result.at(1) | count_parts{}) == 1); + inflightMessages = std::move(result[0]); + inflightMessages.emplace_back(std::move(result[1][0])); + inflightMessages.emplace_back(std::move(result[1][1])); } } @@ -332,7 +333,7 @@ static void BM_RelaySplitParts(benchmark::State& state) relayer.getReadyToProcess(ready); assert(ready.size() == 1); assert(ready[0].op == CompletionPolicy::CompletionOp::Consume); - inflightMessages = std::move(relayer.consumeAllInputsForTimeslice(ready[0].slot)[0].messages); + inflightMessages = std::move(relayer.consumeAllInputsForTimeslice(ready[0].slot)[0]); } } @@ -386,7 +387,7 @@ static void BM_RelayMultiplePayloads(benchmark::State& state) relayer.getReadyToProcess(ready); assert(ready.size() == 1); assert(ready[0].op == CompletionPolicy::CompletionOp::Consume); - inflightMessages = std::move(relayer.consumeAllInputsForTimeslice(ready[0].slot)[0].messages); + inflightMessages = std::move(relayer.consumeAllInputsForTimeslice(ready[0].slot)[0]); } } diff --git a/Framework/Core/test/test_DataRelayer.cxx b/Framework/Core/test/test_DataRelayer.cxx index 1f7518860bf57..5606d01619bbb 100644 --- a/Framework/Core/test/test_DataRelayer.cxx +++ b/Framework/Core/test/test_DataRelayer.cxx @@ -16,6 +16,7 @@ #include "MemoryResources/MemoryResources.h" #include "Framework/CompletionPolicyHelpers.h" #include "Framework/DataRelayer.h" +#include "Framework/DataModelViews.h" #include "Framework/DataProcessingStats.h" #include "Framework/DataProcessingStates.h" #include "Framework/DriverConfig.h" @@ -119,7 +120,7 @@ TEST_CASE("DataRelayer") auto result = relayer.consumeAllInputsForTimeslice(ready[0].slot); // one MessageSet with one PartRef with header and payload REQUIRE(result.size() == 1); - REQUIRE((result.at(0).messages | count_parts{}) == 1); + REQUIRE((result.at(0) | count_parts{}) == 1); } // @@ -169,7 +170,7 @@ TEST_CASE("DataRelayer") auto result = relayer.consumeAllInputsForTimeslice(ready[0].slot); // one MessageSet with one PartRef with header and payload REQUIRE(result.size() == 1); - REQUIRE((result.at(0).messages | count_parts{}) == 1); + REQUIRE((result.at(0) | count_parts{}) == 1); } // This test a more complicated set of inputs, and verifies that data is @@ -249,8 +250,8 @@ TEST_CASE("DataRelayer") auto result = relayer.consumeAllInputsForTimeslice(ready[0].slot); // two MessageSets, each with one PartRef REQUIRE(result.size() == 2); - REQUIRE((result.at(0).messages | count_parts{}) == 1); - REQUIRE((result.at(1).messages | count_parts{}) == 1); + REQUIRE((result.at(0) | count_parts{}) == 1); + REQUIRE((result.at(1) | count_parts{}) == 1); } // This test a more complicated set of inputs, and verifies that data is @@ -737,8 +738,8 @@ TEST_CASE("DataRelayer") // we have one input route and thus one message set containing pairs for all // payloads REQUIRE(messageSet.size() == 1); - REQUIRE((messageSet[0].messages | count_parts{}) == nSplitParts); - REQUIRE(messageSet[0].getNumberOfPayloads(0) == 1); + REQUIRE((messageSet[0] | count_parts{}) == nSplitParts); + REQUIRE((messageSet[0] | get_num_payloads{0}) == 1); } SECTION("SplitPayloadSequence") @@ -800,13 +801,13 @@ TEST_CASE("DataRelayer") // we have one input route REQUIRE(messageSet.size() == 1); // one message set containing number of added sequences of messages - REQUIRE((messageSet[0].messages | count_parts{}) == sequenceSize.size()); + REQUIRE((messageSet[0] | count_parts{}) == sequenceSize.size()); size_t counter = 0; for (size_t seqid = 0; seqid < sequenceSize.size(); ++seqid) { - REQUIRE(messageSet[0].getNumberOfPayloads(seqid) == sequenceSize[seqid]); - for (size_t pi = 0; pi < messageSet[0].getNumberOfPayloads(seqid); ++pi) { - REQUIRE((messageSet[0].messages | get_payload{seqid, pi})); - auto const* data = (messageSet[0].messages | get_payload{seqid, pi})->GetData(); + REQUIRE((messageSet[0] | get_num_payloads{seqid}) == sequenceSize[seqid]); + for (size_t pi = 0; pi < (messageSet[0] | get_num_payloads{seqid}); ++pi) { + REQUIRE((messageSet[0] | get_payload{seqid, pi})); + auto const* data = (messageSet[0] | get_payload{seqid, pi})->GetData(); REQUIRE(*(reinterpret_cast(data)) == counter); ++counter; } diff --git a/Framework/Core/test/test_ForwardInputs.cxx b/Framework/Core/test/test_ForwardInputs.cxx index e3031b7e72a69..0263158ee0f9b 100644 --- a/Framework/Core/test/test_ForwardInputs.cxx +++ b/Framework/Core/test/test_ForwardInputs.cxx @@ -16,7 +16,7 @@ #include "Framework/SourceInfoHeader.h" #include "Framework/DomainInfoHeader.h" #include "Framework/Signpost.h" -#include "Framework/MessageSet.h" +#include "Framework/DataModelViews.h" #include "Framework/FairMQDeviceProxy.h" #include "Headers/Stack.h" #include "MemoryResources/MemoryResources.h" @@ -43,7 +43,7 @@ TEST_CASE("ForwardInputsEmpty") bool copyByDefault = true; FairMQDeviceProxy proxy; - std::vector currentSetOfInputs; + std::vector> currentSetOfInputs; auto result = o2::framework::DataProcessingHelpers::routeForwardedMessageSet(proxy, currentSetOfInputs, copyByDefault, consume); REQUIRE(result.empty()); @@ -84,15 +84,16 @@ TEST_CASE("ForwardInputsSingleMessageSingleRoute") proxy.bind({}, {}, routes, findChannelByName, nullptr); - std::vector currentSetOfInputs; - MessageSet messageSet; + std::vector> currentSetOfInputs; + std::vector messageSet; auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); fair::mq::MessagePtr payload(transport->CreateMessage()); auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); auto header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph}); - messageSet.add(PartRef{std::move(header), std::move(payload)}); - REQUIRE((messageSet.messages | count_parts{}) == 1); + messageSet.emplace_back(std::move(header)); + messageSet.emplace_back(std::move(payload)); + REQUIRE((messageSet | count_parts{}) == 1); currentSetOfInputs.emplace_back(std::move(messageSet)); auto result = o2::framework::DataProcessingHelpers::routeForwardedMessageSet(proxy, currentSetOfInputs, copyByDefault, consume); @@ -134,16 +135,17 @@ TEST_CASE("ForwardInputsSingleMessageSingleRouteNoConsume") proxy.bind({}, {}, routes, findChannelByName, nullptr); - std::vector currentSetOfInputs; - MessageSet messageSet; + std::vector> currentSetOfInputs; + std::vector messageSet; auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); fair::mq::MessagePtr payload(nullptr); REQUIRE(payload.get() == nullptr); auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); auto header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph}); - messageSet.add(PartRef{std::move(header), std::move(payload)}); - REQUIRE((messageSet.messages | count_parts{}) == 1); + messageSet.emplace_back(std::move(header)); + messageSet.emplace_back(std::move(payload)); + REQUIRE((messageSet | count_parts{}) == 1); currentSetOfInputs.emplace_back(std::move(messageSet)); auto result = o2::framework::DataProcessingHelpers::routeForwardedMessageSet(proxy, currentSetOfInputs, copyByDefault, true); @@ -189,16 +191,17 @@ TEST_CASE("ForwardInputsSingleMessageSingleRouteAtEOS") proxy.bind({}, {}, routes, findChannelByName, nullptr); - std::vector currentSetOfInputs; - MessageSet messageSet; + std::vector> currentSetOfInputs; + std::vector messageSet; auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); fair::mq::MessagePtr payload(transport->CreateMessage()); auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); auto header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph, sih}); REQUIRE(o2::header::get(header->GetData())); - messageSet.add(PartRef{std::move(header), std::move(payload)}); - REQUIRE((messageSet.messages | count_parts{}) == 1); + messageSet.emplace_back(std::move(header)); + messageSet.emplace_back(std::move(payload)); + REQUIRE((messageSet | count_parts{}) == 1); currentSetOfInputs.emplace_back(std::move(messageSet)); auto result = o2::framework::DataProcessingHelpers::routeForwardedMessageSet(proxy, currentSetOfInputs, copyByDefault, consume); @@ -247,16 +250,17 @@ TEST_CASE("ForwardInputsSingleMessageSingleRouteWithOldestPossible") proxy.bind({}, {}, routes, findChannelByName, nullptr); - std::vector currentSetOfInputs; - MessageSet messageSet; + std::vector> currentSetOfInputs; + std::vector messageSet; auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); fair::mq::MessagePtr payload(transport->CreateMessage()); auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); auto header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph, dih}); REQUIRE(o2::header::get(header->GetData())); - messageSet.add(PartRef{std::move(header), std::move(payload)}); - REQUIRE((messageSet.messages | count_parts{}) == 1); + messageSet.emplace_back(std::move(header)); + messageSet.emplace_back(std::move(payload)); + REQUIRE((messageSet | count_parts{}) == 1); currentSetOfInputs.emplace_back(std::move(messageSet)); auto result = o2::framework::DataProcessingHelpers::routeForwardedMessageSet(proxy, currentSetOfInputs, copyByDefault, consume); @@ -313,15 +317,16 @@ TEST_CASE("ForwardInputsSingleMessageMultipleRoutes") proxy.bind({}, {}, routes, findChannelByName, nullptr); - std::vector currentSetOfInputs; - MessageSet messageSet; + std::vector> currentSetOfInputs; + std::vector messageSet; auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); fair::mq::MessagePtr payload(transport->CreateMessage()); auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); auto header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph}); - messageSet.add(PartRef{std::move(header), std::move(payload)}); - REQUIRE((messageSet.messages | count_parts{}) == 1); + messageSet.emplace_back(std::move(header)); + messageSet.emplace_back(std::move(payload)); + REQUIRE((messageSet | count_parts{}) == 1); currentSetOfInputs.emplace_back(std::move(messageSet)); auto result = o2::framework::DataProcessingHelpers::routeForwardedMessageSet(proxy, currentSetOfInputs, copyByDefault, consume); @@ -376,15 +381,16 @@ TEST_CASE("ForwardInputsSingleMessageMultipleRoutesExternals") proxy.bind({}, {}, routes, findChannelByName, nullptr); - std::vector currentSetOfInputs; - MessageSet messageSet; + std::vector> currentSetOfInputs; + std::vector messageSet; auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); fair::mq::MessagePtr payload(transport->CreateMessage()); auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); auto header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph}); - messageSet.add(PartRef{std::move(header), std::move(payload)}); - REQUIRE((messageSet.messages | count_parts{}) == 1); + messageSet.emplace_back(std::move(header)); + messageSet.emplace_back(std::move(payload)); + REQUIRE((messageSet | count_parts{}) == 1); currentSetOfInputs.emplace_back(std::move(messageSet)); auto result = o2::framework::DataProcessingHelpers::routeForwardedMessageSet(proxy, currentSetOfInputs, copyByDefault, consume); @@ -446,21 +452,23 @@ TEST_CASE("ForwardInputsMultiMessageMultipleRoutes") proxy.bind({}, {}, routes, findChannelByName, nullptr); - std::vector currentSetOfInputs; + std::vector> currentSetOfInputs; auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); fair::mq::MessagePtr payload1(transport->CreateMessage()); fair::mq::MessagePtr payload2(transport->CreateMessage()); auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); auto header1 = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh1, dph}); - MessageSet messageSet1; - messageSet1.add(PartRef{std::move(header1), std::move(payload1)}); - REQUIRE((messageSet1.messages | count_parts{}) == 1); + std::vector messageSet1; + messageSet1.emplace_back(std::move(header1)); + messageSet1.emplace_back(std::move(payload1)); + REQUIRE((messageSet1 | count_parts{}) == 1); auto header2 = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh2, dph}); - MessageSet messageSet2; - messageSet2.add(PartRef{std::move(header2), std::move(payload2)}); - REQUIRE((messageSet2.messages | count_parts{}) == 1); + std::vector messageSet2; + messageSet2.emplace_back(std::move(header2)); + messageSet2.emplace_back(std::move(payload2)); + REQUIRE((messageSet2 | count_parts{}) == 1); currentSetOfInputs.emplace_back(std::move(messageSet1)); currentSetOfInputs.emplace_back(std::move(messageSet2)); REQUIRE(currentSetOfInputs.size() == 2); @@ -517,15 +525,16 @@ TEST_CASE("ForwardInputsSingleMessageMultipleRoutesOnlyOneMatches") proxy.bind({}, {}, routes, findChannelByName, nullptr); - std::vector currentSetOfInputs; - MessageSet messageSet; + std::vector> currentSetOfInputs; + std::vector messageSet; auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); fair::mq::MessagePtr payload(transport->CreateMessage()); auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); auto header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph}); - messageSet.add(PartRef{std::move(header), std::move(payload)}); - REQUIRE((messageSet.messages | count_parts{}) == 1); + messageSet.emplace_back(std::move(header)); + messageSet.emplace_back(std::move(payload)); + REQUIRE((messageSet | count_parts{}) == 1); currentSetOfInputs.emplace_back(std::move(messageSet)); auto result = o2::framework::DataProcessingHelpers::routeForwardedMessageSet(proxy, currentSetOfInputs, copyByDefault, consume); @@ -587,8 +596,8 @@ TEST_CASE("ForwardInputsSplitPayload") proxy.bind({}, {}, routes, findChannelByName, nullptr); - std::vector currentSetOfInputs; - MessageSet messageSet; + std::vector> currentSetOfInputs; + std::vector messageSet; auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); fair::mq::MessagePtr payload1(transport->CreateMessage()); @@ -602,12 +611,14 @@ TEST_CASE("ForwardInputsSplitPayload") auto fillMessages = [&messages](size_t t) -> fair::mq::MessagePtr { return std::move(messages[t]); }; - messageSet.add(fillMessages, 3); + for (size_t i = 0; i < 3; ++i) { + messageSet.emplace_back(fillMessages(i)); + } auto header2 = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh2, dph}); - PartRef part{std::move(header2), transport->CreateMessage()}; - messageSet.add(std::move(part)); + messageSet.emplace_back(std::move(header2)); + messageSet.emplace_back(transport->CreateMessage()); - REQUIRE((messageSet.messages | count_parts{}) == 2); + REQUIRE((messageSet | count_parts{}) == 2); currentSetOfInputs.emplace_back(std::move(messageSet)); auto result = o2::framework::DataProcessingHelpers::routeForwardedMessageSet(proxy, currentSetOfInputs, copyByDefault, consume); @@ -719,15 +730,16 @@ TEST_CASE("ForwardInputEOSSingleRoute") proxy.bind({}, {}, routes, findChannelByName, nullptr); - std::vector currentSetOfInputs; - MessageSet messageSet; + std::vector> currentSetOfInputs; + std::vector messageSet; auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); fair::mq::MessagePtr payload(transport->CreateMessage()); auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); auto header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, sih}); - messageSet.add(PartRef{std::move(header), std::move(payload)}); - REQUIRE((messageSet.messages | count_parts{}) == 1); + messageSet.emplace_back(std::move(header)); + messageSet.emplace_back(std::move(payload)); + REQUIRE((messageSet | count_parts{}) == 1); currentSetOfInputs.emplace_back(std::move(messageSet)); auto result = o2::framework::DataProcessingHelpers::routeForwardedMessageSet(proxy, currentSetOfInputs, copyByDefault, consume); @@ -764,15 +776,16 @@ TEST_CASE("ForwardInputOldestPossibleSingleRoute") proxy.bind({}, {}, routes, findChannelByName, nullptr); - std::vector currentSetOfInputs; - MessageSet messageSet; + std::vector> currentSetOfInputs; + std::vector messageSet; auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); fair::mq::MessagePtr payload(transport->CreateMessage()); auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); auto header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dih}); - messageSet.add(PartRef{std::move(header), std::move(payload)}); - REQUIRE((messageSet.messages | count_parts{}) == 1); + messageSet.emplace_back(std::move(header)); + messageSet.emplace_back(std::move(payload)); + REQUIRE((messageSet | count_parts{}) == 1); currentSetOfInputs.emplace_back(std::move(messageSet)); auto result = o2::framework::DataProcessingHelpers::routeForwardedMessageSet(proxy, currentSetOfInputs, copyByDefault, consume); diff --git a/Framework/Core/test/test_MessageSet.cxx b/Framework/Core/test/test_MessageSet.cxx index 290e55220d6cb..a9750ce5ca249 100644 --- a/Framework/Core/test/test_MessageSet.cxx +++ b/Framework/Core/test/test_MessageSet.cxx @@ -14,6 +14,7 @@ #include "Framework/MessageSet.h" #include "Framework/DataModelViews.h" #include "Framework/DataProcessingHeader.h" +#include "Framework/PartRef.h" #include "Headers/Stack.h" #include "Headers/DataHeader.h" #include "MemoryResources/MemoryResources.h" @@ -23,7 +24,7 @@ using namespace o2::framework; TEST_CASE("MessageSet") { - o2::framework::MessageSet msgSet; + std::vector messages; o2::header::DataHeader dh{}; dh.splitPayloadParts = 0; dh.splitPayloadIndex = 0; @@ -36,7 +37,9 @@ TEST_CASE("MessageSet") std::vector ptrs; ptrs.emplace_back(std::move(header)); ptrs.emplace_back(std::move(msg2)); - msgSet.add([&ptrs](size_t i) -> fair::mq::MessagePtr& { return ptrs[i]; }, 2); + for (size_t i = 0; i < 2; ++i) { + messages.emplace_back(std::move(ptrs[i])); + } REQUIRE(msgSet.messages.size() == 2); REQUIRE((msgSet.messages | count_payloads{}) == 1); @@ -45,12 +48,11 @@ TEST_CASE("MessageSet") REQUIRE((msgSet.messages | get_pair{0}).headerIdx == 0); REQUIRE((msgSet.messages | get_pair{0}).payloadIdx == 1); CHECK_THROWS((msgSet.messages | get_pair{1})); - // Validate pipe operators match old API - REQUIRE(&(msgSet.messages | get_header{0}) == &msgSet.header(0)); - REQUIRE(&(msgSet.messages | get_payload{0, 0}) == &msgSet.payload(0)); - REQUIRE((msgSet.messages | get_num_payloads{0}) == msgSet.messageMap[0].size); - REQUIRE((msgSet.messages | count_parts{}) == msgSet.messageMap.size()); - REQUIRE((msgSet.messages | count_payloads{}) == msgSet.pairMap.size()); + REQUIRE((msgSet.messages | get_num_payloads{0}) == 1); + REQUIRE((msgSet.messages | count_parts{}) == 1); + // messages: [hdr, pl] — one pair + REQUIRE((msgSet.messages | get_pair{0}).headerIdx == 0); + REQUIRE((msgSet.messages | get_pair{0}).payloadIdx == 1); } TEST_CASE("MessageSetWithFunction") @@ -67,7 +69,10 @@ TEST_CASE("MessageSetWithFunction") std::unique_ptr msg2(nullptr); ptrs.emplace_back(std::move(header)); ptrs.emplace_back(std::move(msg2)); - o2::framework::MessageSet msgSet([&ptrs](size_t i) -> fair::mq::MessagePtr& { return ptrs[i]; }, 2); + std::vector messages; + for (size_t i = 0; i < 2; ++i) { + messages.emplace_back(std::move(ptrs[i])); + } REQUIRE(msgSet.messages.size() == 2); REQUIRE((msgSet.messages | count_payloads{}) == 1); @@ -76,11 +81,8 @@ TEST_CASE("MessageSetWithFunction") REQUIRE((msgSet.messages | get_pair{0}).headerIdx == 0); REQUIRE((msgSet.messages | get_pair{0}).payloadIdx == 1); CHECK_THROWS((msgSet.messages | get_pair{1})); - REQUIRE(&(msgSet.messages | get_header{0}) == &msgSet.header(0)); - REQUIRE(&(msgSet.messages | get_payload{0, 0}) == &msgSet.payload(0)); - REQUIRE((msgSet.messages | get_num_payloads{0}) == msgSet.messageMap[0].size); - REQUIRE((msgSet.messages | count_parts{}) == msgSet.messageMap.size()); - REQUIRE((msgSet.messages | count_payloads{}) == msgSet.pairMap.size()); + REQUIRE((msgSet.messages | get_num_payloads{0}) == 1); + REQUIRE((msgSet.messages | count_parts{}) == 1); } TEST_CASE("MessageSetWithMultipart") @@ -99,7 +101,10 @@ TEST_CASE("MessageSetWithMultipart") ptrs.emplace_back(std::move(header)); ptrs.emplace_back(std::move(msg2)); ptrs.emplace_back(std::move(msg3)); - o2::framework::MessageSet msgSet([&ptrs](size_t i) -> fair::mq::MessagePtr& { return ptrs[i]; }, 3); + std::vector messages; + for (size_t i = 0; i < 3; ++i) { + messages.emplace_back(std::move(ptrs[i])); + } REQUIRE(msgSet.messages.size() == 3); REQUIRE((msgSet.messages | count_payloads{}) == 2); @@ -112,13 +117,13 @@ TEST_CASE("MessageSetWithMultipart") REQUIRE((msgSet.messages | get_pair{1}).headerIdx == 0); REQUIRE((msgSet.messages | get_pair{1}).payloadIdx == 2); CHECK_THROWS((msgSet.messages | get_pair{2})); - // Validate pipe operators match old API for multi-payload - REQUIRE(&(msgSet.messages | get_header{0}) == &msgSet.header(0)); - REQUIRE(&(msgSet.messages | get_payload{0, 0}) == &msgSet.payload(0, 0)); - REQUIRE(&(msgSet.messages | get_payload{0, 1}) == &msgSet.payload(0, 1)); - REQUIRE((msgSet.messages | get_num_payloads{0}) == msgSet.messageMap[0].size); - REQUIRE((msgSet.messages | count_parts{}) == msgSet.messageMap.size()); - REQUIRE((msgSet.messages | count_payloads{}) == msgSet.pairMap.size()); + REQUIRE((msgSet.messages | get_num_payloads{0}) == 2); + REQUIRE((msgSet.messages | count_parts{}) == 1); + // messages: [hdr, pl0, pl1] — one header, two payloads + REQUIRE((msgSet.messages | get_pair{0}).headerIdx == 0); + REQUIRE((msgSet.messages | get_pair{0}).payloadIdx == 1); + REQUIRE((msgSet.messages | get_pair{1}).headerIdx == 0); + REQUIRE((msgSet.messages | get_pair{1}).payloadIdx == 2); } TEST_CASE("MessageSetAddPartRef") @@ -129,10 +134,11 @@ TEST_CASE("MessageSetAddPartRef") ptrs.emplace_back(std::move(msg)); ptrs.emplace_back(std::move(msg2)); PartRef ref{std::move(msg), std::move(msg2)}; - o2::framework::MessageSet msgSet; - msgSet.add(std::move(ref)); + std::vector messages; + messages.emplace_back(std::move(ref.header)); + messages.emplace_back(std::move(ref.payload)); - REQUIRE(msgSet.messages.size() == 2); + REQUIRE(messages.size() == 2); } TEST_CASE("MessageSetAddMultiple") @@ -158,20 +164,21 @@ TEST_CASE("MessageSetAddMultiple") std::unique_ptr msg2(nullptr); std::unique_ptr msg3(nullptr); PartRef ref{std::move(header1), std::move(msg2)}; - o2::framework::MessageSet msgSet; - msgSet.add(std::move(ref)); + std::vector messages; + messages.emplace_back(std::move(ref.header)); + messages.emplace_back(std::move(ref.payload)); PartRef ref2{std::move(header2), std::move(msg2)}; - msgSet.add(std::move(ref2)); + messages.emplace_back(std::move(ref2.header)); + messages.emplace_back(std::move(ref2.payload)); std::vector msgs; msgs.push_back(std::move(header3)); msgs.push_back(std::unique_ptr(nullptr)); msgs.push_back(std::unique_ptr(nullptr)); - msgSet.add([&msgs](size_t i) { - return std::move(msgs[i]); - }, - 3); + for (size_t i = 0; i < 3; ++i) { + messages.emplace_back(std::move(msgs[i])); + } - REQUIRE(msgSet.messages.size() == 7); + REQUIRE(messages.size() == 7); REQUIRE((msgSet.messages | count_payloads{}) == 4); REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).headerIdx == 0); @@ -190,18 +197,11 @@ TEST_CASE("MessageSetAddMultiple") REQUIRE((msgSet.messages | get_pair{2}).payloadIdx == 5); REQUIRE((msgSet.messages | get_pair{3}).headerIdx == 4); REQUIRE((msgSet.messages | get_pair{3}).payloadIdx == 6); - // Validate pipe operators match old API for mixed modes - for (size_t i = 0; i < 3; ++i) { - REQUIRE(&(msgSet.messages | get_header{i}) == &msgSet.header(i)); - REQUIRE(&(msgSet.messages | get_payload{i, 0}) == &msgSet.payload(i, 0)); - } - // Part 2 has a second payload (multi-payload with splitPayloadParts=2, splitPayloadIndex=2) - REQUIRE(&(msgSet.messages | get_payload{2, 1}) == &msgSet.payload(2, 1)); - for (size_t i = 0; i < 3; ++i) { - REQUIRE((msgSet.messages | get_num_payloads{i}) == msgSet.messageMap[i].size); - } - REQUIRE((msgSet.messages | count_parts{}) == msgSet.messageMap.size()); - REQUIRE((msgSet.messages | count_payloads{}) == msgSet.pairMap.size()); + REQUIRE((msgSet.messages | get_num_payloads{0}) == 1); + REQUIRE((msgSet.messages | get_num_payloads{1}) == 1); + REQUIRE((msgSet.messages | get_num_payloads{2}) == 2); + REQUIRE((msgSet.messages | count_parts{}) == 3); + REQUIRE((msgSet.messages | count_payloads{}) == 4); } TEST_CASE("GetHeaderPayloadOperators") @@ -250,6 +250,14 @@ TEST_CASE("GetHeaderPayloadOperators") auto& pl1 = msgSet.messages | get_payload{1, 0}; REQUIRE(pl1.get() != nullptr); REQUIRE(pl1->GetSize() == 200); + + REQUIRE((msgSet.messages | count_parts{}) == 2); + REQUIRE((msgSet.messages | count_payloads{}) == 2); + // messages: [hdr0, pl0, hdr1, pl1] — two standard pairs + REQUIRE((msgSet.messages | get_pair{0}).headerIdx == 0); + REQUIRE((msgSet.messages | get_pair{0}).payloadIdx == 1); + REQUIRE((msgSet.messages | get_pair{1}).headerIdx == 2); + REQUIRE((msgSet.messages | get_pair{1}).payloadIdx == 3); } TEST_CASE("GetHeaderPayloadMultiPayload") @@ -335,18 +343,19 @@ TEST_CASE("GetHeaderPayloadMultiPayload") // get_num_payloads for part 1 should be 3 REQUIRE((msgSet.messages | get_num_payloads{1}) == 3); - // Validate pipe operators match old API for multi-payload (header, pl, pl, pl) - REQUIRE(&(msgSet.messages | get_header{0}) == &msgSet.header(0)); - REQUIRE(&(msgSet.messages | get_header{1}) == &msgSet.header(1)); - REQUIRE(&(msgSet.messages | get_payload{0, 0}) == &msgSet.payload(0, 0)); - REQUIRE(&(msgSet.messages | get_payload{1, 0}) == &msgSet.payload(1, 0)); - REQUIRE(&(msgSet.messages | get_payload{1, 1}) == &msgSet.payload(1, 1)); - REQUIRE(&(msgSet.messages | get_payload{1, 2}) == &msgSet.payload(1, 2)); - for (size_t i = 0; i < 2; ++i) { - REQUIRE((msgSet.messages | get_num_payloads{i}) == msgSet.messageMap[i].size); - } - REQUIRE((msgSet.messages | count_parts{}) == msgSet.messageMap.size()); - REQUIRE((msgSet.messages | count_payloads{}) == msgSet.pairMap.size()); + REQUIRE((msgSet.messages | get_num_payloads{0}) == 1); + REQUIRE((msgSet.messages | get_num_payloads{1}) == 3); + REQUIRE((msgSet.messages | count_parts{}) == 2); + REQUIRE((msgSet.messages | count_payloads{}) == 4); + // messages: [hdr0, pl0, hdr1, pl1_0, pl1_1, pl1_2] + REQUIRE((msgSet.messages | get_pair{0}).headerIdx == 0); + REQUIRE((msgSet.messages | get_pair{0}).payloadIdx == 1); + REQUIRE((msgSet.messages | get_pair{1}).headerIdx == 2); + REQUIRE((msgSet.messages | get_pair{1}).payloadIdx == 3); + REQUIRE((msgSet.messages | get_pair{2}).headerIdx == 2); + REQUIRE((msgSet.messages | get_pair{2}).payloadIdx == 4); + REQUIRE((msgSet.messages | get_pair{3}).headerIdx == 2); + REQUIRE((msgSet.messages | get_pair{3}).payloadIdx == 5); } TEST_CASE("TraditionalSplitParts") @@ -410,14 +419,15 @@ TEST_CASE("TraditionalSplitParts") // get_num_payloads: each traditional split pair has 1 payload for (size_t i = 0; i < 3; ++i) { - REQUIRE((msgSet.messages | get_num_payloads{i}) == msgSet.messageMap[i].size); + REQUIRE((msgSet.messages | get_num_payloads{i}) == 1); } - - // Validate pipe operators match old MessageSet::header()/payload() API - for (size_t i = 0; i < 3; ++i) { - REQUIRE(&(msgSet.messages | get_header{i}) == &msgSet.header(i)); - REQUIRE(&(msgSet.messages | get_payload{i, 0}) == &msgSet.payload(i)); - } - REQUIRE((msgSet.messages | count_parts{}) == msgSet.messageMap.size()); - REQUIRE((msgSet.messages | count_payloads{}) == msgSet.pairMap.size()); + REQUIRE((msgSet.messages | count_parts{}) == 3); + REQUIRE((msgSet.messages | count_payloads{}) == 3); + // messages: [hdr0, pl0, hdr1, pl1, hdr2, pl2] — three traditional split pairs + REQUIRE((msgSet.messages | get_pair{0}).headerIdx == 0); + REQUIRE((msgSet.messages | get_pair{0}).payloadIdx == 1); + REQUIRE((msgSet.messages | get_pair{1}).headerIdx == 2); + REQUIRE((msgSet.messages | get_pair{1}).payloadIdx == 3); + REQUIRE((msgSet.messages | get_pair{2}).headerIdx == 4); + REQUIRE((msgSet.messages | get_pair{2}).payloadIdx == 5); }